File size: 15,643 Bytes
c607869
 
9ba2da4
c607869
6b81e1e
a9950fb
 
9ba2da4
db3d901
 
 
 
77c2d62
db3d901
93d5dc5
a9950fb
b279884
ae347c6
a9950fb
77c2d62
 
93d5dc5
 
 
a9950fb
 
c607869
 
 
 
a9950fb
9ba2da4
 
6b81e1e
9ba2da4
6b81e1e
 
 
 
 
 
 
 
9ba2da4
 
 
 
 
 
 
6b81e1e
 
 
 
 
 
 
 
 
a9950fb
 
9ba2da4
a9950fb
 
9ba2da4
a9950fb
9ba2da4
 
 
 
 
 
 
db3d901
 
 
 
 
 
 
9ba2da4
db3d901
9ba2da4
 
 
 
 
db3d901
 
a9950fb
db3d901
 
 
9ba2da4
 
 
 
 
 
 
 
a9950fb
9ba2da4
93d5dc5
9ba2da4
 
 
 
a9950fb
9ba2da4
 
 
 
 
 
12cdb17
 
 
 
 
 
a9950fb
 
9ba2da4
 
 
 
 
 
 
 
 
 
a9950fb
 
9ba2da4
 
 
 
 
 
 
 
 
b279884
9ba2da4
 
b279884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db3d901
 
 
 
 
b279884
ae347c6
db3d901
 
b279884
9ba2da4
a9950fb
9ba2da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9950fb
9ba2da4
 
 
 
a9950fb
9ba2da4
 
 
a9950fb
9ba2da4
 
 
 
 
 
 
 
 
 
 
 
 
 
ae347c6
9ba2da4
 
 
 
 
 
 
a9950fb
 
9ba2da4
 
 
 
 
6b81e1e
93d5dc5
6b81e1e
 
 
 
a9950fb
 
 
 
9ba2da4
 
 
 
 
 
 
 
 
 
d8ae160
93d5dc5
 
 
9ba2da4
 
a9950fb
ae347c6
a9950fb
 
 
 
 
 
 
6b81e1e
93d5dc5
a9950fb
 
6b81e1e
 
 
 
 
9ba2da4
6b81e1e
a9950fb
 
 
93d5dc5
 
 
 
 
a9950fb
 
 
 
93d5dc5
a9950fb
6b81e1e
 
 
 
 
 
 
 
a9950fb
 
 
 
 
6b81e1e
 
93d5dc5
a9950fb
 
 
9ba2da4
6b81e1e
 
 
93d5dc5
a9950fb
 
9ba2da4
 
 
 
 
 
 
 
 
 
 
 
a9950fb
9ba2da4
 
 
 
 
 
 
 
 
 
 
 
ae347c6
9ba2da4
 
 
 
 
 
 
a9950fb
 
9ba2da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9950fb
9ba2da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9950fb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

import streamlit as st

from state import ChatState, default_chat_state, reset_chat_context_state
from tabs.chat_shared import (
    generate_chat_reply_result,
    hydrate_chat_state,
    render_chat_selection,
)
from utils.chat import ChatReply, build_chat_messages, resolve_system_prompt
from utils.chat_export import save_chat_export
from utils.contrast import compute_contrast, compute_contrast_pair
from utils.helpers import format_ndif_status, persona_label, session_key, widget_key
from utils.runtime import cached_model, session_ndif_api_key

from .chat_ui import (
    GenerationConfig,
    render_chat_message,
    render_chat_window,
    render_system_prompt,
)

if TYPE_CHECKING:
    from nnterp import StandardizedTransformer
    from persona_data.synth_persona import PersonaData


@dataclass(frozen=True)
class ComparePanel:
    side: str
    state: ChatState
    log: Any
    prompt: str | None
    persona: PersonaData
    prompt_key: str
    edit_key: str
    pending_key: str


def _get_compare_state(context_key: str, side: str) -> tuple[str, ChatState]:
    panel_key = widget_key(context_key, f"cmp_{side}")
    if panel_key not in st.session_state:
        st.session_state[panel_key] = default_chat_state()
    return panel_key, st.session_state[panel_key]


