persona-ui / tabs /compare_chat.py
Jac-Zac
add session-scoped NDIF execution and improve cold-load UX
ae347c6
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import streamlit as st
from state import ChatState, default_chat_state, reset_chat_context_state
from tabs.chat_shared import (
generate_chat_reply_result,
hydrate_chat_state,
render_chat_selection,
)
from utils.chat import ChatReply, build_chat_messages, resolve_system_prompt
from utils.chat_export import save_chat_export
from utils.contrast import compute_contrast, compute_contrast_pair
from utils.helpers import format_ndif_status, persona_label, session_key, widget_key
from utils.runtime import cached_model, session_ndif_api_key
from .chat_ui import (
GenerationConfig,
render_chat_message,
render_chat_window,
render_system_prompt,
)
if TYPE_CHECKING:
from nnterp import StandardizedTransformer
from persona_data.synth_persona import PersonaData
@dataclass(frozen=True)
class ComparePanel:
side: str
state: ChatState
log: Any
prompt: str | None
persona: PersonaData
prompt_key: str
edit_key: str
pending_key: str
def _get_compare_state(context_key: str, side: str) -> tuple[str, ChatState]:
panel_key = widget_key(context_key, f"cmp_{side}")
if panel_key not in st.session_state:
st.session_state[panel_key] = default_chat_state()
return panel_key, st.session_state[panel_key]
def _reset_compare_panel(panel: ComparePanel) -> None:
reset_chat_context_state(
panel.state,
panel.persona.id,
panel.state["prompt_mode"],
panel.prompt_key,
panel.pending_key,
)
st.session_state.pop(panel.edit_key, None)
def _render_compare_panel(
*,
context_key: str,
side: str,
personas: list[PersonaData],
) -> ComparePanel:
panel_key, state = _get_compare_state(context_key, side)
prompt_key = widget_key(panel_key, "custom_prompt")
edit_key = widget_key(panel_key, "edit_idx")
pending_key = widget_key(panel_key, "pending_regen")
persist_persona_key = session_key("chat", f"last_cmp_{side}_persona")
persist_prompt_key = session_key("chat", f"last_cmp_{side}_prompt")
hydrate_chat_state(
state,
persisted_persona_key=persist_persona_key,
persisted_prompt_key=persist_prompt_key,
)
selection = render_chat_selection(
personas,
state["persona_id"],
state["prompt_mode"],
widget_key(panel_key, "persona"),
widget_key(panel_key, "prompt_mode"),
persisted_persona_key=persist_persona_key,
persisted_prompt_key=persist_prompt_key,
)
selected_persona = selection.persona
prompt_mode = selection.prompt_mode
changed = selection.changed
if changed:
reset_chat_context_state(
state,
selected_persona.id,
prompt_mode,
prompt_key,
pending_key,
)
st.session_state.pop(edit_key, None)
active_system_prompt = resolve_system_prompt(
persona=selected_persona,
mode=prompt_mode,
)
chat_log = st.container()
with chat_log:
active_system_prompt = render_system_prompt(
prompt_key,
prompt_mode,
active_system_prompt,
on_save=lambda: reset_chat_context_state(
state,
selected_persona.id,
prompt_mode,
pending_key,
),
)
return ComparePanel(
side=side,
state=state,
log=chat_log,
prompt=active_system_prompt,
persona=selected_persona,
prompt_key=prompt_key,
edit_key=edit_key,
pending_key=pending_key,
)
def _generate_panels(
*,
model: StandardizedTransformer,
remote: bool,
panels: list[ComparePanel],
generation: GenerationConfig,
spinner_label: str,
) -> list[ChatReply | Exception]:
results: list[ChatReply | Exception] = []
status_box = st.empty()
with st.spinner(spinner_label):
for panel in panels:
panel_label = panel.side.title()
status_box.caption(
f"{panel_label}: {'Submitting to NDIF...' if remote else 'Generating locally...'}"
)
def _show_ndif_status(
job_id: str,
status_name: str,
description: str,
*,
label: str = panel_label,
) -> None:
status_box.caption(
format_ndif_status(
job_id,
status_name,
description,
prefix=label,
completed_detail="Downloading result...",
)
)
reply, error = generate_chat_reply_result(
model=model,
messages=build_chat_messages(panel.prompt, panel.state["messages"]),
remote=remote,
generation=generation,
on_status=_show_ndif_status if remote else None,
ndif_api_key=session_ndif_api_key(),
)
results.append(reply if error is None else error)
status_box.empty()
return results
def _apply_panel_results(
*,
panels: list[ComparePanel],
results: list[ChatReply | Exception],
rollback_user_on_error: bool,
) -> list[ChatReply | None]:
valid_results: list[ChatReply | None] = []
for panel, result in zip(panels, results, strict=True):
if isinstance(result, Exception):
with panel.log:
st.error(f"Generation failed: {result}")
if rollback_user_on_error and panel.state["messages"]:
panel.state["messages"].pop()
valid_results.append(None)
continue
panel.state["messages"].append({"role": "assistant", "content": result.text})
valid_results.append(result)
return valid_results
def _pending_contrast_edits(panels: list[ComparePanel]) -> list[tuple[int, int]]:
return [
(panel_idx, msg_idx)
for panel_idx, panel in enumerate(panels)
for msg_idx, msg in enumerate(panel.state["messages"])
if msg.get("_needs_contrast") and msg.get("role") == "assistant"
]
def _recompute_pending_contrast(
*,
model: StandardizedTransformer,
remote: bool,
panels: list[ComparePanel],
) -> bool:
pending_edits = _pending_contrast_edits(panels)
if not pending_edits:
return False
left, right = panels
label_a = persona_label(left.persona)
label_b = persona_label(right.persona)
with st.spinner("Recomputing token contrast..."):
for panel_idx, msg_idx in pending_edits:
panel = panels[panel_idx]
msg = panel.state["messages"][msg_idx]
if msg_idx >= len(left.state["messages"]) or msg_idx >= len(
right.state["messages"]
):
msg.pop("_needs_contrast", None)
continue
context_a = build_chat_messages(
left.prompt,
left.state["messages"][:msg_idx],
)
context_b = build_chat_messages(
right.prompt,
right.state["messages"][:msg_idx],
)
try:
response_ids = model.tokenizer(
msg["content"],
add_special_tokens=False,
return_tensors="pt",
).input_ids[0]
contrast = compute_contrast(
model=model,
context_a=context_a,
context_b=context_b,
response_ids=response_ids,
label_a=label_a,
label_b=label_b,
remote=remote,
ndif_api_key=session_ndif_api_key(),
)
if contrast is not None:
msg["_contrast"] = contrast
except Exception as exc:
st.warning(f"Token contrast recompute failed: {exc}")
msg.pop("_needs_contrast", None)
return True
def _render_compare_history(
*,
panels: list[ComparePanel],
contrast_enabled: bool,
) -> None:
for panel in panels:
render_chat_window(
chat_log=panel.log,
messages=panel.state["messages"],
edit_key=panel.edit_key,
pending_key=panel.pending_key,
show_contrast=contrast_enabled,
edit_column_ratio=(10, 1),
)
def _render_compare_footer(
*,
context_key: str,
model_name: str,
dataset_source: str,
panels: list[ComparePanel],
generation: GenerationConfig,
) -> None:
# Bumping this nonce after a reset gives the popover a fresh widget key,
# which forces Streamlit to re-mount it closed (popovers don't auto-close on click).
reset_menu_nonce_key = widget_key(context_key, "cmp_reset_menu_nonce")
if reset_menu_nonce_key not in st.session_state:
st.session_state[reset_menu_nonce_key] = 0
footer = st.container()
with footer:
exp_col, rst_col, _spacer = st.columns([1, 1.25, 20], gap="xsmall")
with exp_col:
if st.button(
"",
icon=":material/download:",
key=widget_key(context_key, "cmp_export"),
help="Export both chats",
):
for panel in panels:
save_chat_export(
model_name=model_name,
dataset_source=dataset_source,
persona_id=panel.persona.id,
persona_name=getattr(panel.persona, "name", None),
prompt_mode=panel.state["prompt_mode"],
system_prompt=panel.prompt,
messages=panel.state["messages"],
generation=generation.to_export_dict(),
panel_label=panel.side,
)
st.toast("Exported", icon=":material/check:")
with rst_col:
popover_key = widget_key(
context_key,
"cmp_reset_menu",
str(st.session_state[reset_menu_nonce_key]),
)
with st.popover(
"",
icon=":material/delete_sweep:",
help="Reset chat",
key=popover_key,
):
for panel in panels:
if st.button(
f"Reset {panel.side}",
key=widget_key(context_key, f"cmp_reset_{panel.side}"),
):
_reset_compare_panel(panel)
st.session_state[reset_menu_nonce_key] += 1
st.rerun()
if st.button(
"Reset both",
key=widget_key(context_key, "cmp_reset_both"),
type="primary",
):
for panel in panels:
_reset_compare_panel(panel)
st.session_state[reset_menu_nonce_key] += 1
st.rerun()
def _append_user_prompt(panels: list[ComparePanel], user_prompt: str) -> None:
for panel in panels:
panel.state["messages"].append({"role": "user", "content": user_prompt})
with panel.log:
render_chat_message({"role": "user", "content": user_prompt})
def _compute_new_reply_contrast(
*,
model: StandardizedTransformer,
remote: bool,
panels: list[ComparePanel],
pre_gen_contexts: list[list[dict[str, str]]],
results: list[ChatReply | None],
) -> None:
if len(results) != 2 or any(
result is None or result.generated_ids is None for result in results
):
return
left, right = panels
with st.spinner("Computing token contrast..."):
try:
left_contrast, right_contrast = compute_contrast_pair(
model=model,
context_a=pre_gen_contexts[0],
context_b=pre_gen_contexts[1],
response_ids_a=results[0].generated_ids,
response_ids_b=results[1].generated_ids,
label_a=persona_label(left.persona),
label_b=persona_label(right.persona),
remote=remote,
ndif_api_key=session_ndif_api_key(),
)
if left_contrast is not None:
left.state["messages"][-1]["_contrast"] = left_contrast
if right_contrast is not None:
right.state["messages"][-1]["_contrast"] = right_contrast
except Exception as exc:
st.warning(f"Token contrast failed: {exc}")
def _render_compare_panels(
*,
context_key: str,
personas: list[PersonaData],
) -> list[ComparePanel]:
left_col, right_col = st.columns(2)
with left_col:
left = _render_compare_panel(
context_key=context_key,
side="left",
personas=personas,
)
with right_col:
right = _render_compare_panel(
context_key=context_key,
side="right",
personas=personas,
)
return [left, right]
def render_compare_mode(
remote: bool,
model_name: str,
context_key: str,
dataset_source: str,
personas: list[PersonaData],
generation: GenerationConfig,
*,
contrast_enabled: bool,
) -> None:
"""Render the full side-by-side comparison UI."""
panels = _render_compare_panels(context_key=context_key, personas=personas)
regen_panels = [
panel for panel in panels if st.session_state.pop(panel.pending_key, False)
]
if regen_panels:
results = _generate_panels(
model=cached_model(model_name=model_name),
remote=remote,
panels=regen_panels,
generation=generation,
spinner_label="Regenerating...",
)
_apply_panel_results(
panels=regen_panels,
results=results,
rollback_user_on_error=False,
)
st.rerun()
if contrast_enabled and _recompute_pending_contrast(
model=cached_model(model_name=model_name),
remote=remote,
panels=panels,
):
st.rerun()
_render_compare_history(panels=panels, contrast_enabled=contrast_enabled)
_render_compare_footer(
context_key=context_key,
model_name=model_name,
dataset_source=dataset_source,
panels=panels,
generation=generation,
)
user_prompt = st.chat_input(
"Ask both...",
key=widget_key(context_key, "cmp_input"),
)
if not user_prompt:
return
_append_user_prompt(panels, user_prompt)
pre_gen_contexts = [
build_chat_messages(panel.prompt, panel.state["messages"]) for panel in panels
]
model = cached_model(model_name=model_name)
results = _generate_panels(
model=model,
remote=remote,
panels=panels,
generation=generation,
spinner_label="Generating...",
)
valid_results = _apply_panel_results(
panels=panels,
results=results,
rollback_user_on_error=True,
)
if contrast_enabled:
_compute_new_reply_contrast(
model=model,
remote=remote,
panels=panels,
pre_gen_contexts=pre_gen_contexts,
results=valid_results,
)
st.rerun()