Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| import streamlit as st | |
| from streamlit.runtime.scriptrunner import RerunException, StopException | |
| from openai import OpenAI | |
| from pymongo import MongoClient | |
| from pinecone import Pinecone | |
| import uuid | |
| from datetime import datetime | |
| import time | |
| from streamlit.runtime.caching import cache_data | |
| from streamlit_autorefresh import st_autorefresh | |
| # Load environment variables | |
| load_dotenv() | |
| # Configuration | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| MONGODB_URI = os.getenv("MONGODB_URI") | |
| PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") | |
| PINECONE_ENVIRONMENT = os.getenv("PINECONE_ENVIRONMENT") | |
| PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME") | |
| GLOBAL_MEMORY_ID = "global_common_memory_id" # Added GLOBAL_MEMORY_ID | |
| openai_client = OpenAI(api_key=OPENAI_API_KEY) | |
| mongo_client = MongoClient(MONGODB_URI) | |
| db = mongo_client["Wall_Street"] | |
| conversation_history = db["conversation_history"] | |
| global_common_memory = db["global_common_memory"] # New global common memory collection | |
| # Initialize GLOBAL_MEMORY_ID if it doesn't exist | |
| if not global_common_memory.find_one({"memory_id": GLOBAL_MEMORY_ID}): | |
| global_common_memory.insert_one({ | |
| "memory_id": GLOBAL_MEMORY_ID, | |
| "memory": [] | |
| }) | |
| # Initialize Pinecone | |
| pc = Pinecone(api_key=PINECONE_API_KEY) | |
| pinecone_index = pc.Index(PINECONE_INDEX_NAME) | |
| # Set up Streamlit page configuration | |
| st.set_page_config(page_title="GPT-Driven Chat System - Tester", page_icon="🔬", layout="wide") | |
| # Custom CSS to improve the UI | |
| st.markdown(""" | |
| <style> | |
| /* Your custom CSS styles */ | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Initialize Streamlit session state | |
| if 'chat_history' not in st.session_state: | |
| st.session_state['chat_history'] = [] | |
| if 'user_type' not in st.session_state: | |
| st.session_state['user_type'] = None | |
| if 'session_id' not in st.session_state: | |
| st.session_state['session_id'] = str(uuid.uuid4()) | |
| # --- Common Memory Functions --- | |
| # Cache for 5 minutes | |
| def get_global_common_memory(): | |
| """Retrieve the global common memory.""" | |
| memory_doc = global_common_memory.find_one({"memory_id": GLOBAL_MEMORY_ID}) | |
| return memory_doc.get('memory', []) if memory_doc else [] | |
| def append_to_global_common_memory(content): | |
| """Append content to the global common memory.""" | |
| try: | |
| # First, ensure the document exists with an initialized memory array | |
| global_common_memory.update_one( | |
| {"memory_id": GLOBAL_MEMORY_ID}, | |
| {"$setOnInsert": {"memory": []}}, | |
| upsert=True | |
| ) | |
| # Then, add the new content to the memory array | |
| result = global_common_memory.update_one( | |
| {"memory_id": GLOBAL_MEMORY_ID}, | |
| {"$push": {"memory": content}} | |
| ) | |
| # Invalidate the cache after updating | |
| get_global_common_memory.clear() | |
| st.success("Memory appended successfully!") | |
| # Instead of using st.rerun(), we'll set a flag in session state | |
| st.session_state['memory_updated'] = True | |
| except Exception as e: | |
| st.error(f"Failed to append to global common memory: {str(e)}") | |
| def clear_global_common_memory(): | |
| """Clear all items from the global common memory.""" | |
| try: | |
| global_common_memory.update_one( | |
| {"memory_id": GLOBAL_MEMORY_ID}, | |
| {"$set": {"memory": []}}, | |
| upsert=True | |
| ) | |
| # Invalidate the cache after clearing | |
| get_global_common_memory.clear() | |
| st.success("Global common memory cleared successfully!") | |
| except Exception as e: | |
| st.error(f"Failed to clear global common memory: {str(e)}") | |
| # --- Relevant Context Retrieval --- | |
| # Cache for 1 minute | |
| def get_relevant_context(query, top_k=3): | |
| """ | |
| Retrieve relevant context from Pinecone based on the user query. | |
| """ | |
| try: | |
| query_embedding = openai_client.embeddings.create( | |
| model="text-embedding-3-large", # Updated to use the larger model | |
| input=query | |
| ).data[0].embedding | |
| results = pinecone_index.query(vector=query_embedding, top_k=top_k, include_metadata=True) | |
| contexts = [item['metadata']['text'] for item in results['matches']] | |
| return " ".join(contexts) | |
| except Exception as e: | |
| print(f"Error retrieving context: {str(e)}") | |
| return "" | |
| # --- GPT Response Function --- | |
| def get_gpt_response(prompt, context="", system_message=None): | |
| try: | |
| common_memory = get_global_common_memory() | |
| system_msg = ( | |
| "You are a helpful assistant. Use the following context and global common memory " | |
| "to inform your responses, but don't mention them explicitly unless directly relevant to the user's question." | |
| ) | |
| if system_message: | |
| system_msg += f"\n\nTrainer Instructions:\n{system_message}" | |
| if common_memory: | |
| memory_str = "\n".join(common_memory) | |
| system_msg += f"\n\nGlobal Common Memory:\n{memory_str}" | |
| messages = [ | |
| {"role": "system", "content": system_msg}, | |
| {"role": "user", "content": f"Context: {context}\n\nUser query: {prompt}"} | |
| ] | |
| completion = openai_client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=messages | |
| ) | |
| response = completion.choices[0].message.content.strip() | |
| return response | |
| except Exception as e: | |
| st.error(f"Error generating response: {str(e)}") | |
| return None | |
| # --- Send User Message --- | |
| def send_message(message): | |
| """ | |
| Sends a user message. If admin takeover is active, messages are sent to admin instead of GPT. | |
| """ | |
| context = get_relevant_context(message) | |
| user_message = { | |
| "role": "user", | |
| "content": message, | |
| "timestamp": datetime.utcnow(), | |
| "status": "approved" # User messages are always approved | |
| } | |
| # Upsert the user message immediately | |
| result = conversation_history.update_one( | |
| {"session_id": st.session_state['session_id']}, | |
| { | |
| "$push": {"messages": user_message}, | |
| "$set": {"last_updated": datetime.utcnow()}, | |
| "$setOnInsert": {"created_at": datetime.utcnow()} | |
| }, | |
| upsert=True | |
| ) | |
| # Update the session state with the user message | |
| st.session_state['chat_history'].append(user_message) | |
| if not st.session_state.get('admin_takeover_active'): | |
| # Generate GPT response if takeover is not active | |
| gpt_response = get_gpt_response(message, context) | |
| assistant_message = { | |
| "role": "assistant", | |
| "content": gpt_response, | |
| "timestamp": datetime.utcnow(), | |
| "status": "pending" # Set status to pending for admin approval | |
| } | |
| # Upsert the assistant message | |
| result = conversation_history.update_one( | |
| {"session_id": st.session_state['session_id']}, | |
| { | |
| "$push": {"messages": assistant_message}, | |
| "$set": {"last_updated": datetime.utcnow()} | |
| } | |
| ) | |
| # Update the session state with the assistant message | |
| st.session_state['chat_history'].append(assistant_message) | |
| # --- Send Admin Message --- | |
| def send_admin_message(message): | |
| """ | |
| Sends an admin message directly to the user during a takeover. | |
| """ | |
| admin_message = { | |
| "role": "admin", | |
| "content": message, | |
| "timestamp": datetime.utcnow(), | |
| "status": "approved" | |
| } | |
| # Upsert the admin message | |
| result = conversation_history.update_one( | |
| {"session_id": st.session_state['session_id']}, | |
| { | |
| "$push": {"messages": admin_message}, | |
| "$set": {"last_updated": datetime.utcnow()} | |
| } | |
| ) | |
| # Update the session state with the admin message | |
| st.session_state['chat_history'].append(admin_message) | |
| # --- Takeover Functions --- | |
| def activate_takeover(session_id): | |
| """ | |
| Activates takeover mode for the given session. | |
| """ | |
| try: | |
| db.takeover_status.update_one( | |
| {"session_id": session_id}, | |
| {"$set": {"active": True, "activated_at": datetime.utcnow()}}, | |
| upsert=True | |
| ) | |
| st.success(f"Takeover activated for session {session_id[:8]}...") | |
| except Exception as e: | |
| st.error(f"Failed to activate takeover: {str(e)}") | |
| def deactivate_takeover(session_id): | |
| """ | |
| Deactivates takeover mode for the given session. | |
| """ | |
| try: | |
| db.takeover_status.update_one( | |
| {"session_id": session_id}, | |
| {"$set": {"active": False}}, | |
| ) | |
| st.success(f"Takeover deactivated for session {session_id[:8]}...") | |
| except Exception as e: | |
| st.error(f"Failed to deactivate takeover: {str(e)}") | |
| def handle_admin_takeover(session_id): | |
| st.subheader("Admin Takeover") | |
| takeover_active = db.takeover_status.find_one({"session_id": session_id}) | |
| is_active = takeover_active.get("active", False) if takeover_active else False | |
| if is_active: | |
| st.info("Takeover is currently active for this session.") | |
| if st.button("Deactivate Takeover"): | |
| deactivate_takeover(session_id) | |
| st.success("Takeover deactivated.") | |
| st.rerun() | |
| else: | |
| st.warning("Takeover is not active for this session.") | |
| if st.button("Activate Takeover"): | |
| activate_takeover(session_id) | |
| st.success("Takeover activated.") | |
| st.rerun() | |
| if is_active: | |
| admin_message = st.text_area("Send Message to User", key="admin_message") | |
| if st.button("Send Admin Message"): | |
| admin_message = st.session_state.get("admin_message", "") | |
| if admin_message.strip(): | |
| send_admin_message(admin_message.strip()) | |
| st.success("Admin message sent successfully!") | |
| st.session_state["admin_message"] = "" | |
| else: | |
| st.warning("Please enter a message to send.") | |
| # --- View Full Chat (User Perspective) --- | |
| def view_full_chat(session_id): | |
| st.title(f"Full Chat View - Session: {session_id[:8]}...") | |
| chat = db.chat_history.find_one({"session_id": session_id}) | |
| if not chat: | |
| st.error("Chat not found.") | |
| return | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.subheader(f"Session ID: {session_id}") | |
| with col2: | |
| st.write(f"Last Updated: {chat.get('last_updated', 'N/A')}") | |
| st.markdown("---") | |
| for message in chat.get('messages', []): | |
| role = message['role'].capitalize() | |
| content = message['content'] | |
| timestamp = message.get('timestamp', 'N/A') | |
| if role == 'User': | |
| with st.chat_message("user"): | |
| st.markdown(f"**User** - {timestamp}") | |
| st.markdown(content) | |
| elif role == 'Assistant': | |
| with st.chat_message("assistant"): | |
| st.markdown(f"**Assistant** - {timestamp}") | |
| st.markdown(content) | |
| elif role == 'Admin': | |
| with st.chat_message("human"): | |
| st.markdown(f"**Admin** - {timestamp}") | |
| st.markdown(content) | |
| st.markdown("---") | |
| # Add text box to append to global memory | |
| st.subheader("Add to Global Memory") | |
| new_memory = st.text_area("Enter new memory item", key=f"new_memory_input_{session_id}") | |
| if st.button("Add Memory", key=f"add_memory_button_{session_id}"): | |
| if new_memory.strip(): | |
| append_to_global_common_memory(new_memory.strip()) | |
| st.success("New memory item added to global memory!") | |
| # Instead of rerunning, we'll update the session state | |
| st.session_state[f'memory_added_{session_id}'] = True | |
| st.rerun() | |
| else: | |
| st.warning("Please enter a valid memory item.") | |
| # Display success message if memory was added | |
| if st.session_state.get(f'memory_added_{session_id}'): | |
| st.success("Memory item added successfully!") | |
| # Clear the flag | |
| del st.session_state[f'memory_added_{session_id}'] | |
| st.markdown("---") | |
| col1, col2, col3 = st.columns([1, 1, 1]) | |
| with col2: | |
| if st.button("Back to Chat History", use_container_width=True): | |
| st.session_state.pop('full_chat_view', None) | |
| st.rerun() | |
| # --- Clear Global Chat Memory--- | |
| def clear_global_common_memory(): | |
| """Clear all items from the global common memory.""" | |
| try: | |
| global_common_memory.update_one( | |
| {"memory_id": GLOBAL_MEMORY_ID}, | |
| {"$set": {"memory": []}}, | |
| upsert=True | |
| ) | |
| # Invalidate the cache after clearing | |
| get_global_common_memory.clear() | |
| st.success("Global common memory cleared successfully!") | |
| except Exception as e: | |
| st.error(f"Failed to clear global common memory: {str(e)}") | |
| def display_chat_history(): | |
| st.subheader("All Chat History") | |
| all_chats = list(db.chat_history.find().sort("last_updated", -1)) | |
| if not all_chats: | |
| st.info("No chat history found.") | |
| return | |
| for idx, chat in enumerate(all_chats): | |
| session_id = chat['session_id'] | |
| last_updated = chat.get('last_updated', 'N/A') | |
| with st.expander(f"Session: {session_id[:8]}... - Last Updated: {last_updated}"): | |
| if chat.get('messages'): | |
| last_message = chat['messages'][-1] | |
| st.markdown(f"Last message ({last_message['role'].capitalize()}):") | |
| st.markdown(f"> {last_message['content'][:100]}...") | |
| if st.button(f"Show Full Chat", key=f"show_full_chat_{idx}"): | |
| st.session_state['full_chat_view'] = session_id | |
| st.rerun() | |
| def trainer_intervention_tab(): | |
| st.subheader("Trainer Intervention") | |
| # Handle admin intervention | |
| handle_admin_intervention() | |
| def handle_admin_intervention(): | |
| st.subheader("Review Pending Responses") | |
| pending_responses = conversation_history.find( | |
| {"messages.role": "assistant", "messages.status": "pending"} | |
| ) | |
| for conversation in pending_responses: | |
| st.write(f"Session ID: {conversation['session_id'][:8]}...") | |
| for i, message in enumerate(conversation['messages']): | |
| if message['role'] == 'assistant' and message.get('status') == 'pending': | |
| user_message = conversation['messages'][i-1]['content'] if i > 0 else "N/A" | |
| st.write(f"**User:** {user_message}") | |
| st.write(f"**GPT:** {message['content']}") | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| if st.button("Approve", key=f"approve_{conversation['session_id']}_{i}"): | |
| if approve_response(conversation['session_id'], i): | |
| st.success("Response approved") | |
| time.sleep(0.5) | |
| st.rerun() | |
| with col2: | |
| if st.button("Modify", key=f"modify_{conversation['session_id']}_{i}"): | |
| st.session_state['modifying'] = (conversation['session_id'], i) | |
| st.rerun() | |
| with col3: | |
| if st.button("Regenerate", key=f"regenerate_{conversation['session_id']}_{i}"): | |
| st.session_state['regenerating'] = (conversation['session_id'], i) | |
| st.rerun() | |
| with col4: | |
| takeover_doc = db.takeover_status.find_one({"session_id": conversation['session_id']}) | |
| takeover_active = takeover_doc.get("active", False) if takeover_doc else False | |
| if takeover_active: | |
| if st.button("Deactivate Takeover", key=f"deactivate_takeover_{conversation['session_id']}_{i}"): | |
| deactivate_takeover(conversation['session_id']) | |
| st.success("Takeover deactivated.") | |
| st.rerun() | |
| else: | |
| if st.button("Activate Takeover", key=f"activate_takeover_{conversation['session_id']}_{i}"): | |
| activate_takeover(conversation['session_id']) | |
| st.success("Takeover activated.") | |
| st.rerun() | |
| st.divider() | |
| if 'regenerating' in st.session_state: | |
| session_id, message_index = st.session_state['regenerating'] | |
| with st.form(key="regenerate_form"): | |
| operator_input = st.text_input("Enter additional instructions for regeneration:") | |
| submit_button = st.form_submit_button("Submit") | |
| if submit_button: | |
| del st.session_state['regenerating'] | |
| regenerate_response(session_id, message_index, operator_input) | |
| st.success("Response regenerated with operator input.") | |
| st.rerun() | |
| if 'modifying' in st.session_state: | |
| session_id, message_index = st.session_state['modifying'] | |
| conversation = conversation_history.find_one({"session_id": session_id}) | |
| message = conversation['messages'][message_index] | |
| modified_content = st.text_area("Modify the response:", value=message['content']) | |
| if st.button("Save Modification"): | |
| save_modified_response(session_id, message_index, modified_content) | |
| st.success("Response modified and approved") | |
| del st.session_state['modifying'] | |
| st.rerun() | |
| def approve_response(session_id, message_index): | |
| try: | |
| result = conversation_history.update_one( | |
| {"session_id": session_id}, | |
| {"$set": {f"messages.{message_index}.status": "approved"}} | |
| ) | |
| return result.modified_count > 0 | |
| except Exception as e: | |
| st.error(f"Failed to approve response: {str(e)}") | |
| return False | |
| def save_modified_response(session_id, message_index, modified_content): | |
| try: | |
| conversation_history.update_one( | |
| {"session_id": session_id}, | |
| { | |
| "$set": { | |
| f"messages.{message_index}.content": modified_content, | |
| f"messages.{message_index}.status": "approved" | |
| } | |
| } | |
| ) | |
| except Exception as e: | |
| st.error(f"Failed to save modified response: {str(e)}") | |
| def regenerate_response(session_id, message_index, operator_input): | |
| try: | |
| conversation = conversation_history.find_one({"session_id": session_id}) | |
| user_message = conversation['messages'][message_index - 1]['content'] if message_index > 0 else "" | |
| new_response = get_gpt_response(user_message, system_message=operator_input) | |
| conversation_history.update_one( | |
| {"session_id": session_id}, | |
| { | |
| "$set": { | |
| f"messages.{message_index}.content": new_response, | |
| f"messages.{message_index}.status": "pending" | |
| } | |
| } | |
| ) | |
| except Exception as e: | |
| st.error(f"Failed to regenerate response: {str(e)}") | |
| def trainer_page(): | |
| st.title("Trainer Dashboard") | |
| # Add auto-refresh every 10 seconds (10000 milliseconds) | |
| st_autorefresh(interval=10000, limit=None, key="trainer_autorefresh") | |
| tab1, tab2, tab3 = st.tabs(["Current Status", "Chat History", "Intervention"]) | |
| with tab1: | |
| # Display current global memory | |
| st.subheader("Current Global Memory") | |
| global_memory = get_global_common_memory() | |
| if global_memory: | |
| for idx, item in enumerate(global_memory, 1): | |
| st.text(f"{idx}. {item}") | |
| else: | |
| st.info("No global memory items found.") | |
| # Add button to clear global memory | |
| if st.button("Clear Global Memory", key="clear_global_memory"): | |
| clear_global_common_memory() | |
| st.success("Global memory cleared successfully!") | |
| time.sleep(1) | |
| st.rerun() | |
| # Display current chats | |
| st.subheader("Active Chats") | |
| chats = list(conversation_history.find().sort("last_updated", -1).limit(5)) | |
| for idx, chat in enumerate(chats): | |
| with st.expander(f"Session: {chat['session_id'][:8]}... - Last Updated: {chat.get('last_updated', 'N/A')}"): | |
| for message in chat.get('messages', [])[-5:]: | |
| role = message['role'].capitalize() | |
| content = message['content'] | |
| st.markdown(f"**{role}:** {content}") | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| if st.button(f"View Full Chat", key=f"view_chat_{idx}"): | |
| st.session_state['selected_chat'] = chat['session_id'] | |
| st.rerun() | |
| with col2: | |
| takeover_doc = db.takeover_status.find_one({"session_id": chat['session_id']}) | |
| takeover_active = takeover_doc.get("active", False) if takeover_doc else False | |
| if takeover_active: | |
| if st.button(f"Deactivate Takeover", key=f"deactivate_takeover_{idx}"): | |
| deactivate_takeover(chat['session_id']) | |
| st.success("Takeover deactivated.") | |
| st.rerun() | |
| else: | |
| if st.button(f"Activate Takeover", key=f"activate_takeover_{idx}"): | |
| activate_takeover(chat['session_id']) | |
| st.success("Takeover activated.") | |
| st.rerun() | |
| with col3: | |
| if st.button(f"Delete Chat", key=f"delete_chat_{idx}"): | |
| delete_chat(chat['session_id']) | |
| st.success(f"Chat {chat['session_id'][:8]}... deleted.") | |
| st.rerun() | |
| with col4: | |
| if takeover_active: | |
| st.text_input("Send message", key=f"takeover_message_{idx}") | |
| if st.button("Send", key=f"send_takeover_{idx}"): | |
| message = st.session_state[f"takeover_message_{idx}"] | |
| if message.strip(): | |
| send_admin_message(chat['session_id'], message.strip()) | |
| st.success("Message sent.") | |
| st.rerun() | |
| else: | |
| st.warning("Please enter a message to send.") | |
| # Manual refresh button | |
| if st.button("Refresh", key="refresh_button"): | |
| st.rerun() | |
| with tab2: | |
| display_chat_history() | |
| with tab3: | |
| trainer_intervention_tab() | |
| def delete_chat(session_id): | |
| try: | |
| result = conversation_history.delete_one({"session_id": session_id}) | |
| if result.deleted_count == 0: | |
| st.error("Failed to delete chat. Please try again.") | |
| except Exception as e: | |
| st.error(f"Error deleting chat: {str(e)}") | |
| # --- Main Function --- | |
| def main(): | |
| try: | |
| if 'memory_updated' in st.session_state: | |
| del st.session_state['memory_updated'] | |
| st.rerun() | |
| if 'full_chat_view' in st.session_state: | |
| view_full_chat(st.session_state['full_chat_view']) | |
| elif 'selected_chat' in st.session_state: | |
| view_full_chat(st.session_state['selected_chat']) | |
| else: | |
| trainer_page() | |
| except (RerunException, StopException): | |
| raise | |
| except Exception as e: | |
| st.error(f"An unexpected error occurred: {str(e)}") | |
| if __name__ == "__main__": | |
| main() |