| from __future__ import annotations |
|
|
| from collections.abc import Callable |
| from dataclasses import dataclass |
| from typing import TYPE_CHECKING |
|
|
| import streamlit as st |
|
|
| from state import ChatState |
| from tabs.chat_ui import GenerationConfig, render_persona_prompt_controls |
| from utils.chat import ChatReply, generate_chat_reply |
| from utils.datasets import load_persona_list |
| from utils.helpers import session_key |
|
|
| if TYPE_CHECKING: |
| from persona_data.synth_persona import PersonaData |
|
|
|
|
| @dataclass(frozen=True) |
| class ChatSelection: |
| persona: PersonaData |
| prompt_mode: str |
| changed: bool |
|
|
|
|
| _LOADED_MODEL_NAMES_KEY = session_key("chat", "loaded_model_names") |
|
|
|
|
| def load_chat_personas(dataset_source: str) -> list[PersonaData] | None: |
| personas_file_key = session_key("extract", "personas_file") |
| qa_file_key = session_key("extract", "qa_file") |
| try: |
| personas, dataset_status = load_persona_list( |
| dataset_source, |
| personas_file=st.session_state.get(personas_file_key), |
| qa_file=st.session_state.get(qa_file_key), |
| ) |
| st.caption(dataset_status) |
| except Exception as exc: |
| st.error(f"Could not load data: {exc}") |
| st.info("Check the selected dataset source or upload both JSONL files.") |
| return None |
|
|
| if not personas: |
| st.warning("No personas found in the selected dataset.") |
| st.info("Try a different dataset source or upload a non-empty personas file.") |
| return None |
| return personas |
|
|
|
|
| def hydrate_chat_state( |
| state: ChatState, |
| *, |
| persisted_persona_key: str, |
| persisted_prompt_key: str, |
| default_prompt_mode: str = "templated", |
| ) -> None: |
| if state["persona_id"] is None: |
| state["persona_id"] = st.session_state.get(persisted_persona_key) |
| state["prompt_mode"] = st.session_state.get( |
| persisted_prompt_key, |
| default_prompt_mode, |
| ) |
|
|
|
|
| def render_chat_selection( |
| personas: list[PersonaData], |
| current_persona_id: str | None, |
| current_prompt_mode: str, |
| persona_key: str, |
| prompt_key: str, |
| *, |
| persisted_persona_key: str, |
| persisted_prompt_key: str, |
| column_widths: tuple[int, int] = (3, 2), |
| ) -> ChatSelection: |
| selected_persona, prompt_mode, changed = render_persona_prompt_controls( |
| personas, |
| current_persona_id, |
| current_prompt_mode, |
| persona_key, |
| prompt_key, |
| column_widths=column_widths, |
| ) |
| st.session_state[persisted_persona_key] = selected_persona.id |
| st.session_state[persisted_prompt_key] = prompt_mode |
| return ChatSelection(selected_persona, prompt_mode, changed) |
|
|
|
|
| def model_load_status(model_name: str) -> str: |
| """Return an honest coarse-grained loading label for the current session.""" |
|
|
| loaded_names = st.session_state.setdefault(_LOADED_MODEL_NAMES_KEY, set()) |
| return "Using cached model..." if model_name in loaded_names else "Loading model..." |
|
|
|
|
| def mark_model_loaded(model_name: str) -> None: |
| """Remember that this session has already requested a model once.""" |
|
|
| loaded_names = st.session_state.setdefault(_LOADED_MODEL_NAMES_KEY, set()) |
| loaded_names.add(model_name) |
|
|
|
|
| def generate_chat_reply_result( |
| *, |
| model: object, |
| messages: list[dict[str, str]], |
| remote: bool, |
| generation: GenerationConfig, |
| on_status: Callable[[str, str, str], None] | None = None, |
| on_error: Callable[[Exception], None] | None = None, |
| ndif_api_key: str | None = None, |
| ) -> tuple[ChatReply | None, Exception | None]: |
| try: |
| return ( |
| generate_chat_reply( |
| model=model, |
| messages=messages, |
| remote=remote, |
| on_status=on_status, |
| ndif_api_key=ndif_api_key, |
| **generation.to_generate_kwargs(), |
| ), |
| None, |
| ) |
| except Exception as exc: |
| if on_error is not None: |
| on_error(exc) |
| return None, exc |
|
|