Spaces:
Sleeping
Sleeping
Jac-Zac commited on
Commit ·
93d5dc5
1
Parent(s): a9950fb
Small cleanups
Browse filesSmall reset bottom cleanup
- state.py +14 -9
- tabs/chat.py +55 -67
- tabs/compare_chat.py +40 -17
- utils/chat.py +2 -2
- utils/contrast.py +11 -8
state.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
|
| 3 |
_CHAT_STATE_PREFIX = "chat_state::"
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
def chat_session_key(model_name: str, dataset_source: str) -> str:
|
|
@@ -37,13 +38,9 @@ def reset_chat_context_state(
|
|
| 37 |
def _evict_inactive_kv_caches(active_key: str) -> None:
|
| 38 |
"""Drop past_key_values from every chat context except the active one."""
|
| 39 |
|
| 40 |
-
for key in st.session_state:
|
| 41 |
-
if
|
| 42 |
-
|
| 43 |
-
and key.startswith(_CHAT_STATE_PREFIX)
|
| 44 |
-
and key != active_key
|
| 45 |
-
):
|
| 46 |
-
state = st.session_state[key]
|
| 47 |
if isinstance(state, dict) and state.get("past_key_values") is not None:
|
| 48 |
state["past_key_values"] = None
|
| 49 |
|
|
@@ -54,13 +51,21 @@ def get_chat_state(
|
|
| 54 |
"""Return the mutable chat state for the active context."""
|
| 55 |
|
| 56 |
key = chat_session_key(model_name, dataset_source)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
state = st.session_state.get(key)
|
| 58 |
if state is None:
|
| 59 |
state = default_chat_state()
|
| 60 |
st.session_state[key] = state
|
| 61 |
else:
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
_evict_inactive_kv_caches(key)
|
| 65 |
if remote and state.get("past_key_values") is not None:
|
| 66 |
state["past_key_values"] = None
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
|
| 3 |
_CHAT_STATE_PREFIX = "chat_state::"
|
| 4 |
+
_CHAT_KEYS_REGISTRY = "chat_state::_registered_keys"
|
| 5 |
|
| 6 |
|
| 7 |
def chat_session_key(model_name: str, dataset_source: str) -> str:
|
|
|
|
| 38 |
def _evict_inactive_kv_caches(active_key: str) -> None:
|
| 39 |
"""Drop past_key_values from every chat context except the active one."""
|
| 40 |
|
| 41 |
+
for key in st.session_state.get(_CHAT_KEYS_REGISTRY, ()):
|
| 42 |
+
if key != active_key:
|
| 43 |
+
state = st.session_state.get(key)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
if isinstance(state, dict) and state.get("past_key_values") is not None:
|
| 45 |
state["past_key_values"] = None
|
| 46 |
|
|
|
|
| 51 |
"""Return the mutable chat state for the active context."""
|
| 52 |
|
| 53 |
key = chat_session_key(model_name, dataset_source)
|
| 54 |
+
registry = st.session_state.get(_CHAT_KEYS_REGISTRY)
|
| 55 |
+
if registry is None:
|
| 56 |
+
registry = set()
|
| 57 |
+
st.session_state[_CHAT_KEYS_REGISTRY] = registry
|
| 58 |
+
registry.add(key)
|
| 59 |
+
|
| 60 |
state = st.session_state.get(key)
|
| 61 |
if state is None:
|
| 62 |
state = default_chat_state()
|
| 63 |
st.session_state[key] = state
|
| 64 |
else:
|
| 65 |
+
state.setdefault("messages", [])
|
| 66 |
+
state.setdefault("persona_id", None)
|
| 67 |
+
state.setdefault("prompt_mode", "templated")
|
| 68 |
+
state.setdefault("past_key_values", None)
|
| 69 |
_evict_inactive_kv_caches(key)
|
| 70 |
if remote and state.get("past_key_values") is not None:
|
| 71 |
state["past_key_values"] = None
|
tabs/chat.py
CHANGED
|
@@ -17,10 +17,19 @@ from utils.helpers import (
|
|
| 17 |
)
|
| 18 |
from utils.runtime import cached_model
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# ── Dialogs ───────────────────────────────────────────────────────────────────
|
| 26 |
|
|
@@ -91,7 +100,7 @@ def _open_system_prompt_dialog(*, prompt_key: str, current_value: str) -> None:
|
|
| 91 |
# ── Message renderers ─────────────────────────────────────────────────────────
|
| 92 |
|
| 93 |
|
| 94 |
-
def
|
| 95 |
message: dict[str, str],
|
| 96 |
show_contrast: bool = False,
|
| 97 |
) -> None:
|
|
@@ -103,7 +112,7 @@ def _render_chat_message(
|
|
| 103 |
if tc is not None:
|
| 104 |
st.html(render_contrast_html(tc))
|
| 105 |
else:
|
| 106 |
-
|
| 107 |
|
| 108 |
|
| 109 |
def _render_editable_message(
|
|
@@ -129,7 +138,7 @@ def _render_editable_message(
|
|
| 129 |
if tc is not None:
|
| 130 |
st.html(render_contrast_html(tc))
|
| 131 |
else:
|
| 132 |
-
|
| 133 |
with edit_col:
|
| 134 |
if st.button(
|
| 135 |
"", icon=":material/edit:", key=f"{edit_key}_edit_{msg_index}", help="Edit"
|
|
@@ -142,7 +151,7 @@ def _render_editable_message(
|
|
| 142 |
)
|
| 143 |
|
| 144 |
|
| 145 |
-
def
|
| 146 |
prompt_key: str,
|
| 147 |
prompt_mode: str,
|
| 148 |
active_system_prompt: str | None,
|
|
@@ -159,7 +168,7 @@ def _render_system_prompt(
|
|
| 159 |
return st.session_state.get(prompt_key) or None
|
| 160 |
|
| 161 |
|
| 162 |
-
def
|
| 163 |
return {
|
| 164 |
"max_new_tokens": int(gen_kwargs["max_new_tokens"]),
|
| 165 |
"advanced_generation": bool(advanced_generation),
|
|
@@ -172,7 +181,7 @@ def _generation_dict(gen_kwargs: dict, advanced_generation: bool) -> dict[str, o
|
|
| 172 |
}
|
| 173 |
|
| 174 |
|
| 175 |
-
def
|
| 176 |
personas: list[PersonaData],
|
| 177 |
current_persona_id: str | None,
|
| 178 |
current_prompt_mode: str,
|
|
@@ -209,7 +218,7 @@ def _render_persona_prompt_controls(
|
|
| 209 |
return selected_persona, prompt_mode, changed
|
| 210 |
|
| 211 |
|
| 212 |
-
def
|
| 213 |
*,
|
| 214 |
chat_log: Any,
|
| 215 |
messages: list[dict[str, str]],
|
|
@@ -233,10 +242,10 @@ def _render_chat_window(
|
|
| 233 |
column_ratio=edit_column_ratio,
|
| 234 |
)
|
| 235 |
else:
|
| 236 |
-
|
| 237 |
|
| 238 |
|
| 239 |
-
def
|
| 240 |
system_prompt: str | None,
|
| 241 |
messages: list[dict[str, str]],
|
| 242 |
) -> list[dict[str, str]]:
|
|
@@ -245,31 +254,6 @@ def _build_chat_messages(
|
|
| 245 |
) + messages
|
| 246 |
|
| 247 |
|
| 248 |
-
def _save_chat_export_message(
|
| 249 |
-
*,
|
| 250 |
-
model_name: str,
|
| 251 |
-
dataset_source: str,
|
| 252 |
-
persona_id: str,
|
| 253 |
-
persona_name: str | None,
|
| 254 |
-
prompt_mode: str,
|
| 255 |
-
system_prompt: str | None,
|
| 256 |
-
messages: list[dict[str, str]],
|
| 257 |
-
generation: dict[str, object],
|
| 258 |
-
panel_label: str | None = None,
|
| 259 |
-
) -> None:
|
| 260 |
-
save_chat_export(
|
| 261 |
-
model_name=model_name,
|
| 262 |
-
dataset_source=dataset_source,
|
| 263 |
-
persona_id=persona_id,
|
| 264 |
-
persona_name=persona_name,
|
| 265 |
-
panel_label=panel_label,
|
| 266 |
-
prompt_mode=prompt_mode,
|
| 267 |
-
system_prompt=system_prompt,
|
| 268 |
-
messages=messages,
|
| 269 |
-
generation=generation,
|
| 270 |
-
)
|
| 271 |
-
|
| 272 |
-
|
| 273 |
# ── Main tab entry point ───────────────────────────────────────────────────────
|
| 274 |
|
| 275 |
|
|
@@ -286,7 +270,7 @@ def _render_generation_settings(context_key: str, remote: bool) -> tuple[dict, b
|
|
| 286 |
"Max new tokens",
|
| 287 |
min_value=16,
|
| 288 |
max_value=512,
|
| 289 |
-
value=
|
| 290 |
step=16,
|
| 291 |
key=widget_key(context_key, "max_new_tokens"),
|
| 292 |
)
|
|
@@ -295,7 +279,7 @@ def _render_generation_settings(context_key: str, remote: bool) -> tuple[dict, b
|
|
| 295 |
"Repetition penalty",
|
| 296 |
min_value=0.5,
|
| 297 |
max_value=2.0,
|
| 298 |
-
value=
|
| 299 |
step=0.05,
|
| 300 |
key=widget_key(context_key, "repetition_penalty"),
|
| 301 |
)
|
|
@@ -313,7 +297,7 @@ def _render_generation_settings(context_key: str, remote: bool) -> tuple[dict, b
|
|
| 313 |
"Temperature",
|
| 314 |
min_value=0.01,
|
| 315 |
max_value=2.0,
|
| 316 |
-
value=
|
| 317 |
step=0.01,
|
| 318 |
disabled=sampling_disabled,
|
| 319 |
key=widget_key(context_key, "temperature"),
|
|
@@ -323,7 +307,7 @@ def _render_generation_settings(context_key: str, remote: bool) -> tuple[dict, b
|
|
| 323 |
"Top-p",
|
| 324 |
min_value=0.01,
|
| 325 |
max_value=1.0,
|
| 326 |
-
value=
|
| 327 |
step=0.01,
|
| 328 |
disabled=sampling_disabled,
|
| 329 |
key=widget_key(context_key, "top_p"),
|
|
@@ -333,7 +317,7 @@ def _render_generation_settings(context_key: str, remote: bool) -> tuple[dict, b
|
|
| 333 |
"Top-k (0 = off)",
|
| 334 |
min_value=0,
|
| 335 |
max_value=100,
|
| 336 |
-
value=
|
| 337 |
step=1,
|
| 338 |
disabled=sampling_disabled,
|
| 339 |
key=widget_key(context_key, "top_k"),
|
|
@@ -365,12 +349,12 @@ def _render_generation_settings(context_key: str, remote: bool) -> tuple[dict, b
|
|
| 365 |
st.caption("Seed is local-only and disabled for remote runs.")
|
| 366 |
|
| 367 |
advanced_generation = (
|
| 368 |
-
max_new_tokens !=
|
| 369 |
or use_sampling
|
| 370 |
-
or temperature !=
|
| 371 |
-
or top_p !=
|
| 372 |
-
or top_k !=
|
| 373 |
-
or repetition_penalty !=
|
| 374 |
or seed is not None
|
| 375 |
)
|
| 376 |
|
|
@@ -395,6 +379,14 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 395 |
|
| 396 |
context_key = chat_session_key(model_name, dataset_source)
|
| 397 |
chat_state = get_chat_state(model_name, remote, dataset_source)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
try:
|
| 399 |
dataset, dataset_status = load_dataset(
|
| 400 |
dataset_source,
|
|
@@ -416,12 +408,17 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 416 |
gen_kwargs, advanced_generation = _render_generation_settings(context_key, remote)
|
| 417 |
|
| 418 |
# ── Mode toggle ───────────────────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
compare_mode = st.toggle(
|
| 420 |
"Compare mode",
|
| 421 |
-
|
| 422 |
-
key=widget_key(context_key, "compare_mode"),
|
| 423 |
help="Side-by-side: send one message to two independent persona/prompt configurations.",
|
| 424 |
)
|
|
|
|
| 425 |
|
| 426 |
if compare_mode:
|
| 427 |
from tabs.compare_chat import render_compare_mode
|
|
@@ -458,7 +455,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 458 |
)
|
| 459 |
st.session_state.pop(edit_key, None)
|
| 460 |
|
| 461 |
-
selected_persona, prompt_mode, changed_context =
|
| 462 |
personas,
|
| 463 |
chat_state["persona_id"],
|
| 464 |
chat_state["prompt_mode"],
|
|
@@ -466,6 +463,8 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 466 |
prompt_mode_select_key,
|
| 467 |
column_widths=(2, 1),
|
| 468 |
)
|
|
|
|
|
|
|
| 469 |
|
| 470 |
active_system_prompt = resolve_system_prompt(
|
| 471 |
persona=selected_persona,
|
|
@@ -481,13 +480,13 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 481 |
chat_log = st.container()
|
| 482 |
|
| 483 |
with chat_log:
|
| 484 |
-
active_system_prompt =
|
| 485 |
prompt_key,
|
| 486 |
prompt_mode,
|
| 487 |
active_system_prompt,
|
| 488 |
)
|
| 489 |
|
| 490 |
-
|
| 491 |
chat_log=chat_log,
|
| 492 |
messages=chat_state["messages"],
|
| 493 |
chat_state=chat_state,
|
|
@@ -505,7 +504,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 505 |
key=export_key,
|
| 506 |
help="Export chat",
|
| 507 |
):
|
| 508 |
-
|
| 509 |
model_name=model_name,
|
| 510 |
dataset_source=dataset_source,
|
| 511 |
persona_id=selected_persona.id,
|
|
@@ -513,7 +512,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 513 |
prompt_mode=prompt_mode,
|
| 514 |
system_prompt=active_system_prompt,
|
| 515 |
messages=chat_state["messages"],
|
| 516 |
-
generation=
|
| 517 |
)
|
| 518 |
st.toast("Exported", icon=":material/check:")
|
| 519 |
with rst_col:
|
|
@@ -538,7 +537,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 538 |
if not st.session_state.pop(pending_key, False):
|
| 539 |
return
|
| 540 |
|
| 541 |
-
messages =
|
| 542 |
|
| 543 |
with st.spinner("Generating reply..."):
|
| 544 |
model = cached_model(model_name=model_name, remote=remote)
|
|
@@ -559,15 +558,4 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 559 |
|
| 560 |
chat_state["messages"].append({"role": "assistant", "content": reply.text})
|
| 561 |
chat_state["past_key_values"] = reply.past_key_values if not remote else None
|
| 562 |
-
|
| 563 |
-
save_chat_export(
|
| 564 |
-
model_name=model_name,
|
| 565 |
-
dataset_source=dataset_source,
|
| 566 |
-
persona_id=selected_persona.id,
|
| 567 |
-
persona_name=getattr(selected_persona, "name", None),
|
| 568 |
-
prompt_mode=prompt_mode,
|
| 569 |
-
system_prompt=active_system_prompt,
|
| 570 |
-
messages=chat_state["messages"],
|
| 571 |
-
generation=_generation_dict(gen_kwargs, advanced_generation),
|
| 572 |
-
)
|
| 573 |
st.rerun()
|
|
|
|
| 17 |
)
|
| 18 |
from utils.runtime import cached_model
|
| 19 |
|
| 20 |
+
# ── Persistence keys for surviving model / remote switches ────────────────────
|
| 21 |
+
_LAST_PERSONA_ID_KEY = "chat:last_persona_id"
|
| 22 |
+
_LAST_PROMPT_MODE_KEY = "chat:last_prompt_mode"
|
| 23 |
+
_LAST_COMPARE_MODE_KEY = "chat:last_compare_mode"
|
| 24 |
+
|
| 25 |
+
# ── Generation defaults (single source of truth) ─────────────────────────────
|
| 26 |
+
_GEN_DEFAULTS = {
|
| 27 |
+
"max_new_tokens": 256,
|
| 28 |
+
"temperature": 1.0,
|
| 29 |
+
"top_p": 1.0,
|
| 30 |
+
"top_k": 50,
|
| 31 |
+
"repetition_penalty": 1.0,
|
| 32 |
+
}
|
| 33 |
|
| 34 |
# ── Dialogs ───────────────────────────────────────────────────────────────────
|
| 35 |
|
|
|
|
| 100 |
# ── Message renderers ─────────────────────────────────────────────────────────
|
| 101 |
|
| 102 |
|
| 103 |
+
def render_chat_message(
|
| 104 |
message: dict[str, str],
|
| 105 |
show_contrast: bool = False,
|
| 106 |
) -> None:
|
|
|
|
| 112 |
if tc is not None:
|
| 113 |
st.html(render_contrast_html(tc))
|
| 114 |
else:
|
| 115 |
+
st.markdown(message["content"])
|
| 116 |
|
| 117 |
|
| 118 |
def _render_editable_message(
|
|
|
|
| 138 |
if tc is not None:
|
| 139 |
st.html(render_contrast_html(tc))
|
| 140 |
else:
|
| 141 |
+
st.markdown(message["content"])
|
| 142 |
with edit_col:
|
| 143 |
if st.button(
|
| 144 |
"", icon=":material/edit:", key=f"{edit_key}_edit_{msg_index}", help="Edit"
|
|
|
|
| 151 |
)
|
| 152 |
|
| 153 |
|
| 154 |
+
def render_system_prompt(
|
| 155 |
prompt_key: str,
|
| 156 |
prompt_mode: str,
|
| 157 |
active_system_prompt: str | None,
|
|
|
|
| 168 |
return st.session_state.get(prompt_key) or None
|
| 169 |
|
| 170 |
|
| 171 |
+
def generation_dict(gen_kwargs: dict, advanced_generation: bool) -> dict[str, object]:
|
| 172 |
return {
|
| 173 |
"max_new_tokens": int(gen_kwargs["max_new_tokens"]),
|
| 174 |
"advanced_generation": bool(advanced_generation),
|
|
|
|
| 181 |
}
|
| 182 |
|
| 183 |
|
| 184 |
+
def render_persona_prompt_controls(
|
| 185 |
personas: list[PersonaData],
|
| 186 |
current_persona_id: str | None,
|
| 187 |
current_prompt_mode: str,
|
|
|
|
| 218 |
return selected_persona, prompt_mode, changed
|
| 219 |
|
| 220 |
|
| 221 |
+
def render_chat_window(
|
| 222 |
*,
|
| 223 |
chat_log: Any,
|
| 224 |
messages: list[dict[str, str]],
|
|
|
|
| 242 |
column_ratio=edit_column_ratio,
|
| 243 |
)
|
| 244 |
else:
|
| 245 |
+
render_chat_message(message, show_contrast=show_contrast)
|
| 246 |
|
| 247 |
|
| 248 |
+
def build_chat_messages(
|
| 249 |
system_prompt: str | None,
|
| 250 |
messages: list[dict[str, str]],
|
| 251 |
) -> list[dict[str, str]]:
|
|
|
|
| 254 |
) + messages
|
| 255 |
|
| 256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
# ── Main tab entry point ───────────────────────────────────────────────────────
|
| 258 |
|
| 259 |
|
|
|
|
| 270 |
"Max new tokens",
|
| 271 |
min_value=16,
|
| 272 |
max_value=512,
|
| 273 |
+
value=_GEN_DEFAULTS["max_new_tokens"],
|
| 274 |
step=16,
|
| 275 |
key=widget_key(context_key, "max_new_tokens"),
|
| 276 |
)
|
|
|
|
| 279 |
"Repetition penalty",
|
| 280 |
min_value=0.5,
|
| 281 |
max_value=2.0,
|
| 282 |
+
value=_GEN_DEFAULTS["repetition_penalty"],
|
| 283 |
step=0.05,
|
| 284 |
key=widget_key(context_key, "repetition_penalty"),
|
| 285 |
)
|
|
|
|
| 297 |
"Temperature",
|
| 298 |
min_value=0.01,
|
| 299 |
max_value=2.0,
|
| 300 |
+
value=_GEN_DEFAULTS["temperature"],
|
| 301 |
step=0.01,
|
| 302 |
disabled=sampling_disabled,
|
| 303 |
key=widget_key(context_key, "temperature"),
|
|
|
|
| 307 |
"Top-p",
|
| 308 |
min_value=0.01,
|
| 309 |
max_value=1.0,
|
| 310 |
+
value=_GEN_DEFAULTS["top_p"],
|
| 311 |
step=0.01,
|
| 312 |
disabled=sampling_disabled,
|
| 313 |
key=widget_key(context_key, "top_p"),
|
|
|
|
| 317 |
"Top-k (0 = off)",
|
| 318 |
min_value=0,
|
| 319 |
max_value=100,
|
| 320 |
+
value=_GEN_DEFAULTS["top_k"],
|
| 321 |
step=1,
|
| 322 |
disabled=sampling_disabled,
|
| 323 |
key=widget_key(context_key, "top_k"),
|
|
|
|
| 349 |
st.caption("Seed is local-only and disabled for remote runs.")
|
| 350 |
|
| 351 |
advanced_generation = (
|
| 352 |
+
max_new_tokens != _GEN_DEFAULTS["max_new_tokens"]
|
| 353 |
or use_sampling
|
| 354 |
+
or temperature != _GEN_DEFAULTS["temperature"]
|
| 355 |
+
or top_p != _GEN_DEFAULTS["top_p"]
|
| 356 |
+
or top_k != _GEN_DEFAULTS["top_k"]
|
| 357 |
+
or repetition_penalty != _GEN_DEFAULTS["repetition_penalty"]
|
| 358 |
or seed is not None
|
| 359 |
)
|
| 360 |
|
|
|
|
| 379 |
|
| 380 |
context_key = chat_session_key(model_name, dataset_source)
|
| 381 |
chat_state = get_chat_state(model_name, remote, dataset_source)
|
| 382 |
+
|
| 383 |
+
# Carry over persona / prompt selections across model or remote switches.
|
| 384 |
+
if chat_state["persona_id"] is None:
|
| 385 |
+
chat_state["persona_id"] = st.session_state.get(_LAST_PERSONA_ID_KEY)
|
| 386 |
+
chat_state["prompt_mode"] = st.session_state.get(
|
| 387 |
+
_LAST_PROMPT_MODE_KEY, "templated"
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
try:
|
| 391 |
dataset, dataset_status = load_dataset(
|
| 392 |
dataset_source,
|
|
|
|
| 408 |
gen_kwargs, advanced_generation = _render_generation_settings(context_key, remote)
|
| 409 |
|
| 410 |
# ── Mode toggle ───────────────────────────────────────────────────────────
|
| 411 |
+
compare_key = widget_key(context_key, "compare_mode")
|
| 412 |
+
if compare_key not in st.session_state:
|
| 413 |
+
st.session_state[compare_key] = st.session_state.get(
|
| 414 |
+
_LAST_COMPARE_MODE_KEY, False
|
| 415 |
+
)
|
| 416 |
compare_mode = st.toggle(
|
| 417 |
"Compare mode",
|
| 418 |
+
key=compare_key,
|
|
|
|
| 419 |
help="Side-by-side: send one message to two independent persona/prompt configurations.",
|
| 420 |
)
|
| 421 |
+
st.session_state[_LAST_COMPARE_MODE_KEY] = compare_mode
|
| 422 |
|
| 423 |
if compare_mode:
|
| 424 |
from tabs.compare_chat import render_compare_mode
|
|
|
|
| 455 |
)
|
| 456 |
st.session_state.pop(edit_key, None)
|
| 457 |
|
| 458 |
+
selected_persona, prompt_mode, changed_context = render_persona_prompt_controls(
|
| 459 |
personas,
|
| 460 |
chat_state["persona_id"],
|
| 461 |
chat_state["prompt_mode"],
|
|
|
|
| 463 |
prompt_mode_select_key,
|
| 464 |
column_widths=(2, 1),
|
| 465 |
)
|
| 466 |
+
st.session_state[_LAST_PERSONA_ID_KEY] = selected_persona.id
|
| 467 |
+
st.session_state[_LAST_PROMPT_MODE_KEY] = prompt_mode
|
| 468 |
|
| 469 |
active_system_prompt = resolve_system_prompt(
|
| 470 |
persona=selected_persona,
|
|
|
|
| 480 |
chat_log = st.container()
|
| 481 |
|
| 482 |
with chat_log:
|
| 483 |
+
active_system_prompt = render_system_prompt(
|
| 484 |
prompt_key,
|
| 485 |
prompt_mode,
|
| 486 |
active_system_prompt,
|
| 487 |
)
|
| 488 |
|
| 489 |
+
render_chat_window(
|
| 490 |
chat_log=chat_log,
|
| 491 |
messages=chat_state["messages"],
|
| 492 |
chat_state=chat_state,
|
|
|
|
| 504 |
key=export_key,
|
| 505 |
help="Export chat",
|
| 506 |
):
|
| 507 |
+
save_chat_export(
|
| 508 |
model_name=model_name,
|
| 509 |
dataset_source=dataset_source,
|
| 510 |
persona_id=selected_persona.id,
|
|
|
|
| 512 |
prompt_mode=prompt_mode,
|
| 513 |
system_prompt=active_system_prompt,
|
| 514 |
messages=chat_state["messages"],
|
| 515 |
+
generation=generation_dict(gen_kwargs, advanced_generation),
|
| 516 |
)
|
| 517 |
st.toast("Exported", icon=":material/check:")
|
| 518 |
with rst_col:
|
|
|
|
| 537 |
if not st.session_state.pop(pending_key, False):
|
| 538 |
return
|
| 539 |
|
| 540 |
+
messages = build_chat_messages(active_system_prompt, chat_state["messages"])
|
| 541 |
|
| 542 |
with st.spinner("Generating reply..."):
|
| 543 |
model = cached_model(model_name=model_name, remote=remote)
|
|
|
|
| 558 |
|
| 559 |
chat_state["messages"].append({"role": "assistant", "content": reply.text})
|
| 560 |
chat_state["past_key_values"] = reply.past_key_values if not remote else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
st.rerun()
|
tabs/compare_chat.py
CHANGED
|
@@ -4,18 +4,18 @@ from persona_data.synth_persona import PersonaData
|
|
| 4 |
|
| 5 |
from state import default_chat_state, reset_chat_context_state
|
| 6 |
from utils.chat import ChatReply, generate_chat_reply, resolve_system_prompt
|
|
|
|
| 7 |
from utils.contrast import compute_contrast, compute_contrast_pair
|
| 8 |
from utils.helpers import persona_label, widget_key
|
| 9 |
from utils.runtime import cached_model
|
| 10 |
|
| 11 |
from .chat import (
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
_save_chat_export_message,
|
| 19 |
)
|
| 20 |
|
| 21 |
|
|
@@ -47,7 +47,7 @@ def _generate_panel_reply(
|
|
| 47 |
) -> ChatReply:
|
| 48 |
return generate_chat_reply(
|
| 49 |
model=model,
|
| 50 |
-
messages=
|
| 51 |
remote=remote,
|
| 52 |
past_key_values=panel_state["past_key_values"],
|
| 53 |
**gen_kwargs,
|
|
@@ -90,17 +90,28 @@ def render_compare_mode(
|
|
| 90 |
def render_panel(side: str) -> tuple[dict, object, str | None, str, PersonaData]:
|
| 91 |
panel_key = widget_key(context_key, f"cmp_{side}")
|
| 92 |
state = _panel_state(panel_key)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
prompt_key = widget_key(panel_key, "custom_prompt")
|
| 94 |
edit_key = widget_key(panel_key, "edit_idx")
|
| 95 |
pending_regen_key = widget_key(panel_key, "pending_regen")
|
| 96 |
|
| 97 |
-
selected_persona, prompt_mode, changed =
|
| 98 |
personas,
|
| 99 |
state["persona_id"],
|
| 100 |
state["prompt_mode"],
|
| 101 |
widget_key(panel_key, "persona"),
|
| 102 |
widget_key(panel_key, "prompt_mode"),
|
| 103 |
)
|
|
|
|
|
|
|
|
|
|
| 104 |
if changed:
|
| 105 |
reset_chat_context_state(
|
| 106 |
state,
|
|
@@ -117,7 +128,7 @@ def render_compare_mode(
|
|
| 117 |
|
| 118 |
chat_log = st.container()
|
| 119 |
with chat_log:
|
| 120 |
-
active_system_prompt =
|
| 121 |
prompt_key,
|
| 122 |
prompt_mode,
|
| 123 |
active_system_prompt,
|
|
@@ -220,10 +231,10 @@ def render_compare_mode(
|
|
| 220 |
):
|
| 221 |
msg.pop("_needs_contrast", None)
|
| 222 |
continue
|
| 223 |
-
context_a =
|
| 224 |
left_prompt, left_state["messages"][:msg_idx]
|
| 225 |
)
|
| 226 |
-
context_b =
|
| 227 |
right_prompt, right_state["messages"][:msg_idx]
|
| 228 |
)
|
| 229 |
try:
|
|
@@ -256,7 +267,7 @@ def render_compare_mode(
|
|
| 256 |
panel_edit_key,
|
| 257 |
_,
|
| 258 |
) in panels:
|
| 259 |
-
|
| 260 |
chat_log=panel_log,
|
| 261 |
messages=panel_state["messages"],
|
| 262 |
chat_state=panel_state,
|
|
@@ -267,6 +278,9 @@ def render_compare_mode(
|
|
| 267 |
)
|
| 268 |
|
| 269 |
footer = st.container()
|
|
|
|
|
|
|
|
|
|
| 270 |
with footer:
|
| 271 |
exp_col, rst_col, _spacer = st.columns([0.5, 0.5, 10], gap="xsmall")
|
| 272 |
with exp_col:
|
|
@@ -280,7 +294,7 @@ def render_compare_mode(
|
|
| 280 |
("left", left_state, left_prompt, left_persona),
|
| 281 |
("right", right_state, right_prompt, right_persona),
|
| 282 |
):
|
| 283 |
-
|
| 284 |
model_name=model_name,
|
| 285 |
dataset_source=dataset_source,
|
| 286 |
persona_id=panel_persona.id,
|
|
@@ -288,15 +302,21 @@ def render_compare_mode(
|
|
| 288 |
prompt_mode=panel_state["prompt_mode"],
|
| 289 |
system_prompt=panel_prompt,
|
| 290 |
messages=panel_state["messages"],
|
| 291 |
-
generation=
|
| 292 |
panel_label=side,
|
| 293 |
)
|
| 294 |
st.toast("Exported", icon=":material/check:")
|
| 295 |
with rst_col:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
with st.popover(
|
| 297 |
"",
|
| 298 |
icon=":material/delete_sweep:",
|
| 299 |
help="Reset chat",
|
|
|
|
| 300 |
):
|
| 301 |
if st.button(
|
| 302 |
"Reset left",
|
|
@@ -310,6 +330,7 @@ def render_compare_mode(
|
|
| 310 |
left_prompt_key,
|
| 311 |
left_pending_key,
|
| 312 |
)
|
|
|
|
| 313 |
st.rerun()
|
| 314 |
if st.button(
|
| 315 |
"Reset right",
|
|
@@ -323,6 +344,7 @@ def render_compare_mode(
|
|
| 323 |
right_prompt_key,
|
| 324 |
right_pending_key,
|
| 325 |
)
|
|
|
|
| 326 |
st.rerun()
|
| 327 |
if st.button(
|
| 328 |
"Reset both",
|
|
@@ -345,6 +367,7 @@ def render_compare_mode(
|
|
| 345 |
right_prompt_key,
|
| 346 |
right_pending_key,
|
| 347 |
)
|
|
|
|
| 348 |
st.rerun()
|
| 349 |
|
| 350 |
user_prompt = st.chat_input(
|
|
@@ -360,11 +383,11 @@ def render_compare_mode(
|
|
| 360 |
for panel_state, panel_log, _panel_prompt, _p_pending, _panel_edit_key, _ in panels:
|
| 361 |
panel_state["messages"].append({"role": "user", "content": user_prompt})
|
| 362 |
with panel_log:
|
| 363 |
-
|
| 364 |
|
| 365 |
# Snapshot contexts before the new assistant turn is appended (needed for contrast).
|
| 366 |
pre_gen_contexts = [
|
| 367 |
-
|
| 368 |
for panel_state, _panel_log, panel_prompt, _p_pending, _panel_edit_key, _ in panels
|
| 369 |
]
|
| 370 |
|
|
|
|
| 4 |
|
| 5 |
from state import default_chat_state, reset_chat_context_state
|
| 6 |
from utils.chat import ChatReply, generate_chat_reply, resolve_system_prompt
|
| 7 |
+
from utils.chat_export import save_chat_export
|
| 8 |
from utils.contrast import compute_contrast, compute_contrast_pair
|
| 9 |
from utils.helpers import persona_label, widget_key
|
| 10 |
from utils.runtime import cached_model
|
| 11 |
|
| 12 |
from .chat import (
|
| 13 |
+
build_chat_messages,
|
| 14 |
+
generation_dict,
|
| 15 |
+
render_chat_message,
|
| 16 |
+
render_chat_window,
|
| 17 |
+
render_persona_prompt_controls,
|
| 18 |
+
render_system_prompt,
|
|
|
|
| 19 |
)
|
| 20 |
|
| 21 |
|
|
|
|
| 47 |
) -> ChatReply:
|
| 48 |
return generate_chat_reply(
|
| 49 |
model=model,
|
| 50 |
+
messages=build_chat_messages(panel_prompt, panel_state["messages"]),
|
| 51 |
remote=remote,
|
| 52 |
past_key_values=panel_state["past_key_values"],
|
| 53 |
**gen_kwargs,
|
|
|
|
| 90 |
def render_panel(side: str) -> tuple[dict, object, str | None, str, PersonaData]:
|
| 91 |
panel_key = widget_key(context_key, f"cmp_{side}")
|
| 92 |
state = _panel_state(panel_key)
|
| 93 |
+
|
| 94 |
+
# Carry over persona / prompt selections across model or remote switches.
|
| 95 |
+
persist_persona_key = f"chat:last_cmp_{side}_persona"
|
| 96 |
+
persist_prompt_key = f"chat:last_cmp_{side}_prompt"
|
| 97 |
+
if state["persona_id"] is None:
|
| 98 |
+
state["persona_id"] = st.session_state.get(persist_persona_key)
|
| 99 |
+
state["prompt_mode"] = st.session_state.get(persist_prompt_key, "templated")
|
| 100 |
+
|
| 101 |
prompt_key = widget_key(panel_key, "custom_prompt")
|
| 102 |
edit_key = widget_key(panel_key, "edit_idx")
|
| 103 |
pending_regen_key = widget_key(panel_key, "pending_regen")
|
| 104 |
|
| 105 |
+
selected_persona, prompt_mode, changed = render_persona_prompt_controls(
|
| 106 |
personas,
|
| 107 |
state["persona_id"],
|
| 108 |
state["prompt_mode"],
|
| 109 |
widget_key(panel_key, "persona"),
|
| 110 |
widget_key(panel_key, "prompt_mode"),
|
| 111 |
)
|
| 112 |
+
st.session_state[persist_persona_key] = selected_persona.id
|
| 113 |
+
st.session_state[persist_prompt_key] = prompt_mode
|
| 114 |
+
|
| 115 |
if changed:
|
| 116 |
reset_chat_context_state(
|
| 117 |
state,
|
|
|
|
| 128 |
|
| 129 |
chat_log = st.container()
|
| 130 |
with chat_log:
|
| 131 |
+
active_system_prompt = render_system_prompt(
|
| 132 |
prompt_key,
|
| 133 |
prompt_mode,
|
| 134 |
active_system_prompt,
|
|
|
|
| 231 |
):
|
| 232 |
msg.pop("_needs_contrast", None)
|
| 233 |
continue
|
| 234 |
+
context_a = build_chat_messages(
|
| 235 |
left_prompt, left_state["messages"][:msg_idx]
|
| 236 |
)
|
| 237 |
+
context_b = build_chat_messages(
|
| 238 |
right_prompt, right_state["messages"][:msg_idx]
|
| 239 |
)
|
| 240 |
try:
|
|
|
|
| 267 |
panel_edit_key,
|
| 268 |
_,
|
| 269 |
) in panels:
|
| 270 |
+
render_chat_window(
|
| 271 |
chat_log=panel_log,
|
| 272 |
messages=panel_state["messages"],
|
| 273 |
chat_state=panel_state,
|
|
|
|
| 278 |
)
|
| 279 |
|
| 280 |
footer = st.container()
|
| 281 |
+
reset_menu_nonce_key = widget_key(context_key, "cmp_reset_menu_nonce")
|
| 282 |
+
if reset_menu_nonce_key not in st.session_state:
|
| 283 |
+
st.session_state[reset_menu_nonce_key] = 0
|
| 284 |
with footer:
|
| 285 |
exp_col, rst_col, _spacer = st.columns([0.5, 0.5, 10], gap="xsmall")
|
| 286 |
with exp_col:
|
|
|
|
| 294 |
("left", left_state, left_prompt, left_persona),
|
| 295 |
("right", right_state, right_prompt, right_persona),
|
| 296 |
):
|
| 297 |
+
save_chat_export(
|
| 298 |
model_name=model_name,
|
| 299 |
dataset_source=dataset_source,
|
| 300 |
persona_id=panel_persona.id,
|
|
|
|
| 302 |
prompt_mode=panel_state["prompt_mode"],
|
| 303 |
system_prompt=panel_prompt,
|
| 304 |
messages=panel_state["messages"],
|
| 305 |
+
generation=generation_dict(gen_kwargs, advanced_generation),
|
| 306 |
panel_label=side,
|
| 307 |
)
|
| 308 |
st.toast("Exported", icon=":material/check:")
|
| 309 |
with rst_col:
|
| 310 |
+
popover_key = widget_key(
|
| 311 |
+
context_key,
|
| 312 |
+
"cmp_reset_menu",
|
| 313 |
+
str(st.session_state[reset_menu_nonce_key]),
|
| 314 |
+
)
|
| 315 |
with st.popover(
|
| 316 |
"",
|
| 317 |
icon=":material/delete_sweep:",
|
| 318 |
help="Reset chat",
|
| 319 |
+
key=popover_key,
|
| 320 |
):
|
| 321 |
if st.button(
|
| 322 |
"Reset left",
|
|
|
|
| 330 |
left_prompt_key,
|
| 331 |
left_pending_key,
|
| 332 |
)
|
| 333 |
+
st.session_state[reset_menu_nonce_key] += 1
|
| 334 |
st.rerun()
|
| 335 |
if st.button(
|
| 336 |
"Reset right",
|
|
|
|
| 344 |
right_prompt_key,
|
| 345 |
right_pending_key,
|
| 346 |
)
|
| 347 |
+
st.session_state[reset_menu_nonce_key] += 1
|
| 348 |
st.rerun()
|
| 349 |
if st.button(
|
| 350 |
"Reset both",
|
|
|
|
| 367 |
right_prompt_key,
|
| 368 |
right_pending_key,
|
| 369 |
)
|
| 370 |
+
st.session_state[reset_menu_nonce_key] += 1
|
| 371 |
st.rerun()
|
| 372 |
|
| 373 |
user_prompt = st.chat_input(
|
|
|
|
| 383 |
for panel_state, panel_log, _panel_prompt, _p_pending, _panel_edit_key, _ in panels:
|
| 384 |
panel_state["messages"].append({"role": "user", "content": user_prompt})
|
| 385 |
with panel_log:
|
| 386 |
+
render_chat_message({"role": "user", "content": user_prompt})
|
| 387 |
|
| 388 |
# Snapshot contexts before the new assistant turn is appended (needed for contrast).
|
| 389 |
pre_gen_contexts = [
|
| 390 |
+
build_chat_messages(panel_prompt, panel_state["messages"])
|
| 391 |
for panel_state, _panel_log, panel_prompt, _p_pending, _panel_edit_key, _ in panels
|
| 392 |
]
|
| 393 |
|
utils/chat.py
CHANGED
|
@@ -73,7 +73,7 @@ def _format_plain_messages(
|
|
| 73 |
return "\n\n".join(lines)
|
| 74 |
|
| 75 |
|
| 76 |
-
def
|
| 77 |
messages: list[dict[str, str]], tokenizer: object
|
| 78 |
) -> tuple[str, int]:
|
| 79 |
"""Render messages into a single prompt string and count prompt tokens.
|
|
@@ -169,7 +169,7 @@ def generate_chat_reply(
|
|
| 169 |
"""
|
| 170 |
|
| 171 |
tokenizer = model.tokenizer
|
| 172 |
-
prompt, prompt_token_count =
|
| 173 |
|
| 174 |
generation_kwargs: dict[str, object] = {
|
| 175 |
"max_new_tokens": max_new_tokens,
|
|
|
|
| 73 |
return "\n\n".join(lines)
|
| 74 |
|
| 75 |
|
| 76 |
+
def format_generation_prompt(
|
| 77 |
messages: list[dict[str, str]], tokenizer: object
|
| 78 |
) -> tuple[str, int]:
|
| 79 |
"""Render messages into a single prompt string and count prompt tokens.
|
|
|
|
| 169 |
"""
|
| 170 |
|
| 171 |
tokenizer = model.tokenizer
|
| 172 |
+
prompt, prompt_token_count = format_generation_prompt(messages, tokenizer)
|
| 173 |
|
| 174 |
generation_kwargs: dict[str, object] = {
|
| 175 |
"max_new_tokens": max_new_tokens,
|
utils/contrast.py
CHANGED
|
@@ -1,7 +1,3 @@
|
|
| 1 |
-
# WARNING: This is mostly vibecoded and need reviews
|
| 2 |
-
# - Check that the model is runned once with normally for gneration and things are beeing traced perphaps at the last step of generation with iter.last or somrething liek that from the docs
|
| 3 |
-
# - Then the model is runned again with the entire context of the conversation from the other context on the rifht ? or on the left dependeing on which one we are doing at the moment. And this will then compute the prob diff and show them.
|
| 4 |
-
|
| 5 |
"""
|
| 6 |
Contrastive token-level log-probability comparison for compare mode.
|
| 7 |
|
|
@@ -15,13 +11,16 @@ Negative (blue) → token is more characteristic of persona B.
|
|
| 15 |
Near-zero (gray) → both personas would emit this token with similar likelihood.
|
| 16 |
"""
|
| 17 |
|
|
|
|
| 18 |
from dataclasses import dataclass
|
| 19 |
from html import escape
|
| 20 |
|
| 21 |
import torch
|
| 22 |
from nnterp import StandardizedTransformer
|
| 23 |
|
| 24 |
-
from utils.chat import
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
@dataclass
|
|
@@ -48,6 +47,7 @@ def _normalise_diffs(diffs: torch.Tensor) -> list[float]:
|
|
| 48 |
|
| 49 |
|
| 50 |
def _decode_ids(tokenizer: object, ids: list[int]) -> str:
|
|
|
|
| 51 |
try:
|
| 52 |
return tokenizer.decode(
|
| 53 |
ids,
|
|
@@ -79,15 +79,18 @@ def _prepare_trace_text(
|
|
| 79 |
response_ids: torch.Tensor,
|
| 80 |
) -> tuple[str, int, int]:
|
| 81 |
"""Build the trace text and return ``(full_text, n_ctx, n_resp)``."""
|
| 82 |
-
context_prompt, _ =
|
| 83 |
context_ids = tokenizer(context_prompt, return_tensors="pt").input_ids[0]
|
| 84 |
response_text = _decode_ids(tokenizer, response_ids.tolist())
|
| 85 |
full_text = context_prompt + response_text
|
| 86 |
full_ids = tokenizer(full_text, return_tensors="pt").input_ids[0]
|
| 87 |
expected_ids = torch.cat([context_ids, response_ids.cpu()])
|
| 88 |
if full_ids.tolist() != expected_ids.tolist():
|
| 89 |
-
|
| 90 |
-
"contrast trace text did not round-trip to the expected token ids"
|
|
|
|
|
|
|
|
|
|
| 91 |
)
|
| 92 |
n_ctx = len(context_ids)
|
| 93 |
n_resp = len(response_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Contrastive token-level log-probability comparison for compare mode.
|
| 3 |
|
|
|
|
| 11 |
Near-zero (gray) → both personas would emit this token with similar likelihood.
|
| 12 |
"""
|
| 13 |
|
| 14 |
+
import logging
|
| 15 |
from dataclasses import dataclass
|
| 16 |
from html import escape
|
| 17 |
|
| 18 |
import torch
|
| 19 |
from nnterp import StandardizedTransformer
|
| 20 |
|
| 21 |
+
from utils.chat import format_generation_prompt
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
|
| 25 |
|
| 26 |
@dataclass
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
def _decode_ids(tokenizer: object, ids: list[int]) -> str:
|
| 50 |
+
"""Decode token IDs, falling back when clean_up_tokenization_spaces is unsupported."""
|
| 51 |
try:
|
| 52 |
return tokenizer.decode(
|
| 53 |
ids,
|
|
|
|
| 79 |
response_ids: torch.Tensor,
|
| 80 |
) -> tuple[str, int, int]:
|
| 81 |
"""Build the trace text and return ``(full_text, n_ctx, n_resp)``."""
|
| 82 |
+
context_prompt, _ = format_generation_prompt(context_messages, tokenizer)
|
| 83 |
context_ids = tokenizer(context_prompt, return_tensors="pt").input_ids[0]
|
| 84 |
response_text = _decode_ids(tokenizer, response_ids.tolist())
|
| 85 |
full_text = context_prompt + response_text
|
| 86 |
full_ids = tokenizer(full_text, return_tensors="pt").input_ids[0]
|
| 87 |
expected_ids = torch.cat([context_ids, response_ids.cpu()])
|
| 88 |
if full_ids.tolist() != expected_ids.tolist():
|
| 89 |
+
logger.warning(
|
| 90 |
+
"contrast trace text did not round-trip to the expected token ids "
|
| 91 |
+
"(expected %d tokens, got %d); contrast scores may be slightly misaligned",
|
| 92 |
+
len(expected_ids),
|
| 93 |
+
len(full_ids),
|
| 94 |
)
|
| 95 |
n_ctx = len(context_ids)
|
| 96 |
n_resp = len(response_ids)
|