from __future__ import annotations from typing import TYPE_CHECKING, cast import streamlit as st from state import ( ChatState, PendingChatAction, chat_session_key, get_chat_state, reset_chat_context_state, ) from tabs.chat_shared import ( generate_chat_reply_result, hydrate_chat_state, load_chat_personas, mark_model_loaded, model_load_status, render_chat_selection, ) from tabs.chat_ui import ( GenerationConfig, render_advanced_settings, render_chat_window, render_system_prompt, ) from utils.chat import build_chat_messages, resolve_system_prompt from utils.chat_export import save_chat_export from utils.helpers import format_ndif_status, session_key, widget_key from utils.runtime import cached_model, session_ndif_api_key if TYPE_CHECKING: from persona_data.synth_persona import PersonaData _LAST_PERSONA_ID_KEY = session_key("chat", "last_persona_id") _LAST_PROMPT_MODE_KEY = session_key("chat", "last_prompt_mode") _LAST_COMPARE_MODE_KEY = session_key("chat", "last_compare_mode") _LAST_PROBE_ENABLED_KEY = session_key("chat", "last_probe_enabled") _LAST_TOKEN_CONTRAST_KEY = session_key("chat", "last_token_contrast") def _render_single_chat_footer( *, model_name: str, dataset_source: str, persona: PersonaData, prompt_mode: str, system_prompt: str | None, chat_state: ChatState, generation: GenerationConfig, export_key: str, reset_key: str, on_reset, ) -> None: footer = st.container() with footer: exp_col, rst_col, _spacer = st.columns([0.5, 0.5, 10], gap="xsmall") with exp_col: if st.button( "", icon=":material/download:", key=export_key, help="Export chat", ): save_chat_export( model_name=model_name, dataset_source=dataset_source, persona_id=persona.id, persona_name=getattr(persona, "name", None), prompt_mode=prompt_mode, system_prompt=system_prompt, messages=chat_state["messages"], generation=generation.to_export_dict(), ) st.toast("Exported", icon=":material/check:") with rst_col: if st.button( "", icon=":material/delete_sweep:", key=reset_key, help="Reset chat", ): on_reset() st.rerun() def _handle_single_chat_generation( *, remote: bool, model_name: str, chat_state: ChatState, active_system_prompt: str | None, generation: GenerationConfig, pending_action: PendingChatAction, chat_log, ) -> None: messages = build_chat_messages(active_system_prompt, chat_state["messages"]) status_box = st.empty() def _show_phase(text: str) -> None: status_box.caption(text) def _show_ndif_status(job_id: str, status_name: str, description: str) -> None: status_box.caption( format_ndif_status( job_id, status_name, description, completed_detail="Downloading result...", ) ) with st.spinner("Generating reply..."): _show_phase(model_load_status(model_name)) model = cached_model(model_name=model_name) mark_model_loaded(model_name) _show_phase("Submitting to NDIF..." if remote else "Generating locally...") def _show_error(exc: Exception) -> None: with chat_log: st.error(f"Could not generate a reply: {exc}") st.info("Try a shorter prompt, reset the chat, or switch personas.") reply, error = generate_chat_reply_result( model=model, messages=messages, remote=remote, generation=generation, on_status=_show_ndif_status if remote else None, on_error=_show_error, ndif_api_key=session_ndif_api_key(), ) if error is not None: status_box.empty() if pending_action == "new_user_prompt" and chat_state["messages"]: chat_state["messages"].pop() return if reply is None: status_box.empty() return status_box.empty() chat_state["messages"].append({"role": "assistant", "content": reply.text}) st.rerun() def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None: """Render the chat tab.""" st.title("Chat") st.caption("Chat with a persona, optionally side-by-side or with token contrast.") context_key = chat_session_key(model_name, dataset_source) chat_state = get_chat_state(model_name, dataset_source) hydrate_chat_state( chat_state, persisted_persona_key=_LAST_PERSONA_ID_KEY, persisted_prompt_key=_LAST_PROMPT_MODE_KEY, ) personas = load_chat_personas(dataset_source) if personas is None: return generation, tools = render_advanced_settings( context_key, remote, last_compare_mode_key=_LAST_COMPARE_MODE_KEY, last_probe_enabled_key=_LAST_PROBE_ENABLED_KEY, last_token_contrast_key=_LAST_TOKEN_CONTRAST_KEY, ) if tools.compare_mode: from tabs.compare_chat import render_compare_mode render_compare_mode( remote, model_name, context_key, dataset_source, personas, generation, contrast_enabled=tools.token_contrast, ) return probe_container = st.container() persona_select_key = widget_key(context_key, "persona_select") prompt_mode_select_key = widget_key(context_key, "system_prompt_select") prompt_key = widget_key(context_key, "custom_system_prompt") chat_input_key = widget_key(context_key, "chat_input") pending_key = widget_key(context_key, "pending_prompt") export_key = widget_key(context_key, "export_chat") reset_key = widget_key(context_key, "reset") edit_key = widget_key(context_key, "edit_idx") selection = render_chat_selection( personas, chat_state["persona_id"], chat_state["prompt_mode"], persona_select_key, prompt_mode_select_key, persisted_persona_key=_LAST_PERSONA_ID_KEY, persisted_prompt_key=_LAST_PROMPT_MODE_KEY, column_widths=(2, 1), ) selected_persona = selection.persona prompt_mode = selection.prompt_mode changed_context = selection.changed def _reset_active_chat_context() -> None: reset_chat_context_state( chat_state, selected_persona.id, prompt_mode, chat_input_key, prompt_key, pending_key, ) st.session_state.pop(edit_key, None) active_system_prompt = resolve_system_prompt( persona=selected_persona, mode=prompt_mode, ) if changed_context: had_history = bool(chat_state["messages"]) _reset_active_chat_context() if had_history: st.info("Chat history reset because the persona or system prompt changed.") chat_log = st.container() with chat_log: active_system_prompt = render_system_prompt( prompt_key, prompt_mode, active_system_prompt, on_save=lambda: reset_chat_context_state( chat_state, selected_persona.id, prompt_mode, chat_input_key, pending_key, ), ) with probe_container: if tools.probe_enabled: from tabs.probe_ui import render_probe_inspector render_probe_inspector( context_key=context_key, model_name=model_name, remote=remote, active_system_prompt=active_system_prompt, chat_state=chat_state, enabled=True, ) else: from utils.probe_overlay import clear_overlays clear_overlays(chat_state["messages"]) render_chat_window( chat_log=chat_log, messages=chat_state["messages"], edit_key=edit_key, pending_key=pending_key, ) _render_single_chat_footer( model_name=model_name, dataset_source=dataset_source, persona=selected_persona, prompt_mode=prompt_mode, system_prompt=active_system_prompt, chat_state=chat_state, generation=generation, export_key=export_key, reset_key=reset_key, on_reset=_reset_active_chat_context, ) user_prompt = st.chat_input("Ask something...", key=chat_input_key) if user_prompt: chat_state["messages"].append({"role": "user", "content": user_prompt}) st.session_state[pending_key] = "new_user_prompt" st.rerun() pending_action = cast( PendingChatAction | None, st.session_state.pop(pending_key, None), ) if not pending_action: return _handle_single_chat_generation( remote=remote, model_name=model_name, chat_state=chat_state, active_system_prompt=active_system_prompt, generation=generation, pending_action=pending_action, chat_log=chat_log, )