File size: 3,971 Bytes
db3d901 c607869 db3d901 c607869 db3d901 b279884 db3d901 b279884 db3d901 b279884 db3d901 ae347c6 db3d901 b279884 ae347c6 db3d901 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING
import streamlit as st
from state import ChatState
from tabs.chat_ui import GenerationConfig, render_persona_prompt_controls
from utils.chat import ChatReply, generate_chat_reply
from utils.datasets import load_persona_list
from utils.helpers import session_key
if TYPE_CHECKING:
from persona_data.synth_persona import PersonaData
@dataclass(frozen=True)
class ChatSelection:
persona: PersonaData
prompt_mode: str
changed: bool
_LOADED_MODEL_NAMES_KEY = session_key("chat", "loaded_model_names")
def load_chat_personas(dataset_source: str) -> list[PersonaData] | None:
personas_file_key = session_key("extract", "personas_file")
qa_file_key = session_key("extract", "qa_file")
try:
personas, dataset_status = load_persona_list(
dataset_source,
personas_file=st.session_state.get(personas_file_key),
qa_file=st.session_state.get(qa_file_key),
)
st.caption(dataset_status)
except Exception as exc:
st.error(f"Could not load data: {exc}")
st.info("Check the selected dataset source or upload both JSONL files.")
return None
if not personas:
st.warning("No personas found in the selected dataset.")
st.info("Try a different dataset source or upload a non-empty personas file.")
return None
return personas
def hydrate_chat_state(
state: ChatState,
*,
persisted_persona_key: str,
persisted_prompt_key: str,
default_prompt_mode: str = "templated",
) -> None:
if state["persona_id"] is None:
state["persona_id"] = st.session_state.get(persisted_persona_key)
state["prompt_mode"] = st.session_state.get(
persisted_prompt_key,
default_prompt_mode,
)
def render_chat_selection(
personas: list[PersonaData],
current_persona_id: str | None,
current_prompt_mode: str,
persona_key: str,
prompt_key: str,
*,
persisted_persona_key: str,
persisted_prompt_key: str,
column_widths: tuple[int, int] = (3, 2),
) -> ChatSelection:
selected_persona, prompt_mode, changed = render_persona_prompt_controls(
personas,
current_persona_id,
current_prompt_mode,
persona_key,
prompt_key,
column_widths=column_widths,
)
st.session_state[persisted_persona_key] = selected_persona.id
st.session_state[persisted_prompt_key] = prompt_mode
return ChatSelection(selected_persona, prompt_mode, changed)
def model_load_status(model_name: str) -> str:
"""Return an honest coarse-grained loading label for the current session."""
loaded_names = st.session_state.setdefault(_LOADED_MODEL_NAMES_KEY, set())
return "Using cached model..." if model_name in loaded_names else "Loading model..."
def mark_model_loaded(model_name: str) -> None:
"""Remember that this session has already requested a model once."""
loaded_names = st.session_state.setdefault(_LOADED_MODEL_NAMES_KEY, set())
loaded_names.add(model_name)
def generate_chat_reply_result(
*,
model: object,
messages: list[dict[str, str]],
remote: bool,
generation: GenerationConfig,
on_status: Callable[[str, str, str], None] | None = None,
on_error: Callable[[Exception], None] | None = None,
ndif_api_key: str | None = None,
) -> tuple[ChatReply | None, Exception | None]:
try:
return (
generate_chat_reply(
model=model,
messages=messages,
remote=remote,
on_status=on_status,
ndif_api_key=ndif_api_key,
**generation.to_generate_kwargs(),
),
None,
)
except Exception as exc:
if on_error is not None:
on_error(exc)
return None, exc
|