import os import secrets import time from typing import cast from urllib.parse import urlencode import requests import streamlit as st from gateway_client import delete_profile, ingest_and_rewrite from llm import chat, set_model from model_config import MODEL_CHOICES, MODEL_TO_PROVIDER, MODEL_DISPLAY_NAMES def _generate_session_name(base: str = "Session") -> str: existing = set(st.session_state.get("session_order", [])) idx = 1 while True: candidate = f"{base} {idx}" if candidate not in existing: return candidate idx += 1 def ensure_session_state() -> None: if "sessions" not in st.session_state: st.session_state.sessions = {} if "session_order" not in st.session_state: st.session_state.session_order = [] if ( "active_session_id" not in st.session_state or st.session_state.active_session_id not in st.session_state.sessions ): default_name = _generate_session_name() st.session_state.sessions.setdefault(default_name, {"history": []}) if default_name not in st.session_state.session_order: st.session_state.session_order.append(default_name) st.session_state.active_session_id = default_name if "session_select" not in st.session_state: st.session_state.session_select = st.session_state.active_session_id if st.session_state.session_select not in st.session_state.sessions: st.session_state.session_select = st.session_state.active_session_id st.session_state.setdefault( "rename_session_name", st.session_state.active_session_id ) st.session_state.setdefault( "rename_session_synced_to", st.session_state.active_session_id ) st.session_state.history = cast( list[dict], st.session_state.sessions[ st.session_state.active_session_id ].setdefault("history", []), ) def create_session(session_name: str | None = None) -> tuple[bool, str]: ensure_session_state() candidate = (session_name or "").strip() if not candidate: candidate = _generate_session_name() if candidate in st.session_state.sessions: return False, candidate st.session_state.sessions[candidate] = {"history": []} st.session_state.session_order.append(candidate) st.session_state.active_session_id = candidate st.session_state.session_select = candidate st.session_state.history = cast( list[dict], st.session_state.sessions[candidate]["history"] ) st.session_state.rename_session_name = candidate st.session_state.rename_session_synced_to = candidate return True, candidate def rename_session(current_name: str, new_name: str) -> bool: ensure_session_state() target = new_name.strip() if not target or target == current_name: return False if target in st.session_state.sessions: return False st.session_state.sessions[target] = st.session_state.sessions.pop(current_name) order = st.session_state.session_order order[order.index(current_name)] = target if st.session_state.active_session_id == current_name: st.session_state.active_session_id = target st.session_state.session_select = target st.session_state.history = cast( list[dict], st.session_state.sessions[st.session_state.active_session_id]["history"], ) st.session_state.rename_session_name = target st.session_state.rename_session_synced_to = target return True def delete_session(session_name: str) -> bool: ensure_session_state() if session_name not in st.session_state.sessions: return False if len(st.session_state.session_order) <= 1: return False st.session_state.sessions.pop(session_name, None) st.session_state.session_order.remove(session_name) if st.session_state.active_session_id == session_name: st.session_state.active_session_id = st.session_state.session_order[-1] st.session_state.session_select = st.session_state.active_session_id st.session_state.rename_session_name = st.session_state.active_session_id st.session_state.rename_session_synced_to = st.session_state.active_session_id st.session_state.history = cast( list[dict], st.session_state.sessions[st.session_state.active_session_id]["history"], ) return True def rewrite_message( msg: str, persona_name: str, show_rationale: bool ) -> str: if persona_name.lower() == "control": rewritten_msg = msg if show_rationale: rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: No personalization applied.'. Begin your answer on the next line." return rewritten_msg try: rewritten_msg = ingest_and_rewrite( user_id=persona_name, query=msg ) if show_rationale: rewritten_msg += " At the beginning of your response, please say the following in ITALIC: 'Persona Rationale: ' followed by 1 sentence about how your reasoning for how the persona traits influenced this response, also in italics. Begin your answer on the next line." except Exception as e: st.error(f"Failed to ingest_and_append message: {e}") raise print(rewritten_msg) return rewritten_msg # ────────────────────────────────────────────────────────────── # Page setup & CSS # ────────────────────────────────────────────────────────────── st.set_page_config(page_title="MemMachine Chatbot", layout="wide") try: with open("./styles.css") as f: st.markdown(f"", unsafe_allow_html=True) except FileNotFoundError: pass ensure_session_state() HEADER_STYLE = """ """ HEADER_HTML = """
""" st.markdown(HEADER_STYLE, unsafe_allow_html=True) st.markdown(HEADER_HTML, unsafe_allow_html=True) # ────────────────────────────────────────────────────────────── # Sidebar # ────────────────────────────────────────────────────────────── default_model = MODEL_CHOICES[0] if MODEL_CHOICES else "gpt-4.1-mini" model_id = default_model provider = MODEL_TO_PROVIDER.get(default_model, "openai") selected_persona = "Charlie" persona_name = "Charlie" skip_rewrite = False compare_personas = False show_rationale = False with st.sidebar: st.markdown("#### Sessions") session_options = st.session_state.session_order active_session = st.session_state.active_session_id if st.session_state.rename_session_synced_to != active_session: st.session_state.rename_session_name = active_session st.session_state.rename_session_synced_to = active_session for idx, session_name in enumerate(session_options, start=1): is_active = session_name == active_session button_label = f"{session_name}" row = st.container() with row: button_col, menu_col = st.columns([0.8, 0.2]) with button_col: if st.button( button_label, key=f"session_button_{session_name}", use_container_width=True, type="primary" if is_active else "secondary", ): if not is_active: st.session_state.active_session_id = session_name st.session_state.session_select = session_name st.session_state.history = cast( list[dict], st.session_state.sessions[session_name]["history"], ) st.session_state.rename_session_name = session_name st.session_state.rename_session_synced_to = session_name st.rerun() with menu_col: if hasattr(st, "popover"): menu_container = st.popover("⋯", use_container_width=True) else: menu_container = st.expander( "⋯", expanded=False, key=f"session_actions_{session_name}" ) with menu_container: st.markdown(f"**Actions for {session_name}**") rename_value = st.text_input( "Rename session", value=session_name, key=f"rename_session_input_{session_name}", ) if st.button( "Rename", use_container_width=True, key=f"rename_session_button_{session_name}", ): rename_target = rename_value.strip() if not rename_target: st.warning("Enter a session name to rename.") elif rename_target == session_name: st.info("Session name unchanged.") elif rename_target in st.session_state.sessions: st.warning(f"Session '{rename_target}' already exists.") elif rename_session(session_name, rename_target): st.success(f"Session renamed to '{rename_target}'.") st.rerun() else: st.error("Unable to rename session. Please try again.") st.divider() if st.button( "Delete session", use_container_width=True, type="secondary", key=f"delete_session_button_{session_name}", ): if delete_session(session_name): new_active = st.session_state.active_session_id st.session_state.session_select = new_active st.session_state.rename_session_name = new_active st.session_state.rename_session_synced_to = new_active st.success(f"Session '{session_name}' deleted.") st.rerun() else: st.warning("Cannot delete the last remaining session.") with st.form("create_session_form", clear_on_submit=True): new_session_name = st.text_input( "New session name", key="create_session_name", placeholder="Leave blank for automatic name", ) if st.form_submit_button("Create session", use_container_width=True): success, created_name = create_session(new_session_name) if success: st.success(f"Session '{created_name}' created.") st.rerun() else: st.warning(f"Session '{created_name}' already exists.") st.divider() st.markdown("#### Choose Model") # Create display options with categories display_options = [MODEL_DISPLAY_NAMES[model] for model in MODEL_CHOICES] selected_display = st.selectbox( "Choose Model", display_options, index=0, label_visibility="collapsed" ) # Get the actual model ID from the display name model_id = next(model for model, display in MODEL_DISPLAY_NAMES.items() if display == selected_display) provider = MODEL_TO_PROVIDER[model_id] set_model(model_id) st.markdown("#### Choose user persona") selected_persona = st.selectbox( "Choose user persona", ["Charlie", "Jing", "Charles", "Control"], label_visibility="collapsed", ) custom_persona = st.text_input("Or enter your name", "") persona_name = ( custom_persona.strip() if custom_persona.strip() else selected_persona ) compare_personas = st.checkbox("Compare with Control persona") show_rationale = st.checkbox("Show Persona Rationale") st.divider() if st.button("Clear chat", use_container_width=True): active = st.session_state.active_session_id st.session_state.sessions[active]["history"].clear() st.session_state.history = cast( list[dict], st.session_state.sessions[active]["history"], ) st.rerun() if st.button("Delete Profile", use_container_width=True): success = delete_profile(persona_name) active = st.session_state.active_session_id st.session_state.sessions[active]["history"].clear() st.session_state.history = cast( list[dict], st.session_state.sessions[active]["history"], ) if success: st.success(f"Profile for '{persona_name}' deleted.") else: st.error(f"Failed to delete profile for '{persona_name}'.") st.divider() # ────────────────────────────────────────────────────────────── # Enforce alternating roles # ────────────────────────────────────────────────────────────── def clean_history(history: list[dict], persona: str) -> list[dict]: out = [] for turn in history: if turn.get("role") == "user": out.append({"role": "user", "content": turn["content"]}) elif turn.get("role") == "assistant" and turn.get("persona") == persona: out.append({"role": "assistant", "content": turn["content"]}) cleaned = [] last_role = None for msg in out: if msg["role"] != last_role: cleaned.append(msg) last_role = msg["role"] return cleaned def append_user_turn(msgs: list[dict], new_user_msg: str) -> list[dict]: if msgs and msgs[-1]["role"] == "user": msgs[-1] = {"role": "user", "content": new_user_msg} else: msgs.append({"role": "user", "content": new_user_msg}) return msgs def typewriter_effect(text: str, speed: float = 0.02): """Generator that yields text word by word to create a typing effect.""" words = text.split(" ") for i, word in enumerate(words): if i == 0: yield word else: yield " " + word time.sleep(speed) msg = st.chat_input("Type your message…") if msg: st.session_state.history.append({"role": "user", "content": msg}) if compare_personas: all_answers = {} rewritten_msg = rewrite_message(msg, persona_name, show_rationale) msgs = clean_history(st.session_state.history, persona_name) msgs = append_user_turn(msgs, rewritten_msg) try: txt, lat, tok, tps = chat(msgs, persona_name) all_answers[persona_name] = txt except ValueError as e: st.error(f"❌ {str(e)}") st.stop() rewritten_msg_control = rewrite_message(msg, "Control", show_rationale) msgs_control = clean_history(st.session_state.history, "Control") msgs_control = append_user_turn(msgs_control, rewritten_msg_control) try: txt_control, lat, tok, tps = chat(msgs_control, "Arnold") all_answers["Control"] = txt_control except ValueError as e: st.error(f"❌ {str(e)}") st.stop() st.session_state.history.append( {"role": "assistant_all", "axis": "role", "content": all_answers, "is_new": True} ) else: rewritten_msg = rewrite_message(msg, persona_name, show_rationale) msgs = clean_history(st.session_state.history, persona_name) msgs = append_user_turn(msgs, rewritten_msg) try: txt, lat, tok, tps = chat( msgs, "Arnold" if persona_name == "Control" else persona_name ) st.session_state.history.append( {"role": "assistant", "persona": persona_name, "content": txt, "is_new": True} ) except ValueError as e: st.error(f"❌ {str(e)}") st.stop() st.rerun() # ────────────────────────────────────────────────────────────── # Chat history display # ────────────────────────────────────────────────────────────── for turn in st.session_state.history: if turn.get("role") == "user": st.chat_message("user").write(turn["content"]) elif turn.get("role") == "assistant": with st.chat_message("assistant"): # Use typing effect for new messages, normal display for old ones if turn.get("is_new", False): st.write_stream(typewriter_effect(turn["content"])) # Mark as no longer new so it displays normally on rerun turn["is_new"] = False else: st.write(turn["content"]) elif turn.get("role") == "assistant_all": content_items = list(turn["content"].items()) is_new = turn.get("is_new", False) if len(content_items) >= 2: cols = st.columns([1, 0.03, 1]) persona_label, persona_response = content_items[0] control_label, control_response = content_items[1] with cols[0]: st.markdown(f"**{persona_label}**") if is_new: st.write_stream(typewriter_effect(persona_response)) else: st.markdown( f'