import gradio as gr import cohere import pickle import os import torch # Step 1: Load the saved model, tokenizer, embeddings, FAISS index, and prompts from the pickle file with open('chatbot_state_v1.pkl', 'rb') as f: data = pickle.load(f) # Restore the saved components tokenizer = data['tokenizer'] model = data['model'] embeddings = data['embeddings'] index = data['faiss_index'] prompts = data['prompts'] # Fetch the API key from the environment variable api_key = os.getenv('COHERE_API_KEY') # Ensure the API key is set if api_key is None: raise ValueError("API Key not found. Please set COHERE_API_KEY in the environment variables.") # Set up the Cohere client using the API key cohere_client = cohere.Client(api_key) # Initialize an empty list to store the conversation history conversation_history = [] # Function to retrieve the closest dialog based on a user's query def retrieve_closest_dialog(query, tokenizer, model, index, prompts): inputs = tokenizer(query, return_tensors='pt', padding=True, truncation=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) cls_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy() # Perform FAISS search to find the closest match D, I = index.search(cls_embedding, k=1) # Retrieve the closest matching dialog from the dataset closest_dialog = prompts[I[0][0]] return closest_dialog # Function to handle the chatbot conversation def chatbot(user_input): global conversation_history # Step 1: Retrieve the closest dialog closest_dialog = retrieve_closest_dialog(user_input, tokenizer, model, index, prompts) # Step 2: Append user input and retrieved dialog to conversation history conversation_history.append(f"User: {user_input}") conversation_history.append(f"Previous conversation: {closest_dialog}") # Step 3: Build the prompt from the conversation history prompt = "\n".join(conversation_history[-6:]) # Limit to last 6 exchanges prompt += f"\nYour response:" # Step 4: Generate response using Cohere try: response = cohere_client.generate( model='command-r-plus-04-2024', prompt=prompt, max_tokens=100 ) generated_response = response.generations[0].text except Exception as e: generated_response = f"Error generating response: {str(e)}" # Step 5: Append the generated response to the conversation history conversation_history.append(f"Bot: {generated_response}") # Return the generated response return generated_response # Gradio Interface def gradio_interface(user_input): return chatbot(user_input) # Set up Gradio interface interface = gr.Interface(fn=gradio_interface, inputs="text", outputs="text", title="Conversational Chatbot") # Launch the Gradio app interface.launch()