def _reset_compare_panel(panel: ComparePanel) -> None:
    reset_chat_context_state(
        panel.state,
        panel.persona.id,
        panel.state["prompt_mode"],
        panel.prompt_key,
        panel.pending_key,
    )
    st.session_state.pop(panel.edit_key, None)


def _render_compare_panel(
    *,
    context_key: str,
    side: str,
    personas: list[PersonaData],
) -> ComparePanel:
    panel_key, state = _get_compare_state(context_key, side)

    prompt_key = widget_key(panel_key, "custom_prompt")
    edit_key = widget_key(panel_key, "edit_idx")
    pending_key = widget_key(panel_key, "pending_regen")

    persist_persona_key = session_key("chat", f"last_cmp_{side}_persona")
    persist_prompt_key = session_key("chat", f"last_cmp_{side}_prompt")
    hydrate_chat_state(
        state,
        persisted_persona_key=persist_persona_key,
        persisted_prompt_key=persist_prompt_key,
    )

    selection = render_chat_selection(
        personas,
        state["persona_id"],
        state["prompt_mode"],
        widget_key(panel_key, "persona"),
        widget_key(panel_key, "prompt_mode"),
        persisted_persona_key=persist_persona_key,
        persisted_prompt_key=persist_prompt_key,
    )
    selected_persona = selection.persona
    prompt_mode = selection.prompt_mode
    changed = selection.changed

    if changed:
        reset_chat_context_state(
            state,
            selected_persona.id,
            prompt_mode,
            prompt_key,
            pending_key,
        )
        st.session_state.pop(edit_key, None)

    active_system_prompt = resolve_system_prompt(
        persona=selected_persona,
        mode=prompt_mode,
    )

    chat_log = st.container()
    with chat_log:
        active_system_prompt = render_system_prompt(
            prompt_key,
            prompt_mode,
            active_system_prompt,
            on_save=lambda: reset_chat_context_state(
                state,
                selected_persona.id,
                prompt_mode,
                pending_key,
            ),
        )

    return ComparePanel(
        side=side,
        state=state,
        log=chat_log,
        prompt=active_system_prompt,
        persona=selected_persona,
        prompt_key=prompt_key,
        edit_key=edit_key,
        pending_key=pending_key,
    )


def _generate_panels(
    *,
    model: StandardizedTransformer,
    remote: bool,
    panels: list[ComparePanel],
    generation: GenerationConfig,
    spinner_label: str,
) -> list[ChatReply | Exception]:
    results: list[ChatReply | Exception] = []
    status_box = st.empty()
    with st.spinner(spinner_label):
        for panel in panels:
            panel_label = panel.side.title()
            status_box.caption(
                f"{panel_label}: {'Submitting to NDIF...' if remote else 'Generating locally...'}"
            )

            def _show_ndif_status(
                job_id: str,
                status_name: str,
                description: str,
                *,
                label: str = panel_label,
            ) -> None:
                status_box.caption(
                    format_ndif_status(
                        job_id,
                        status_name,
                        description,
                        prefix=label,
                        completed_detail="Downloading result...",
                    )
                )

            reply, error = generate_chat_reply_result(
                model=model,
                messages=build_chat_messages(panel.prompt, panel.state["messages"]),
                remote=remote,
                generation=generation,
                on_status=_show_ndif_status if remote else None,
                ndif_api_key=session_ndif_api_key(),
            )
            results.append(reply if error is None else error)
    status_box.empty()
    return results


def _apply_panel_results(
    *,
    panels: list[ComparePanel],
    results: list[ChatReply | Exception],
    rollback_user_on_error: bool,
) -> list[ChatReply | None]:
    valid_results: list[ChatReply | None] = []
    for panel, result in zip(panels, results, strict=True):
        if isinstance(result, Exception):
            with panel.log:
                st.error(f"Generation failed: {result}")
            if rollback_user_on_error and panel.state["messages"]:
                panel.state["messages"].pop()
            valid_results.append(None)
            continue

        panel.state["messages"].append({"role": "assistant", "content": result.text})
        valid_results.append(result)
    return valid_results


