Jac-Zac commited on
Commit ·
a9950fb
1
Parent(s): e2cecb1
Updated to new chat edit and comparison
Browse files- Cleaned up chat
- Improved flow and UI
- Added comparison
- Fixed remote to make it more robus
- README.md +6 -2
- state.py +19 -13
- tabs/chat.py +183 -416
- tabs/compare_chat.py +443 -0
- tabs/extract.py +0 -2
- utils/chat.py +3 -7
- utils/chat_export.py +1 -1
- utils/contrast.py +311 -0
- utils/helpers.py +0 -2
- utils/runtime.py +39 -10
README.md
CHANGED
|
@@ -29,13 +29,17 @@ A web app built on top of [persona-vectors](../persona-vectors) that provides th
|
|
| 29 |
persona-ui/
|
| 30 |
├── app.py # Main entry point (Streamlit)
|
| 31 |
├── state.py # Session state management (chat history, KV cache)
|
|
|
|
|
|
|
| 32 |
├── tabs/
|
| 33 |
│ ├── chat.py # Chat tab
|
| 34 |
│ ├── compare.py # Activation comparison tab
|
|
|
|
| 35 |
│ └── extract.py # Extraction tab
|
| 36 |
└── utils/
|
| 37 |
├── chat.py # Chat generation logic
|
| 38 |
├── chat_export.py # Export chat logs to JSON
|
|
|
|
| 39 |
├── datasets.py # Dataset loader wrapper
|
| 40 |
├── helpers.py # UI labels and slug helpers
|
| 41 |
└── runtime.py # Model caching and NDIF queries
|
|
@@ -121,8 +125,8 @@ artifacts/
|
|
| 121 |
├── activations/<model_dir>/<prompt_variant>/<persona_id>/
|
| 122 |
│ ├── activations.safetensors
|
| 123 |
│ └── metadata.json # used for persona names and layer counts
|
| 124 |
-
└── chats/<model_dir>/<
|
| 125 |
└── <export>.json
|
| 126 |
```
|
| 127 |
|
| 128 |
-
`<model_dir>` is the model name with `/` replaced by `__` (e.g. `google__gemma-2-9b-it`).
|
|
|
|
| 29 |
persona-ui/
|
| 30 |
├── app.py # Main entry point (Streamlit)
|
| 31 |
├── state.py # Session state management (chat history, KV cache)
|
| 32 |
+
├── scripts/
|
| 33 |
+
│ └── oracle_probe.py # Notebook-style activation oracle script
|
| 34 |
├── tabs/
|
| 35 |
│ ├── chat.py # Chat tab
|
| 36 |
│ ├── compare.py # Activation comparison tab
|
| 37 |
+
│ ├── compare_chat.py # Side-by-side chat comparison mode
|
| 38 |
│ └── extract.py # Extraction tab
|
| 39 |
└── utils/
|
| 40 |
├── chat.py # Chat generation logic
|
| 41 |
├── chat_export.py # Export chat logs to JSON
|
| 42 |
+
├── contrast.py # Contrastive token log-prob coloring
|
| 43 |
├── datasets.py # Dataset loader wrapper
|
| 44 |
├── helpers.py # UI labels and slug helpers
|
| 45 |
└── runtime.py # Model caching and NDIF queries
|
|
|
|
| 125 |
├── activations/<model_dir>/<prompt_variant>/<persona_id>/
|
| 126 |
│ ├── activations.safetensors
|
| 127 |
│ └── metadata.json # used for persona names and layer counts
|
| 128 |
+
└── chats/<model_dir>/<persona_id>/
|
| 129 |
└── <export>.json
|
| 130 |
```
|
| 131 |
|
| 132 |
+
`<model_dir>` is the model name with `/` replaced by `__` (e.g. `google__gemma-2-9b-it`). Chat exports still store `dataset_source` in the JSON payload.
|
state.py
CHANGED
|
@@ -9,7 +9,7 @@ def chat_session_key(model_name: str, dataset_source: str) -> str:
|
|
| 9 |
return f"{_CHAT_STATE_PREFIX}{model_name}::{dataset_source}"
|
| 10 |
|
| 11 |
|
| 12 |
-
def
|
| 13 |
return {
|
| 14 |
"messages": [],
|
| 15 |
"persona_id": None,
|
|
@@ -18,6 +18,22 @@ def _default_chat_state() -> dict[str, object]:
|
|
| 18 |
}
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def _evict_inactive_kv_caches(active_key: str) -> None:
|
| 22 |
"""Drop past_key_values from every chat context except the active one."""
|
| 23 |
|
|
@@ -40,22 +56,12 @@ def get_chat_state(
|
|
| 40 |
key = chat_session_key(model_name, dataset_source)
|
| 41 |
state = st.session_state.get(key)
|
| 42 |
if state is None:
|
| 43 |
-
state =
|
| 44 |
st.session_state[key] = state
|
| 45 |
else:
|
| 46 |
-
for default_key, default_value in
|
| 47 |
state.setdefault(default_key, default_value)
|
| 48 |
_evict_inactive_kv_caches(key)
|
| 49 |
if remote and state.get("past_key_values") is not None:
|
| 50 |
state["past_key_values"] = None
|
| 51 |
return state
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def reset_chat_state(model_name: str, dataset_source: str) -> None:
|
| 55 |
-
"""Reset chat history and cache for the active context."""
|
| 56 |
-
|
| 57 |
-
key = chat_session_key(model_name, dataset_source)
|
| 58 |
-
if key in st.session_state:
|
| 59 |
-
state = st.session_state[key]
|
| 60 |
-
state["messages"] = []
|
| 61 |
-
state["past_key_values"] = None
|
|
|
|
| 9 |
return f"{_CHAT_STATE_PREFIX}{model_name}::{dataset_source}"
|
| 10 |
|
| 11 |
|
| 12 |
+
def default_chat_state() -> dict[str, object]:
|
| 13 |
return {
|
| 14 |
"messages": [],
|
| 15 |
"persona_id": None,
|
|
|
|
| 18 |
}
|
| 19 |
|
| 20 |
|
| 21 |
+
def reset_chat_context_state(
|
| 22 |
+
state: dict[str, object],
|
| 23 |
+
persona_id: str,
|
| 24 |
+
prompt_mode: str,
|
| 25 |
+
*ui_keys: str,
|
| 26 |
+
) -> None:
|
| 27 |
+
"""Reset one chat context and clear any related widget state."""
|
| 28 |
+
|
| 29 |
+
state["messages"] = []
|
| 30 |
+
state["past_key_values"] = None
|
| 31 |
+
state["persona_id"] = persona_id
|
| 32 |
+
state["prompt_mode"] = prompt_mode
|
| 33 |
+
for key in ui_keys:
|
| 34 |
+
st.session_state.pop(key, None)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
def _evict_inactive_kv_caches(active_key: str) -> None:
|
| 38 |
"""Drop past_key_values from every chat context except the active one."""
|
| 39 |
|
|
|
|
| 56 |
key = chat_session_key(model_name, dataset_source)
|
| 57 |
state = st.session_state.get(key)
|
| 58 |
if state is None:
|
| 59 |
+
state = default_chat_state()
|
| 60 |
st.session_state[key] = state
|
| 61 |
else:
|
| 62 |
+
for default_key, default_value in default_chat_state().items():
|
| 63 |
state.setdefault(default_key, default_value)
|
| 64 |
_evict_inactive_kv_caches(key)
|
| 65 |
if remote and state.get("past_key_values") is not None:
|
| 66 |
state["past_key_values"] = None
|
| 67 |
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tabs/chat.py
CHANGED
|
@@ -1,72 +1,109 @@
|
|
| 1 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 2 |
from typing import Any
|
| 3 |
|
| 4 |
import streamlit as st
|
| 5 |
from persona_data.synth_persona import PersonaData
|
| 6 |
|
| 7 |
-
from state import
|
| 8 |
-
_default_chat_state,
|
| 9 |
-
chat_session_key,
|
| 10 |
-
get_chat_state,
|
| 11 |
-
reset_chat_state,
|
| 12 |
-
)
|
| 13 |
from utils.chat import ChatReply, generate_chat_reply, resolve_system_prompt
|
| 14 |
from utils.chat_export import save_chat_export
|
|
|
|
| 15 |
from utils.datasets import load_dataset
|
| 16 |
from utils.helpers import (
|
| 17 |
MODE_LABEL_TO_KEY,
|
| 18 |
MODE_LABELS,
|
| 19 |
VARIANT_LABELS,
|
| 20 |
-
VISIBLE_MESSAGE_COUNT,
|
| 21 |
persona_label,
|
| 22 |
widget_key,
|
| 23 |
)
|
| 24 |
from utils.runtime import cached_model
|
| 25 |
|
| 26 |
-
COLLAPSED_MESSAGE_CHAR_LIMIT = 500
|
| 27 |
-
|
| 28 |
|
| 29 |
def _render_collapsible_markdown(content: str) -> None:
|
| 30 |
-
|
| 31 |
-
st.markdown(content)
|
| 32 |
-
return
|
| 33 |
|
| 34 |
-
with st.expander(f"Show full text ({len(content)} chars)", expanded=False):
|
| 35 |
-
st.markdown(content)
|
| 36 |
|
|
|
|
| 37 |
|
| 38 |
-
def _render_chat_message(message: dict[str, str]) -> None:
|
| 39 |
-
if not message.get("content"):
|
| 40 |
-
return
|
| 41 |
-
with st.container(border=True):
|
| 42 |
-
st.caption(message["role"])
|
| 43 |
-
_render_collapsible_markdown(message["content"])
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
)
|
| 52 |
-
"""Render the system prompt as an always-editable text area at the top of the chat."""
|
| 53 |
-
if prompt_mode == "empty":
|
| 54 |
-
return active_system_prompt
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
with st.container(border=True):
|
| 60 |
-
st.caption("System prompt")
|
| 61 |
-
st.text_area(
|
| 62 |
-
"system_prompt_edit",
|
| 63 |
-
value=st.session_state[prompt_key],
|
| 64 |
-
height=height,
|
| 65 |
-
label_visibility="collapsed",
|
| 66 |
-
key=prompt_key,
|
| 67 |
-
)
|
| 68 |
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
|
| 72 |
def _render_editable_message(
|
|
@@ -76,63 +113,50 @@ def _render_editable_message(
|
|
| 76 |
chat_state: dict[str, object],
|
| 77 |
edit_key: str,
|
| 78 |
pending_key: str,
|
|
|
|
|
|
|
| 79 |
) -> None:
|
| 80 |
-
"""Render a single message with an inline edit button."""
|
| 81 |
if not message.get("content"):
|
| 82 |
return
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
)
|
| 96 |
-
c1, c2 = st.columns(2)
|
| 97 |
-
with c1:
|
| 98 |
-
if st.button(
|
| 99 |
-
"Save", key=f"{edit_key}_msg_save_{msg_index}", type="primary"
|
| 100 |
-
):
|
| 101 |
-
messages[msg_index]["content"] = new_content
|
| 102 |
-
del messages[msg_index + 1 :]
|
| 103 |
-
chat_state["past_key_values"] = None
|
| 104 |
-
st.session_state[edit_key] = None
|
| 105 |
-
if message["role"] == "user":
|
| 106 |
-
st.session_state[pending_key] = True
|
| 107 |
-
st.rerun()
|
| 108 |
-
with c2:
|
| 109 |
-
if st.button("Cancel", key=f"{edit_key}_msg_cancel_{msg_index}"):
|
| 110 |
-
st.session_state[edit_key] = None
|
| 111 |
-
st.rerun()
|
| 112 |
-
else:
|
| 113 |
-
st.markdown(message["content"])
|
| 114 |
-
if st.button("Edit", key=f"{edit_key}_msg_edit_{msg_index}"):
|
| 115 |
-
st.session_state[edit_key] = msg_index
|
| 116 |
-
st.rerun()
|
| 117 |
-
|
| 118 |
|
| 119 |
-
def _clear_chat_ui_state(*keys: str) -> None:
|
| 120 |
-
for key in keys:
|
| 121 |
-
st.session_state.pop(key, None)
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
model_name: str,
|
| 126 |
-
dataset_source: str,
|
| 127 |
-
chat_state: dict[str, object],
|
| 128 |
-
persona_id: str,
|
| 129 |
prompt_mode: str,
|
| 130 |
-
|
| 131 |
-
) -> None:
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
|
| 138 |
def _generation_dict(gen_kwargs: dict, advanced_generation: bool) -> dict[str, object]:
|
|
@@ -189,42 +213,27 @@ def _render_chat_window(
|
|
| 189 |
*,
|
| 190 |
chat_log: Any,
|
| 191 |
messages: list[dict[str, str]],
|
| 192 |
-
show_all_key: str,
|
| 193 |
-
show_all_btn_key: str,
|
| 194 |
-
show_earlier_label: str,
|
| 195 |
chat_state: dict[str, object] | None = None,
|
| 196 |
edit_key: str | None = None,
|
| 197 |
pending_key: str | None = None,
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
with chat_log:
|
| 202 |
-
|
| 203 |
-
show_all_key, False
|
| 204 |
-
):
|
| 205 |
-
hidden_count = len(messages) - VISIBLE_MESSAGE_COUNT
|
| 206 |
-
if st.button(
|
| 207 |
-
f"{show_earlier_label} ({hidden_count} hidden)",
|
| 208 |
-
key=show_all_btn_key,
|
| 209 |
-
):
|
| 210 |
-
st.session_state[show_all_key] = True
|
| 211 |
-
st.rerun()
|
| 212 |
-
visible_messages = messages[-VISIBLE_MESSAGE_COUNT:]
|
| 213 |
-
index_offset = len(messages) - VISIBLE_MESSAGE_COUNT
|
| 214 |
-
else:
|
| 215 |
-
visible_messages = messages
|
| 216 |
-
index_offset = 0
|
| 217 |
-
|
| 218 |
-
for i, message in enumerate(visible_messages):
|
| 219 |
-
actual_index = index_offset + i
|
| 220 |
if edit_key and pending_key:
|
| 221 |
_render_editable_message(
|
| 222 |
-
message,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
)
|
| 224 |
else:
|
| 225 |
-
_render_chat_message(message)
|
| 226 |
-
|
| 227 |
-
return chat_log
|
| 228 |
|
| 229 |
|
| 230 |
def _build_chat_messages(
|
|
@@ -247,8 +256,8 @@ def _save_chat_export_message(
|
|
| 247 |
messages: list[dict[str, str]],
|
| 248 |
generation: dict[str, object],
|
| 249 |
panel_label: str | None = None,
|
| 250 |
-
) ->
|
| 251 |
-
|
| 252 |
model_name=model_name,
|
| 253 |
dataset_source=dataset_source,
|
| 254 |
persona_id=persona_id,
|
|
@@ -259,230 +268,12 @@ def _save_chat_export_message(
|
|
| 259 |
messages=messages,
|
| 260 |
generation=generation,
|
| 261 |
)
|
| 262 |
-
return f"Saved chat export to {export_path}"
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
# ── Compare mode helpers ───────────────────────────────────────────────────────
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
def _panel_state(panel_key: str) -> dict:
|
| 269 |
-
"""Get or initialise compare-panel chat state stored in session_state."""
|
| 270 |
-
if panel_key not in st.session_state:
|
| 271 |
-
st.session_state[panel_key] = _default_chat_state()
|
| 272 |
-
return st.session_state[panel_key]
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
def _render_compare_mode(
|
| 276 |
-
remote: bool,
|
| 277 |
-
model_name: str,
|
| 278 |
-
context_key: str,
|
| 279 |
-
dataset_source: str,
|
| 280 |
-
personas: list[PersonaData],
|
| 281 |
-
gen_kwargs: dict,
|
| 282 |
-
advanced_generation: bool,
|
| 283 |
-
) -> None:
|
| 284 |
-
"""Render the full side-by-side comparison UI."""
|
| 285 |
-
left_col, right_col = st.columns(2)
|
| 286 |
-
|
| 287 |
-
def render_panel(side: str) -> tuple[dict[str, object], Any, str | None, str]:
|
| 288 |
-
panel_key = widget_key(context_key, f"cmp_{side}")
|
| 289 |
-
state = _panel_state(panel_key)
|
| 290 |
-
prompt_key = widget_key(panel_key, "custom_prompt")
|
| 291 |
-
show_all_key = widget_key(panel_key, "show_all")
|
| 292 |
-
edit_key = widget_key(panel_key, "edit_idx")
|
| 293 |
-
pending_regen_key = widget_key(panel_key, "pending_regen")
|
| 294 |
-
|
| 295 |
-
selected_persona, prompt_mode, changed = _render_persona_prompt_controls(
|
| 296 |
-
personas,
|
| 297 |
-
state["persona_id"],
|
| 298 |
-
state["prompt_mode"],
|
| 299 |
-
widget_key(panel_key, "persona"),
|
| 300 |
-
widget_key(panel_key, "prompt_mode"),
|
| 301 |
-
)
|
| 302 |
-
if changed:
|
| 303 |
-
state["messages"] = []
|
| 304 |
-
state["past_key_values"] = None
|
| 305 |
-
state["persona_id"] = selected_persona.id
|
| 306 |
-
state["prompt_mode"] = prompt_mode
|
| 307 |
-
_clear_chat_ui_state(prompt_key, show_all_key)
|
| 308 |
-
st.session_state.pop(edit_key, None)
|
| 309 |
-
|
| 310 |
-
active_system_prompt = resolve_system_prompt(
|
| 311 |
-
persona=selected_persona, mode=prompt_mode
|
| 312 |
-
)
|
| 313 |
-
|
| 314 |
-
btn_col1, btn_col2 = st.columns(2)
|
| 315 |
-
with btn_col1:
|
| 316 |
-
if st.button(
|
| 317 |
-
"Export chat", key=widget_key(panel_key, "export_chat"), width="stretch"
|
| 318 |
-
):
|
| 319 |
-
st.success(
|
| 320 |
-
_save_chat_export_message(
|
| 321 |
-
model_name=model_name,
|
| 322 |
-
dataset_source=dataset_source,
|
| 323 |
-
persona_id=selected_persona.id,
|
| 324 |
-
persona_name=getattr(selected_persona, "name", None),
|
| 325 |
-
prompt_mode=prompt_mode,
|
| 326 |
-
system_prompt=active_system_prompt,
|
| 327 |
-
messages=state["messages"],
|
| 328 |
-
generation=_generation_dict(gen_kwargs, advanced_generation),
|
| 329 |
-
panel_label=side,
|
| 330 |
-
)
|
| 331 |
-
)
|
| 332 |
-
with btn_col2:
|
| 333 |
-
if st.button(
|
| 334 |
-
"Reset chat",
|
| 335 |
-
key=widget_key(panel_key, "reset"),
|
| 336 |
-
width="stretch",
|
| 337 |
-
type="secondary",
|
| 338 |
-
):
|
| 339 |
-
state["messages"] = []
|
| 340 |
-
state["past_key_values"] = None
|
| 341 |
-
_clear_chat_ui_state(prompt_key, show_all_key)
|
| 342 |
-
st.session_state.pop(edit_key, None)
|
| 343 |
-
st.rerun()
|
| 344 |
-
|
| 345 |
-
chat_log = st.container()
|
| 346 |
-
with chat_log:
|
| 347 |
-
active_system_prompt = _render_inline_system_prompt(
|
| 348 |
-
prompt_key,
|
| 349 |
-
prompt_mode,
|
| 350 |
-
active_system_prompt,
|
| 351 |
-
height=150,
|
| 352 |
-
)
|
| 353 |
-
_render_chat_window(
|
| 354 |
-
chat_log=chat_log,
|
| 355 |
-
messages=state["messages"],
|
| 356 |
-
show_all_key=show_all_key,
|
| 357 |
-
show_all_btn_key=widget_key(panel_key, "show_all_btn"),
|
| 358 |
-
show_earlier_label="Show earlier",
|
| 359 |
-
chat_state=state,
|
| 360 |
-
edit_key=edit_key,
|
| 361 |
-
pending_key=pending_regen_key,
|
| 362 |
-
)
|
| 363 |
-
return state, chat_log, active_system_prompt, pending_regen_key
|
| 364 |
-
|
| 365 |
-
with left_col:
|
| 366 |
-
left_state, left_log, left_prompt, left_pending = render_panel("left")
|
| 367 |
-
with right_col:
|
| 368 |
-
right_state, right_log, right_prompt, right_pending = render_panel("right")
|
| 369 |
-
|
| 370 |
-
panels = [
|
| 371 |
-
(left_state, left_log, left_prompt, left_pending),
|
| 372 |
-
(right_state, right_log, right_prompt, right_pending),
|
| 373 |
-
]
|
| 374 |
-
|
| 375 |
-
# Handle per-panel regeneration triggered by message edits
|
| 376 |
-
any_regen = any(st.session_state.get(p_pending) for _, _, _, p_pending in panels)
|
| 377 |
-
if any_regen:
|
| 378 |
-
model = cached_model(model_name=model_name, remote=remote)
|
| 379 |
-
for panel_state, panel_log, panel_prompt, p_pending in panels:
|
| 380 |
-
if not st.session_state.pop(p_pending, False):
|
| 381 |
-
continue
|
| 382 |
-
regen_messages = _build_chat_messages(panel_prompt, panel_state["messages"])
|
| 383 |
-
with st.spinner("Regenerating..."):
|
| 384 |
-
try:
|
| 385 |
-
result = generate_chat_reply(
|
| 386 |
-
model=model,
|
| 387 |
-
messages=regen_messages,
|
| 388 |
-
remote=remote,
|
| 389 |
-
past_key_values=panel_state["past_key_values"],
|
| 390 |
-
**gen_kwargs,
|
| 391 |
-
)
|
| 392 |
-
except Exception as exc:
|
| 393 |
-
with panel_log:
|
| 394 |
-
st.error(f"Generation failed: {exc}")
|
| 395 |
-
panel_state["messages"].pop()
|
| 396 |
-
continue
|
| 397 |
-
panel_state["messages"].append(
|
| 398 |
-
{"role": "assistant", "content": result.text}
|
| 399 |
-
)
|
| 400 |
-
panel_state["past_key_values"] = (
|
| 401 |
-
result.past_key_values if not remote else None
|
| 402 |
-
)
|
| 403 |
-
with panel_log:
|
| 404 |
-
_render_chat_message({"role": "assistant", "content": result.text})
|
| 405 |
-
st.rerun()
|
| 406 |
-
|
| 407 |
-
user_prompt = st.chat_input(
|
| 408 |
-
"Ask both...",
|
| 409 |
-
key=widget_key(context_key, "cmp_input"),
|
| 410 |
-
)
|
| 411 |
-
if not user_prompt:
|
| 412 |
-
return
|
| 413 |
-
|
| 414 |
-
model = cached_model(model_name=model_name, remote=remote)
|
| 415 |
-
|
| 416 |
-
for panel_state, panel_log, _panel_prompt, _p_pending in panels:
|
| 417 |
-
panel_state["messages"].append({"role": "user", "content": user_prompt})
|
| 418 |
-
with panel_log:
|
| 419 |
-
_render_chat_message({"role": "user", "content": user_prompt})
|
| 420 |
-
|
| 421 |
-
with st.spinner("Generating..."):
|
| 422 |
-
if remote:
|
| 423 |
-
with ThreadPoolExecutor(max_workers=2) as executor:
|
| 424 |
-
futures = [
|
| 425 |
-
executor.submit(
|
| 426 |
-
generate_chat_reply,
|
| 427 |
-
model=model,
|
| 428 |
-
messages=_build_chat_messages(
|
| 429 |
-
panel_prompt, panel_state["messages"]
|
| 430 |
-
),
|
| 431 |
-
remote=remote,
|
| 432 |
-
past_key_values=panel_state["past_key_values"],
|
| 433 |
-
**gen_kwargs,
|
| 434 |
-
)
|
| 435 |
-
for panel_state, _panel_log, panel_prompt, _p_pending in panels
|
| 436 |
-
]
|
| 437 |
-
results: list[ChatReply | Exception] = []
|
| 438 |
-
for future in futures:
|
| 439 |
-
try:
|
| 440 |
-
results.append(future.result())
|
| 441 |
-
except Exception as exc:
|
| 442 |
-
results.append(exc)
|
| 443 |
-
else:
|
| 444 |
-
results = []
|
| 445 |
-
for panel_state, _panel_log, panel_prompt, _p_pending in panels:
|
| 446 |
-
try:
|
| 447 |
-
results.append(
|
| 448 |
-
generate_chat_reply(
|
| 449 |
-
model=model,
|
| 450 |
-
messages=_build_chat_messages(
|
| 451 |
-
panel_prompt, panel_state["messages"]
|
| 452 |
-
),
|
| 453 |
-
remote=remote,
|
| 454 |
-
past_key_values=panel_state["past_key_values"],
|
| 455 |
-
**gen_kwargs,
|
| 456 |
-
)
|
| 457 |
-
)
|
| 458 |
-
except Exception as exc:
|
| 459 |
-
results.append(exc)
|
| 460 |
-
|
| 461 |
-
for (panel_state, panel_log, _panel_prompt, _p_pending), result in zip(
|
| 462 |
-
panels, results
|
| 463 |
-
):
|
| 464 |
-
if isinstance(result, Exception):
|
| 465 |
-
with panel_log:
|
| 466 |
-
st.error(f"Generation failed: {result}")
|
| 467 |
-
panel_state["messages"].pop()
|
| 468 |
-
continue
|
| 469 |
-
|
| 470 |
-
panel_state["messages"].append({"role": "assistant", "content": result.text})
|
| 471 |
-
panel_state["past_key_values"] = result.past_key_values if not remote else None
|
| 472 |
-
with panel_log:
|
| 473 |
-
_render_chat_message({"role": "assistant", "content": result.text})
|
| 474 |
-
|
| 475 |
-
# Rerun so the newly appended turns are redrawn through the editable history
|
| 476 |
-
# renderer instead of only appearing in the one-off generation pass.
|
| 477 |
-
st.rerun()
|
| 478 |
|
| 479 |
|
| 480 |
# ── Main tab entry point ───────────────────────────────────────────────────────
|
| 481 |
|
| 482 |
|
| 483 |
-
def _render_generation_settings(
|
| 484 |
-
context_key: str, remote: bool
|
| 485 |
-
) -> tuple[dict, bool]:
|
| 486 |
"""Render the Advanced generation settings expander.
|
| 487 |
|
| 488 |
Returns ``(gen_kwargs, advanced_generation)`` where ``advanced_generation``
|
|
@@ -633,7 +424,9 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 633 |
)
|
| 634 |
|
| 635 |
if compare_mode:
|
| 636 |
-
|
|
|
|
|
|
|
| 637 |
remote,
|
| 638 |
model_name,
|
| 639 |
context_key,
|
|
@@ -648,76 +441,70 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 648 |
persona_select_key = widget_key(context_key, "persona_select")
|
| 649 |
prompt_mode_select_key = widget_key(context_key, "system_prompt_select")
|
| 650 |
prompt_key = widget_key(context_key, "custom_system_prompt")
|
| 651 |
-
show_all_key = widget_key(context_key, "show_all_messages")
|
| 652 |
chat_input_key = widget_key(context_key, "chat_input")
|
| 653 |
pending_key = widget_key(context_key, "pending_prompt")
|
| 654 |
export_key = widget_key(context_key, "export_chat")
|
| 655 |
reset_key = widget_key(context_key, "reset")
|
| 656 |
edit_key = widget_key(context_key, "edit_idx")
|
| 657 |
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
options=personas,
|
| 667 |
-
index=selected_index,
|
| 668 |
-
format_func=persona_label,
|
| 669 |
-
key=persona_select_key,
|
| 670 |
-
)
|
| 671 |
-
with col2:
|
| 672 |
-
current_mode_label = VARIANT_LABELS.get(chat_state["prompt_mode"], "None")
|
| 673 |
-
st.selectbox(
|
| 674 |
-
"Prompt",
|
| 675 |
-
options=MODE_LABELS,
|
| 676 |
-
index=MODE_LABELS.index(current_mode_label),
|
| 677 |
-
key=prompt_mode_select_key,
|
| 678 |
)
|
| 679 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 680 |
|
| 681 |
active_system_prompt = resolve_system_prompt(
|
| 682 |
persona=selected_persona,
|
| 683 |
mode=prompt_mode,
|
| 684 |
)
|
| 685 |
|
| 686 |
-
changed_context = (
|
| 687 |
-
chat_state["persona_id"] != selected_persona.id
|
| 688 |
-
or chat_state["prompt_mode"] != prompt_mode
|
| 689 |
-
)
|
| 690 |
if changed_context:
|
| 691 |
had_history = bool(chat_state["messages"])
|
| 692 |
-
|
| 693 |
-
model_name,
|
| 694 |
-
dataset_source,
|
| 695 |
-
chat_state,
|
| 696 |
-
selected_persona.id,
|
| 697 |
-
prompt_mode,
|
| 698 |
-
chat_input_key,
|
| 699 |
-
show_all_key,
|
| 700 |
-
prompt_key,
|
| 701 |
-
pending_key,
|
| 702 |
-
)
|
| 703 |
-
st.session_state.pop(edit_key, None)
|
| 704 |
if had_history:
|
| 705 |
st.info("Chat history reset because the persona or system prompt changed.")
|
| 706 |
|
| 707 |
chat_log = st.container()
|
| 708 |
|
| 709 |
with chat_log:
|
| 710 |
-
active_system_prompt =
|
| 711 |
prompt_key,
|
| 712 |
prompt_mode,
|
| 713 |
active_system_prompt,
|
| 714 |
-
height=200,
|
| 715 |
)
|
| 716 |
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 721 |
_save_chat_export_message(
|
| 722 |
model_name=model_name,
|
| 723 |
dataset_source=dataset_source,
|
|
@@ -728,38 +515,18 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 728 |
messages=chat_state["messages"],
|
| 729 |
generation=_generation_dict(gen_kwargs, advanced_generation),
|
| 730 |
)
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
show_all_key,
|
| 742 |
-
prompt_key,
|
| 743 |
-
pending_key,
|
| 744 |
-
)
|
| 745 |
-
st.session_state.pop(edit_key, None)
|
| 746 |
-
st.rerun()
|
| 747 |
-
|
| 748 |
-
_render_chat_window(
|
| 749 |
-
chat_log=chat_log,
|
| 750 |
-
messages=chat_state["messages"],
|
| 751 |
-
show_all_key=show_all_key,
|
| 752 |
-
show_all_btn_key=widget_key(context_key, "show_all_btn"),
|
| 753 |
-
show_earlier_label="Show earlier messages",
|
| 754 |
-
chat_state=chat_state,
|
| 755 |
-
edit_key=edit_key,
|
| 756 |
-
pending_key=pending_key,
|
| 757 |
-
)
|
| 758 |
|
| 759 |
-
user_prompt = st.chat_input(
|
| 760 |
-
"Ask something...",
|
| 761 |
-
key=chat_input_key,
|
| 762 |
-
)
|
| 763 |
|
| 764 |
# Pass 1: user submitted — append message and rerun so it renders before generation.
|
| 765 |
if user_prompt:
|
|
|
|
|
|
|
| 1 |
from typing import Any
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
from persona_data.synth_persona import PersonaData
|
| 5 |
|
| 6 |
+
from state import chat_session_key, get_chat_state, reset_chat_context_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from utils.chat import ChatReply, generate_chat_reply, resolve_system_prompt
|
| 8 |
from utils.chat_export import save_chat_export
|
| 9 |
+
from utils.contrast import TokenContrast, render_contrast_html
|
| 10 |
from utils.datasets import load_dataset
|
| 11 |
from utils.helpers import (
|
| 12 |
MODE_LABEL_TO_KEY,
|
| 13 |
MODE_LABELS,
|
| 14 |
VARIANT_LABELS,
|
|
|
|
| 15 |
persona_label,
|
| 16 |
widget_key,
|
| 17 |
)
|
| 18 |
from utils.runtime import cached_model
|
| 19 |
|
|
|
|
|
|
|
| 20 |
|
| 21 |
def _render_collapsible_markdown(content: str) -> None:
|
| 22 |
+
st.markdown(content)
|
|
|
|
|
|
|
| 23 |
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
# ── Dialogs ───────────────────────────────────────────────────────────────────
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
@st.dialog("Edit", width="medium")
|
| 29 |
+
def _open_edit_dialog(
|
| 30 |
+
*,
|
| 31 |
+
msg_index: int,
|
| 32 |
+
messages: list[dict[str, str]],
|
| 33 |
+
chat_state: dict[str, object],
|
| 34 |
+
pending_key: str,
|
| 35 |
+
) -> None:
|
| 36 |
+
message = messages[msg_index]
|
| 37 |
+
role = message["role"]
|
| 38 |
+
|
| 39 |
+
n_after = len(messages) - msg_index - 1
|
| 40 |
+
st.caption(
|
| 41 |
+
f"**{role}**"
|
| 42 |
+
+ (
|
| 43 |
+
f" — {n_after} subsequent {'message' if n_after == 1 else 'messages'} will be cleared"
|
| 44 |
+
if n_after > 0
|
| 45 |
+
else ""
|
| 46 |
+
)
|
| 47 |
+
)
|
| 48 |
|
| 49 |
+
new_content = st.text_area(
|
| 50 |
+
"Content",
|
| 51 |
+
value=message["content"],
|
| 52 |
+
height=320,
|
| 53 |
+
label_visibility="collapsed",
|
| 54 |
+
)
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
save_col, cancel_col = st.columns(2)
|
| 57 |
+
with save_col:
|
| 58 |
+
if st.button("Save", type="primary", use_container_width=True):
|
| 59 |
+
messages[msg_index]["content"] = new_content
|
| 60 |
+
messages[msg_index].pop("_contrast", None)
|
| 61 |
+
if role == "assistant":
|
| 62 |
+
messages[msg_index]["_needs_contrast"] = True
|
| 63 |
+
del messages[msg_index + 1 :]
|
| 64 |
+
chat_state["past_key_values"] = None
|
| 65 |
+
if role == "user":
|
| 66 |
+
st.session_state[pending_key] = True
|
| 67 |
+
st.rerun()
|
| 68 |
+
with cancel_col:
|
| 69 |
+
if st.button("Cancel", use_container_width=True):
|
| 70 |
+
st.rerun()
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
+
@st.dialog("Edit system prompt", width="large")
|
| 74 |
+
def _open_system_prompt_dialog(*, prompt_key: str, current_value: str) -> None:
|
| 75 |
+
new_value = st.text_area(
|
| 76 |
+
"System prompt",
|
| 77 |
+
value=current_value,
|
| 78 |
+
height=320,
|
| 79 |
+
label_visibility="collapsed",
|
| 80 |
+
)
|
| 81 |
+
save_col, cancel_col = st.columns(2)
|
| 82 |
+
with save_col:
|
| 83 |
+
if st.button("Save", type="primary", use_container_width=True):
|
| 84 |
+
st.session_state[prompt_key] = new_value
|
| 85 |
+
st.rerun()
|
| 86 |
+
with cancel_col:
|
| 87 |
+
if st.button("Cancel", use_container_width=True):
|
| 88 |
+
st.rerun()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ── Message renderers ─────────────────────────────────────────────────────────
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _render_chat_message(
|
| 95 |
+
message: dict[str, str],
|
| 96 |
+
show_contrast: bool = False,
|
| 97 |
+
) -> None:
|
| 98 |
+
if not message.get("content"):
|
| 99 |
+
return
|
| 100 |
+
role = message["role"]
|
| 101 |
+
tc: TokenContrast | None = message.get("_contrast") if show_contrast else None
|
| 102 |
+
with st.chat_message(role):
|
| 103 |
+
if tc is not None:
|
| 104 |
+
st.html(render_contrast_html(tc))
|
| 105 |
+
else:
|
| 106 |
+
_render_collapsible_markdown(message["content"])
|
| 107 |
|
| 108 |
|
| 109 |
def _render_editable_message(
|
|
|
|
| 113 |
chat_state: dict[str, object],
|
| 114 |
edit_key: str,
|
| 115 |
pending_key: str,
|
| 116 |
+
show_contrast: bool = False,
|
| 117 |
+
column_ratio: tuple[int, int] = (25, 1),
|
| 118 |
) -> None:
|
|
|
|
| 119 |
if not message.get("content"):
|
| 120 |
return
|
| 121 |
+
role = message["role"]
|
| 122 |
+
tc: TokenContrast | None = message.get("_contrast") if show_contrast else None
|
| 123 |
|
| 124 |
+
msg_col, edit_col = st.columns(
|
| 125 |
+
list(column_ratio), gap="xsmall", vertical_alignment="center"
|
| 126 |
+
)
|
| 127 |
+
with msg_col:
|
| 128 |
+
with st.chat_message(role):
|
| 129 |
+
if tc is not None:
|
| 130 |
+
st.html(render_contrast_html(tc))
|
| 131 |
+
else:
|
| 132 |
+
_render_collapsible_markdown(message["content"])
|
| 133 |
+
with edit_col:
|
| 134 |
+
if st.button(
|
| 135 |
+
"", icon=":material/edit:", key=f"{edit_key}_edit_{msg_index}", help="Edit"
|
| 136 |
+
):
|
| 137 |
+
_open_edit_dialog(
|
| 138 |
+
msg_index=msg_index,
|
| 139 |
+
messages=messages,
|
| 140 |
+
chat_state=chat_state,
|
| 141 |
+
pending_key=pending_key,
|
| 142 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
+
def _render_system_prompt(
|
| 146 |
+
prompt_key: str,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
prompt_mode: str,
|
| 148 |
+
active_system_prompt: str | None,
|
| 149 |
+
) -> str | None:
|
| 150 |
+
if prompt_key not in st.session_state:
|
| 151 |
+
st.session_state[prompt_key] = active_system_prompt or ""
|
| 152 |
+
current = st.session_state.get(prompt_key) or ""
|
| 153 |
+
with st.expander("System prompt"):
|
| 154 |
+
st.markdown(current or "*empty*")
|
| 155 |
+
if prompt_mode != "empty" and st.button(
|
| 156 |
+
"Edit", icon=":material/edit:", key=f"{prompt_key}_edit"
|
| 157 |
+
):
|
| 158 |
+
_open_system_prompt_dialog(prompt_key=prompt_key, current_value=current)
|
| 159 |
+
return st.session_state.get(prompt_key) or None
|
| 160 |
|
| 161 |
|
| 162 |
def _generation_dict(gen_kwargs: dict, advanced_generation: bool) -> dict[str, object]:
|
|
|
|
| 213 |
*,
|
| 214 |
chat_log: Any,
|
| 215 |
messages: list[dict[str, str]],
|
|
|
|
|
|
|
|
|
|
| 216 |
chat_state: dict[str, object] | None = None,
|
| 217 |
edit_key: str | None = None,
|
| 218 |
pending_key: str | None = None,
|
| 219 |
+
show_contrast: bool = False,
|
| 220 |
+
edit_column_ratio: tuple[int, int] = (25, 1),
|
| 221 |
+
) -> None:
|
| 222 |
with chat_log:
|
| 223 |
+
for i, message in enumerate(messages):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
if edit_key and pending_key:
|
| 225 |
_render_editable_message(
|
| 226 |
+
message,
|
| 227 |
+
i,
|
| 228 |
+
messages,
|
| 229 |
+
chat_state,
|
| 230 |
+
edit_key,
|
| 231 |
+
pending_key,
|
| 232 |
+
show_contrast=show_contrast,
|
| 233 |
+
column_ratio=edit_column_ratio,
|
| 234 |
)
|
| 235 |
else:
|
| 236 |
+
_render_chat_message(message, show_contrast=show_contrast)
|
|
|
|
|
|
|
| 237 |
|
| 238 |
|
| 239 |
def _build_chat_messages(
|
|
|
|
| 256 |
messages: list[dict[str, str]],
|
| 257 |
generation: dict[str, object],
|
| 258 |
panel_label: str | None = None,
|
| 259 |
+
) -> None:
|
| 260 |
+
save_chat_export(
|
| 261 |
model_name=model_name,
|
| 262 |
dataset_source=dataset_source,
|
| 263 |
persona_id=persona_id,
|
|
|
|
| 268 |
messages=messages,
|
| 269 |
generation=generation,
|
| 270 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
|
| 273 |
# ── Main tab entry point ───────────────────────────────────────────────────────
|
| 274 |
|
| 275 |
|
| 276 |
+
def _render_generation_settings(context_key: str, remote: bool) -> tuple[dict, bool]:
|
|
|
|
|
|
|
| 277 |
"""Render the Advanced generation settings expander.
|
| 278 |
|
| 279 |
Returns ``(gen_kwargs, advanced_generation)`` where ``advanced_generation``
|
|
|
|
| 424 |
)
|
| 425 |
|
| 426 |
if compare_mode:
|
| 427 |
+
from tabs.compare_chat import render_compare_mode
|
| 428 |
+
|
| 429 |
+
render_compare_mode(
|
| 430 |
remote,
|
| 431 |
model_name,
|
| 432 |
context_key,
|
|
|
|
| 441 |
persona_select_key = widget_key(context_key, "persona_select")
|
| 442 |
prompt_mode_select_key = widget_key(context_key, "system_prompt_select")
|
| 443 |
prompt_key = widget_key(context_key, "custom_system_prompt")
|
|
|
|
| 444 |
chat_input_key = widget_key(context_key, "chat_input")
|
| 445 |
pending_key = widget_key(context_key, "pending_prompt")
|
| 446 |
export_key = widget_key(context_key, "export_chat")
|
| 447 |
reset_key = widget_key(context_key, "reset")
|
| 448 |
edit_key = widget_key(context_key, "edit_idx")
|
| 449 |
|
| 450 |
+
def _reset_active_chat_context() -> None:
|
| 451 |
+
reset_chat_context_state(
|
| 452 |
+
chat_state,
|
| 453 |
+
selected_persona.id,
|
| 454 |
+
prompt_mode,
|
| 455 |
+
chat_input_key,
|
| 456 |
+
prompt_key,
|
| 457 |
+
pending_key,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
)
|
| 459 |
+
st.session_state.pop(edit_key, None)
|
| 460 |
+
|
| 461 |
+
selected_persona, prompt_mode, changed_context = _render_persona_prompt_controls(
|
| 462 |
+
personas,
|
| 463 |
+
chat_state["persona_id"],
|
| 464 |
+
chat_state["prompt_mode"],
|
| 465 |
+
persona_select_key,
|
| 466 |
+
prompt_mode_select_key,
|
| 467 |
+
column_widths=(2, 1),
|
| 468 |
+
)
|
| 469 |
|
| 470 |
active_system_prompt = resolve_system_prompt(
|
| 471 |
persona=selected_persona,
|
| 472 |
mode=prompt_mode,
|
| 473 |
)
|
| 474 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
if changed_context:
|
| 476 |
had_history = bool(chat_state["messages"])
|
| 477 |
+
_reset_active_chat_context()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
if had_history:
|
| 479 |
st.info("Chat history reset because the persona or system prompt changed.")
|
| 480 |
|
| 481 |
chat_log = st.container()
|
| 482 |
|
| 483 |
with chat_log:
|
| 484 |
+
active_system_prompt = _render_system_prompt(
|
| 485 |
prompt_key,
|
| 486 |
prompt_mode,
|
| 487 |
active_system_prompt,
|
|
|
|
| 488 |
)
|
| 489 |
|
| 490 |
+
_render_chat_window(
|
| 491 |
+
chat_log=chat_log,
|
| 492 |
+
messages=chat_state["messages"],
|
| 493 |
+
chat_state=chat_state,
|
| 494 |
+
edit_key=edit_key,
|
| 495 |
+
pending_key=pending_key,
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
footer = st.container()
|
| 499 |
+
with footer:
|
| 500 |
+
exp_col, rst_col, _spacer = st.columns([0.5, 0.5, 10], gap="xsmall")
|
| 501 |
+
with exp_col:
|
| 502 |
+
if st.button(
|
| 503 |
+
"",
|
| 504 |
+
icon=":material/download:",
|
| 505 |
+
key=export_key,
|
| 506 |
+
help="Export chat",
|
| 507 |
+
):
|
| 508 |
_save_chat_export_message(
|
| 509 |
model_name=model_name,
|
| 510 |
dataset_source=dataset_source,
|
|
|
|
| 515 |
messages=chat_state["messages"],
|
| 516 |
generation=_generation_dict(gen_kwargs, advanced_generation),
|
| 517 |
)
|
| 518 |
+
st.toast("Exported", icon=":material/check:")
|
| 519 |
+
with rst_col:
|
| 520 |
+
if st.button(
|
| 521 |
+
"",
|
| 522 |
+
icon=":material/delete_sweep:",
|
| 523 |
+
key=reset_key,
|
| 524 |
+
help="Reset chat",
|
| 525 |
+
):
|
| 526 |
+
_reset_active_chat_context()
|
| 527 |
+
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
|
| 529 |
+
user_prompt = st.chat_input("Ask something...", key=chat_input_key)
|
|
|
|
|
|
|
|
|
|
| 530 |
|
| 531 |
# Pass 1: user submitted — append message and rerun so it renders before generation.
|
| 532 |
if user_prompt:
|
tabs/compare_chat.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from nnterp import StandardizedTransformer
|
| 3 |
+
from persona_data.synth_persona import PersonaData
|
| 4 |
+
|
| 5 |
+
from state import default_chat_state, reset_chat_context_state
|
| 6 |
+
from utils.chat import ChatReply, generate_chat_reply, resolve_system_prompt
|
| 7 |
+
from utils.contrast import compute_contrast, compute_contrast_pair
|
| 8 |
+
from utils.helpers import persona_label, widget_key
|
| 9 |
+
from utils.runtime import cached_model
|
| 10 |
+
|
| 11 |
+
from .chat import (
|
| 12 |
+
_build_chat_messages,
|
| 13 |
+
_generation_dict,
|
| 14 |
+
_render_chat_message,
|
| 15 |
+
_render_chat_window,
|
| 16 |
+
_render_persona_prompt_controls,
|
| 17 |
+
_render_system_prompt,
|
| 18 |
+
_save_chat_export_message,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _panel_state(panel_key: str) -> dict[str, object]:
|
| 23 |
+
"""Get or initialise compare-panel chat state stored in session_state."""
|
| 24 |
+
if panel_key not in st.session_state:
|
| 25 |
+
st.session_state[panel_key] = default_chat_state()
|
| 26 |
+
return st.session_state[panel_key]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _reset_compare_panel(
|
| 30 |
+
panel_state: dict,
|
| 31 |
+
edit_key: str,
|
| 32 |
+
persona_id: str,
|
| 33 |
+
prompt_mode: str,
|
| 34 |
+
*ui_keys: str,
|
| 35 |
+
) -> None:
|
| 36 |
+
reset_chat_context_state(panel_state, persona_id, prompt_mode, *ui_keys)
|
| 37 |
+
st.session_state.pop(edit_key, None)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _generate_panel_reply(
|
| 41 |
+
*,
|
| 42 |
+
model: StandardizedTransformer,
|
| 43 |
+
remote: bool,
|
| 44 |
+
panel_state: dict[str, object],
|
| 45 |
+
panel_prompt: str | None,
|
| 46 |
+
gen_kwargs: dict,
|
| 47 |
+
) -> ChatReply:
|
| 48 |
+
return generate_chat_reply(
|
| 49 |
+
model=model,
|
| 50 |
+
messages=_build_chat_messages(panel_prompt, panel_state["messages"]),
|
| 51 |
+
remote=remote,
|
| 52 |
+
past_key_values=panel_state["past_key_values"],
|
| 53 |
+
**gen_kwargs,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def render_compare_mode(
|
| 58 |
+
remote: bool,
|
| 59 |
+
model_name: str,
|
| 60 |
+
context_key: str,
|
| 61 |
+
dataset_source: str,
|
| 62 |
+
personas: list[PersonaData],
|
| 63 |
+
gen_kwargs: dict,
|
| 64 |
+
advanced_generation: bool,
|
| 65 |
+
) -> None:
|
| 66 |
+
"""Render the full side-by-side comparison UI."""
|
| 67 |
+
contrast_key = widget_key(context_key, "token_contrast")
|
| 68 |
+
contrast_enabled = st.toggle(
|
| 69 |
+
"Token contrast",
|
| 70 |
+
value=False,
|
| 71 |
+
key=contrast_key,
|
| 72 |
+
help=(
|
| 73 |
+
"Color each generated token by how characteristic it is of each persona. "
|
| 74 |
+
"Red = more likely under the left persona, blue = more likely under the right. "
|
| 75 |
+
"Requires four extra forward passes after each turn (batched into one "
|
| 76 |
+
"remote session when running on NDIF)."
|
| 77 |
+
),
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
left_col, right_col = st.columns(2)
|
| 81 |
+
left_panel_key = widget_key(context_key, "cmp_left")
|
| 82 |
+
right_panel_key = widget_key(context_key, "cmp_right")
|
| 83 |
+
left_prompt_key = widget_key(left_panel_key, "custom_prompt")
|
| 84 |
+
right_prompt_key = widget_key(right_panel_key, "custom_prompt")
|
| 85 |
+
left_edit_key = widget_key(left_panel_key, "edit_idx")
|
| 86 |
+
right_edit_key = widget_key(right_panel_key, "edit_idx")
|
| 87 |
+
left_pending_key = widget_key(left_panel_key, "pending_regen")
|
| 88 |
+
right_pending_key = widget_key(right_panel_key, "pending_regen")
|
| 89 |
+
|
| 90 |
+
def render_panel(side: str) -> tuple[dict, object, str | None, str, PersonaData]:
|
| 91 |
+
panel_key = widget_key(context_key, f"cmp_{side}")
|
| 92 |
+
state = _panel_state(panel_key)
|
| 93 |
+
prompt_key = widget_key(panel_key, "custom_prompt")
|
| 94 |
+
edit_key = widget_key(panel_key, "edit_idx")
|
| 95 |
+
pending_regen_key = widget_key(panel_key, "pending_regen")
|
| 96 |
+
|
| 97 |
+
selected_persona, prompt_mode, changed = _render_persona_prompt_controls(
|
| 98 |
+
personas,
|
| 99 |
+
state["persona_id"],
|
| 100 |
+
state["prompt_mode"],
|
| 101 |
+
widget_key(panel_key, "persona"),
|
| 102 |
+
widget_key(panel_key, "prompt_mode"),
|
| 103 |
+
)
|
| 104 |
+
if changed:
|
| 105 |
+
reset_chat_context_state(
|
| 106 |
+
state,
|
| 107 |
+
selected_persona.id,
|
| 108 |
+
prompt_mode,
|
| 109 |
+
prompt_key,
|
| 110 |
+
pending_regen_key,
|
| 111 |
+
)
|
| 112 |
+
st.session_state.pop(edit_key, None)
|
| 113 |
+
|
| 114 |
+
active_system_prompt = resolve_system_prompt(
|
| 115 |
+
persona=selected_persona, mode=prompt_mode
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
chat_log = st.container()
|
| 119 |
+
with chat_log:
|
| 120 |
+
active_system_prompt = _render_system_prompt(
|
| 121 |
+
prompt_key,
|
| 122 |
+
prompt_mode,
|
| 123 |
+
active_system_prompt,
|
| 124 |
+
)
|
| 125 |
+
return (
|
| 126 |
+
state,
|
| 127 |
+
chat_log,
|
| 128 |
+
active_system_prompt,
|
| 129 |
+
pending_regen_key,
|
| 130 |
+
selected_persona,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
with left_col:
|
| 134 |
+
left_state, left_log, left_prompt, left_pending, left_persona = render_panel(
|
| 135 |
+
"left"
|
| 136 |
+
)
|
| 137 |
+
with right_col:
|
| 138 |
+
right_state, right_log, right_prompt, right_pending, right_persona = (
|
| 139 |
+
render_panel("right")
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
panels = [
|
| 143 |
+
(
|
| 144 |
+
left_state,
|
| 145 |
+
left_log,
|
| 146 |
+
left_prompt,
|
| 147 |
+
left_pending,
|
| 148 |
+
left_edit_key,
|
| 149 |
+
left_persona,
|
| 150 |
+
),
|
| 151 |
+
(
|
| 152 |
+
right_state,
|
| 153 |
+
right_log,
|
| 154 |
+
right_prompt,
|
| 155 |
+
right_pending,
|
| 156 |
+
right_edit_key,
|
| 157 |
+
right_persona,
|
| 158 |
+
),
|
| 159 |
+
]
|
| 160 |
+
|
| 161 |
+
# Handle per-panel regeneration triggered by message edits
|
| 162 |
+
regen_panels = [
|
| 163 |
+
(panel_state, panel_log, panel_prompt)
|
| 164 |
+
for panel_state, panel_log, panel_prompt, p_pending, _panel_edit_key, _ in panels
|
| 165 |
+
if st.session_state.pop(p_pending, False)
|
| 166 |
+
]
|
| 167 |
+
if regen_panels:
|
| 168 |
+
model = cached_model(model_name=model_name, remote=remote)
|
| 169 |
+
|
| 170 |
+
results: list[ChatReply | Exception] = []
|
| 171 |
+
with st.spinner("Regenerating..."):
|
| 172 |
+
for panel_state, _panel_log, panel_prompt in regen_panels:
|
| 173 |
+
try:
|
| 174 |
+
results.append(
|
| 175 |
+
_generate_panel_reply(
|
| 176 |
+
model=model,
|
| 177 |
+
remote=remote,
|
| 178 |
+
panel_state=panel_state,
|
| 179 |
+
panel_prompt=panel_prompt,
|
| 180 |
+
gen_kwargs=gen_kwargs,
|
| 181 |
+
)
|
| 182 |
+
)
|
| 183 |
+
except Exception as exc:
|
| 184 |
+
results.append(exc)
|
| 185 |
+
|
| 186 |
+
for (panel_state, panel_log, _panel_prompt), result in zip(
|
| 187 |
+
regen_panels, results
|
| 188 |
+
):
|
| 189 |
+
if isinstance(result, Exception):
|
| 190 |
+
with panel_log:
|
| 191 |
+
st.error(f"Generation failed: {result}")
|
| 192 |
+
panel_state["messages"].pop()
|
| 193 |
+
continue
|
| 194 |
+
panel_state["messages"].append(
|
| 195 |
+
{"role": "assistant", "content": result.text}
|
| 196 |
+
)
|
| 197 |
+
panel_state["past_key_values"] = (
|
| 198 |
+
result.past_key_values if not remote else None
|
| 199 |
+
)
|
| 200 |
+
st.rerun()
|
| 201 |
+
|
| 202 |
+
# Recompute contrast for assistant messages that were edited in place.
|
| 203 |
+
if contrast_enabled:
|
| 204 |
+
pending_edits: list[tuple[int, int]] = [
|
| 205 |
+
(panel_idx, msg_idx)
|
| 206 |
+
for panel_idx, (panel_state, *_rest) in enumerate(panels)
|
| 207 |
+
for msg_idx, msg in enumerate(panel_state["messages"])
|
| 208 |
+
if msg.get("_needs_contrast") and msg.get("role") == "assistant"
|
| 209 |
+
]
|
| 210 |
+
if pending_edits:
|
| 211 |
+
model = cached_model(model_name=model_name, remote=remote)
|
| 212 |
+
label_a = persona_label(left_persona)
|
| 213 |
+
label_b = persona_label(right_persona)
|
| 214 |
+
with st.spinner("Recomputing token contrast…"):
|
| 215 |
+
for panel_idx, msg_idx in pending_edits:
|
| 216 |
+
panel_state = panels[panel_idx][0]
|
| 217 |
+
msg = panel_state["messages"][msg_idx]
|
| 218 |
+
if msg_idx >= len(left_state["messages"]) or msg_idx >= len(
|
| 219 |
+
right_state["messages"]
|
| 220 |
+
):
|
| 221 |
+
msg.pop("_needs_contrast", None)
|
| 222 |
+
continue
|
| 223 |
+
context_a = _build_chat_messages(
|
| 224 |
+
left_prompt, left_state["messages"][:msg_idx]
|
| 225 |
+
)
|
| 226 |
+
context_b = _build_chat_messages(
|
| 227 |
+
right_prompt, right_state["messages"][:msg_idx]
|
| 228 |
+
)
|
| 229 |
+
try:
|
| 230 |
+
response_ids = model.tokenizer(
|
| 231 |
+
msg["content"],
|
| 232 |
+
add_special_tokens=False,
|
| 233 |
+
return_tensors="pt",
|
| 234 |
+
).input_ids[0]
|
| 235 |
+
tc = compute_contrast(
|
| 236 |
+
model=model,
|
| 237 |
+
context_a=context_a,
|
| 238 |
+
context_b=context_b,
|
| 239 |
+
response_ids=response_ids,
|
| 240 |
+
label_a=label_a,
|
| 241 |
+
label_b=label_b,
|
| 242 |
+
remote=remote,
|
| 243 |
+
)
|
| 244 |
+
if tc is not None:
|
| 245 |
+
msg["_contrast"] = tc
|
| 246 |
+
except Exception as exc:
|
| 247 |
+
st.warning(f"Token contrast recompute failed: {exc}")
|
| 248 |
+
msg.pop("_needs_contrast", None)
|
| 249 |
+
st.rerun()
|
| 250 |
+
|
| 251 |
+
for (
|
| 252 |
+
panel_state,
|
| 253 |
+
panel_log,
|
| 254 |
+
_panel_prompt,
|
| 255 |
+
panel_pending,
|
| 256 |
+
panel_edit_key,
|
| 257 |
+
_,
|
| 258 |
+
) in panels:
|
| 259 |
+
_render_chat_window(
|
| 260 |
+
chat_log=panel_log,
|
| 261 |
+
messages=panel_state["messages"],
|
| 262 |
+
chat_state=panel_state,
|
| 263 |
+
edit_key=panel_edit_key,
|
| 264 |
+
pending_key=panel_pending,
|
| 265 |
+
show_contrast=contrast_enabled,
|
| 266 |
+
edit_column_ratio=(10, 1),
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
footer = st.container()
|
| 270 |
+
with footer:
|
| 271 |
+
exp_col, rst_col, _spacer = st.columns([0.5, 0.5, 10], gap="xsmall")
|
| 272 |
+
with exp_col:
|
| 273 |
+
if st.button(
|
| 274 |
+
"",
|
| 275 |
+
icon=":material/download:",
|
| 276 |
+
key=widget_key(context_key, "cmp_export"),
|
| 277 |
+
help="Export both chats",
|
| 278 |
+
):
|
| 279 |
+
for side, panel_state, panel_prompt, panel_persona in (
|
| 280 |
+
("left", left_state, left_prompt, left_persona),
|
| 281 |
+
("right", right_state, right_prompt, right_persona),
|
| 282 |
+
):
|
| 283 |
+
_save_chat_export_message(
|
| 284 |
+
model_name=model_name,
|
| 285 |
+
dataset_source=dataset_source,
|
| 286 |
+
persona_id=panel_persona.id,
|
| 287 |
+
persona_name=getattr(panel_persona, "name", None),
|
| 288 |
+
prompt_mode=panel_state["prompt_mode"],
|
| 289 |
+
system_prompt=panel_prompt,
|
| 290 |
+
messages=panel_state["messages"],
|
| 291 |
+
generation=_generation_dict(gen_kwargs, advanced_generation),
|
| 292 |
+
panel_label=side,
|
| 293 |
+
)
|
| 294 |
+
st.toast("Exported", icon=":material/check:")
|
| 295 |
+
with rst_col:
|
| 296 |
+
with st.popover(
|
| 297 |
+
"",
|
| 298 |
+
icon=":material/delete_sweep:",
|
| 299 |
+
help="Reset chat",
|
| 300 |
+
):
|
| 301 |
+
if st.button(
|
| 302 |
+
"Reset left",
|
| 303 |
+
key=widget_key(context_key, "cmp_reset_left"),
|
| 304 |
+
):
|
| 305 |
+
_reset_compare_panel(
|
| 306 |
+
left_state,
|
| 307 |
+
left_edit_key,
|
| 308 |
+
left_persona.id,
|
| 309 |
+
left_state["prompt_mode"],
|
| 310 |
+
left_prompt_key,
|
| 311 |
+
left_pending_key,
|
| 312 |
+
)
|
| 313 |
+
st.rerun()
|
| 314 |
+
if st.button(
|
| 315 |
+
"Reset right",
|
| 316 |
+
key=widget_key(context_key, "cmp_reset_right"),
|
| 317 |
+
):
|
| 318 |
+
_reset_compare_panel(
|
| 319 |
+
right_state,
|
| 320 |
+
right_edit_key,
|
| 321 |
+
right_persona.id,
|
| 322 |
+
right_state["prompt_mode"],
|
| 323 |
+
right_prompt_key,
|
| 324 |
+
right_pending_key,
|
| 325 |
+
)
|
| 326 |
+
st.rerun()
|
| 327 |
+
if st.button(
|
| 328 |
+
"Reset both",
|
| 329 |
+
key=widget_key(context_key, "cmp_reset_both"),
|
| 330 |
+
type="primary",
|
| 331 |
+
):
|
| 332 |
+
_reset_compare_panel(
|
| 333 |
+
left_state,
|
| 334 |
+
left_edit_key,
|
| 335 |
+
left_persona.id,
|
| 336 |
+
left_state["prompt_mode"],
|
| 337 |
+
left_prompt_key,
|
| 338 |
+
left_pending_key,
|
| 339 |
+
)
|
| 340 |
+
_reset_compare_panel(
|
| 341 |
+
right_state,
|
| 342 |
+
right_edit_key,
|
| 343 |
+
right_persona.id,
|
| 344 |
+
right_state["prompt_mode"],
|
| 345 |
+
right_prompt_key,
|
| 346 |
+
right_pending_key,
|
| 347 |
+
)
|
| 348 |
+
st.rerun()
|
| 349 |
+
|
| 350 |
+
user_prompt = st.chat_input(
|
| 351 |
+
"Ask both...",
|
| 352 |
+
key=widget_key(context_key, "cmp_input"),
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
if not user_prompt:
|
| 356 |
+
return
|
| 357 |
+
|
| 358 |
+
model = cached_model(model_name=model_name, remote=remote)
|
| 359 |
+
|
| 360 |
+
for panel_state, panel_log, _panel_prompt, _p_pending, _panel_edit_key, _ in panels:
|
| 361 |
+
panel_state["messages"].append({"role": "user", "content": user_prompt})
|
| 362 |
+
with panel_log:
|
| 363 |
+
_render_chat_message({"role": "user", "content": user_prompt})
|
| 364 |
+
|
| 365 |
+
# Snapshot contexts before the new assistant turn is appended (needed for contrast).
|
| 366 |
+
pre_gen_contexts = [
|
| 367 |
+
_build_chat_messages(panel_prompt, panel_state["messages"])
|
| 368 |
+
for panel_state, _panel_log, panel_prompt, _p_pending, _panel_edit_key, _ in panels
|
| 369 |
+
]
|
| 370 |
+
|
| 371 |
+
results: list[ChatReply | Exception] = []
|
| 372 |
+
with st.spinner("Generating..."):
|
| 373 |
+
# Keep compare-mode generation sequential so both panels use the same
|
| 374 |
+
# model/session state safely.
|
| 375 |
+
for (
|
| 376 |
+
panel_state,
|
| 377 |
+
_panel_log,
|
| 378 |
+
panel_prompt,
|
| 379 |
+
_p_pending,
|
| 380 |
+
_panel_edit_key,
|
| 381 |
+
_,
|
| 382 |
+
) in panels:
|
| 383 |
+
try:
|
| 384 |
+
results.append(
|
| 385 |
+
_generate_panel_reply(
|
| 386 |
+
model=model,
|
| 387 |
+
remote=remote,
|
| 388 |
+
panel_state=panel_state,
|
| 389 |
+
panel_prompt=panel_prompt,
|
| 390 |
+
gen_kwargs=gen_kwargs,
|
| 391 |
+
)
|
| 392 |
+
)
|
| 393 |
+
except Exception as exc:
|
| 394 |
+
results.append(exc)
|
| 395 |
+
|
| 396 |
+
valid_results: list[ChatReply | None] = []
|
| 397 |
+
for (
|
| 398 |
+
panel_state,
|
| 399 |
+
panel_log,
|
| 400 |
+
_panel_prompt,
|
| 401 |
+
_p_pending,
|
| 402 |
+
_panel_edit_key,
|
| 403 |
+
_,
|
| 404 |
+
), result in zip(panels, results):
|
| 405 |
+
if isinstance(result, Exception):
|
| 406 |
+
with panel_log:
|
| 407 |
+
st.error(f"Generation failed: {result}")
|
| 408 |
+
panel_state["messages"].pop()
|
| 409 |
+
valid_results.append(None)
|
| 410 |
+
continue
|
| 411 |
+
|
| 412 |
+
panel_state["messages"].append({"role": "assistant", "content": result.text})
|
| 413 |
+
panel_state["past_key_values"] = result.past_key_values if not remote else None
|
| 414 |
+
valid_results.append(result)
|
| 415 |
+
|
| 416 |
+
# Compute contrastive token coloring when both panels succeeded.
|
| 417 |
+
if (
|
| 418 |
+
contrast_enabled
|
| 419 |
+
and len(valid_results) == 2
|
| 420 |
+
and all(r is not None and r.generated_ids is not None for r in valid_results)
|
| 421 |
+
):
|
| 422 |
+
with st.spinner("Computing token contrast…"):
|
| 423 |
+
try:
|
| 424 |
+
tc_a, tc_b = compute_contrast_pair(
|
| 425 |
+
model=model,
|
| 426 |
+
context_a=pre_gen_contexts[0],
|
| 427 |
+
context_b=pre_gen_contexts[1],
|
| 428 |
+
response_ids_a=valid_results[0].generated_ids,
|
| 429 |
+
response_ids_b=valid_results[1].generated_ids,
|
| 430 |
+
label_a=persona_label(left_persona),
|
| 431 |
+
label_b=persona_label(right_persona),
|
| 432 |
+
remote=remote,
|
| 433 |
+
)
|
| 434 |
+
if tc_a is not None:
|
| 435 |
+
left_state["messages"][-1]["_contrast"] = tc_a
|
| 436 |
+
if tc_b is not None:
|
| 437 |
+
right_state["messages"][-1]["_contrast"] = tc_b
|
| 438 |
+
except Exception as exc:
|
| 439 |
+
st.warning(f"Token contrast failed: {exc}")
|
| 440 |
+
|
| 441 |
+
# Rerun so the newly appended turns are redrawn through the editable history
|
| 442 |
+
# renderer instead of only appearing in the one-off generation pass.
|
| 443 |
+
st.rerun()
|
tabs/extract.py
CHANGED
|
@@ -111,8 +111,6 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 111 |
st.info("Select at least one persona.")
|
| 112 |
return
|
| 113 |
|
| 114 |
-
max_questions = 0
|
| 115 |
-
|
| 116 |
with st.expander("Advanced", expanded=False):
|
| 117 |
st.caption("Filters")
|
| 118 |
|
|
|
|
| 111 |
st.info("Select at least one persona.")
|
| 112 |
return
|
| 113 |
|
|
|
|
|
|
|
| 114 |
with st.expander("Advanced", expanded=False):
|
| 115 |
st.caption("Filters")
|
| 116 |
|
utils/chat.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import logging
|
| 2 |
-
from contextlib import contextmanager
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from typing import Literal
|
| 5 |
|
|
@@ -15,9 +15,8 @@ SystemPromptMode = Literal["empty", "templated", "biography", "custom"]
|
|
| 15 |
@dataclass
|
| 16 |
class ChatReply:
|
| 17 |
text: str
|
| 18 |
-
prompt_tokens: int
|
| 19 |
-
output_tokens: int
|
| 20 |
past_key_values: object | None
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
def resolve_system_prompt(
|
|
@@ -204,13 +203,10 @@ def generate_chat_reply(
|
|
| 204 |
|
| 205 |
generated_ids = sequences[0, prompt_token_count:]
|
| 206 |
text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 207 |
-
output_tokens = int(sequences.shape[1] - prompt_token_count)
|
| 208 |
-
|
| 209 |
return ChatReply(
|
| 210 |
text=text,
|
| 211 |
-
prompt_tokens=prompt_token_count,
|
| 212 |
-
output_tokens=max(0, output_tokens),
|
| 213 |
past_key_values=(
|
| 214 |
getattr(generated, "past_key_values", None) if not remote else None
|
| 215 |
),
|
|
|
|
| 216 |
)
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from contextlib import contextmanager, nullcontext
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from typing import Literal
|
| 5 |
|
|
|
|
| 15 |
@dataclass
|
| 16 |
class ChatReply:
|
| 17 |
text: str
|
|
|
|
|
|
|
| 18 |
past_key_values: object | None
|
| 19 |
+
generated_ids: torch.Tensor | None = None
|
| 20 |
|
| 21 |
|
| 22 |
def resolve_system_prompt(
|
|
|
|
| 203 |
|
| 204 |
generated_ids = sequences[0, prompt_token_count:]
|
| 205 |
text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
|
|
|
|
|
|
| 206 |
return ChatReply(
|
| 207 |
text=text,
|
|
|
|
|
|
|
| 208 |
past_key_values=(
|
| 209 |
getattr(generated, "past_key_values", None) if not remote else None
|
| 210 |
),
|
| 211 |
+
generated_ids=generated_ids.detach().cpu(),
|
| 212 |
)
|
utils/chat_export.py
CHANGED
|
@@ -30,6 +30,7 @@ def save_chat_export(
|
|
| 30 |
system_prompt: Current system prompt text, if any.
|
| 31 |
messages: Conversation messages without the system prompt.
|
| 32 |
generation: Generation settings used for the chat.
|
|
|
|
| 33 |
|
| 34 |
Returns:
|
| 35 |
The path the export was written to.
|
|
@@ -55,7 +56,6 @@ def save_chat_export(
|
|
| 55 |
get_artifacts_dir()
|
| 56 |
/ "chats"
|
| 57 |
/ "__".join(slugify(part) for part in model_name.split("/"))
|
| 58 |
-
/ slugify(dataset_source)
|
| 59 |
/ slugify(persona_id)
|
| 60 |
)
|
| 61 |
export_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 30 |
system_prompt: Current system prompt text, if any.
|
| 31 |
messages: Conversation messages without the system prompt.
|
| 32 |
generation: Generation settings used for the chat.
|
| 33 |
+
panel_label: Optional side label (e.g. "left"/"right") for compare-mode exports.
|
| 34 |
|
| 35 |
Returns:
|
| 36 |
The path the export was written to.
|
|
|
|
| 56 |
get_artifacts_dir()
|
| 57 |
/ "chats"
|
| 58 |
/ "__".join(slugify(part) for part in model_name.split("/"))
|
|
|
|
| 59 |
/ slugify(persona_id)
|
| 60 |
)
|
| 61 |
export_dir.mkdir(parents=True, exist_ok=True)
|
utils/contrast.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# WARNING: This is mostly vibecoded and need reviews
|
| 2 |
+
# - Check that the model is runned once with normally for gneration and things are beeing traced perphaps at the last step of generation with iter.last or somrething liek that from the docs
|
| 3 |
+
# - Then the model is runned again with the entire context of the conversation from the other context on the rifht ? or on the left dependeing on which one we are doing at the moment. And this will then compute the prob diff and show them.
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
Contrastive token-level log-probability comparison for compare mode.
|
| 7 |
+
|
| 8 |
+
For a pair of responses generated under different persona contexts, each token
|
| 9 |
+
gets a weight:
|
| 10 |
+
|
| 11 |
+
w(token) = log P(token | context_A) − log P(token | context_B)
|
| 12 |
+
|
| 13 |
+
Positive (red) → token is more characteristic of persona A.
|
| 14 |
+
Negative (blue) → token is more characteristic of persona B.
|
| 15 |
+
Near-zero (gray) → both personas would emit this token with similar likelihood.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from html import escape
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from nnterp import StandardizedTransformer
|
| 23 |
+
|
| 24 |
+
from utils.chat import _format_generation_prompt
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class TokenContrast:
|
| 29 |
+
tokens: list[str]
|
| 30 |
+
weights: list[float] # normalised to [-1, 1], used for coloring
|
| 31 |
+
raw_diffs: list[float] # unclipped log P(A) - log P(B) per token
|
| 32 |
+
label_a: str
|
| 33 |
+
label_b: str
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ── Weight computation ────────────────────────────────────────────────────────
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _normalise_diffs(diffs: torch.Tensor) -> list[float]:
|
| 40 |
+
"""
|
| 41 |
+
Clip at the 95th percentile of |diff| and scale to [-1, 1] so a few
|
| 42 |
+
high-magnitude tokens don't wash out everything else.
|
| 43 |
+
"""
|
| 44 |
+
if len(diffs) < 2:
|
| 45 |
+
return diffs.tolist()
|
| 46 |
+
clip_val = max(torch.quantile(diffs.abs(), 0.95).item(), 0.3)
|
| 47 |
+
return (diffs.float().clamp(-clip_val, clip_val) / clip_val).tolist()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _decode_ids(tokenizer: object, ids: list[int]) -> str:
|
| 51 |
+
try:
|
| 52 |
+
return tokenizer.decode(
|
| 53 |
+
ids,
|
| 54 |
+
skip_special_tokens=False,
|
| 55 |
+
clean_up_tokenization_spaces=False,
|
| 56 |
+
)
|
| 57 |
+
except TypeError:
|
| 58 |
+
return tokenizer.decode(ids, skip_special_tokens=False)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _strip_special_ids(
|
| 62 |
+
ids: torch.Tensor,
|
| 63 |
+
tokenizer: object,
|
| 64 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 65 |
+
"""Return display ids and a mask that excludes special tokens."""
|
| 66 |
+
ids = ids.cpu()
|
| 67 |
+
special_ids = set(getattr(tokenizer, "all_special_ids", []) or [])
|
| 68 |
+
if not special_ids or ids.numel() == 0:
|
| 69 |
+
return ids, torch.ones(ids.shape[0], dtype=torch.bool)
|
| 70 |
+
keep = torch.tensor(
|
| 71 |
+
[tid.item() not in special_ids for tid in ids], dtype=torch.bool
|
| 72 |
+
)
|
| 73 |
+
return ids[keep], keep
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _prepare_trace_text(
|
| 77 |
+
tokenizer: object,
|
| 78 |
+
context_messages: list[dict[str, str]],
|
| 79 |
+
response_ids: torch.Tensor,
|
| 80 |
+
) -> tuple[str, int, int]:
|
| 81 |
+
"""Build the trace text and return ``(full_text, n_ctx, n_resp)``."""
|
| 82 |
+
context_prompt, _ = _format_generation_prompt(context_messages, tokenizer)
|
| 83 |
+
context_ids = tokenizer(context_prompt, return_tensors="pt").input_ids[0]
|
| 84 |
+
response_text = _decode_ids(tokenizer, response_ids.tolist())
|
| 85 |
+
full_text = context_prompt + response_text
|
| 86 |
+
full_ids = tokenizer(full_text, return_tensors="pt").input_ids[0]
|
| 87 |
+
expected_ids = torch.cat([context_ids, response_ids.cpu()])
|
| 88 |
+
if full_ids.tolist() != expected_ids.tolist():
|
| 89 |
+
raise ValueError(
|
| 90 |
+
"contrast trace text did not round-trip to the expected token ids"
|
| 91 |
+
)
|
| 92 |
+
n_ctx = len(context_ids)
|
| 93 |
+
n_resp = len(response_ids)
|
| 94 |
+
return full_text, n_ctx, n_resp
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _build_contrast(
|
| 98 |
+
tokenizer: object,
|
| 99 |
+
response_ids: torch.Tensor,
|
| 100 |
+
lp_a: torch.Tensor,
|
| 101 |
+
lp_b: torch.Tensor,
|
| 102 |
+
label_a: str,
|
| 103 |
+
label_b: str,
|
| 104 |
+
) -> TokenContrast:
|
| 105 |
+
diffs = (lp_a - lp_b).cpu()
|
| 106 |
+
display_ids, keep_mask = _strip_special_ids(response_ids, tokenizer)
|
| 107 |
+
display_diffs = diffs[keep_mask]
|
| 108 |
+
return TokenContrast(
|
| 109 |
+
tokens=[_token_display(tokenizer, tid.item()) for tid in display_ids],
|
| 110 |
+
weights=_normalise_diffs(display_diffs),
|
| 111 |
+
raw_diffs=display_diffs.float().tolist(),
|
| 112 |
+
label_a=label_a,
|
| 113 |
+
label_b=label_b,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _token_display(tokenizer: object, token_id: int) -> str:
|
| 118 |
+
"""Render a single token id as normal decoded text."""
|
| 119 |
+
return _decode_ids(tokenizer, [token_id])
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# Each spec: (key, full_text, n_ctx, n_resp, target_ids).
|
| 123 |
+
PassSpec = tuple[str, str, int, int, torch.Tensor]
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _score_passes(
|
| 127 |
+
model: StandardizedTransformer,
|
| 128 |
+
specs: list[PassSpec],
|
| 129 |
+
remote: bool,
|
| 130 |
+
) -> dict[str, torch.Tensor]:
|
| 131 |
+
"""
|
| 132 |
+
Run one forward pass per spec and return reduced per-token logprobs.
|
| 133 |
+
|
| 134 |
+
The log-softmax and target-pick happen *inside* the trace, so only the
|
| 135 |
+
reduced ``[n_resp]`` logprob vector per pass is shipped back — not the full
|
| 136 |
+
``[1, seq, vocab]`` logits (which would be hundreds of MB per pass on NDIF).
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def _score_pass(
|
| 140 |
+
full_text: str,
|
| 141 |
+
n_ctx: int,
|
| 142 |
+
n_resp: int,
|
| 143 |
+
target_ids: torch.Tensor,
|
| 144 |
+
) -> torch.Tensor:
|
| 145 |
+
with torch.no_grad(), model.trace(full_text, remote=remote):
|
| 146 |
+
# logit at position i predicts token i+1, so response token j
|
| 147 |
+
# (at full-text position n_ctx+j) uses logit at n_ctx+j-1.
|
| 148 |
+
resp_logits = model.logits[0, n_ctx - 1 : n_ctx - 1 + n_resp].float()
|
| 149 |
+
log_probs = torch.log_softmax(resp_logits, dim=-1)
|
| 150 |
+
targets = target_ids.to(log_probs.device).view(-1, 1)
|
| 151 |
+
picked = log_probs.gather(1, targets).view(-1)
|
| 152 |
+
out = picked.detach().cpu().save()
|
| 153 |
+
|
| 154 |
+
if hasattr(out, "value") and getattr(out, "value") is not None:
|
| 155 |
+
out = out.value
|
| 156 |
+
if not isinstance(out, torch.Tensor):
|
| 157 |
+
raise TypeError(
|
| 158 |
+
f"contrast score did not resolve to a tensor: {type(out)!r}"
|
| 159 |
+
)
|
| 160 |
+
return out.detach().cpu()
|
| 161 |
+
|
| 162 |
+
saved = [
|
| 163 |
+
_score_pass(full_text, n_ctx, n_resp, target_ids)
|
| 164 |
+
for _key, full_text, n_ctx, n_resp, target_ids in specs
|
| 165 |
+
]
|
| 166 |
+
|
| 167 |
+
if len(saved) != len(specs):
|
| 168 |
+
raise RuntimeError(
|
| 169 |
+
f"contrast scoring returned {len(saved)} result(s) for {len(specs)} spec(s)"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
return {spec[0]: tensor for spec, tensor in zip(specs, saved)}
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _specs_for_response(
|
| 176 |
+
tokenizer: object,
|
| 177 |
+
response_ids: torch.Tensor,
|
| 178 |
+
context_a: list[dict[str, str]],
|
| 179 |
+
context_b: list[dict[str, str]],
|
| 180 |
+
prefix: str,
|
| 181 |
+
) -> list[PassSpec]:
|
| 182 |
+
"""Build the (under_a, under_b) pass specs for a single response."""
|
| 183 |
+
text_a, n_ctx_a, n_resp = _prepare_trace_text(tokenizer, context_a, response_ids)
|
| 184 |
+
text_b, n_ctx_b, _ = _prepare_trace_text(tokenizer, context_b, response_ids)
|
| 185 |
+
return [
|
| 186 |
+
(f"{prefix}_under_a", text_a, n_ctx_a, n_resp, response_ids),
|
| 187 |
+
(f"{prefix}_under_b", text_b, n_ctx_b, n_resp, response_ids),
|
| 188 |
+
]
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def compute_contrast(
|
| 192 |
+
model: StandardizedTransformer,
|
| 193 |
+
context_a: list[dict[str, str]],
|
| 194 |
+
context_b: list[dict[str, str]],
|
| 195 |
+
response_ids: torch.Tensor,
|
| 196 |
+
label_a: str,
|
| 197 |
+
label_b: str,
|
| 198 |
+
remote: bool = False,
|
| 199 |
+
) -> "TokenContrast | None":
|
| 200 |
+
"""Compute per-token contrast weights for a single response (2 forward passes)."""
|
| 201 |
+
tokenizer = model.tokenizer
|
| 202 |
+
if response_ids.numel() == 0:
|
| 203 |
+
return None
|
| 204 |
+
|
| 205 |
+
specs = _specs_for_response(tokenizer, response_ids, context_a, context_b, "r")
|
| 206 |
+
out = _score_passes(model, specs, remote)
|
| 207 |
+
return _build_contrast(
|
| 208 |
+
tokenizer, response_ids, out["r_under_a"], out["r_under_b"], label_a, label_b
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def compute_contrast_pair(
|
| 213 |
+
model: StandardizedTransformer,
|
| 214 |
+
context_a: list[dict[str, str]],
|
| 215 |
+
context_b: list[dict[str, str]],
|
| 216 |
+
response_ids_a: torch.Tensor,
|
| 217 |
+
response_ids_b: torch.Tensor,
|
| 218 |
+
label_a: str,
|
| 219 |
+
label_b: str,
|
| 220 |
+
remote: bool = False,
|
| 221 |
+
) -> tuple["TokenContrast | None", "TokenContrast | None"]:
|
| 222 |
+
"""
|
| 223 |
+
Compute contrast weights for both panel responses (up to 4 remote passes).
|
| 224 |
+
"""
|
| 225 |
+
tokenizer = model.tokenizer
|
| 226 |
+
if response_ids_a.numel() == 0 and response_ids_b.numel() == 0:
|
| 227 |
+
return None, None
|
| 228 |
+
|
| 229 |
+
specs: list[PassSpec] = []
|
| 230 |
+
if response_ids_a.numel() > 0:
|
| 231 |
+
specs += _specs_for_response(
|
| 232 |
+
tokenizer, response_ids_a, context_a, context_b, "a"
|
| 233 |
+
)
|
| 234 |
+
if response_ids_b.numel() > 0:
|
| 235 |
+
specs += _specs_for_response(
|
| 236 |
+
tokenizer, response_ids_b, context_a, context_b, "b"
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
out = _score_passes(model, specs, remote)
|
| 240 |
+
|
| 241 |
+
def _build(resp_ids: torch.Tensor, prefix: str) -> "TokenContrast | None":
|
| 242 |
+
k_a, k_b = f"{prefix}_under_a", f"{prefix}_under_b"
|
| 243 |
+
if resp_ids.numel() == 0 or k_a not in out or k_b not in out:
|
| 244 |
+
return None
|
| 245 |
+
return _build_contrast(
|
| 246 |
+
tokenizer, resp_ids, out[k_a], out[k_b], label_a, label_b
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
return _build(response_ids_a, "a"), _build(response_ids_b, "b")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# ── HTML rendering ────────────────────────────────────────────────────────────
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def _weight_to_bg(w: float) -> str:
|
| 256 |
+
"""Map a normalised weight in [-1, 1] to a CSS rgba background color."""
|
| 257 |
+
w = max(-1.0, min(1.0, w))
|
| 258 |
+
alpha = abs(w) * 0.5 # cap at 0.5 opacity so text stays readable
|
| 259 |
+
if w > 0.05:
|
| 260 |
+
return f"rgba(210,60,60,{alpha:.3f})"
|
| 261 |
+
if w < -0.05:
|
| 262 |
+
return f"rgba(50,110,210,{alpha:.3f})"
|
| 263 |
+
return "rgba(0,0,0,0)"
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
_CONTRAST_CSS = (
|
| 267 |
+
"<style>"
|
| 268 |
+
".contrast-tok{position:relative;border-radius:2px;padding:0 1px;"
|
| 269 |
+
"cursor:default;white-space:pre;}"
|
| 270 |
+
".contrast-tok>.contrast-tip{display:none;position:absolute;bottom:100%;"
|
| 271 |
+
"left:50%;transform:translateX(-50%);margin-bottom:4px;padding:2px 6px;"
|
| 272 |
+
"border-radius:3px;background:#222;color:#eee;font-size:0.72em;"
|
| 273 |
+
"font-family:ui-monospace,monospace;white-space:nowrap;pointer-events:none;"
|
| 274 |
+
"z-index:10;box-shadow:0 2px 6px rgba(0,0,0,0.3);}"
|
| 275 |
+
".contrast-tok:hover>.contrast-tip{display:block;}"
|
| 276 |
+
"</style>"
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def render_contrast_html(result: TokenContrast) -> str:
|
| 281 |
+
"""
|
| 282 |
+
Render each token with a colored background reflecting how A- or B-specific
|
| 283 |
+
it is, with a hover tooltip showing the raw Δlog P, plus a legend.
|
| 284 |
+
"""
|
| 285 |
+
spans: list[str] = []
|
| 286 |
+
for token, weight, raw in zip(result.tokens, result.weights, result.raw_diffs):
|
| 287 |
+
bg = _weight_to_bg(weight)
|
| 288 |
+
tip = escape(f"Δlog P(A−B): {raw:+.3f}")
|
| 289 |
+
text = escape(token)
|
| 290 |
+
spans.append(
|
| 291 |
+
f'<span class="contrast-tok" style="background:{bg};">'
|
| 292 |
+
f'{text}<span class="contrast-tip">{tip}</span></span>'
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
la = escape(result.label_a)
|
| 296 |
+
lb = escape(result.label_b)
|
| 297 |
+
|
| 298 |
+
return (
|
| 299 |
+
_CONTRAST_CSS + '<div style="font-family:inherit;line-height:1.75;'
|
| 300 |
+
'white-space:pre-wrap;word-break:break-word;padding:2px 0 6px 0;">'
|
| 301 |
+
+ "".join(spans)
|
| 302 |
+
+ '<div style="margin-top:10px;font-size:0.72em;color:#888;'
|
| 303 |
+
+ 'display:flex;gap:12px;flex-wrap:wrap;">'
|
| 304 |
+
+ f'<span><span style="background:rgba(210,60,60,0.45);'
|
| 305 |
+
+ f'padding:1px 6px;border-radius:2px;"> </span> {la}</span>'
|
| 306 |
+
+ f'<span><span style="background:rgba(50,110,210,0.45);'
|
| 307 |
+
+ f'padding:1px 6px;border-radius:2px;"> </span> {lb}</span>'
|
| 308 |
+
+ '<span style="color:#aaa;">gray = shared by both</span>'
|
| 309 |
+
+ "</div>"
|
| 310 |
+
+ "</div>"
|
| 311 |
+
)
|
utils/helpers.py
CHANGED
|
@@ -16,8 +16,6 @@ MODE_LABELS = list(VARIANT_LABELS.values())
|
|
| 16 |
# Reverse lookup: label -> key
|
| 17 |
MODE_LABEL_TO_KEY = {v: k for k, v in VARIANT_LABELS.items()}
|
| 18 |
|
| 19 |
-
VISIBLE_MESSAGE_COUNT = 5
|
| 20 |
-
|
| 21 |
DATASET_SOURCES = ["HuggingFace: synth-persona", "Local JSONL upload"]
|
| 22 |
ANALYSIS_MODES = ["Cosine similarity", "PCA", "UMAP"]
|
| 23 |
|
|
|
|
| 16 |
# Reverse lookup: label -> key
|
| 17 |
MODE_LABEL_TO_KEY = {v: k for k, v in VARIANT_LABELS.items()}
|
| 18 |
|
|
|
|
|
|
|
| 19 |
DATASET_SOURCES = ["HuggingFace: synth-persona", "Local JSONL upload"]
|
| 20 |
ANALYSIS_MODES = ["Cosine similarity", "PCA", "UMAP"]
|
| 21 |
|
utils/runtime.py
CHANGED
|
@@ -7,33 +7,62 @@ logger = logging.getLogger(__name__)
|
|
| 7 |
|
| 8 |
@st.cache_data(show_spinner=False, ttl=30)
|
| 9 |
def list_remote_models() -> list[str]:
|
| 10 |
-
"""Return the NDIF language models that are currently running.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
import nnsight
|
| 13 |
|
| 14 |
try:
|
| 15 |
-
|
| 16 |
except Exception:
|
| 17 |
logger.warning("Failed to fetch NDIF status", exc_info=True)
|
| 18 |
return []
|
| 19 |
|
| 20 |
model_names: list[str] = []
|
|
|
|
| 21 |
|
| 22 |
-
for
|
| 23 |
-
if not isinstance(
|
| 24 |
continue
|
| 25 |
-
if
|
|
|
|
|
|
|
|
|
|
| 26 |
continue
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
if isinstance(repo_id, str):
|
| 35 |
model_names.append(repo_id)
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
return sorted(set(model_names))
|
| 38 |
|
| 39 |
|
|
|
|
| 7 |
|
| 8 |
@st.cache_data(show_spinner=False, ttl=30)
|
| 9 |
def list_remote_models() -> list[str]:
|
| 10 |
+
"""Return the NDIF language models that are currently running.
|
| 11 |
+
|
| 12 |
+
Parses the raw NDIF response directly instead of going through
|
| 13 |
+
``nnsight.ndif_status()`` because that call crashes whenever NDIF reports
|
| 14 |
+
any deployment with an ``application_state`` that isn't in nnsight's
|
| 15 |
+
``ModelStatus`` enum (e.g. ``UNHEALTHY``) — one bad deployment poisons
|
| 16 |
+
the whole response. See nnsight 0.6.3 ``ndif.py::status``.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import json
|
| 20 |
|
| 21 |
import nnsight
|
| 22 |
|
| 23 |
try:
|
| 24 |
+
raw = nnsight.ndif_status(raw=True)
|
| 25 |
except Exception:
|
| 26 |
logger.warning("Failed to fetch NDIF status", exc_info=True)
|
| 27 |
return []
|
| 28 |
|
| 29 |
model_names: list[str] = []
|
| 30 |
+
bad_states: list[tuple[str, str]] = [] # (repo_id_or_key, application_state)
|
| 31 |
|
| 32 |
+
for value in (raw or {}).get("deployments", {}).values():
|
| 33 |
+
if not isinstance(value, dict):
|
| 34 |
continue
|
| 35 |
+
if (
|
| 36 |
+
value.get("deployment_level") not in {"HOT", "WARM"}
|
| 37 |
+
and "schedule" not in value
|
| 38 |
+
):
|
| 39 |
continue
|
| 40 |
|
| 41 |
+
model_key = value.get("model_key", "")
|
| 42 |
+
model_class = model_key.split(":", 1)[0].split(".")[-1]
|
| 43 |
+
try:
|
| 44 |
+
repo_id = json.loads(model_key.split(":", 1)[-1]).get("repo_id")
|
| 45 |
+
except Exception:
|
| 46 |
+
repo_id = model_key
|
| 47 |
|
| 48 |
+
state = value.get("application_state", "NOT DEPLOYED")
|
| 49 |
+
if state not in {"RUNNING", "NOT DEPLOYED", "DEPLOYING", "DELETING"}:
|
| 50 |
+
bad_states.append((repo_id or model_key, state))
|
| 51 |
+
|
| 52 |
+
if model_class not in {"LanguageModel", "StandardizedTransformer"}:
|
| 53 |
+
continue
|
| 54 |
+
if state != "RUNNING":
|
| 55 |
+
continue
|
| 56 |
if isinstance(repo_id, str):
|
| 57 |
model_names.append(repo_id)
|
| 58 |
|
| 59 |
+
if bad_states:
|
| 60 |
+
logger.warning(
|
| 61 |
+
"NDIF reported deployments with unexpected application_state values "
|
| 62 |
+
"(nnsight's ModelStatus enum may not know about these): %s",
|
| 63 |
+
bad_states,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
return sorted(set(model_names))
|
| 67 |
|
| 68 |
|