Spaces:
Sleeping
Sleeping
| import os | |
| from openai import OpenAI | |
| import gradio as gr | |
| from typing import Tuple | |
| from supabase import create_client | |
| from datetime import datetime | |
| class ChatManager: | |
| def __init__(self, client: OpenAI, assistant_id: str, supabase_url: str = None, supabase_key: str = None): | |
| self.client = client | |
| self.assistant_id = assistant_id | |
| """Initialize Supabase client""" | |
| self.supabase = create_client( | |
| supabase_url or os.environ.get("SUPABASE_URL"), | |
| supabase_key or os.environ.get("SUPABASE_KEY") | |
| ) | |
| def get_current_chat_id(self, session_state): | |
| """Get current chat ID from session state""" | |
| return session_state.get('current_chat_id') | |
| def set_current_chat_id(self, session_state, chat_id): | |
| """Set current chat ID in session state""" | |
| session_state['current_chat_id'] = chat_id | |
| def create_new_chat(self, session_state): | |
| """Create a new chat using OpenAI thread ID as chat_id""" | |
| thread = self.client.beta.threads.create() | |
| self.set_current_chat_id(session_state, thread.id) | |
| return thread.id | |
| def get_or_create_thread(self, session_state): | |
| """Get or create a thread for the chat""" | |
| current_chat_id = self.get_current_chat_id(session_state) | |
| if not current_chat_id: | |
| current_chat_id = self.create_new_chat(session_state) | |
| return current_chat_id | |
| def reset_chat_on_load(self, username: str, session_state, init_chat=[]) -> Tuple[gr.update, list]: | |
| """Reset chat ID on page load and update chat selector""" | |
| self.create_new_chat(session_state) | |
| # Update chat selector with available chats for this user | |
| chats = self.list_user_chats(username) | |
| return gr.update(choices=[(c["preview"], c["chat_id"]) for c in chats], value=self.get_current_chat_id(session_state)), init_chat | |
| def load_chat_history(self, user_id="default_user"): | |
| """Load chat history for a user""" | |
| chat_id = self.get_latest_chat_id(user_id) | |
| if chat_id: | |
| return self.load_chat(user_id, chat_id) | |
| return [] | |
| def list_user_chats(self, user_id="default_user"): | |
| """List all chats for a user""" | |
| response = self.supabase.table("chats") \ | |
| .select("chat_id,last_updated,messages") \ | |
| .eq("user_id", user_id) \ | |
| .execute() | |
| chats = [{ | |
| "chat_id": chat["chat_id"], | |
| "last_updated": chat["last_updated"], | |
| "preview": chat["messages"][0]["content"] if chat["messages"] else "Empty chat" | |
| } for chat in response.data] | |
| return sorted(chats, key=lambda x: x["last_updated"], reverse=True) | |
| def switch_chat(self, chat_id, user_id="default_user", session_state=None): | |
| """Switch to a different chat""" | |
| if session_state: | |
| self.set_current_chat_id(session_state, chat_id) | |
| messages = self.load_chat(user_id, chat_id) | |
| return messages | |
| def save_chat(self, user_id, chat_id, messages, current_lesson_plan, app_name): | |
| """Save chat history to Supabase""" | |
| chat_data = { | |
| "user_id": user_id, | |
| "chat_id": chat_id, | |
| "last_updated": datetime.now().isoformat(), | |
| "messages": messages, | |
| "current_lesson_plan": current_lesson_plan, | |
| "app_name": app_name | |
| } | |
| # Check if chat exists | |
| existing_chat = self.supabase.table("chats") \ | |
| .select("*") \ | |
| .eq("user_id", user_id) \ | |
| .eq("chat_id", chat_id) \ | |
| .execute() | |
| if existing_chat.data: | |
| # Update existing chat | |
| self.supabase.table("chats") \ | |
| .update(chat_data) \ | |
| .eq("user_id", user_id) \ | |
| .eq("chat_id", chat_id) \ | |
| .execute() | |
| else: | |
| # Insert new chat | |
| self.supabase.table("chats") \ | |
| .insert(chat_data) \ | |
| .execute() | |
| def load_chat(self, user_id, chat_id): | |
| """Load specific chat history from Supabase""" | |
| response = self.supabase.table("chats") \ | |
| .select("messages") \ | |
| .eq("user_id", user_id) \ | |
| .eq("chat_id", chat_id) \ | |
| .execute() | |
| return response.data[0]["messages"] if response.data else [] | |
| def get_latest_chat_id(self, user_id): | |
| """Get the most recent chat ID for a user""" | |
| chats = self.list_user_chats(user_id) | |
| return chats[0]["chat_id"] if chats else None | |