File size: 7,296 Bytes
af65410
734fa7e
f45d8b2
 
 
40966a6
 
f45d8b2
 
 
40966a6
 
 
f45d8b2
 
40966a6
 
 
 
 
 
 
 
 
f45d8b2
 
 
40966a6
f45d8b2
40966a6
 
 
 
 
 
 
f45d8b2
40966a6
 
f45d8b2
 
40966a6
 
f45d8b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
734fa7e
40966a6
f45d8b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231773b
 
f45d8b2
231773b
f45d8b2
 
 
 
 
 
 
231773b
f45d8b2
 
 
 
 
 
231773b
 
f45d8b2
231773b
f45d8b2
 
 
 
 
 
231773b
f45d8b2
 
 
 
 
 
 
231773b
 
f45d8b2
231773b
f45d8b2
 
 
 
 
 
 
76efecf
f45d8b2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import asyncio
from typing import AsyncGenerator, List, Dict, Tuple
from config import logger # Make sure config.py exists with a logger
from api import ask_openai, ask_anthropic, ask_gemini # Make sure api.py exists with these functions
from database import Database # Make sure database.py exists with a Database class
import json

# Assuming database.py has a Database class with connect, add_user, get_user,
# get_conversations, add_conversation, clear_conversation methods.
# If not, you'll need to implement these or adjust.
db = Database()
db.connect()

# These functions might be redundant if login/register is handled fully in app.py
# but keeping them here based on previous code structure.
def register(username, password, message=None):
    if db.add_user(username, password):
        return "Registration successful"
    else:
        return "Username already exists"

def login(username, password, message=None):
    user_id = db.get_user(username, password)
    if user_id:
        # Note: This loads a single conversation string, not separate histories.
        # If you need separate histories loaded here, database schema needs adjustment.
        # For now, app.py's get_chat_history handles loading separate histories from the DB.
        conversation = db.get_conversations(user_id)
        history = json.loads(conversation) if conversation else [] # Assuming conversation is JSON string of history
        return "Login successful", user_id, history
    else:
        return "Invalid credentials", None, []

def logout():
    return "Logout successful"

# This clear_history might be redundant if app.py handles clearing via Gradio states
def clear_history(user_id):
    db.clear_conversation(user_id)
    return [], [], [], [] # Assuming this returns empty histories for 3 models + context


async def submit_query(query, providers, history, user_id):
    # submit_query calls query_model and saves the history
    # query_model will return the updated histories for each model
    # The yielded values are: error_msg, openai_msgs, anthropic_msgs, gemini_msgs, updated_context_history
    async for error_msg, openai_msgs, anthropic_msgs, gemini_msgs, updated_context_history in query_model(query, providers, history):
        # Save the *combined* history for the user (using one of the updated histories, e.g., openai_msgs)
        # Note: Saving separate histories per user per model might be better for full history recall per model
        # but sticking to the current DB schema which seems to save combined responses per turn.
        # Let's save the history from one model (e.g., OpenAI) as the main conversation history.
        # A better approach for saving separate histories would require changing the DB schema.
        # For now, let's save the OpenAI history as the main user conversation history.
        # db.add_conversation(user_id, json.dumps(openai_msgs)) # This seems to save the *full* history list
        # The save_chat_history in app.py saves individual responses per turn.
        # Let's rely on app.py's save_chat_history for saving to the DB.
        # This submit_query function should focus on getting responses and yielding updated histories.

        # Yield the results from query_model
        yield error_msg, openai_msgs, anthropic_msgs, gemini_msgs, updated_context_history


async def query_model(query: str, providers: List[str], history: List[Dict[str, str]]) -> AsyncGenerator[Tuple[str, List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]], List[Dict[str, str]]], None]:
    # history input is the context history (e.g., openai_history from app.py)
    # We need to create the *new* history lists for this turn based on the input history

    # Start with the context history for each model's potential history
    openai_msgs = history.copy()
    anthropic_msgs = history.copy()
    gemini_msgs = history.copy()

    user_msg_dict = {"role": "user", "content": query} # User message for this turn

    # Append user message to the history copies *before* calling models
    # This ensures the history passed to ask_* includes the current user message
    # and the history returned includes the user message.
    openai_msgs.append(user_msg_dict)
    anthropic_msgs.append(user_msg_dict)
    gemini_msgs.append(user_msg_dict)

    error_msg = ""

    # --- OpenAI ---
    openai_response = ""
    if "OpenAI" in providers:
        try:
            # Pass the history *including* the current user message to the model API
            async for chunk in ask_openai(query, openai_msgs): # Pass openai_msgs as history
                openai_response += chunk
            if openai_response:
                # Append assistant response to the OpenAI history copy
                openai_msgs.append({"role": "assistant", "content": openai_response.strip()})
            # If no response (e.g., API error), openai_msgs remains with just the user message
        except Exception as e:
            logger.error(f"Error calling OpenAI: {e}")
            error_msg += f"OpenAI Error: {e}\n"
            # Optionally append an error message to the history
            openai_msgs.append({"role": "assistant", "content": f"Error: {e}"})


    # --- Anthropic ---
    anthropic_response = ""
    if "Anthropic" in providers:
        try:
            # Pass the history *including* the current user message to the model API
            async for chunk in ask_anthropic(query, anthropic_msgs): # Pass anthropic_msgs as history
                anthropic_response += chunk
            if anthropic_response:
                # Append assistant response to the Anthropic history copy
                anthropic_msgs.append({"role": "assistant", "content": anthropic_response.strip()})
            # If no response, anthropic_msgs remains with just the user message
        except Exception as e:
            logger.error(f"Error calling Anthropic: {e}")
            error_msg += f"Anthropic Error: {e}\n"
            # Optionally append an error message to the history
            anthropic_msgs.append({"role": "assistant", "content": f"Error: {e}"})


    # --- Gemini ---
    gemini_response = ""
    if "Gemini" in providers: # Add Gemini check
        try:
             # Pass the history *including* the current user message to the model API
            async for chunk in ask_gemini(query, gemini_msgs): # Pass gemini_msgs as history
                gemini_response += chunk
            if gemini_response:
                # Append assistant response to the Gemini history copy
                gemini_msgs.append({"role": "assistant", "content": gemini_response.strip()})
            # If no response, gemini_msgs remains with just the user message
        except Exception as e:
            logger.error(f"Error calling Gemini: {e}")
            error_msg += f"Gemini Error: {e}\n"
            # Optionally append an error message to the history
            gemini_msgs.append({"role": "assistant", "content": f"Error: {e}"})


    # Yield the updated histories for each model.
    # The first element is for error messages.
    # The last element is the updated context history (using openai_msgs as the main one).
    yield error_msg.strip(), openai_msgs, anthropic_msgs, gemini_msgs, openai_msgs # Yield the updated lists