Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import requests | |
| import json | |
| import logging | |
| import google.generativeai as genai | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| # API Keys configuration | |
| COHERE_API_KEY = os.getenv("COHERE_API_KEY") | |
| MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if not all([COHERE_API_KEY, MISTRAL_API_KEY, GEMINI_API_KEY]): | |
| raise ValueError("Missing required API keys in environment variables") | |
| # Configure Gemini | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| # API endpoints configuration | |
| COHERE_API_URL = "https://api.cohere.ai/v1/chat" | |
| MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions" | |
| VECTOR_API_URL = "https://sendthat.cc" | |
| HISTORY_INDEX = "onramps" | |
| # Model configurations | |
| MODELS = { | |
| "Cohere": { | |
| "name": "command-r-plus-08-2024", | |
| "api_url": COHERE_API_URL, | |
| "api_key": COHERE_API_KEY | |
| }, | |
| "Mistral": { | |
| "name": "ft:open-mistral-nemo:ef730d29:20241022:2a0e7d46", | |
| "api_url": MISTRAL_API_URL, | |
| "api_key": MISTRAL_API_KEY | |
| }, | |
| "Gemini": { | |
| "name": "gemini-1.5-pro", | |
| "model": genai.GenerativeModel('gemini-1.5-pro'), | |
| "api_key": GEMINI_API_KEY | |
| } | |
| } | |
| def search_document(query, k): | |
| try: | |
| url = f"{VECTOR_API_URL}/search/{HISTORY_INDEX}" | |
| payload = {"text": query, "k": k} | |
| headers = {"Content-Type": "application/json"} | |
| response = requests.post(url, json=payload, headers=headers) | |
| response.raise_for_status() | |
| return response.json(), "", k | |
| except requests.exceptions.RequestException as e: | |
| logging.error(f"Error in search: {e}") | |
| return {"error": str(e)}, query, k | |
| def generate_answer_cohere(question, context, citations): | |
| headers = { | |
| "Authorization": f"Bearer {MODELS['Cohere']['api_key']}", | |
| "Content-Type": "application/json" | |
| } | |
| prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. Include citations as [1], [2], etc.:" | |
| payload = { | |
| "message": prompt, | |
| "model": MODELS['Cohere']['name'], | |
| "preamble": "You are an AI-assistant chatbot. Provide thorough responses with citations.", | |
| "chat_history": [] | |
| } | |
| try: | |
| response = requests.post(MODELS['Cohere']['api_url'], headers=headers, json=payload) | |
| response.raise_for_status() | |
| answer = response.json()['text'] | |
| answer += "\n\nSources:" | |
| for i, citation in enumerate(citations, 1): | |
| answer += f"\n[{i}] {citation}" | |
| return answer | |
| except requests.exceptions.RequestException as e: | |
| logging.error(f"Error in generate_answer_cohere: {e}") | |
| return f"An error occurred: {str(e)}" | |
| def generate_answer_mistral(question, context, citations): | |
| headers = { | |
| "Authorization": f"Bearer {MODELS['Mistral']['api_key']}", | |
| "Content-Type": "application/json", | |
| "Accept": "application/json" | |
| } | |
| prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context and any pre-trained knowledge. Include citations as [1], [2], etc.:" | |
| payload = { | |
| "model": MODELS['Mistral']['name'], | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": prompt | |
| } | |
| ] | |
| } | |
| try: | |
| response = requests.post(MODELS['Mistral']['api_url'], headers=headers, json=payload) | |
| response.raise_for_status() | |
| answer = response.json()['choices'][0]['message']['content'] | |
| answer += "\n\nSources:" | |
| for i, citation in enumerate(citations, 1): | |
| answer += f"\n[{i}] {citation}" | |
| return answer | |
| except requests.exceptions.RequestException as e: | |
| logging.error(f"Error in generate_answer_mistral: {e}") | |
| return f"An error occurred: {str(e)}" | |
| def generate_answer_gemini(question, context, citations): | |
| prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. Include citations as [1], [2], etc.:" | |
| try: | |
| model = MODELS['Gemini']['model'] | |
| response = model.generate_content( | |
| prompt, | |
| generation_config=genai.types.GenerationConfig( | |
| temperature=1.0, | |
| top_k=40, | |
| top_p=0.95, | |
| max_output_tokens=8192, | |
| ) | |
| ) | |
| answer = response.text | |
| answer += "\n\nSources:" | |
| for i, citation in enumerate(citations, 1): | |
| answer += f"\n[{i}] {citation}" | |
| return answer | |
| except Exception as e: | |
| logging.error(f"Error in generate_answer_gemini: {e}") | |
| return f"An error occurred: {str(e)}" | |
| def answer_question(question, model_choice, k=3): | |
| # Search the vector database | |
| search_results, _, _ = search_document(question, k) | |
| # Extract and combine the retrieved contexts | |
| if "results" in search_results: | |
| contexts = [] | |
| citations = [] | |
| for item in search_results['results']: | |
| contexts.append(item['metadata']['content']) | |
| citations.append(f"{item['metadata'].get('title', 'Unknown Source')} - {item['metadata'].get('source', 'No source provided')}") | |
| combined_context = " ".join(contexts) | |
| else: | |
| logging.error(f"Error in database search or no results found: {search_results}") | |
| combined_context = "" | |
| citations = [] | |
| # Generate answer using the selected model | |
| if model_choice == "Cohere": | |
| return generate_answer_cohere(question, combined_context, citations) | |
| elif model_choice == "Mistral": | |
| return generate_answer_mistral(question, combined_context, citations) | |
| else: | |
| return generate_answer_gemini(question, combined_context, citations) | |
| def chatbot(message, history, model_choice): | |
| response = answer_question(message, model_choice) | |
| return response | |
| # Example questions with default model choice | |
| EXAMPLE_QUESTIONS = [ | |
| ["Why was Anne Hutchinson banished from Massachusetts?", "Cohere"], | |
| ["What were the major causes of World War I?", "Mistral"], | |
| ["Who was the first President of the United States?", "Gemini"], | |
| ["What was the significance of the Industrial Revolution?", "Cohere"] | |
| ] | |
| # Create Gradio interface | |
| with gr.Blocks(theme="soft") as iface: | |
| gr.Markdown("# History Chatbot") | |
| gr.Markdown("Ask me anything about history, and I'll provide answers with citations!") | |
| with gr.Row(): | |
| model_choice = gr.Radio( | |
| choices=["Cohere", "Mistral", "Gemini"], | |
| value="Cohere", | |
| label="Choose LLM Model", | |
| info="Select which AI model to use for generating responses" | |
| ) | |
| chatbot_interface = gr.ChatInterface( | |
| fn=lambda message, history, model: chatbot(message, history, model), | |
| additional_inputs=[model_choice], | |
| chatbot=gr.Chatbot(height=300), | |
| textbox=gr.Textbox(placeholder="Ask a question about history...", container=False, scale=7), | |
| examples=EXAMPLE_QUESTIONS, | |
| cache_examples=False, | |
| retry_btn=None, | |
| undo_btn="Delete Previous", | |
| clear_btn="Clear", | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| iface.launch() |