Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import sqlite3 | |
| import os | |
| import asyncio | |
| from core import submit_query | |
| from config import logger | |
| DB_PATH = "chat_history.db" | |
| def init_db(): | |
| conn = sqlite3.connect(DB_PATH) | |
| c = conn.cursor() | |
| c.execute(""" | |
| CREATE TABLE IF NOT EXISTS users ( | |
| username TEXT PRIMARY KEY, | |
| password TEXT NOT NULL | |
| ) | |
| """) | |
| c.execute(""" | |
| CREATE TABLE IF NOT EXISTS sessions ( | |
| session_id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| username TEXT, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| c.execute(""" | |
| CREATE TABLE IF NOT EXISTS history ( | |
| session_id INTEGER, | |
| username TEXT, | |
| providers TEXT, | |
| message TEXT, | |
| openai_response TEXT, | |
| anthropic_response TEXT, | |
| gemini_response TEXT, | |
| FOREIGN KEY(session_id) REFERENCES sessions(session_id) | |
| ) | |
| """) | |
| conn.commit() | |
| conn.close() | |
| def register_user(username, password): | |
| conn = sqlite3.connect(DB_PATH) | |
| c = conn.cursor() | |
| try: | |
| c.execute("INSERT INTO users (username, password) VALUES (?, ?)", (username, password)) | |
| conn.commit() | |
| return True | |
| except sqlite3.IntegrityError: | |
| return False | |
| finally: | |
| conn.close() | |
| def login_user(username, password): | |
| conn = sqlite3.connect(DB_PATH) | |
| c = conn.cursor() | |
| c.execute("SELECT * FROM users WHERE username=? AND password=?", (username, password)) | |
| result = c.fetchone() | |
| conn.close() | |
| return result is not None | |
| def create_new_session(username): | |
| conn = sqlite3.connect(DB_PATH) | |
| c = conn.cursor() | |
| c.execute("INSERT INTO sessions (username) VALUES (?)", (username,)) | |
| session_id = c.lastrowid | |
| conn.commit() | |
| conn.close() | |
| return session_id | |
| def list_sessions(username): | |
| conn = sqlite3.connect(DB_PATH) | |
| c = conn.cursor() | |
| c.execute("SELECT session_id, created_at FROM sessions WHERE username=? ORDER BY created_at DESC", (username,)) | |
| sessions = c.fetchall() | |
| conn.close() | |
| return sessions | |
| def get_chat_history(session_id): | |
| conn = sqlite3.connect(DB_PATH) | |
| c = conn.cursor() | |
| c.execute("SELECT message, openai_response, anthropic_response, gemini_response FROM history WHERE session_id=? ORDER BY rowid", (session_id,)) | |
| rows = c.fetchall() | |
| conn.close() | |
| openai_history = [] | |
| anthropic_history = [] | |
| gemini_history = [] | |
| for m, o, a, g in rows: | |
| user_msg_dict = {"role": "user", "content": m} | |
| openai_history.append(user_msg_dict) | |
| anthropic_history.append(user_msg_dict) | |
| gemini_history.append(user_msg_dict) | |
| if o: | |
| openai_history.append({"role": "assistant", "content": o}) | |
| if a: | |
| anthropic_history.append({"role": "assistant", "content": a}) | |
| if g: | |
| gemini_history.append({"role": "assistant", "content": g}) | |
| return openai_history, anthropic_history, gemini_history | |
| def save_chat_history(session_id, username, providers, message, openai_resp, anthropic_resp, gemini_resp): | |
| conn = sqlite3.connect(DB_PATH) | |
| c = conn.cursor() | |
| c.execute( | |
| "INSERT INTO history (session_id, username, providers, message, openai_response, anthropic_response, gemini_response) VALUES (?, ?, ?, ?, ?, ?, ?)", | |
| (session_id, username, ",".join(providers), message, openai_resp, anthropic_resp, gemini_resp) | |
| ) | |
| conn.commit() | |
| conn.close() | |
| async def chatbot_fn(message, openai_history, anthropic_history, gemini_history, username, providers, session_id): | |
| if not username or not session_id: | |
| return "", openai_history or [], anthropic_history or [], gemini_history or [], openai_history or [], anthropic_history or [], gemini_history or [] | |
| if not providers: | |
| user_msg_dict = {"role": "user", "content": message} | |
| no_provider_msg = {"role": "assistant", "content": "Please select at least one provider."} | |
| openai_history = (openai_history or []) + [user_msg_dict, no_provider_msg] | |
| anthropic_history = (anthropic_history or []) + [user_msg_dict, no_provider_msg] | |
| gemini_history = (gemini_history or []) + [user_msg_dict, no_provider_msg] | |
| return "", openai_history, anthropic_history, gemini_history, openai_history, anthropic_history, gemini_history | |
| context_history = openai_history | |
| final_result = None | |
| error_message = "" | |
| async for err, openai_msgs, anthropic_msgs, gemini_msgs, updated_context_history in submit_query( | |
| message, providers, context_history, username | |
| ): | |
| final_result = (openai_msgs, anthropic_msgs, gemini_msgs, updated_context_history) | |
| error_message = err | |
| if not final_result: | |
| logger.error("submit_query returned no result") | |
| return f"Error: submit_query failed.", openai_history, anthropic_history, gemini_history, openai_history, anthropic_history, gemini_history | |
| new_openai_history, new_anthropic_history, new_gemini_history, _ = final_result | |
| def get_last_response_text_from_history(history_list): | |
| if isinstance(history_list, list) and history_list: | |
| for m in reversed(history_list): | |
| if isinstance(m, dict) and m.get("role") == "assistant" and m.get("content"): | |
| return m.get("content", "") | |
| return "" | |
| openai_text_to_save = get_last_response_text_from_history(new_openai_history) | |
| anthropic_text_to_save = get_last_response_text_from_history(new_anthropic_history) | |
| gemini_text_to_save = get_last_response_text_from_history(new_gemini_history) | |
| save_chat_history(session_id, username, providers, message, openai_text_to_save, anthropic_text_to_save, gemini_text_to_save) | |
| return error_message, new_openai_history, new_anthropic_history, new_gemini_history, new_openai_history, new_anthropic_history, new_gemini_history | |
| def session_label(session): | |
| sid, created = session | |
| return f"{sid} ({created})" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Multi-Model Chat") | |
| # Login/Register group | |
| login_group = gr.Group(visible=True) | |
| with login_group: | |
| username = gr.Textbox(label="Username") | |
| password = gr.Textbox(label="Password", type="password") | |
| login_btn = gr.Button("Login") | |
| register_btn = gr.Button("Register") | |
| login_status = gr.Markdown("") | |
| # Chat group (hidden by default) | |
| chat_group = gr.Group(visible=False) | |
| with chat_group: | |
| with gr.Row(): | |
| session_selector = gr.Dropdown(choices=[], label="Select Chat Session", interactive=True) | |
| new_chat_btn = gr.Button("New Chat") | |
| user_display = gr.Markdown("", elem_id="user_display") | |
| logout_btn = gr.Button("Logout") | |
| providers = gr.CheckboxGroup( | |
| choices=["OpenAI", "Anthropic", "Gemini"], | |
| label="Select Providers", | |
| value=["OpenAI"] | |
| ) | |
| with gr.Row(): | |
| openai_chat = gr.Chatbot(label="OpenAI", type="messages", height=400) | |
| anthropic_chat = gr.Chatbot(label="Anthropic", type="messages", height=400) | |
| gemini_chat = gr.Chatbot(label="Gemini", type="messages", height=400) | |
| msg = gr.Textbox(label="Message") | |
| send_btn = gr.Button("Send") | |
| clear_btn = gr.Button("Clear History") | |
| hidden_username = gr.State("") | |
| openai_history_state = gr.State([]) | |
| anthropic_history_state = gr.State([]) | |
| gemini_history_state = gr.State([]) | |
| session_id_state = gr.State("") | |
| # --- Function Definitions --- | |
| def do_login(u, p): | |
| if login_user(u, p): | |
| session_id = create_new_session(u) | |
| sessions = list_sessions(u) | |
| session_choices = [session_label(s) for s in sessions] | |
| session_value = session_label(sessions[0]) if sessions else "" | |
| return ( | |
| f"Welcome, {u}!", | |
| gr.update(visible=False), # Hide login group | |
| gr.update(visible=True), # Show chat group | |
| [], [], [], # Empty chats for new session | |
| [], [], [], | |
| u, | |
| f"Logged in as: **{u}**", | |
| session_id, | |
| gr.update(choices=session_choices, value=session_value) | |
| ) | |
| else: | |
| return "Login failed.", gr.update(), gr.update(), [], [], [], [], [], [], "", "", gr.update(choices=[], value="") | |
| def do_register(u, p): | |
| if register_user(u, p): | |
| return "Registration successful! Please login.", gr.update(), gr.update(), [], [], [], [], [], [], "", "", gr.update(choices=[], value="") | |
| else: | |
| return "Username already exists.", gr.update(), gr.update(), [], [], [], [], [], [], "", "", gr.update(choices=[], value="") | |
| def do_logout(): | |
| return ( | |
| "", # login_status | |
| gr.update(visible=True), # Show login group | |
| gr.update(visible=False), # Hide chat group | |
| [], [], [], [], [], [], "", "", gr.update(choices=[], value="") | |
| ) | |
| def do_new_chat(username): | |
| if not username: | |
| return "", [], [], [], gr.update(choices=[], value="") | |
| session_id = create_new_session(username) | |
| sessions = list_sessions(username) | |
| session_choices = [session_label(s) for s in sessions] | |
| session_value = session_label(sessions[0]) if sessions else "" | |
| return session_id, [], [], [], gr.update(choices=session_choices, value=session_value) | |
| def do_select_session(session_label_str, username): | |
| if not session_label_str or not username: | |
| return [], [], [], [], [], [], "" | |
| session_id = int(session_label_str.split(" ")[0]) | |
| openai_hist, anthropic_hist, gemini_hist = get_chat_history(session_id) | |
| return ( | |
| openai_hist, anthropic_hist, gemini_hist, # update chat displays | |
| openai_hist, anthropic_hist, gemini_hist, # update state variables | |
| session_id # update session_id_state | |
| ) | |
| # --- Gradio Wiring --- | |
| login_btn.click( | |
| do_login, | |
| [username, password], | |
| [ | |
| login_status, login_group, chat_group, | |
| openai_chat, anthropic_chat, gemini_chat, | |
| openai_history_state, anthropic_history_state, gemini_history_state, | |
| hidden_username, | |
| user_display, | |
| session_id_state, | |
| session_selector # Only one output for dropdown! | |
| ] | |
| ) | |
| register_btn.click( | |
| do_register, | |
| [username, password], | |
| [ | |
| login_status, login_group, chat_group, | |
| openai_chat, anthropic_chat, gemini_chat, | |
| openai_history_state, anthropic_history_state, gemini_history_state, | |
| hidden_username, | |
| user_display, | |
| session_id_state, | |
| session_selector | |
| ] | |
| ) | |
| send_btn.click( | |
| chatbot_fn, | |
| [msg, openai_history_state, anthropic_history_state, gemini_history_state, hidden_username, providers, session_id_state], | |
| [msg, openai_chat, anthropic_chat, gemini_chat, openai_history_state, anthropic_history_state, gemini_history_state] | |
| ) | |
| clear_btn.click( | |
| lambda: ("", [], [], [], [], [], []), | |
| [], | |
| [msg, openai_chat, anthropic_chat, gemini_chat, openai_history_state, anthropic_history_state, gemini_history_state] | |
| ) | |
| logout_btn.click( | |
| do_logout, | |
| [], | |
| [ | |
| login_status, login_group, chat_group, | |
| openai_chat, anthropic_chat, gemini_chat, | |
| openai_history_state, anthropic_history_state, gemini_history_state, | |
| hidden_username, | |
| user_display, | |
| session_id_state, | |
| session_selector | |
| ] | |
| ) | |
| new_chat_btn.click( | |
| do_new_chat, | |
| [hidden_username], | |
| [session_id_state, openai_chat, anthropic_chat, gemini_chat, session_selector] | |
| ) | |
| session_selector.change( | |
| do_select_session, | |
| [session_selector, hidden_username], | |
| [ | |
| openai_chat, anthropic_chat, gemini_chat, # chat displays | |
| openai_history_state, anthropic_history_state, gemini_history_state, # state variables | |
| session_id_state # session id state | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| init_db() | |
| demo.launch() |