| import gradio as gr |
| import cohere |
| import pickle |
| import os |
| import torch |
|
|
| |
| with open('chatbot_state_v1.pkl', 'rb') as f: |
| data = pickle.load(f) |
|
|
| |
| tokenizer = data['tokenizer'] |
| model = data['model'] |
| embeddings = data['embeddings'] |
| index = data['faiss_index'] |
| prompts = data['prompts'] |
|
|
| |
| api_key = os.getenv('COHERE_API_KEY') |
|
|
| |
| if api_key is None: |
| raise ValueError("API Key not found. Please set COHERE_API_KEY in the environment variables.") |
|
|
| |
| cohere_client = cohere.Client(api_key) |
|
|
| |
| conversation_history = [] |
|
|
| |
| def retrieve_closest_dialog(query, tokenizer, model, index, prompts): |
| inputs = tokenizer(query, return_tensors='pt', padding=True, truncation=True) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = model(**inputs) |
| cls_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy() |
|
|
| |
| D, I = index.search(cls_embedding, k=1) |
|
|
| |
| closest_dialog = prompts[I[0][0]] |
|
|
| return closest_dialog |
|
|
| |
| def chatbot(user_input): |
| global conversation_history |
|
|
| |
| closest_dialog = retrieve_closest_dialog(user_input, tokenizer, model, index, prompts) |
|
|
| |
| conversation_history.append(f"User: {user_input}") |
| conversation_history.append(f"Previous conversation: {closest_dialog}") |
|
|
| |
| prompt = "\n".join(conversation_history[-6:]) |
| prompt += f"\nYour response:" |
|
|
| |
| try: |
| response = cohere_client.generate( |
| model='command-r-plus-04-2024', |
| prompt=prompt, |
| max_tokens=100 |
| ) |
| generated_response = response.generations[0].text |
| except Exception as e: |
| generated_response = f"Error generating response: {str(e)}" |
| |
| |
| conversation_history.append(f"Bot: {generated_response}") |
|
|
| |
| return generated_response |
|
|
| |
| def gradio_interface(user_input): |
| return chatbot(user_input) |
|
|
| |
| interface = gr.Interface(fn=gradio_interface, inputs="text", outputs="text", title="Conversational Chatbot") |
|
|
| |
| interface.launch() |
|
|