| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import TYPE_CHECKING, Any |
|
|
| import streamlit as st |
|
|
| from state import ChatState, default_chat_state, reset_chat_context_state |
| from tabs.chat_shared import ( |
| generate_chat_reply_result, |
| hydrate_chat_state, |
| render_chat_selection, |
| ) |
| from utils.chat import ChatReply, build_chat_messages, resolve_system_prompt |
| from utils.chat_export import save_chat_export |
| from utils.contrast import compute_contrast, compute_contrast_pair |
| from utils.helpers import format_ndif_status, persona_label, session_key, widget_key |
| from utils.runtime import cached_model, session_ndif_api_key |
|
|
| from .chat_ui import ( |
| GenerationConfig, |
| render_chat_message, |
| render_chat_window, |
| render_system_prompt, |
| ) |
|
|
| if TYPE_CHECKING: |
| from nnterp import StandardizedTransformer |
| from persona_data.synth_persona import PersonaData |
|
|
|
|
| @dataclass(frozen=True) |
| class ComparePanel: |
| side: str |
| state: ChatState |
| log: Any |
| prompt: str | None |
| persona: PersonaData |
| prompt_key: str |
| edit_key: str |
| pending_key: str |
|
|
|
|
| def _get_compare_state(context_key: str, side: str) -> tuple[str, ChatState]: |
| panel_key = widget_key(context_key, f"cmp_{side}") |
| if panel_key not in st.session_state: |
| st.session_state[panel_key] = default_chat_state() |
| return panel_key, st.session_state[panel_key] |
|
|
|
|
| def _reset_compare_panel(panel: ComparePanel) -> None: |
| reset_chat_context_state( |
| panel.state, |
| panel.persona.id, |
| panel.state["prompt_mode"], |
| panel.prompt_key, |
| panel.pending_key, |
| ) |
| st.session_state.pop(panel.edit_key, None) |
|
|
|
|
| def _render_compare_panel( |
| *, |
| context_key: str, |
| side: str, |
| personas: list[PersonaData], |
| ) -> ComparePanel: |
| panel_key, state = _get_compare_state(context_key, side) |
|
|
| prompt_key = widget_key(panel_key, "custom_prompt") |
| edit_key = widget_key(panel_key, "edit_idx") |
| pending_key = widget_key(panel_key, "pending_regen") |
|
|
| persist_persona_key = session_key("chat", f"last_cmp_{side}_persona") |
| persist_prompt_key = session_key("chat", f"last_cmp_{side}_prompt") |
| hydrate_chat_state( |
| state, |
| persisted_persona_key=persist_persona_key, |
| persisted_prompt_key=persist_prompt_key, |
| ) |
|
|
| selection = render_chat_selection( |
| personas, |
| state["persona_id"], |
| state["prompt_mode"], |
| widget_key(panel_key, "persona"), |
| widget_key(panel_key, "prompt_mode"), |
| persisted_persona_key=persist_persona_key, |
| persisted_prompt_key=persist_prompt_key, |
| ) |
| selected_persona = selection.persona |
| prompt_mode = selection.prompt_mode |
| changed = selection.changed |
|
|
| if changed: |
| reset_chat_context_state( |
| state, |
| selected_persona.id, |
| prompt_mode, |
| prompt_key, |
| pending_key, |
| ) |
| st.session_state.pop(edit_key, None) |
|
|
| active_system_prompt = resolve_system_prompt( |
| persona=selected_persona, |
| mode=prompt_mode, |
| ) |
|
|
| chat_log = st.container() |
| with chat_log: |
| active_system_prompt = render_system_prompt( |
| prompt_key, |
| prompt_mode, |
| active_system_prompt, |
| on_save=lambda: reset_chat_context_state( |
| state, |
| selected_persona.id, |
| prompt_mode, |
| pending_key, |
| ), |
| ) |
|
|
| return ComparePanel( |
| side=side, |
| state=state, |
| log=chat_log, |
| prompt=active_system_prompt, |
| persona=selected_persona, |
| prompt_key=prompt_key, |
| edit_key=edit_key, |
| pending_key=pending_key, |
| ) |
|
|
|
|
| def _generate_panels( |
| *, |
| model: StandardizedTransformer, |
| remote: bool, |
| panels: list[ComparePanel], |
| generation: GenerationConfig, |
| spinner_label: str, |
| ) -> list[ChatReply | Exception]: |
| results: list[ChatReply | Exception] = [] |
| status_box = st.empty() |
| with st.spinner(spinner_label): |
| for panel in panels: |
| panel_label = panel.side.title() |
| status_box.caption( |
| f"{panel_label}: {'Submitting to NDIF...' if remote else 'Generating locally...'}" |
| ) |
|
|
| def _show_ndif_status( |
| job_id: str, |
| status_name: str, |
| description: str, |
| *, |
| label: str = panel_label, |
| ) -> None: |
| status_box.caption( |
| format_ndif_status( |
| job_id, |
| status_name, |
| description, |
| prefix=label, |
| completed_detail="Downloading result...", |
| ) |
| ) |
|
|
| reply, error = generate_chat_reply_result( |
| model=model, |
| messages=build_chat_messages(panel.prompt, panel.state["messages"]), |
| remote=remote, |
| generation=generation, |
| on_status=_show_ndif_status if remote else None, |
| ndif_api_key=session_ndif_api_key(), |
| ) |
| results.append(reply if error is None else error) |
| status_box.empty() |
| return results |
|
|
|
|
| def _apply_panel_results( |
| *, |
| panels: list[ComparePanel], |
| results: list[ChatReply | Exception], |
| rollback_user_on_error: bool, |
| ) -> list[ChatReply | None]: |
| valid_results: list[ChatReply | None] = [] |
| for panel, result in zip(panels, results, strict=True): |
| if isinstance(result, Exception): |
| with panel.log: |
| st.error(f"Generation failed: {result}") |
| if rollback_user_on_error and panel.state["messages"]: |
| panel.state["messages"].pop() |
| valid_results.append(None) |
| continue |
|
|
| panel.state["messages"].append({"role": "assistant", "content": result.text}) |
| valid_results.append(result) |
| return valid_results |
|
|
|
|
| def _pending_contrast_edits(panels: list[ComparePanel]) -> list[tuple[int, int]]: |
| return [ |
| (panel_idx, msg_idx) |
| for panel_idx, panel in enumerate(panels) |
| for msg_idx, msg in enumerate(panel.state["messages"]) |
| if msg.get("_needs_contrast") and msg.get("role") == "assistant" |
| ] |
|
|
|
|
| def _recompute_pending_contrast( |
| *, |
| model: StandardizedTransformer, |
| remote: bool, |
| panels: list[ComparePanel], |
| ) -> bool: |
| pending_edits = _pending_contrast_edits(panels) |
| if not pending_edits: |
| return False |
|
|
| left, right = panels |
| label_a = persona_label(left.persona) |
| label_b = persona_label(right.persona) |
| with st.spinner("Recomputing token contrast..."): |
| for panel_idx, msg_idx in pending_edits: |
| panel = panels[panel_idx] |
| msg = panel.state["messages"][msg_idx] |
| if msg_idx >= len(left.state["messages"]) or msg_idx >= len( |
| right.state["messages"] |
| ): |
| msg.pop("_needs_contrast", None) |
| continue |
|
|
| context_a = build_chat_messages( |
| left.prompt, |
| left.state["messages"][:msg_idx], |
| ) |
| context_b = build_chat_messages( |
| right.prompt, |
| right.state["messages"][:msg_idx], |
| ) |
| try: |
| response_ids = model.tokenizer( |
| msg["content"], |
| add_special_tokens=False, |
| return_tensors="pt", |
| ).input_ids[0] |
| contrast = compute_contrast( |
| model=model, |
| context_a=context_a, |
| context_b=context_b, |
| response_ids=response_ids, |
| label_a=label_a, |
| label_b=label_b, |
| remote=remote, |
| ndif_api_key=session_ndif_api_key(), |
| ) |
| if contrast is not None: |
| msg["_contrast"] = contrast |
| except Exception as exc: |
| st.warning(f"Token contrast recompute failed: {exc}") |
| msg.pop("_needs_contrast", None) |
| return True |
|
|
|
|
| def _render_compare_history( |
| *, |
| panels: list[ComparePanel], |
| contrast_enabled: bool, |
| ) -> None: |
| for panel in panels: |
| render_chat_window( |
| chat_log=panel.log, |
| messages=panel.state["messages"], |
| edit_key=panel.edit_key, |
| pending_key=panel.pending_key, |
| show_contrast=contrast_enabled, |
| edit_column_ratio=(10, 1), |
| ) |
|
|
|
|
| def _render_compare_footer( |
| *, |
| context_key: str, |
| model_name: str, |
| dataset_source: str, |
| panels: list[ComparePanel], |
| generation: GenerationConfig, |
| ) -> None: |
| |
| |
| reset_menu_nonce_key = widget_key(context_key, "cmp_reset_menu_nonce") |
| if reset_menu_nonce_key not in st.session_state: |
| st.session_state[reset_menu_nonce_key] = 0 |
|
|
| footer = st.container() |
| with footer: |
| exp_col, rst_col, _spacer = st.columns([1, 1.25, 20], gap="xsmall") |
| with exp_col: |
| if st.button( |
| "", |
| icon=":material/download:", |
| key=widget_key(context_key, "cmp_export"), |
| help="Export both chats", |
| ): |
| for panel in panels: |
| save_chat_export( |
| model_name=model_name, |
| dataset_source=dataset_source, |
| persona_id=panel.persona.id, |
| persona_name=getattr(panel.persona, "name", None), |
| prompt_mode=panel.state["prompt_mode"], |
| system_prompt=panel.prompt, |
| messages=panel.state["messages"], |
| generation=generation.to_export_dict(), |
| panel_label=panel.side, |
| ) |
| st.toast("Exported", icon=":material/check:") |
| with rst_col: |
| popover_key = widget_key( |
| context_key, |
| "cmp_reset_menu", |
| str(st.session_state[reset_menu_nonce_key]), |
| ) |
| with st.popover( |
| "", |
| icon=":material/delete_sweep:", |
| help="Reset chat", |
| key=popover_key, |
| ): |
| for panel in panels: |
| if st.button( |
| f"Reset {panel.side}", |
| key=widget_key(context_key, f"cmp_reset_{panel.side}"), |
| ): |
| _reset_compare_panel(panel) |
| st.session_state[reset_menu_nonce_key] += 1 |
| st.rerun() |
| if st.button( |
| "Reset both", |
| key=widget_key(context_key, "cmp_reset_both"), |
| type="primary", |
| ): |
| for panel in panels: |
| _reset_compare_panel(panel) |
| st.session_state[reset_menu_nonce_key] += 1 |
| st.rerun() |
|
|
|
|
| def _append_user_prompt(panels: list[ComparePanel], user_prompt: str) -> None: |
| for panel in panels: |
| panel.state["messages"].append({"role": "user", "content": user_prompt}) |
| with panel.log: |
| render_chat_message({"role": "user", "content": user_prompt}) |
|
|
|
|
| def _compute_new_reply_contrast( |
| *, |
| model: StandardizedTransformer, |
| remote: bool, |
| panels: list[ComparePanel], |
| pre_gen_contexts: list[list[dict[str, str]]], |
| results: list[ChatReply | None], |
| ) -> None: |
| if len(results) != 2 or any( |
| result is None or result.generated_ids is None for result in results |
| ): |
| return |
|
|
| left, right = panels |
| with st.spinner("Computing token contrast..."): |
| try: |
| left_contrast, right_contrast = compute_contrast_pair( |
| model=model, |
| context_a=pre_gen_contexts[0], |
| context_b=pre_gen_contexts[1], |
| response_ids_a=results[0].generated_ids, |
| response_ids_b=results[1].generated_ids, |
| label_a=persona_label(left.persona), |
| label_b=persona_label(right.persona), |
| remote=remote, |
| ndif_api_key=session_ndif_api_key(), |
| ) |
| if left_contrast is not None: |
| left.state["messages"][-1]["_contrast"] = left_contrast |
| if right_contrast is not None: |
| right.state["messages"][-1]["_contrast"] = right_contrast |
| except Exception as exc: |
| st.warning(f"Token contrast failed: {exc}") |
|
|
|
|
| def _render_compare_panels( |
| *, |
| context_key: str, |
| personas: list[PersonaData], |
| ) -> list[ComparePanel]: |
| left_col, right_col = st.columns(2) |
| with left_col: |
| left = _render_compare_panel( |
| context_key=context_key, |
| side="left", |
| personas=personas, |
| ) |
| with right_col: |
| right = _render_compare_panel( |
| context_key=context_key, |
| side="right", |
| personas=personas, |
| ) |
| return [left, right] |
|
|
|
|
| def render_compare_mode( |
| remote: bool, |
| model_name: str, |
| context_key: str, |
| dataset_source: str, |
| personas: list[PersonaData], |
| generation: GenerationConfig, |
| *, |
| contrast_enabled: bool, |
| ) -> None: |
| """Render the full side-by-side comparison UI.""" |
|
|
| panels = _render_compare_panels(context_key=context_key, personas=personas) |
|
|
| regen_panels = [ |
| panel for panel in panels if st.session_state.pop(panel.pending_key, False) |
| ] |
| if regen_panels: |
| results = _generate_panels( |
| model=cached_model(model_name=model_name), |
| remote=remote, |
| panels=regen_panels, |
| generation=generation, |
| spinner_label="Regenerating...", |
| ) |
| _apply_panel_results( |
| panels=regen_panels, |
| results=results, |
| rollback_user_on_error=False, |
| ) |
| st.rerun() |
|
|
| if contrast_enabled and _recompute_pending_contrast( |
| model=cached_model(model_name=model_name), |
| remote=remote, |
| panels=panels, |
| ): |
| st.rerun() |
|
|
| _render_compare_history(panels=panels, contrast_enabled=contrast_enabled) |
| _render_compare_footer( |
| context_key=context_key, |
| model_name=model_name, |
| dataset_source=dataset_source, |
| panels=panels, |
| generation=generation, |
| ) |
|
|
| user_prompt = st.chat_input( |
| "Ask both...", |
| key=widget_key(context_key, "cmp_input"), |
| ) |
| if not user_prompt: |
| return |
|
|
| _append_user_prompt(panels, user_prompt) |
| pre_gen_contexts = [ |
| build_chat_messages(panel.prompt, panel.state["messages"]) for panel in panels |
| ] |
| model = cached_model(model_name=model_name) |
| results = _generate_panels( |
| model=model, |
| remote=remote, |
| panels=panels, |
| generation=generation, |
| spinner_label="Generating...", |
| ) |
| valid_results = _apply_panel_results( |
| panels=panels, |
| results=results, |
| rollback_user_on_error=True, |
| ) |
| if contrast_enabled: |
| _compute_new_reply_contrast( |
| model=model, |
| remote=remote, |
| panels=panels, |
| pre_gen_contexts=pre_gen_contexts, |
| results=valid_results, |
| ) |
|
|
| st.rerun() |
|
|