Spaces:
Sleeping
Sleeping
| import asyncio | |
| from typing import AsyncGenerator, List, Dict, Tuple | |
| from config import logger # Make sure config.py exists with a logger | |
| from api import ask_openai, ask_anthropic, ask_gemini # Make sure api.py exists with these functions | |
| from database import Database # Make sure database.py exists with a Database class | |
| import json | |
| # Assuming database.py has a Database class with connect, add_user, get_user, | |
| # get_conversations, add_conversation, clear_conversation methods. | |
| # If not, you'll need to implement these or adjust. | |
| db = Database() | |
| db.connect() | |
| # These functions might be redundant if login/register is handled fully in app.py | |
| # but keeping them here based on previous code structure. | |
| def register(username, password, message=None): | |
| if db.add_user(username, password): | |
| return "Registration successful" | |
| else: | |
| return "Username already exists" | |
| def login(username, password, message=None): | |
| user_id = db.get_user(username, password) | |
| if user_id: | |
| # Note: This loads a single conversation string, not separate histories. | |
| # If you need separate histories loaded here, database schema needs adjustment. | |
| # For now, app.py's get_chat_history handles loading separate histories from the DB. | |
| conversation = db.get_conversations(user_id) | |
| history = json.loads(conversation) if conversation else [] # Assuming conversation is JSON string of history | |
| return "Login successful", user_id, history | |
| else: | |
| return "Invalid credentials", None, [] | |
| def logout(): | |
| return "Logout successful" | |
| # This clear_history might be redundant if app.py handles clearing via Gradio states | |
| def clear_history(user_id): | |
| db.clear_conversation(user_id) | |
| return [], [], [], [] # Assuming this returns empty histories for 3 models + context | |
| async def submit_query(query, providers, history, user_id): | |
| # submit_query calls query_model and saves the history | |
| # query_model will return the updated histories for each model | |
| # The yielded values are: error_msg, openai_msgs, anthropic_msgs, gemini_msgs, updated_context_history | |
| async for error_msg, openai_msgs, anthropic_msgs, gemini_msgs, updated_context_history in query_model(query, providers, history): | |
| # Save the *combined* history for the user (using one of the updated histories, e.g., openai_msgs) | |
| # Note: Saving separate histories per user per model might be better for full history recall per model | |
| # but sticking to the current DB schema which seems to save combined responses per turn. | |
| # Let's save the history from one model (e.g., OpenAI) as the main conversation history. | |
| # A better approach for saving separate histories would require changing the DB schema. | |
| # For now, let's save the OpenAI history as the main user conversation history. | |
| # db.add_conversation(user_id, json.dumps(openai_msgs)) # This seems to save the *full* history list | |
| # The save_chat_history in app.py saves individual responses per turn. | |
| # Let's rely on app.py's save_chat_history for saving to the DB. | |
| # This submit_query function should focus on getting responses and yielding updated histories. | |
| # Yield the results from query_model | |
| yield error_msg, openai_msgs, anthropic_msgs, gemini_msgs, updated_context_history | |
| async def query_model(query: str, providers: List[str], history: List[Dict[str, str]]) -> AsyncGenerator[Tuple[str, List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]], None]: | |
| # history input is the context history (e.g., openai_history from app.py) | |
| # We need to create the *new* history lists for this turn based on the input history | |
| # Start with the context history for each model's potential history | |
| openai_msgs = history.copy() | |
| anthropic_msgs = history.copy() | |
| gemini_msgs = history.copy() | |
| user_msg_dict = {"role": "user", "content": query} # User message for this turn | |
| # Append user message to the history copies *before* calling models | |
| # This ensures the history passed to ask_* includes the current user message | |
| # and the history returned includes the user message. | |
| openai_msgs.append(user_msg_dict) | |
| anthropic_msgs.append(user_msg_dict) | |
| gemini_msgs.append(user_msg_dict) | |
| error_msg = "" | |
| # --- OpenAI --- | |
| openai_response = "" | |
| if "OpenAI" in providers: | |
| try: | |
| # Pass the history *including* the current user message to the model API | |
| async for chunk in ask_openai(query, openai_msgs): # Pass openai_msgs as history | |
| openai_response += chunk | |
| if openai_response: | |
| # Append assistant response to the OpenAI history copy | |
| openai_msgs.append({"role": "assistant", "content": openai_response.strip()}) | |
| # If no response (e.g., API error), openai_msgs remains with just the user message | |
| except Exception as e: | |
| logger.error(f"Error calling OpenAI: {e}") | |
| error_msg += f"OpenAI Error: {e}\n" | |
| # Optionally append an error message to the history | |
| openai_msgs.append({"role": "assistant", "content": f"Error: {e}"}) | |
| # --- Anthropic --- | |
| anthropic_response = "" | |
| if "Anthropic" in providers: | |
| try: | |
| # Pass the history *including* the current user message to the model API | |
| async for chunk in ask_anthropic(query, anthropic_msgs): # Pass anthropic_msgs as history | |
| anthropic_response += chunk | |
| if anthropic_response: | |
| # Append assistant response to the Anthropic history copy | |
| anthropic_msgs.append({"role": "assistant", "content": anthropic_response.strip()}) | |
| # If no response, anthropic_msgs remains with just the user message | |
| except Exception as e: | |
| logger.error(f"Error calling Anthropic: {e}") | |
| error_msg += f"Anthropic Error: {e}\n" | |
| # Optionally append an error message to the history | |
| anthropic_msgs.append({"role": "assistant", "content": f"Error: {e}"}) | |
| # --- Gemini --- | |
| gemini_response = "" | |
| if "Gemini" in providers: # Add Gemini check | |
| try: | |
| # Pass the history *including* the current user message to the model API | |
| async for chunk in ask_gemini(query, gemini_msgs): # Pass gemini_msgs as history | |
| gemini_response += chunk | |
| if gemini_response: | |
| # Append assistant response to the Gemini history copy | |
| gemini_msgs.append({"role": "assistant", "content": gemini_response.strip()}) | |
| # If no response, gemini_msgs remains with just the user message | |
| except Exception as e: | |
| logger.error(f"Error calling Gemini: {e}") | |
| error_msg += f"Gemini Error: {e}\n" | |
| # Optionally append an error message to the history | |
| gemini_msgs.append({"role": "assistant", "content": f"Error: {e}"}) | |
| # Yield the updated histories for each model. | |
| # The first element is for error messages. | |
| # The last element is the updated context history (using openai_msgs as the main one). | |
| yield error_msg.strip(), openai_msgs, anthropic_msgs, gemini_msgs, openai_msgs # Yield the updated lists |