File size: 9,523 Bytes
db3d901
 
c607869
db3d901
a89a7f1
 
db3d901
 
 
 
 
 
 
 
 
 
 
b279884
 
db3d901
 
77c2d62
 
9ba2da4
77c2d62
 
 
db3d901
a89a7f1
b279884
ae347c6
a89a7f1
c607869
 
 
db3d901
 
 
 
 
eaeaa68
a89a7f1
9ba2da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db3d901
9ba2da4
 
 
b279884
 
 
 
 
 
 
 
 
 
 
 
 
 
9ba2da4
 
b279884
9ba2da4
b279884
 
db3d901
 
9ba2da4
 
 
db3d901
 
 
 
 
 
b279884
db3d901
ae347c6
db3d901
 
b279884
9ba2da4
 
 
db3d901
b279884
db3d901
9ba2da4
b279884
9ba2da4
 
 
 
 
 
 
 
 
 
 
db3d901
 
 
 
 
 
a89a7f1
db3d901
9ba2da4
 
 
 
 
 
 
dc186e4
 
9ba2da4
 
99c28ab
 
a9950fb
a89a7f1
 
 
 
 
77c2d62
9ba2da4
a89a7f1
 
 
d8ae160
 
a89a7f1
 
f4259c0
 
 
 
 
eb41f91
a89a7f1
db3d901
 
 
 
 
 
 
 
 
 
 
 
 
 
a9950fb
 
 
 
 
 
 
 
a89a7f1
a9950fb
 
a89a7f1
 
 
 
 
 
 
a9950fb
a89a7f1
 
 
 
 
eb41f91
93d5dc5
eb41f91
 
 
12cdb17
 
 
 
 
 
 
eb41f91
a89a7f1
d8ae160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77c2d62
93d5dc5
a9950fb
 
 
 
 
 
9ba2da4
 
 
 
 
 
 
 
 
 
 
 
a89a7f1
a9950fb
a89a7f1
 
 
77c2d62
a89a7f1
 
db3d901
 
 
 
77c2d62
a89a7f1
 
9ba2da4
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
from __future__ import annotations

from typing import TYPE_CHECKING, cast

import streamlit as st

from state import (
    ChatState,
    PendingChatAction,
    chat_session_key,
    get_chat_state,
    reset_chat_context_state,
)
from tabs.chat_shared import (
    generate_chat_reply_result,
    hydrate_chat_state,
    load_chat_personas,
    mark_model_loaded,
    model_load_status,
    render_chat_selection,
)
from tabs.chat_ui import (
    GenerationConfig,
    render_advanced_settings,
    render_chat_window,
    render_system_prompt,
)
from utils.chat import build_chat_messages, resolve_system_prompt
from utils.chat_export import save_chat_export
from utils.helpers import format_ndif_status, session_key, widget_key
from utils.runtime import cached_model, session_ndif_api_key

if TYPE_CHECKING:
    from persona_data.synth_persona import PersonaData

_LAST_PERSONA_ID_KEY = session_key("chat", "last_persona_id")
_LAST_PROMPT_MODE_KEY = session_key("chat", "last_prompt_mode")
_LAST_COMPARE_MODE_KEY = session_key("chat", "last_compare_mode")
_LAST_PROBE_ENABLED_KEY = session_key("chat", "last_probe_enabled")
_LAST_TOKEN_CONTRAST_KEY = session_key("chat", "last_token_contrast")


def _render_single_chat_footer(
    *,
    model_name: str,
    dataset_source: str,
    persona: PersonaData,
    prompt_mode: str,
    system_prompt: str | None,
    chat_state: ChatState,
    generation: GenerationConfig,
    export_key: str,
    reset_key: str,
    on_reset,
) -> None:
    footer = st.container()
    with footer:
        exp_col, rst_col, _spacer = st.columns([0.5, 0.5, 10], gap="xsmall")
        with exp_col:
            if st.button(
                "",
                icon=":material/download:",
                key=export_key,
                help="Export chat",
            ):
                save_chat_export(
                    model_name=model_name,
                    dataset_source=dataset_source,
                    persona_id=persona.id,
                    persona_name=getattr(persona, "name", None),
                    prompt_mode=prompt_mode,
                    system_prompt=system_prompt,
                    messages=chat_state["messages"],
                    generation=generation.to_export_dict(),
                )
                st.toast("Exported", icon=":material/check:")
        with rst_col:
            if st.button(
                "",
                icon=":material/delete_sweep:",
                key=reset_key,
                help="Reset chat",
            ):
                on_reset()
                st.rerun()