def _pending_contrast_edits(panels: list[ComparePanel]) -> list[tuple[int, int]]:
    return [
        (panel_idx, msg_idx)
        for panel_idx, panel in enumerate(panels)
        for msg_idx, msg in enumerate(panel.state["messages"])
        if msg.get("_needs_contrast") and msg.get("role") == "assistant"
    ]


def _recompute_pending_contrast(
    *,
    model: StandardizedTransformer,
    remote: bool,
    panels: list[ComparePanel],
) -> bool:
    pending_edits = _pending_contrast_edits(panels)
    if not pending_edits:
        return False

    left, right = panels
    label_a = persona_label(left.persona)
    label_b = persona_label(right.persona)
    with st.spinner("Recomputing token contrast..."):
        for panel_idx, msg_idx in pending_edits:
            panel = panels[panel_idx]
            msg = panel.state["messages"][msg_idx]
            if msg_idx >= len(left.state["messages"]) or msg_idx >= len(
                right.state["messages"]
            ):
                msg.pop("_needs_contrast", None)
                continue

            context_a = build_chat_messages(
                left.prompt,
                left.state["messages"][:msg_idx],
            )
            context_b = build_chat_messages(
                right.prompt,
                right.state["messages"][:msg_idx],
            )
            try:
                response_ids = model.tokenizer(
                    msg["content"],
                    add_special_tokens=False,
                    return_tensors="pt",
                ).input_ids[0]
                contrast = compute_contrast(
                    model=model,
                    context_a=context_a,
                    context_b=context_b,
                    response_ids=response_ids,
                    label_a=label_a,
                    label_b=label_b,
                    remote=remote,
                    ndif_api_key=session_ndif_api_key(),
                )
                if contrast is not None:
                    msg["_contrast"] = contrast
            except Exception as exc:
                st.warning(f"Token contrast recompute failed: {exc}")
            msg.pop("_needs_contrast", None)
    return True


def _render_compare_history(
    *,
    panels: list[ComparePanel],
    contrast_enabled: bool,
) -> None:
    for panel in panels:
        render_chat_window(
            chat_log=panel.log,
            messages=panel.state["messages"],
            edit_key=panel.edit_key,
            pending_key=panel.pending_key,
            show_contrast=contrast_enabled,
            edit_column_ratio=(10, 1),
        )


def _render_compare_footer(
    *,
    context_key: str,
    model_name: str,
    dataset_source: str,
    panels: list[ComparePanel],
    generation: GenerationConfig,
) -> None:
    # Bumping this nonce after a reset gives the popover a fresh widget key,
    # which forces Streamlit to re-mount it closed (popovers don't auto-close on click).
    reset_menu_nonce_key = widget_key(context_key, "cmp_reset_menu_nonce")
    if reset_menu_nonce_key not in st.session_state:
        st.session_state[reset_menu_nonce_key] = 0

    footer = st.container()
    with footer:
        exp_col, rst_col, _spacer = st.columns([1, 1.25, 20], gap="xsmall")
        with exp_col:
            if st.button(
                "",
                icon=":material/download:",
                key=widget_key(context_key, "cmp_export"),
                help="Export both chats",
            ):
                for panel in panels:
                    save_chat_export(
                        model_name=model_name,
                        dataset_source=dataset_source,
                        persona_id=panel.persona.id,
                        persona_name=getattr(panel.persona, "name", None),
                        prompt_mode=panel.state["prompt_mode"],
                        system_prompt=panel.prompt,
                        messages=panel.state["messages"],
                        generation=generation.to_export_dict(),
                        panel_label=panel.side,
                    )
                st.toast("Exported", icon=":material/check:")
        with rst_col:
            popover_key = widget_key(
                context_key,
                "cmp_reset_menu",
                str(st.session_state[reset_menu_nonce_key]),
            )
            with st.popover(
                "",
                icon=":material/delete_sweep:",
                help="Reset chat",
                key=popover_key,
            ):
                for panel in panels:
                    if st.button(
                        f"Reset {panel.side}",
                        key=widget_key(context_key, f"cmp_reset_{panel.side}"),
                    ):
                        _reset_compare_panel(panel)
                        st.session_state[reset_menu_nonce_key] += 1
                        st.rerun()
                if st.button(
                    "Reset both",
                    key=widget_key(context_key, "cmp_reset_both"),
                    type="primary",
                ):
                    for panel in panels:
                        _reset_compare_panel(panel)
                    st.session_state[reset_menu_nonce_key] += 1
                    st.rerun()


