from __future__ import annotations from collections.abc import Callable from dataclasses import dataclass from typing import TYPE_CHECKING import streamlit as st from state import ChatState from tabs.chat_ui import GenerationConfig, render_persona_prompt_controls from utils.chat import ChatReply, generate_chat_reply from utils.datasets import load_persona_list from utils.helpers import session_key if TYPE_CHECKING: from persona_data.synth_persona import PersonaData @dataclass(frozen=True) class ChatSelection: persona: PersonaData prompt_mode: str changed: bool _LOADED_MODEL_NAMES_KEY = session_key("chat", "loaded_model_names") def load_chat_personas(dataset_source: str) -> list[PersonaData] | None: personas_file_key = session_key("extract", "personas_file") qa_file_key = session_key("extract", "qa_file") try: personas, dataset_status = load_persona_list( dataset_source, personas_file=st.session_state.get(personas_file_key), qa_file=st.session_state.get(qa_file_key), ) st.caption(dataset_status) except Exception as exc: st.error(f"Could not load data: {exc}") st.info("Check the selected dataset source or upload both JSONL files.") return None if not personas: st.warning("No personas found in the selected dataset.") st.info("Try a different dataset source or upload a non-empty personas file.") return None return personas def hydrate_chat_state( state: ChatState, *, persisted_persona_key: str, persisted_prompt_key: str, default_prompt_mode: str = "templated", ) -> None: if state["persona_id"] is None: state["persona_id"] = st.session_state.get(persisted_persona_key) state["prompt_mode"] = st.session_state.get( persisted_prompt_key, default_prompt_mode, ) def render_chat_selection( personas: list[PersonaData], current_persona_id: str | None, current_prompt_mode: str, persona_key: str, prompt_key: str, *, persisted_persona_key: str, persisted_prompt_key: str, column_widths: tuple[int, int] = (3, 2), ) -> ChatSelection: selected_persona, prompt_mode, changed = render_persona_prompt_controls( personas, current_persona_id, current_prompt_mode, persona_key, prompt_key, column_widths=column_widths, ) st.session_state[persisted_persona_key] = selected_persona.id st.session_state[persisted_prompt_key] = prompt_mode return ChatSelection(selected_persona, prompt_mode, changed) def model_load_status(model_name: str) -> str: """Return an honest coarse-grained loading label for the current session.""" loaded_names = st.session_state.setdefault(_LOADED_MODEL_NAMES_KEY, set()) return "Using cached model..." if model_name in loaded_names else "Loading model..." def mark_model_loaded(model_name: str) -> None: """Remember that this session has already requested a model once.""" loaded_names = st.session_state.setdefault(_LOADED_MODEL_NAMES_KEY, set()) loaded_names.add(model_name) def generate_chat_reply_result( *, model: object, messages: list[dict[str, str]], remote: bool, generation: GenerationConfig, on_status: Callable[[str, str, str], None] | None = None, on_error: Callable[[Exception], None] | None = None, ndif_api_key: str | None = None, ) -> tuple[ChatReply | None, Exception | None]: try: return ( generate_chat_reply( model=model, messages=messages, remote=remote, on_status=on_status, ndif_api_key=ndif_api_key, **generation.to_generate_kwargs(), ), None, ) except Exception as exc: if on_error is not None: on_error(exc) return None, exc