File size: 3,971 Bytes
db3d901
 
 
 
c607869
db3d901
 
 
 
 
 
 
 
 
c607869
 
 
db3d901
 
 
 
 
 
 
 
b279884
 
 
db3d901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b279884
 
 
 
 
 
 
 
 
 
 
 
 
 
db3d901
 
 
 
 
 
b279884
db3d901
ae347c6
db3d901
 
 
 
 
 
 
b279884
ae347c6
db3d901
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING

import streamlit as st

from state import ChatState
from tabs.chat_ui import GenerationConfig, render_persona_prompt_controls
from utils.chat import ChatReply, generate_chat_reply
from utils.datasets import load_persona_list
from utils.helpers import session_key

if TYPE_CHECKING:
    from persona_data.synth_persona import PersonaData


@dataclass(frozen=True)
class ChatSelection:
    persona: PersonaData
    prompt_mode: str
    changed: bool


_LOADED_MODEL_NAMES_KEY = session_key("chat", "loaded_model_names")


def load_chat_personas(dataset_source: str) -> list[PersonaData] | None:
    personas_file_key = session_key("extract", "personas_file")
    qa_file_key = session_key("extract", "qa_file")
    try:
        personas, dataset_status = load_persona_list(
            dataset_source,
            personas_file=st.session_state.get(personas_file_key),
            qa_file=st.session_state.get(qa_file_key),
        )
        st.caption(dataset_status)
    except Exception as exc:
        st.error(f"Could not load data: {exc}")
        st.info("Check the selected dataset source or upload both JSONL files.")
        return None

    if not personas:
        st.warning("No personas found in the selected dataset.")
        st.info("Try a different dataset source or upload a non-empty personas file.")
        return None
    return personas


def hydrate_chat_state(
    state: ChatState,
    *,
    persisted_persona_key: str,
    persisted_prompt_key: str,
    default_prompt_mode: str = "templated",
) -> None:
    if state["persona_id"] is None:
        state["persona_id"] = st.session_state.get(persisted_persona_key)
        state["prompt_mode"] = st.session_state.get(
            persisted_prompt_key,
            default_prompt_mode,
        )


def render_chat_selection(
    personas: list[PersonaData],
    current_persona_id: str | None,
    current_prompt_mode: str,
    persona_key: str,
    prompt_key: str,
    *,
    persisted_persona_key: str,
    persisted_prompt_key: str,
    column_widths: tuple[int, int] = (3, 2),
) -> ChatSelection:
    selected_persona, prompt_mode, changed = render_persona_prompt_controls(
        personas,
        current_persona_id,
        current_prompt_mode,
        persona_key,
        prompt_key,
        column_widths=column_widths,
    )
    st.session_state[persisted_persona_key] = selected_persona.id
    st.session_state[persisted_prompt_key] = prompt_mode
    return ChatSelection(selected_persona, prompt_mode, changed)


def model_load_status(model_name: str) -> str:
    """Return an honest coarse-grained loading label for the current session."""

    loaded_names = st.session_state.setdefault(_LOADED_MODEL_NAMES_KEY, set())
    return "Using cached model..." if model_name in loaded_names else "Loading model..."


def mark_model_loaded(model_name: str) -> None:
    """Remember that this session has already requested a model once."""

    loaded_names = st.session_state.setdefault(_LOADED_MODEL_NAMES_KEY, set())
    loaded_names.add(model_name)


def generate_chat_reply_result(
    *,
    model: object,
    messages: list[dict[str, str]],
    remote: bool,
    generation: GenerationConfig,
    on_status: Callable[[str, str, str], None] | None = None,
    on_error: Callable[[Exception], None] | None = None,
    ndif_api_key: str | None = None,
) -> tuple[ChatReply | None, Exception | None]:
    try:
        return (
            generate_chat_reply(
                model=model,
                messages=messages,
                remote=remote,
                on_status=on_status,
                ndif_api_key=ndif_api_key,
                **generation.to_generate_kwargs(),
            ),
            None,
        )
    except Exception as exc:
        if on_error is not None:
            on_error(exc)
        return None, exc