Girishug's picture
Update app.py
56b7c71 verified
Raw
History Blame Contribute Delete
2.99 kB
import gradio as gr
import cohere
import pickle
import os
import torch
# Step 1: Load the saved model, tokenizer, embeddings, FAISS index, and prompts from the pickle file
with open('chatbot_state_v1.pkl', 'rb') as f:
data = pickle.load(f)
# Restore the saved components
tokenizer = data['tokenizer']
model = data['model']
embeddings = data['embeddings']
index = data['faiss_index']
prompts = data['prompts']
# Fetch the API key from the environment variable
api_key = os.getenv('COHERE_API_KEY')
# Ensure the API key is set
if api_key is None:
raise ValueError("API Key not found. Please set COHERE_API_KEY in the environment variables.")
# Set up the Cohere client using the API key
cohere_client = cohere.Client(api_key)
# Initialize an empty list to store the conversation history
conversation_history = []
# Function to retrieve the closest dialog based on a user's query
def retrieve_closest_dialog(query, tokenizer, model, index, prompts):
inputs = tokenizer(query, return_tensors='pt', padding=True, truncation=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
cls_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
# Perform FAISS search to find the closest match
D, I = index.search(cls_embedding, k=1)
# Retrieve the closest matching dialog from the dataset
closest_dialog = prompts[I[0][0]]
return closest_dialog
# Function to handle the chatbot conversation
def chatbot(user_input):
global conversation_history
# Step 1: Retrieve the closest dialog
closest_dialog = retrieve_closest_dialog(user_input, tokenizer, model, index, prompts)
# Step 2: Append user input and retrieved dialog to conversation history
conversation_history.append(f"User: {user_input}")
conversation_history.append(f"Previous conversation: {closest_dialog}")
# Step 3: Build the prompt from the conversation history
prompt = "\n".join(conversation_history[-6:]) # Limit to last 6 exchanges
prompt += f"\nYour response:"
# Step 4: Generate response using Cohere
try:
response = cohere_client.generate(
model='command-r-plus-04-2024',
prompt=prompt,
max_tokens=100
)
generated_response = response.generations[0].text
except Exception as e:
generated_response = f"Error generating response: {str(e)}"
# Step 5: Append the generated response to the conversation history
conversation_history.append(f"Bot: {generated_response}")
# Return the generated response
return generated_response
# Gradio Interface
def gradio_interface(user_input):
return chatbot(user_input)
# Set up Gradio interface
interface = gr.Interface(fn=gradio_interface, inputs="text", outputs="text", title="Conversational Chatbot")
# Launch the Gradio app
interface.launch()