def _handle_single_chat_generation(
    *,
    remote: bool,
    model_name: str,
    chat_state: ChatState,
    active_system_prompt: str | None,
    generation: GenerationConfig,
    pending_action: PendingChatAction,
    chat_log,
) -> None:
    messages = build_chat_messages(active_system_prompt, chat_state["messages"])
    status_box = st.empty()

    def _show_phase(text: str) -> None:
        status_box.caption(text)

    def _show_ndif_status(job_id: str, status_name: str, description: str) -> None:
        status_box.caption(
            format_ndif_status(
                job_id,
                status_name,
                description,
                completed_detail="Downloading result...",
            )
        )

    with st.spinner("Generating reply..."):
        _show_phase(model_load_status(model_name))
        model = cached_model(model_name=model_name)
        mark_model_loaded(model_name)
        _show_phase("Submitting to NDIF..." if remote else "Generating locally...")

        def _show_error(exc: Exception) -> None:
            with chat_log:
                st.error(f"Could not generate a reply: {exc}")
                st.info("Try a shorter prompt, reset the chat, or switch personas.")

        reply, error = generate_chat_reply_result(
            model=model,
            messages=messages,
            remote=remote,
            generation=generation,
            on_status=_show_ndif_status if remote else None,
            on_error=_show_error,
            ndif_api_key=session_ndif_api_key(),
        )
        if error is not None:
            status_box.empty()
            if pending_action == "new_user_prompt" and chat_state["messages"]:
                chat_state["messages"].pop()
            return
        if reply is None:
            status_box.empty()
            return

    status_box.empty()
    chat_state["messages"].append({"role": "assistant", "content": reply.text})
    st.rerun()


def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
    """Render the chat tab."""

    st.title("Chat")
    st.caption("Chat with a persona, optionally side-by-side or with token contrast.")

    context_key = chat_session_key(model_name, dataset_source)
    chat_state = get_chat_state(model_name, dataset_source)
    hydrate_chat_state(
        chat_state,
        persisted_persona_key=_LAST_PERSONA_ID_KEY,
        persisted_prompt_key=_LAST_PROMPT_MODE_KEY,
    )

    personas = load_chat_personas(dataset_source)
    if personas is None:
        return

    generation, tools = render_advanced_settings(
        context_key,
        remote,
        last_compare_mode_key=_LAST_COMPARE_MODE_KEY,
        last_probe_enabled_key=_LAST_PROBE_ENABLED_KEY,
        last_token_contrast_key=_LAST_TOKEN_CONTRAST_KEY,
    )
    if tools.compare_mode:
        from tabs.compare_chat import render_compare_mode

        render_compare_mode(
            remote,
            model_name,
            context_key,
            dataset_source,
            personas,
            generation,
            contrast_enabled=tools.token_contrast,
        )
        return

    probe_container = st.container()

    persona_select_key = widget_key(context_key, "persona_select")
    prompt_mode_select_key = widget_key(context_key, "system_prompt_select")
    prompt_key = widget_key(context_key, "custom_system_prompt")
    chat_input_key = widget_key(context_key, "chat_input")
    pending_key = widget_key(context_key, "pending_prompt")
    export_key = widget_key(context_key, "export_chat")
    reset_key = widget_key(context_key, "reset")
    edit_key = widget_key(context_key, "edit_idx")

    selection = render_chat_selection(
        personas,
        chat_state["persona_id"],
        chat_state["prompt_mode"],
        persona_select_key,
        prompt_mode_select_key,
        persisted_persona_key=_LAST_PERSONA_ID_KEY,
        persisted_prompt_key=_LAST_PROMPT_MODE_KEY,
        column_widths=(2, 1),
    )
    selected_persona = selection.persona
    prompt_mode = selection.prompt_mode
    changed_context = selection.changed

    def _reset_active_chat_context() -> None:
        reset_chat_context_state(
            chat_state,
            selected_persona.id,
            prompt_mode,
            chat_input_key,
            prompt_key,
            pending_key,
        )
        st.session_state.pop(edit_key, None)

    active_system_prompt = resolve_system_prompt(
        persona=selected_persona,
        mode=prompt_mode,
    )

    if changed_context:
        had_history = bool(chat_state["messages"])
        _reset_active_chat_context()
        if had_history:
            st.info("Chat history reset because the persona or system prompt changed.")

    chat_log = st.container()

    with chat_log:
        active_system_prompt = render_system_prompt(
            prompt_key,
            prompt_mode,
            active_system_prompt,
            on_save=lambda: reset_chat_context_state(
                chat_state,
                selected_persona.id,
                prompt_mode,
                chat_input_key,
                pending_key,
            ),
        )

    with probe_container:
        if tools.probe_enabled:
            from tabs.probe_ui import render_probe_inspector

            render_probe_inspector(
                context_key=context_key,
                model_name=model_name,
                remote=remote,
                active_system_prompt=active_system_prompt,
                chat_state=chat_state,
                enabled=True,
            )
        else:
            from utils.probe_overlay import clear_overlays

            clear_overlays(chat_state["messages"])

    render_chat_window(
        chat_log=chat_log,
        messages=chat_state["messages"],
        edit_key=edit_key,
        pending_key=pending_key,
    )

    _render_single_chat_footer(
        model_name=model_name,
        dataset_source=dataset_source,
        persona=selected_persona,
        prompt_mode=prompt_mode,
        system_prompt=active_system_prompt,
        chat_state=chat_state,
        generation=generation,
        export_key=export_key,
        reset_key=reset_key,
        on_reset=_reset_active_chat_context,
    )

    user_prompt = st.chat_input("Ask something...", key=chat_input_key)

    if user_prompt:
        chat_state["messages"].append({"role": "user", "content": user_prompt})
        st.session_state[pending_key] = "new_user_prompt"
        st.rerun()

    pending_action = cast(
        PendingChatAction | None,
        st.session_state.pop(pending_key, None),
    )
    if not pending_action:
        return

    _handle_single_chat_generation(
        remote=remote,
        model_name=model_name,
        chat_state=chat_state,
        active_system_prompt=active_system_prompt,
        generation=generation,
        pending_action=pending_action,
        chat_log=chat_log,
    )