def _append_user_prompt(panels: list[ComparePanel], user_prompt: str) -> None:
    for panel in panels:
        panel.state["messages"].append({"role": "user", "content": user_prompt})
        with panel.log:
            render_chat_message({"role": "user", "content": user_prompt})


def _compute_new_reply_contrast(
    *,
    model: StandardizedTransformer,
    remote: bool,
    panels: list[ComparePanel],
    pre_gen_contexts: list[list[dict[str, str]]],
    results: list[ChatReply | None],
) -> None:
    if len(results) != 2 or any(
        result is None or result.generated_ids is None for result in results
    ):
        return

    left, right = panels
    with st.spinner("Computing token contrast..."):
        try:
            left_contrast, right_contrast = compute_contrast_pair(
                model=model,
                context_a=pre_gen_contexts[0],
                context_b=pre_gen_contexts[1],
                response_ids_a=results[0].generated_ids,
                response_ids_b=results[1].generated_ids,
                label_a=persona_label(left.persona),
                label_b=persona_label(right.persona),
                remote=remote,
                ndif_api_key=session_ndif_api_key(),
            )
            if left_contrast is not None:
                left.state["messages"][-1]["_contrast"] = left_contrast
            if right_contrast is not None:
                right.state["messages"][-1]["_contrast"] = right_contrast
        except Exception as exc:
            st.warning(f"Token contrast failed: {exc}")


def _render_compare_panels(
    *,
    context_key: str,
    personas: list[PersonaData],
) -> list[ComparePanel]:
    left_col, right_col = st.columns(2)
    with left_col:
        left = _render_compare_panel(
            context_key=context_key,
            side="left",
            personas=personas,
        )
    with right_col:
        right = _render_compare_panel(
            context_key=context_key,
            side="right",
            personas=personas,
        )
    return [left, right]


def render_compare_mode(
    remote: bool,
    model_name: str,
    context_key: str,
    dataset_source: str,
    personas: list[PersonaData],
    generation: GenerationConfig,
    *,
    contrast_enabled: bool,
) -> None:
    """Render the full side-by-side comparison UI."""

    panels = _render_compare_panels(context_key=context_key, personas=personas)

    regen_panels = [
        panel for panel in panels if st.session_state.pop(panel.pending_key, False)
    ]
    if regen_panels:
        results = _generate_panels(
            model=cached_model(model_name=model_name),
            remote=remote,
            panels=regen_panels,
            generation=generation,
            spinner_label="Regenerating...",
        )
        _apply_panel_results(
            panels=regen_panels,
            results=results,
            rollback_user_on_error=False,
        )
        st.rerun()

    if contrast_enabled and _recompute_pending_contrast(
        model=cached_model(model_name=model_name),
        remote=remote,
        panels=panels,
    ):
        st.rerun()

    _render_compare_history(panels=panels, contrast_enabled=contrast_enabled)
    _render_compare_footer(
        context_key=context_key,
        model_name=model_name,
        dataset_source=dataset_source,
        panels=panels,
        generation=generation,
    )

    user_prompt = st.chat_input(
        "Ask both...",
        key=widget_key(context_key, "cmp_input"),
    )
    if not user_prompt:
        return

    _append_user_prompt(panels, user_prompt)
    pre_gen_contexts = [
        build_chat_messages(panel.prompt, panel.state["messages"]) for panel in panels
    ]
    model = cached_model(model_name=model_name)
    results = _generate_panels(
        model=model,
        remote=remote,
        panels=panels,
        generation=generation,
        spinner_label="Generating...",
    )
    valid_results = _apply_panel_results(
        panels=panels,
        results=results,
        rollback_user_on_error=True,
    )
    if contrast_enabled:
        _compute_new_reply_contrast(
            model=model,
            remote=remote,
            panels=panels,
            pre_gen_contexts=pre_gen_contexts,
            results=valid_results,
        )

    st.rerun()