| from __future__ import annotations |
|
|
| from collections.abc import Callable |
| from dataclasses import asdict, dataclass |
| from typing import TYPE_CHECKING, Any |
|
|
| import streamlit as st |
|
|
| from utils.helpers import ( |
| CHAT_PROMPT_MODE_LABEL_TO_KEY, |
| CHAT_PROMPT_MODE_LABELS, |
| VARIANT_LABELS, |
| persona_label, |
| widget_key, |
| ) |
|
|
| if TYPE_CHECKING: |
| from persona_data.synth_persona import PersonaData |
|
|
| from utils.contrast import TokenContrast |
|
|
| GENERATION_DEFAULTS = { |
| "max_new_tokens": 256, |
| "temperature": 1.0, |
| "top_p": 1.0, |
| "top_k": 50, |
| "repetition_penalty": 1.0, |
| } |
|
|
| _LAST_GEN_PREFIX = "chat:last_gen:" |
|
|
|
|
| def _last_generation_key(name: str) -> str: |
| return f"{_LAST_GEN_PREFIX}{name}" |
|
|
|
|
| def _persisted_key(context_key: str, name: str, default: object) -> str: |
| """Per-context widget key, seeded from the last cross-context value.""" |
| key = widget_key(context_key, name) |
| if key not in st.session_state: |
| st.session_state[key] = st.session_state.get( |
| _last_generation_key(name), |
| default, |
| ) |
| return key |
|
|
|
|
| @dataclass(frozen=True) |
| class GenerationConfig: |
| max_new_tokens: int |
| do_sample: bool |
| temperature: float |
| top_p: float |
| top_k: int |
| repetition_penalty: float |
| seed: int | None |
|
|
| def to_generate_kwargs(self) -> dict[str, object]: |
| return asdict(self) |
|
|
| def to_export_dict(self) -> dict[str, object]: |
| return { |
| "max_new_tokens": self.max_new_tokens, |
| "use_sampling": self.do_sample, |
| "temperature": self.temperature, |
| "top_p": self.top_p, |
| "top_k": self.top_k, |
| "repetition_penalty": self.repetition_penalty, |
| "seed": self.seed, |
| } |
|
|
|
|
| @dataclass(frozen=True) |
| class ChatTools: |
| probe_enabled: bool |
| compare_mode: bool |
| token_contrast: bool |
|
|
|
|
| @st.dialog("Edit", width="medium") |
| def _open_edit_dialog( |
| *, |
| msg_index: int, |
| messages: list[dict[str, str]], |
| pending_key: str, |
| ) -> None: |
| message = messages[msg_index] |
| role = message["role"] |
| n_after = len(messages) - msg_index - 1 |
| suffix = ( |
| f" - {n_after} subsequent {'message' if n_after == 1 else 'messages'} will be cleared" |
| if n_after > 0 |
| else "" |
| ) |
| st.caption(f"**{role}**{suffix}") |
|
|
| new_content = st.text_area( |
| "Content", |
| value=message["content"], |
| height=320, |
| label_visibility="collapsed", |
| ) |
|
|
| save_col, cancel_col = st.columns(2) |
| with save_col: |
| if st.button("Save", type="primary", width="stretch"): |
| messages[msg_index]["content"] = new_content |
| messages[msg_index].pop("_contrast", None) |
| if role == "assistant": |
| messages[msg_index]["_needs_contrast"] = True |
| del messages[msg_index + 1 :] |
| if role == "user": |
| st.session_state[pending_key] = "regenerate_after_edit" |
| st.rerun() |
| with cancel_col: |
| if st.button("Cancel", width="stretch"): |
| st.rerun() |
|
|
|
|
| @st.dialog("Edit system prompt", width="large") |
| def _open_system_prompt_dialog( |
| *, |
| prompt_key: str, |
| current_value: str, |
| on_save: Callable[[], None] | None = None, |
| ) -> None: |
| new_value = st.text_area( |
| "System prompt", |
| value=current_value, |
| height=320, |
| label_visibility="collapsed", |
| ) |
| save_col, cancel_col = st.columns(2) |
| with save_col: |
| if st.button("Save", type="primary", width="stretch"): |
| st.session_state[prompt_key] = new_value |
| if on_save is not None: |
| on_save() |
| st.rerun() |
| with cancel_col: |
| if st.button("Cancel", width="stretch"): |
| st.rerun() |
|
|
|
|
| def render_advanced_settings( |
| context_key: str, |
| remote: bool, |
| *, |
| last_compare_mode_key: str, |
| last_probe_enabled_key: str = "", |
| last_token_contrast_key: str = "", |
| ) -> tuple[GenerationConfig, ChatTools]: |
| """Render the two advanced expanders: chat tools + generation.""" |
| compare_key = widget_key(context_key, "compare_mode") |
| if compare_key not in st.session_state: |
| st.session_state[compare_key] = st.session_state.get( |
| last_compare_mode_key, False |
| ) |
|
|
| probe_key = widget_key(context_key, "probe_enabled") |
| if probe_key not in st.session_state: |
| st.session_state[probe_key] = st.session_state.get( |
| last_probe_enabled_key, False |
| ) |
|
|
| token_contrast_key = widget_key(context_key, "token_contrast") |
| if token_contrast_key not in st.session_state: |
| st.session_state[token_contrast_key] = st.session_state.get( |
| last_token_contrast_key, False |
| ) |
|
|
| with st.expander("Chat tools", expanded=False): |
| tools_col1, tools_col2, tools_col3 = st.columns(3) |
| with tools_col1: |
| probe_enabled = st.toggle( |
| "Probe", |
| key=probe_key, |
| help="Color each assistant token by a loaded `.pt` probe's prediction.", |
| ) |
| with tools_col2: |
| compare_mode = st.toggle( |
| "Compare mode", |
| key=compare_key, |
| help="Side-by-side: send one message to two independent persona/prompt configurations.", |
| ) |
| with tools_col3: |
| token_contrast = st.toggle( |
| "Token contrast", |
| key=token_contrast_key, |
| disabled=not compare_mode, |
| help=( |
| "Color each generated token by how characteristic it is of each persona. " |
| "Red = more likely under the left persona, blue = more likely under the " |
| "right. Requires up to four extra scoring passes after each turn. " |
| "Available only in Compare mode." |
| ), |
| ) |
|
|
| st.session_state[last_compare_mode_key] = compare_mode |
| if last_probe_enabled_key: |
| st.session_state[last_probe_enabled_key] = probe_enabled |
| if last_token_contrast_key: |
| st.session_state[last_token_contrast_key] = token_contrast |
|
|
| with st.expander("Generation", expanded=False): |
| generation = _render_generation_fragment(context_key, remote) |
|
|
| tools = ChatTools( |
| probe_enabled=probe_enabled, |
| compare_mode=compare_mode, |
| token_contrast=token_contrast and compare_mode, |
| ) |
| return generation, tools |
|
|
|
|
| @st.fragment |
| def _render_generation_fragment(context_key: str, remote: bool) -> GenerationConfig: |
| """Render generation sliders inside a fragment so tweaks don't full-rerun.""" |
| config_col1, config_col2 = st.columns([2, 1]) |
| with config_col1: |
| max_new_tokens = st.slider( |
| "Max new tokens", |
| min_value=16, |
| max_value=512, |
| step=16, |
| key=_persisted_key( |
| context_key, "max_new_tokens", GENERATION_DEFAULTS["max_new_tokens"] |
| ), |
| ) |
| with config_col2: |
| repetition_penalty = st.slider( |
| "Repetition penalty", |
| min_value=0.5, |
| max_value=2.0, |
| step=0.05, |
| key=_persisted_key( |
| context_key, |
| "repetition_penalty", |
| GENERATION_DEFAULTS["repetition_penalty"], |
| ), |
| ) |
|
|
| use_sampling = st.checkbox( |
| "Random sampling", |
| key=_persisted_key(context_key, "use_sampling", False), |
| ) |
|
|
| sampling_disabled = not use_sampling |
| sampling_col1, sampling_col2, sampling_col3 = st.columns(3) |
| with sampling_col1: |
| temperature = st.slider( |
| "Temperature", |
| min_value=0.01, |
| max_value=2.0, |
| step=0.01, |
| disabled=sampling_disabled, |
| key=_persisted_key( |
| context_key, "temperature", GENERATION_DEFAULTS["temperature"] |
| ), |
| ) |
| with sampling_col2: |
| top_p = st.slider( |
| "Top-p", |
| min_value=0.01, |
| max_value=1.0, |
| step=0.01, |
| disabled=sampling_disabled, |
| key=_persisted_key(context_key, "top_p", GENERATION_DEFAULTS["top_p"]), |
| ) |
| with sampling_col3: |
| top_k = st.slider( |
| "Top-k (0 = off)", |
| min_value=0, |
| max_value=100, |
| step=1, |
| disabled=sampling_disabled, |
| key=_persisted_key(context_key, "top_k", GENERATION_DEFAULTS["top_k"]), |
| ) |
|
|
| seed_disabled = sampling_disabled or remote |
| seed_enabled = st.checkbox( |
| "Fix seed", |
| disabled=seed_disabled, |
| key=_persisted_key(context_key, "seed_enabled", False), |
| ) |
| seed = None |
| if seed_enabled: |
| seed = int( |
| st.number_input( |
| "Seed", |
| min_value=0, |
| max_value=2_147_483_647, |
| step=1, |
| disabled=seed_disabled, |
| key=_persisted_key(context_key, "seed", 0), |
| ) |
| ) |
|
|
| if remote: |
| st.caption("Seed is local-only and disabled for remote runs.") |
|
|
| for name, value in ( |
| ("max_new_tokens", max_new_tokens), |
| ("repetition_penalty", repetition_penalty), |
| ("use_sampling", use_sampling), |
| ("temperature", temperature), |
| ("top_p", top_p), |
| ("top_k", top_k), |
| ("seed_enabled", seed_enabled), |
| ): |
| st.session_state[_last_generation_key(name)] = value |
| if seed is not None: |
| st.session_state[_last_generation_key("seed")] = seed |
|
|
| do_sample = bool(use_sampling) |
| return GenerationConfig( |
| max_new_tokens=int(max_new_tokens), |
| do_sample=do_sample, |
| temperature=float(temperature), |
| top_p=float(top_p), |
| top_k=int(top_k), |
| repetition_penalty=float(repetition_penalty), |
| seed=seed if do_sample and seed is not None and not remote else None, |
| ) |
|
|
|
|
| def render_chat_message( |
| message: dict[str, str], |
| show_contrast: bool = False, |
| ) -> None: |
| if not message.get("content"): |
| return |
| contrast: TokenContrast | None = message.get("_contrast") if show_contrast else None |
| overlay = message.get("_probe_overlay") |
| with st.chat_message(message["role"]): |
| if contrast is not None: |
| from utils.contrast import render_contrast_html |
|
|
| st.html(render_contrast_html(contrast)) |
| elif overlay is not None: |
| from utils.probe_overlay import render_probe_html |
|
|
| st.html(render_probe_html(overlay)) |
| else: |
| st.markdown(message["content"]) |
|
|
|
|
| def render_chat_window( |
| *, |
| chat_log: Any, |
| messages: list[dict[str, str]], |
| edit_key: str, |
| pending_key: str, |
| show_contrast: bool = False, |
| edit_column_ratio: tuple[int, int] = (25, 1), |
| ) -> None: |
| with chat_log: |
| for i, message in enumerate(messages): |
| if not message.get("content"): |
| continue |
| msg_col, edit_col = st.columns( |
| list(edit_column_ratio), gap="xsmall", vertical_alignment="center" |
| ) |
| with msg_col: |
| render_chat_message(message, show_contrast=show_contrast) |
| with edit_col: |
| if st.button( |
| "", |
| icon=":material/edit:", |
| key=f"{edit_key}_edit_{i}", |
| help="Edit", |
| ): |
| _open_edit_dialog( |
| msg_index=i, |
| messages=messages, |
| pending_key=pending_key, |
| ) |
|
|
|
|
| def _assistant_first(personas: list[PersonaData]) -> list[PersonaData]: |
| def is_assistant(persona: PersonaData) -> bool: |
| persona_id = str(getattr(persona, "id", "")).strip().lower() |
| persona_name = str(getattr(persona, "name", "")).strip().lower() |
| return persona_id == "assistant" or persona_name == "assistant" |
|
|
| return sorted(personas, key=lambda persona: 0 if is_assistant(persona) else 1) |
|
|
|
|
| def render_system_prompt( |
| prompt_key: str, |
| prompt_mode: str, |
| active_system_prompt: str | None, |
| *, |
| on_save: Callable[[], None] | None = None, |
| ) -> str | None: |
| if prompt_key not in st.session_state: |
| st.session_state[prompt_key] = active_system_prompt or "" |
| current = st.session_state.get(prompt_key) or "" |
| with st.expander("System prompt"): |
| st.markdown(current or "*empty*") |
| if prompt_mode != "empty" and st.button( |
| "Edit", icon=":material/edit:", key=f"{prompt_key}_edit" |
| ): |
| _open_system_prompt_dialog( |
| prompt_key=prompt_key, |
| current_value=current, |
| on_save=on_save, |
| ) |
| return st.session_state.get(prompt_key) or None |
|
|
|
|
| def render_persona_prompt_controls( |
| personas: list[PersonaData], |
| current_persona_id: str | None, |
| current_prompt_mode: str, |
| persona_key: str, |
| prompt_key: str, |
| column_widths: tuple[int, int] = (3, 2), |
| ) -> tuple[PersonaData, str, bool]: |
| """Render persona and prompt selectors, returning the selected values.""" |
|
|
| p_col, m_col = st.columns(list(column_widths)) |
| with p_col: |
| persona_options = _assistant_first(personas) |
| selected_index = next( |
| (i for i, p in enumerate(persona_options) if p.id == current_persona_id), |
| 0, |
| ) |
| selected_persona = st.selectbox( |
| "Persona", |
| options=persona_options, |
| index=selected_index, |
| format_func=persona_label, |
| key=persona_key, |
| ) |
| with m_col: |
| current_label = VARIANT_LABELS[current_prompt_mode] |
| prompt_mode_label = st.selectbox( |
| "Prompt", |
| options=CHAT_PROMPT_MODE_LABELS, |
| index=CHAT_PROMPT_MODE_LABELS.index(current_label), |
| key=prompt_key, |
| ) |
| prompt_mode = CHAT_PROMPT_MODE_LABEL_TO_KEY[prompt_mode_label] |
| changed = ( |
| current_persona_id != selected_persona.id or current_prompt_mode != prompt_mode |
| ) |
| return selected_persona, prompt_mode, changed |
|
|