Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| import streamlit as st | |
| from streamlit.runtime.scriptrunner import RerunException, StopException, RerunData | |
| from openai import OpenAI | |
| from pymongo import MongoClient | |
| from datetime import datetime, timedelta | |
| import time | |
| from streamlit_autorefresh import st_autorefresh | |
| from streamlit.runtime.caching import cache_data | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| # Load environment variables | |
| load_dotenv() | |
| # Configuration | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| MONGODB_URI = os.getenv("MONGODB_URI") | |
| # Initialize clients | |
| openai_client = OpenAI(api_key=OPENAI_API_KEY) | |
| mongo_client = MongoClient(MONGODB_URI) | |
| db = mongo_client["Wall_Street"] | |
| conversation_history = db["conversation_history"] | |
| trainer_feedback = db["trainer_feedback"] | |
| trainer_instructions = db["trainer_instructions"] | |
| global_common_memory = db["global_common_memory"] # New global common memory collection | |
| # Define a unique identifier for global memory | |
| GLOBAL_MEMORY_ID = "global_common_memory_id" | |
| # Set up Streamlit page configuration | |
| st.set_page_config(page_title="GPT-Driven Chat System - Operator", page_icon="๐ ๏ธ", layout="wide") | |
| # Custom CSS to improve the UI | |
| st.markdown(""" | |
| <style> | |
| /* Your custom CSS styles */ | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --- Common Memory Functions --- | |
| # Cache for 5 minutes | |
| def get_global_common_memory(): | |
| """Retrieve the global common memory.""" | |
| memory_doc = global_common_memory.find_one({"memory_id": "global_common_memory_id"}) | |
| return memory_doc.get('memory', []) if memory_doc else [] | |
| def append_to_global_common_memory(content): | |
| """Append content to the global common memory.""" | |
| try: | |
| # First, ensure the document exists with an initialized memory array | |
| global_common_memory.update_one( | |
| {"memory_id": GLOBAL_MEMORY_ID}, | |
| {"$setOnInsert": {"memory": []}}, | |
| upsert=True | |
| ) | |
| # Then, add the new content to the memory array | |
| result = global_common_memory.update_one( | |
| {"memory_id": GLOBAL_MEMORY_ID}, | |
| {"$addToSet": {"memory": content}} | |
| ) | |
| # Invalidate the cache after updating | |
| get_global_common_memory.clear() | |
| if result.modified_count > 0: | |
| st.success("Memory appended successfully!") | |
| else: | |
| st.info("This memory item already exists or no changes were made.") | |
| raise RerunException(RerunData(page_script_hash=None)) | |
| except RerunException: | |
| raise | |
| except Exception as e: | |
| st.error(f"Failed to append to global common memory: {str(e)}") | |
| def clear_global_common_memory(): | |
| """Clear all items from the global common memory.""" | |
| try: | |
| global_common_memory.update_one( | |
| {"memory_id": GLOBAL_MEMORY_ID}, | |
| {"$set": {"memory": []}}, | |
| upsert=True | |
| ) | |
| # Invalidate the cache after clearing | |
| get_global_common_memory.clear() | |
| st.success("Global common memory cleared successfully!") | |
| except Exception as e: | |
| st.error(f"Failed to clear global common memory: {str(e)}") | |
| # --- Takeover Functions --- | |
| def activate_takeover(session_id): | |
| """ | |
| Activates takeover mode for the given session. | |
| """ | |
| try: | |
| db.takeover_status.update_one( | |
| {"session_id": session_id}, | |
| {"$set": {"active": True, "activated_at": datetime.utcnow()}}, | |
| upsert=True | |
| ) | |
| st.success(f"Takeover activated for session {session_id[:8]}...") | |
| except Exception as e: | |
| st.error(f"Failed to activate takeover: {str(e)}") | |
| def deactivate_takeover(session_id): | |
| """ | |
| Deactivates takeover mode for the given session. | |
| """ | |
| try: | |
| db.takeover_status.update_one( | |
| {"session_id": session_id}, | |
| {"$set": {"active": False}}, | |
| ) | |
| st.success(f"Takeover deactivated for session {session_id[:8]}...") | |
| except Exception as e: | |
| st.error(f"Failed to deactivate takeover: {str(e)}") | |
| def send_admin_message(session_id, message): | |
| """ | |
| Sends an admin message directly to the user during a takeover. | |
| """ | |
| admin_message = { | |
| "role": "admin", | |
| "content": message, | |
| "timestamp": datetime.utcnow(), | |
| "status": "approved" | |
| } | |
| try: | |
| # Upsert the admin message | |
| result = conversation_history.update_one( | |
| {"session_id": session_id}, | |
| { | |
| "$push": {"messages": admin_message}, | |
| "$set": {"last_updated": datetime.utcnow()} | |
| } | |
| ) | |
| if result.modified_count > 0: | |
| st.success("Admin message sent successfully!") | |
| else: | |
| st.error("Failed to send admin message.") | |
| except Exception as e: | |
| st.error(f"Failed to send admin message: {str(e)}") | |
| # --- Admin Dashboard Functions --- | |
| def handle_admin_intervention(): | |
| st.subheader("Admin Intervention") | |
| st.subheader("Review Pending Responses") | |
| pending_responses = conversation_history.find( | |
| {"messages.role": "assistant", "messages.status": "pending"} | |
| ) | |
| for conversation in pending_responses: | |
| st.write(f"Session ID: {conversation['session_id'][:8]}...") | |
| # Display global common memory | |
| st.subheader("Global Common Memory") | |
| common_memory = get_global_common_memory() | |
| if common_memory: | |
| for idx, item in enumerate(common_memory, 1): | |
| st.text(f"{idx}. {item}") | |
| else: | |
| st.info("Global common memory is currently empty.") | |
| for i, message in enumerate(conversation['messages']): | |
| if message['role'] == 'assistant' and message.get('status') == 'pending': | |
| user_message = conversation['messages'][i-1]['content'] if i > 0 else "N/A" | |
| st.write(f"**User:** {user_message}") | |
| st.write(f"**GPT:** {message['content']}") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| if st.button("Approve", key=f"approve_{conversation['session_id']}_{i}"): | |
| if approve_response(conversation['session_id'], i): | |
| st.success("Response approved") | |
| time.sleep(0.5) # Short delay to ensure the success message is visible | |
| st.rerun() | |
| with col2: | |
| if st.button("Modify", key=f"modify_{conversation['session_id']}_{i}"): | |
| st.session_state['modifying'] = (conversation['session_id'], i) | |
| st.rerun() | |
| with col3: | |
| if st.button("Regenerate", key=f"regenerate_{conversation['session_id']}_{i}"): | |
| st.session_state['regenerating'] = (conversation['session_id'], i) | |
| st.rerun() | |
| st.divider() | |
| if 'regenerating' in st.session_state: | |
| try: | |
| session_id, message_index = st.session_state['regenerating'] | |
| with st.form(key="regenerate_form"): | |
| operator_input = st.text_input("Enter additional instructions for regeneration:") | |
| submit_button = st.form_submit_button("Submit") | |
| if submit_button: | |
| del st.session_state['regenerating'] # Remove the key after submission | |
| regenerate_response(session_id, message_index, operator_input) | |
| st.success("Response regenerated with operator input.") | |
| st.rerun() | |
| except ValueError: | |
| st.error("Invalid regenerating state. Please try again.") | |
| if 'modifying' in st.session_state: | |
| session_id, message_index = st.session_state['modifying'] | |
| conversation = conversation_history.find_one({"session_id": session_id}) | |
| message = conversation['messages'][message_index] | |
| modified_content = st.text_area("Modify the response:", value=message['content']) | |
| if st.button("Save Modification"): | |
| save_modified_response(session_id, message_index, modified_content) | |
| st.success("Response modified and approved") | |
| del st.session_state['modifying'] | |
| st.rerun() | |
| def approve_response(session_id, message_index): | |
| try: | |
| result = conversation_history.update_one( | |
| {"session_id": session_id}, | |
| {"$set": {f"messages.{message_index}.status": "approved"}} | |
| ) | |
| return result.modified_count > 0 | |
| except Exception as e: | |
| st.error(f"Failed to approve response: {str(e)}") | |
| return False | |
| def save_modified_response(session_id, message_index, modified_content): | |
| try: | |
| conversation_history.update_one( | |
| {"session_id": session_id}, | |
| { | |
| "$set": { | |
| f"messages.{message_index}.content": modified_content, | |
| f"messages.{message_index}.status": "approved" | |
| } | |
| } | |
| ) | |
| except Exception as e: | |
| st.error(f"Failed to save modified response: {str(e)}") | |
| def regenerate_response(session_id, message_index, operator_input): | |
| try: | |
| conversation = conversation_history.find_one({"session_id": session_id}) | |
| user_message = conversation['messages'][message_index - 1]['content'] if message_index > 0 else "" | |
| new_response, is_uncertain = get_gpt_response(user_message, system_message=operator_input) | |
| if is_uncertain: | |
| status = "pending" | |
| else: | |
| status = "approved" | |
| conversation_history.update_one( | |
| {"session_id": session_id}, | |
| { | |
| "$set": { | |
| f"messages.{message_index}.content": new_response, | |
| f"messages.{message_index}.status": status | |
| } | |
| } | |
| ) | |
| except Exception as e: | |
| st.error(f"Failed to regenerate response: {str(e)}") | |
| # --- Admin Page --- | |
| def admin_page(): | |
| st.title("๐ ๏ธ Operator Dashboard") | |
| # Add auto-refresh every 10 seconds (10000 milliseconds) | |
| st_autorefresh(interval=10000, limit=None, key="operator_autorefresh") | |
| if st.button("๐ Reload Dashboard"): | |
| st.rerun() | |
| try: | |
| deleted_count = cleanup_old_chats() | |
| if deleted_count is not None: | |
| if deleted_count > 0: | |
| st.success(f"Cleaned up {deleted_count} inactive chat(s).") | |
| else: | |
| st.info("No inactive chats to clean up.") | |
| else: | |
| st.warning("Unable to perform cleanup. Please check the database connection.") | |
| tab1, tab2 = st.tabs([ | |
| "๐ Current Chats", | |
| "๐ง Admin Intervention", | |
| ]) | |
| with tab1: | |
| st.header("Current Chats") | |
| recent_chats = fetch_recent_chats() | |
| if not recent_chats: | |
| st.info("No recent chats found.") | |
| else: | |
| cols_per_row = 3 | |
| for i in range(0, len(recent_chats), cols_per_row): | |
| cols = st.columns(cols_per_row) | |
| for j, chat in enumerate(recent_chats[i:i + cols_per_row]): | |
| with cols[j]: | |
| with st.expander(f"Session: {chat['session_id'][:8]}...", expanded=False): | |
| display_chat_preview(chat) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("View Full Chat", key=f"view_{chat['session_id']}"): | |
| st.session_state['selected_chat'] = chat['session_id'] | |
| st.rerun() | |
| with col2: | |
| if st.button("Delete Chat", key=f"delete_{chat['session_id']}"): | |
| delete_chat(chat['session_id']) | |
| st.rerun() | |
| with tab2: | |
| handle_admin_intervention() | |
| st.caption(f"Last refreshed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| except (RerunException, StopException): | |
| raise | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| # --- Fetch Recent Chats --- | |
| def fetch_recent_chats(): | |
| return list(conversation_history.find({}, | |
| {"session_id": 1, "last_updated": 1, "messages": {"$slice": 3}}) | |
| .sort("last_updated", -1) | |
| .limit(10)) | |
| # --- Display Chat Preview --- | |
| def display_chat_preview(chat): | |
| st.subheader(f"Session: {chat['session_id'][:8]}...") | |
| last_updated = chat.get('last_updated', datetime.utcnow()) | |
| st.caption(f"Last updated: {last_updated.strftime('%Y-%m-%d %H:%M:%S')}") | |
| for message in chat.get('messages', [])[:3]: | |
| with st.chat_message(message['role']): | |
| st.markdown(f"**{message['role'].capitalize()}**: {message['content'][:100]}...") | |
| st.divider() | |
| # --- Delete Chat --- | |
| def delete_chat(session_id): | |
| try: | |
| result = conversation_history.delete_one({"session_id": session_id}) | |
| if result.deleted_count > 0: | |
| st.success(f"Chat {session_id[:8]}... deleted successfully.") | |
| else: | |
| st.error("Failed to delete chat. Please try again.") | |
| except Exception as e: | |
| st.error(f"Error deleting chat: {str(e)}") | |
| # --- Cleanup Old Chats --- | |
| def cleanup_old_chats(): | |
| try: | |
| cutoff_time = datetime.utcnow() - timedelta(minutes=5) | |
| result = conversation_history.delete_many({"last_updated": {"$lt": cutoff_time}}) | |
| return result.deleted_count | |
| except Exception as e: | |
| print(f"Error during chat cleanup: {str(e)}") | |
| return None | |
| # --- GPT Response Function --- | |
| def get_gpt_response(prompt, context="", system_message=None): | |
| """ | |
| Generates a response from the GPT model based on the user prompt and retrieved context. | |
| Incorporates the global common memory and optional system message. | |
| Returns a tuple of (response, is_uncertain). | |
| """ | |
| try: | |
| common_memory = get_global_common_memory() | |
| system_msg = ( | |
| "You are a helpful assistant. Use the following context and global common memory " | |
| "to inform your responses, but don't mention them explicitly unless directly relevant to the user's question." | |
| ) | |
| if system_message: | |
| system_msg += f"\n\nOperator Instructions:\n{system_message}" | |
| if common_memory: | |
| memory_str = "\n".join(common_memory) | |
| system_msg += f"\n\nGlobal Common Memory:\n{memory_str}" | |
| messages = [ | |
| {"role": "system", "content": system_msg}, | |
| {"role": "user", "content": f"Context: {context}\n\nUser query: {prompt}"} | |
| ] | |
| completion = openai_client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=messages | |
| ) | |
| print(completion) | |
| response = completion.choices[0].message.content.strip() | |
| # TODO: Implement your logic to determine if the response is uncertain | |
| is_uncertain = False # Example placeholder | |
| return response, is_uncertain | |
| except Exception as e: | |
| st.error(f"Error generating response: {str(e)}") | |
| return None, True # Indicates uncertainty due to error | |
| # --- View Full Chat Function --- | |
| def view_full_chat(session_id): | |
| """Display the full chat and provide takeover functionality.""" | |
| # Add a "Go to Dashboard" button at the top | |
| if st.button("๐ Go to Dashboard"): | |
| st.session_state.pop('selected_chat', None) | |
| st.rerun() | |
| conversation = conversation_history.find_one({"session_id": session_id}) | |
| if not conversation: | |
| st.error("Chat not found.") | |
| return | |
| st.header(f"Full Chat - Session ID: {conversation['session_id'][:8]}...") | |
| st.caption(f"Last updated: {conversation.get('last_updated', datetime.utcnow()).strftime('%Y-%m-%d %H:%M:%S')}") | |
| for message in conversation.get('messages', []): | |
| with st.chat_message(message['role']): | |
| st.markdown(f"**{message['role'].capitalize()}**: {message['content']}") | |
| # Takeover functionality | |
| takeover_doc = db.takeover_status.find_one({"session_id": session_id}) | |
| takeover_active = takeover_doc.get("active", False) if takeover_doc else False | |
| if takeover_active: | |
| if st.button("Deactivate Takeover"): | |
| deactivate_takeover(session_id) | |
| st.success("Takeover deactivated.") | |
| st.rerun() | |
| else: | |
| if st.button("Activate Takeover"): | |
| activate_takeover(session_id) | |
| st.success("Takeover activated.") | |
| st.rerun() | |
| # If takeover is active, allow operator to send messages | |
| if takeover_active: | |
| with st.form(key=f"admin_message_form_{session_id}"): | |
| admin_message = st.text_input("Enter message to send to the user:") | |
| submit_button = st.form_submit_button("Send Message") | |
| if submit_button and admin_message: | |
| send_admin_message(session_id, admin_message) | |
| st.success("Admin message sent.") | |
| st.rerun() | |
| # --- Main Function --- | |
| def main(): | |
| try: | |
| if 'selected_chat' in st.session_state: | |
| view_full_chat(st.session_state['selected_chat']) | |
| else: | |
| admin_page() | |
| except (RerunException, StopException): | |
| raise | |
| except Exception as e: | |
| st.error(f"An unexpected error occurred: {str(e)}") | |
| if __name__ == "__main__": | |
| main() |