persona-ui / tabs /chat_ui.py
Jac-Zac
Big refactoring
b279884
from __future__ import annotations
from collections.abc import Callable
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any
import streamlit as st
from utils.helpers import (
CHAT_PROMPT_MODE_LABEL_TO_KEY,
CHAT_PROMPT_MODE_LABELS,
VARIANT_LABELS,
persona_label,
widget_key,
)
if TYPE_CHECKING:
from persona_data.synth_persona import PersonaData
from utils.contrast import TokenContrast
GENERATION_DEFAULTS = {
"max_new_tokens": 256,
"temperature": 1.0,
"top_p": 1.0,
"top_k": 50,
"repetition_penalty": 1.0,
}
_LAST_GEN_PREFIX = "chat:last_gen:"
def _last_generation_key(name: str) -> str:
return f"{_LAST_GEN_PREFIX}{name}"
def _persisted_key(context_key: str, name: str, default: object) -> str:
"""Per-context widget key, seeded from the last cross-context value."""
key = widget_key(context_key, name)
if key not in st.session_state:
st.session_state[key] = st.session_state.get(
_last_generation_key(name),
default,
)
return key
@dataclass(frozen=True)
class GenerationConfig:
max_new_tokens: int
do_sample: bool
temperature: float
top_p: float
top_k: int
repetition_penalty: float
seed: int | None
def to_generate_kwargs(self) -> dict[str, object]:
return asdict(self)
def to_export_dict(self) -> dict[str, object]:
return {
"max_new_tokens": self.max_new_tokens,
"use_sampling": self.do_sample,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"repetition_penalty": self.repetition_penalty,
"seed": self.seed,
}
@dataclass(frozen=True)
class ChatTools:
probe_enabled: bool
compare_mode: bool
token_contrast: bool
@st.dialog("Edit", width="medium")
def _open_edit_dialog(
*,
msg_index: int,
messages: list[dict[str, str]],
pending_key: str,
) -> None:
message = messages[msg_index]
role = message["role"]
n_after = len(messages) - msg_index - 1
suffix = (
f" - {n_after} subsequent {'message' if n_after == 1 else 'messages'} will be cleared"
if n_after > 0
else ""
)
st.caption(f"**{role}**{suffix}")
new_content = st.text_area(
"Content",
value=message["content"],
height=320,
label_visibility="collapsed",
)
save_col, cancel_col = st.columns(2)
with save_col:
if st.button("Save", type="primary", width="stretch"):
messages[msg_index]["content"] = new_content
messages[msg_index].pop("_contrast", None)
if role == "assistant":
messages[msg_index]["_needs_contrast"] = True
del messages[msg_index + 1 :]
if role == "user":
st.session_state[pending_key] = "regenerate_after_edit"
st.rerun()
with cancel_col:
if st.button("Cancel", width="stretch"):
st.rerun()
@st.dialog("Edit system prompt", width="large")
def _open_system_prompt_dialog(
*,
prompt_key: str,
current_value: str,
on_save: Callable[[], None] | None = None,
) -> None:
new_value = st.text_area(
"System prompt",
value=current_value,
height=320,
label_visibility="collapsed",
)
save_col, cancel_col = st.columns(2)
with save_col:
if st.button("Save", type="primary", width="stretch"):
st.session_state[prompt_key] = new_value
if on_save is not None:
on_save()
st.rerun()
with cancel_col:
if st.button("Cancel", width="stretch"):
st.rerun()
def render_advanced_settings(
context_key: str,
remote: bool,
*,
last_compare_mode_key: str,
last_probe_enabled_key: str = "",
last_token_contrast_key: str = "",
) -> tuple[GenerationConfig, ChatTools]:
"""Render the two advanced expanders: chat tools + generation."""
compare_key = widget_key(context_key, "compare_mode")
if compare_key not in st.session_state:
st.session_state[compare_key] = st.session_state.get(
last_compare_mode_key, False
)
probe_key = widget_key(context_key, "probe_enabled")
if probe_key not in st.session_state:
st.session_state[probe_key] = st.session_state.get(
last_probe_enabled_key, False
)
token_contrast_key = widget_key(context_key, "token_contrast")
if token_contrast_key not in st.session_state:
st.session_state[token_contrast_key] = st.session_state.get(
last_token_contrast_key, False
)
with st.expander("Chat tools", expanded=False):
tools_col1, tools_col2, tools_col3 = st.columns(3)
with tools_col1:
probe_enabled = st.toggle(
"Probe",
key=probe_key,
help="Color each assistant token by a loaded `.pt` probe's prediction.",
)
with tools_col2:
compare_mode = st.toggle(
"Compare mode",
key=compare_key,
help="Side-by-side: send one message to two independent persona/prompt configurations.",
)
with tools_col3:
token_contrast = st.toggle(
"Token contrast",
key=token_contrast_key,
disabled=not compare_mode,
help=(
"Color each generated token by how characteristic it is of each persona. "
"Red = more likely under the left persona, blue = more likely under the "
"right. Requires up to four extra scoring passes after each turn. "
"Available only in Compare mode."
),
)
st.session_state[last_compare_mode_key] = compare_mode
if last_probe_enabled_key:
st.session_state[last_probe_enabled_key] = probe_enabled
if last_token_contrast_key:
st.session_state[last_token_contrast_key] = token_contrast
with st.expander("Generation", expanded=False):
generation = _render_generation_fragment(context_key, remote)
tools = ChatTools(
probe_enabled=probe_enabled,
compare_mode=compare_mode,
token_contrast=token_contrast and compare_mode,
)
return generation, tools
@st.fragment
def _render_generation_fragment(context_key: str, remote: bool) -> GenerationConfig:
"""Render generation sliders inside a fragment so tweaks don't full-rerun."""
config_col1, config_col2 = st.columns([2, 1])
with config_col1:
max_new_tokens = st.slider(
"Max new tokens",
min_value=16,
max_value=512,
step=16,
key=_persisted_key(
context_key, "max_new_tokens", GENERATION_DEFAULTS["max_new_tokens"]
),
)
with config_col2:
repetition_penalty = st.slider(
"Repetition penalty",
min_value=0.5,
max_value=2.0,
step=0.05,
key=_persisted_key(
context_key,
"repetition_penalty",
GENERATION_DEFAULTS["repetition_penalty"],
),
)
use_sampling = st.checkbox(
"Random sampling",
key=_persisted_key(context_key, "use_sampling", False),
)
sampling_disabled = not use_sampling
sampling_col1, sampling_col2, sampling_col3 = st.columns(3)
with sampling_col1:
temperature = st.slider(
"Temperature",
min_value=0.01,
max_value=2.0,
step=0.01,
disabled=sampling_disabled,
key=_persisted_key(
context_key, "temperature", GENERATION_DEFAULTS["temperature"]
),
)
with sampling_col2:
top_p = st.slider(
"Top-p",
min_value=0.01,
max_value=1.0,
step=0.01,
disabled=sampling_disabled,
key=_persisted_key(context_key, "top_p", GENERATION_DEFAULTS["top_p"]),
)
with sampling_col3:
top_k = st.slider(
"Top-k (0 = off)",
min_value=0,
max_value=100,
step=1,
disabled=sampling_disabled,
key=_persisted_key(context_key, "top_k", GENERATION_DEFAULTS["top_k"]),
)
seed_disabled = sampling_disabled or remote
seed_enabled = st.checkbox(
"Fix seed",
disabled=seed_disabled,
key=_persisted_key(context_key, "seed_enabled", False),
)
seed = None
if seed_enabled:
seed = int(
st.number_input(
"Seed",
min_value=0,
max_value=2_147_483_647,
step=1,
disabled=seed_disabled,
key=_persisted_key(context_key, "seed", 0),
)
)
if remote:
st.caption("Seed is local-only and disabled for remote runs.")
for name, value in (
("max_new_tokens", max_new_tokens),
("repetition_penalty", repetition_penalty),
("use_sampling", use_sampling),
("temperature", temperature),
("top_p", top_p),
("top_k", top_k),
("seed_enabled", seed_enabled),
):
st.session_state[_last_generation_key(name)] = value
if seed is not None:
st.session_state[_last_generation_key("seed")] = seed
do_sample = bool(use_sampling)
return GenerationConfig(
max_new_tokens=int(max_new_tokens),
do_sample=do_sample,
temperature=float(temperature),
top_p=float(top_p),
top_k=int(top_k),
repetition_penalty=float(repetition_penalty),
seed=seed if do_sample and seed is not None and not remote else None,
)
def render_chat_message(
message: dict[str, str],
show_contrast: bool = False,
) -> None:
if not message.get("content"):
return
contrast: TokenContrast | None = message.get("_contrast") if show_contrast else None
overlay = message.get("_probe_overlay")
with st.chat_message(message["role"]):
if contrast is not None:
from utils.contrast import render_contrast_html
st.html(render_contrast_html(contrast))
elif overlay is not None:
from utils.probe_overlay import render_probe_html
st.html(render_probe_html(overlay))
else:
st.markdown(message["content"])
def render_chat_window(
*,
chat_log: Any,
messages: list[dict[str, str]],
edit_key: str,
pending_key: str,
show_contrast: bool = False,
edit_column_ratio: tuple[int, int] = (25, 1),
) -> None:
with chat_log:
for i, message in enumerate(messages):
if not message.get("content"):
continue
msg_col, edit_col = st.columns(
list(edit_column_ratio), gap="xsmall", vertical_alignment="center"
)
with msg_col:
render_chat_message(message, show_contrast=show_contrast)
with edit_col:
if st.button(
"",
icon=":material/edit:",
key=f"{edit_key}_edit_{i}",
help="Edit",
):
_open_edit_dialog(
msg_index=i,
messages=messages,
pending_key=pending_key,
)
def _assistant_first(personas: list[PersonaData]) -> list[PersonaData]:
def is_assistant(persona: PersonaData) -> bool:
persona_id = str(getattr(persona, "id", "")).strip().lower()
persona_name = str(getattr(persona, "name", "")).strip().lower()
return persona_id == "assistant" or persona_name == "assistant"
return sorted(personas, key=lambda persona: 0 if is_assistant(persona) else 1)
def render_system_prompt(
prompt_key: str,
prompt_mode: str,
active_system_prompt: str | None,
*,
on_save: Callable[[], None] | None = None,
) -> str | None:
if prompt_key not in st.session_state:
st.session_state[prompt_key] = active_system_prompt or ""
current = st.session_state.get(prompt_key) or ""
with st.expander("System prompt"):
st.markdown(current or "*empty*")
if prompt_mode != "empty" and st.button(
"Edit", icon=":material/edit:", key=f"{prompt_key}_edit"
):
_open_system_prompt_dialog(
prompt_key=prompt_key,
current_value=current,
on_save=on_save,
)
return st.session_state.get(prompt_key) or None
def render_persona_prompt_controls(
personas: list[PersonaData],
current_persona_id: str | None,
current_prompt_mode: str,
persona_key: str,
prompt_key: str,
column_widths: tuple[int, int] = (3, 2),
) -> tuple[PersonaData, str, bool]:
"""Render persona and prompt selectors, returning the selected values."""
p_col, m_col = st.columns(list(column_widths))
with p_col:
persona_options = _assistant_first(personas)
selected_index = next(
(i for i, p in enumerate(persona_options) if p.id == current_persona_id),
0,
)
selected_persona = st.selectbox(
"Persona",
options=persona_options,
index=selected_index,
format_func=persona_label,
key=persona_key,
)
with m_col:
current_label = VARIANT_LABELS[current_prompt_mode]
prompt_mode_label = st.selectbox(
"Prompt",
options=CHAT_PROMPT_MODE_LABELS,
index=CHAT_PROMPT_MODE_LABELS.index(current_label),
key=prompt_key,
)
prompt_mode = CHAT_PROMPT_MODE_LABEL_TO_KEY[prompt_mode_label]
changed = (
current_persona_id != selected_persona.id or current_prompt_mode != prompt_mode
)
return selected_persona, prompt_mode, changed