Jac-Zac commited on
Commit
93d5dc5
·
1 Parent(s): a9950fb

Small cleanups

Browse files

Small reset bottom cleanup

Files changed (5) hide show
  1. state.py +14 -9
  2. tabs/chat.py +55 -67
  3. tabs/compare_chat.py +40 -17
  4. utils/chat.py +2 -2
  5. utils/contrast.py +11 -8
state.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
 
3
  _CHAT_STATE_PREFIX = "chat_state::"
 
4
 
5
 
6
  def chat_session_key(model_name: str, dataset_source: str) -> str:
@@ -37,13 +38,9 @@ def reset_chat_context_state(
37
  def _evict_inactive_kv_caches(active_key: str) -> None:
38
  """Drop past_key_values from every chat context except the active one."""
39
 
40
- for key in st.session_state:
41
- if (
42
- isinstance(key, str)
43
- and key.startswith(_CHAT_STATE_PREFIX)
44
- and key != active_key
45
- ):
46
- state = st.session_state[key]
47
  if isinstance(state, dict) and state.get("past_key_values") is not None:
48
  state["past_key_values"] = None
49
 
@@ -54,13 +51,21 @@ def get_chat_state(
54
  """Return the mutable chat state for the active context."""
55
 
56
  key = chat_session_key(model_name, dataset_source)
 
 
 
 
 
 
57
  state = st.session_state.get(key)
58
  if state is None:
59
  state = default_chat_state()
60
  st.session_state[key] = state
61
  else:
62
- for default_key, default_value in default_chat_state().items():
63
- state.setdefault(default_key, default_value)
 
 
64
  _evict_inactive_kv_caches(key)
65
  if remote and state.get("past_key_values") is not None:
66
  state["past_key_values"] = None
 
1
  import streamlit as st
2
 
3
  _CHAT_STATE_PREFIX = "chat_state::"
4
+ _CHAT_KEYS_REGISTRY = "chat_state::_registered_keys"
5
 
6
 
7
  def chat_session_key(model_name: str, dataset_source: str) -> str:
 
38
  def _evict_inactive_kv_caches(active_key: str) -> None:
39
  """Drop past_key_values from every chat context except the active one."""
40
 
41
+ for key in st.session_state.get(_CHAT_KEYS_REGISTRY, ()):
42
+ if key != active_key:
43
+ state = st.session_state.get(key)
 
 
 
 
44
  if isinstance(state, dict) and state.get("past_key_values") is not None:
45
  state["past_key_values"] = None
46
 
 
51
  """Return the mutable chat state for the active context."""
52
 
53
  key = chat_session_key(model_name, dataset_source)
54
+ registry = st.session_state.get(_CHAT_KEYS_REGISTRY)
55
+ if registry is None:
56
+ registry = set()
57
+ st.session_state[_CHAT_KEYS_REGISTRY] = registry
58
+ registry.add(key)
59
+
60
  state = st.session_state.get(key)
61
  if state is None:
62
  state = default_chat_state()
63
  st.session_state[key] = state
64
  else:
65
+ state.setdefault("messages", [])
66
+ state.setdefault("persona_id", None)
67
+ state.setdefault("prompt_mode", "templated")
68
+ state.setdefault("past_key_values", None)
69
  _evict_inactive_kv_caches(key)
70
  if remote and state.get("past_key_values") is not None:
71
  state["past_key_values"] = None
tabs/chat.py CHANGED
@@ -17,10 +17,19 @@ from utils.helpers import (
17
  )
18
  from utils.runtime import cached_model
19
 
20
-
21
- def _render_collapsible_markdown(content: str) -> None:
22
- st.markdown(content)
23
-
 
 
 
 
 
 
 
 
 
24
 
25
  # ── Dialogs ───────────────────────────────────────────────────────────────────
26
 
@@ -91,7 +100,7 @@ def _open_system_prompt_dialog(*, prompt_key: str, current_value: str) -> None:
91
  # ── Message renderers ─────────────────────────────────────────────────────────
92
 
93
 
94
- def _render_chat_message(
95
  message: dict[str, str],
96
  show_contrast: bool = False,
97
  ) -> None:
@@ -103,7 +112,7 @@ def _render_chat_message(
103
  if tc is not None:
104
  st.html(render_contrast_html(tc))
105
  else:
106
- _render_collapsible_markdown(message["content"])
107
 
108
 
109
  def _render_editable_message(
@@ -129,7 +138,7 @@ def _render_editable_message(
129
  if tc is not None:
130
  st.html(render_contrast_html(tc))
131
  else:
132
- _render_collapsible_markdown(message["content"])
133
  with edit_col:
134
  if st.button(
135
  "", icon=":material/edit:", key=f"{edit_key}_edit_{msg_index}", help="Edit"
@@ -142,7 +151,7 @@ def _render_editable_message(
142
  )
143
 
144
 
145
- def _render_system_prompt(
146
  prompt_key: str,
147
  prompt_mode: str,
148
  active_system_prompt: str | None,
@@ -159,7 +168,7 @@ def _render_system_prompt(
159
  return st.session_state.get(prompt_key) or None
160
 
161
 
162
- def _generation_dict(gen_kwargs: dict, advanced_generation: bool) -> dict[str, object]:
163
  return {
164
  "max_new_tokens": int(gen_kwargs["max_new_tokens"]),
165
  "advanced_generation": bool(advanced_generation),
@@ -172,7 +181,7 @@ def _generation_dict(gen_kwargs: dict, advanced_generation: bool) -> dict[str, o
172
  }
173
 
174
 
175
- def _render_persona_prompt_controls(
176
  personas: list[PersonaData],
177
  current_persona_id: str | None,
178
  current_prompt_mode: str,
@@ -209,7 +218,7 @@ def _render_persona_prompt_controls(
209
  return selected_persona, prompt_mode, changed
210
 
211
 
212
- def _render_chat_window(
213
  *,
214
  chat_log: Any,
215
  messages: list[dict[str, str]],
@@ -233,10 +242,10 @@ def _render_chat_window(
233
  column_ratio=edit_column_ratio,
234
  )
235
  else:
236
- _render_chat_message(message, show_contrast=show_contrast)
237
 
238
 
239
- def _build_chat_messages(
240
  system_prompt: str | None,
241
  messages: list[dict[str, str]],
242
  ) -> list[dict[str, str]]:
@@ -245,31 +254,6 @@ def _build_chat_messages(
245
  ) + messages
246
 
247
 
248
- def _save_chat_export_message(
249
- *,
250
- model_name: str,
251
- dataset_source: str,
252
- persona_id: str,
253
- persona_name: str | None,
254
- prompt_mode: str,
255
- system_prompt: str | None,
256
- messages: list[dict[str, str]],
257
- generation: dict[str, object],
258
- panel_label: str | None = None,
259
- ) -> None:
260
- save_chat_export(
261
- model_name=model_name,
262
- dataset_source=dataset_source,
263
- persona_id=persona_id,
264
- persona_name=persona_name,
265
- panel_label=panel_label,
266
- prompt_mode=prompt_mode,
267
- system_prompt=system_prompt,
268
- messages=messages,
269
- generation=generation,
270
- )
271
-
272
-
273
  # ── Main tab entry point ───────────────────────────────────────────────────────
274
 
275
 
@@ -286,7 +270,7 @@ def _render_generation_settings(context_key: str, remote: bool) -> tuple[dict, b
286
  "Max new tokens",
287
  min_value=16,
288
  max_value=512,
289
- value=256,
290
  step=16,
291
  key=widget_key(context_key, "max_new_tokens"),
292
  )
@@ -295,7 +279,7 @@ def _render_generation_settings(context_key: str, remote: bool) -> tuple[dict, b
295
  "Repetition penalty",
296
  min_value=0.5,
297
  max_value=2.0,
298
- value=1.0,
299
  step=0.05,
300
  key=widget_key(context_key, "repetition_penalty"),
301
  )
@@ -313,7 +297,7 @@ def _render_generation_settings(context_key: str, remote: bool) -> tuple[dict, b
313
  "Temperature",
314
  min_value=0.01,
315
  max_value=2.0,
316
- value=1.0,
317
  step=0.01,
318
  disabled=sampling_disabled,
319
  key=widget_key(context_key, "temperature"),
@@ -323,7 +307,7 @@ def _render_generation_settings(context_key: str, remote: bool) -> tuple[dict, b
323
  "Top-p",
324
  min_value=0.01,
325
  max_value=1.0,
326
- value=1.0,
327
  step=0.01,
328
  disabled=sampling_disabled,
329
  key=widget_key(context_key, "top_p"),
@@ -333,7 +317,7 @@ def _render_generation_settings(context_key: str, remote: bool) -> tuple[dict, b
333
  "Top-k (0 = off)",
334
  min_value=0,
335
  max_value=100,
336
- value=50,
337
  step=1,
338
  disabled=sampling_disabled,
339
  key=widget_key(context_key, "top_k"),
@@ -365,12 +349,12 @@ def _render_generation_settings(context_key: str, remote: bool) -> tuple[dict, b
365
  st.caption("Seed is local-only and disabled for remote runs.")
366
 
367
  advanced_generation = (
368
- max_new_tokens != 256
369
  or use_sampling
370
- or temperature != 1.0
371
- or top_p != 1.0
372
- or top_k != 50
373
- or repetition_penalty != 1.0
374
  or seed is not None
375
  )
376
 
@@ -395,6 +379,14 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
395
 
396
  context_key = chat_session_key(model_name, dataset_source)
397
  chat_state = get_chat_state(model_name, remote, dataset_source)
 
 
 
 
 
 
 
 
398
  try:
399
  dataset, dataset_status = load_dataset(
400
  dataset_source,
@@ -416,12 +408,17 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
416
  gen_kwargs, advanced_generation = _render_generation_settings(context_key, remote)
417
 
418
  # ── Mode toggle ───────────────────────────────────────────────────────────
 
 
 
 
 
419
  compare_mode = st.toggle(
420
  "Compare mode",
421
- value=False,
422
- key=widget_key(context_key, "compare_mode"),
423
  help="Side-by-side: send one message to two independent persona/prompt configurations.",
424
  )
 
425
 
426
  if compare_mode:
427
  from tabs.compare_chat import render_compare_mode
@@ -458,7 +455,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
458
  )
459
  st.session_state.pop(edit_key, None)
460
 
461
- selected_persona, prompt_mode, changed_context = _render_persona_prompt_controls(
462
  personas,
463
  chat_state["persona_id"],
464
  chat_state["prompt_mode"],
@@ -466,6 +463,8 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
466
  prompt_mode_select_key,
467
  column_widths=(2, 1),
468
  )
 
 
469
 
470
  active_system_prompt = resolve_system_prompt(
471
  persona=selected_persona,
@@ -481,13 +480,13 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
481
  chat_log = st.container()
482
 
483
  with chat_log:
484
- active_system_prompt = _render_system_prompt(
485
  prompt_key,
486
  prompt_mode,
487
  active_system_prompt,
488
  )
489
 
490
- _render_chat_window(
491
  chat_log=chat_log,
492
  messages=chat_state["messages"],
493
  chat_state=chat_state,
@@ -505,7 +504,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
505
  key=export_key,
506
  help="Export chat",
507
  ):
508
- _save_chat_export_message(
509
  model_name=model_name,
510
  dataset_source=dataset_source,
511
  persona_id=selected_persona.id,
@@ -513,7 +512,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
513
  prompt_mode=prompt_mode,
514
  system_prompt=active_system_prompt,
515
  messages=chat_state["messages"],
516
- generation=_generation_dict(gen_kwargs, advanced_generation),
517
  )
518
  st.toast("Exported", icon=":material/check:")
519
  with rst_col:
@@ -538,7 +537,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
538
  if not st.session_state.pop(pending_key, False):
539
  return
540
 
541
- messages = _build_chat_messages(active_system_prompt, chat_state["messages"])
542
 
543
  with st.spinner("Generating reply..."):
544
  model = cached_model(model_name=model_name, remote=remote)
@@ -559,15 +558,4 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
559
 
560
  chat_state["messages"].append({"role": "assistant", "content": reply.text})
561
  chat_state["past_key_values"] = reply.past_key_values if not remote else None
562
-
563
- save_chat_export(
564
- model_name=model_name,
565
- dataset_source=dataset_source,
566
- persona_id=selected_persona.id,
567
- persona_name=getattr(selected_persona, "name", None),
568
- prompt_mode=prompt_mode,
569
- system_prompt=active_system_prompt,
570
- messages=chat_state["messages"],
571
- generation=_generation_dict(gen_kwargs, advanced_generation),
572
- )
573
  st.rerun()
 
17
  )
18
  from utils.runtime import cached_model
19
 
20
+ # ── Persistence keys for surviving model / remote switches ────────────────────
21
+ _LAST_PERSONA_ID_KEY = "chat:last_persona_id"
22
+ _LAST_PROMPT_MODE_KEY = "chat:last_prompt_mode"
23
+ _LAST_COMPARE_MODE_KEY = "chat:last_compare_mode"
24
+
25
+ # ── Generation defaults (single source of truth) ─────────────────────────────
26
+ _GEN_DEFAULTS = {
27
+ "max_new_tokens": 256,
28
+ "temperature": 1.0,
29
+ "top_p": 1.0,
30
+ "top_k": 50,
31
+ "repetition_penalty": 1.0,
32
+ }
33
 
34
  # ── Dialogs ───────────────────────────────────────────────────────────────────
35
 
 
100
  # ── Message renderers ─────────────────────────────────────────────────────────
101
 
102
 
103
+ def render_chat_message(
104
  message: dict[str, str],
105
  show_contrast: bool = False,
106
  ) -> None:
 
112
  if tc is not None:
113
  st.html(render_contrast_html(tc))
114
  else:
115
+ st.markdown(message["content"])
116
 
117
 
118
  def _render_editable_message(
 
138
  if tc is not None:
139
  st.html(render_contrast_html(tc))
140
  else:
141
+ st.markdown(message["content"])
142
  with edit_col:
143
  if st.button(
144
  "", icon=":material/edit:", key=f"{edit_key}_edit_{msg_index}", help="Edit"
 
151
  )
152
 
153
 
154
+ def render_system_prompt(
155
  prompt_key: str,
156
  prompt_mode: str,
157
  active_system_prompt: str | None,
 
168
  return st.session_state.get(prompt_key) or None
169
 
170
 
171
+ def generation_dict(gen_kwargs: dict, advanced_generation: bool) -> dict[str, object]:
172
  return {
173
  "max_new_tokens": int(gen_kwargs["max_new_tokens"]),
174
  "advanced_generation": bool(advanced_generation),
 
181
  }
182
 
183
 
184
+ def render_persona_prompt_controls(
185
  personas: list[PersonaData],
186
  current_persona_id: str | None,
187
  current_prompt_mode: str,
 
218
  return selected_persona, prompt_mode, changed
219
 
220
 
221
+ def render_chat_window(
222
  *,
223
  chat_log: Any,
224
  messages: list[dict[str, str]],
 
242
  column_ratio=edit_column_ratio,
243
  )
244
  else:
245
+ render_chat_message(message, show_contrast=show_contrast)
246
 
247
 
248
+ def build_chat_messages(
249
  system_prompt: str | None,
250
  messages: list[dict[str, str]],
251
  ) -> list[dict[str, str]]:
 
254
  ) + messages
255
 
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  # ── Main tab entry point ───────────────────────────────────────────────────────
258
 
259
 
 
270
  "Max new tokens",
271
  min_value=16,
272
  max_value=512,
273
+ value=_GEN_DEFAULTS["max_new_tokens"],
274
  step=16,
275
  key=widget_key(context_key, "max_new_tokens"),
276
  )
 
279
  "Repetition penalty",
280
  min_value=0.5,
281
  max_value=2.0,
282
+ value=_GEN_DEFAULTS["repetition_penalty"],
283
  step=0.05,
284
  key=widget_key(context_key, "repetition_penalty"),
285
  )
 
297
  "Temperature",
298
  min_value=0.01,
299
  max_value=2.0,
300
+ value=_GEN_DEFAULTS["temperature"],
301
  step=0.01,
302
  disabled=sampling_disabled,
303
  key=widget_key(context_key, "temperature"),
 
307
  "Top-p",
308
  min_value=0.01,
309
  max_value=1.0,
310
+ value=_GEN_DEFAULTS["top_p"],
311
  step=0.01,
312
  disabled=sampling_disabled,
313
  key=widget_key(context_key, "top_p"),
 
317
  "Top-k (0 = off)",
318
  min_value=0,
319
  max_value=100,
320
+ value=_GEN_DEFAULTS["top_k"],
321
  step=1,
322
  disabled=sampling_disabled,
323
  key=widget_key(context_key, "top_k"),
 
349
  st.caption("Seed is local-only and disabled for remote runs.")
350
 
351
  advanced_generation = (
352
+ max_new_tokens != _GEN_DEFAULTS["max_new_tokens"]
353
  or use_sampling
354
+ or temperature != _GEN_DEFAULTS["temperature"]
355
+ or top_p != _GEN_DEFAULTS["top_p"]
356
+ or top_k != _GEN_DEFAULTS["top_k"]
357
+ or repetition_penalty != _GEN_DEFAULTS["repetition_penalty"]
358
  or seed is not None
359
  )
360
 
 
379
 
380
  context_key = chat_session_key(model_name, dataset_source)
381
  chat_state = get_chat_state(model_name, remote, dataset_source)
382
+
383
+ # Carry over persona / prompt selections across model or remote switches.
384
+ if chat_state["persona_id"] is None:
385
+ chat_state["persona_id"] = st.session_state.get(_LAST_PERSONA_ID_KEY)
386
+ chat_state["prompt_mode"] = st.session_state.get(
387
+ _LAST_PROMPT_MODE_KEY, "templated"
388
+ )
389
+
390
  try:
391
  dataset, dataset_status = load_dataset(
392
  dataset_source,
 
408
  gen_kwargs, advanced_generation = _render_generation_settings(context_key, remote)
409
 
410
  # ── Mode toggle ───────────────────────────────────────────────────────────
411
+ compare_key = widget_key(context_key, "compare_mode")
412
+ if compare_key not in st.session_state:
413
+ st.session_state[compare_key] = st.session_state.get(
414
+ _LAST_COMPARE_MODE_KEY, False
415
+ )
416
  compare_mode = st.toggle(
417
  "Compare mode",
418
+ key=compare_key,
 
419
  help="Side-by-side: send one message to two independent persona/prompt configurations.",
420
  )
421
+ st.session_state[_LAST_COMPARE_MODE_KEY] = compare_mode
422
 
423
  if compare_mode:
424
  from tabs.compare_chat import render_compare_mode
 
455
  )
456
  st.session_state.pop(edit_key, None)
457
 
458
+ selected_persona, prompt_mode, changed_context = render_persona_prompt_controls(
459
  personas,
460
  chat_state["persona_id"],
461
  chat_state["prompt_mode"],
 
463
  prompt_mode_select_key,
464
  column_widths=(2, 1),
465
  )
466
+ st.session_state[_LAST_PERSONA_ID_KEY] = selected_persona.id
467
+ st.session_state[_LAST_PROMPT_MODE_KEY] = prompt_mode
468
 
469
  active_system_prompt = resolve_system_prompt(
470
  persona=selected_persona,
 
480
  chat_log = st.container()
481
 
482
  with chat_log:
483
+ active_system_prompt = render_system_prompt(
484
  prompt_key,
485
  prompt_mode,
486
  active_system_prompt,
487
  )
488
 
489
+ render_chat_window(
490
  chat_log=chat_log,
491
  messages=chat_state["messages"],
492
  chat_state=chat_state,
 
504
  key=export_key,
505
  help="Export chat",
506
  ):
507
+ save_chat_export(
508
  model_name=model_name,
509
  dataset_source=dataset_source,
510
  persona_id=selected_persona.id,
 
512
  prompt_mode=prompt_mode,
513
  system_prompt=active_system_prompt,
514
  messages=chat_state["messages"],
515
+ generation=generation_dict(gen_kwargs, advanced_generation),
516
  )
517
  st.toast("Exported", icon=":material/check:")
518
  with rst_col:
 
537
  if not st.session_state.pop(pending_key, False):
538
  return
539
 
540
+ messages = build_chat_messages(active_system_prompt, chat_state["messages"])
541
 
542
  with st.spinner("Generating reply..."):
543
  model = cached_model(model_name=model_name, remote=remote)
 
558
 
559
  chat_state["messages"].append({"role": "assistant", "content": reply.text})
560
  chat_state["past_key_values"] = reply.past_key_values if not remote else None
 
 
 
 
 
 
 
 
 
 
 
561
  st.rerun()
tabs/compare_chat.py CHANGED
@@ -4,18 +4,18 @@ from persona_data.synth_persona import PersonaData
4
 
5
  from state import default_chat_state, reset_chat_context_state
6
  from utils.chat import ChatReply, generate_chat_reply, resolve_system_prompt
 
7
  from utils.contrast import compute_contrast, compute_contrast_pair
8
  from utils.helpers import persona_label, widget_key
9
  from utils.runtime import cached_model
10
 
11
  from .chat import (
12
- _build_chat_messages,
13
- _generation_dict,
14
- _render_chat_message,
15
- _render_chat_window,
16
- _render_persona_prompt_controls,
17
- _render_system_prompt,
18
- _save_chat_export_message,
19
  )
20
 
21
 
@@ -47,7 +47,7 @@ def _generate_panel_reply(
47
  ) -> ChatReply:
48
  return generate_chat_reply(
49
  model=model,
50
- messages=_build_chat_messages(panel_prompt, panel_state["messages"]),
51
  remote=remote,
52
  past_key_values=panel_state["past_key_values"],
53
  **gen_kwargs,
@@ -90,17 +90,28 @@ def render_compare_mode(
90
  def render_panel(side: str) -> tuple[dict, object, str | None, str, PersonaData]:
91
  panel_key = widget_key(context_key, f"cmp_{side}")
92
  state = _panel_state(panel_key)
 
 
 
 
 
 
 
 
93
  prompt_key = widget_key(panel_key, "custom_prompt")
94
  edit_key = widget_key(panel_key, "edit_idx")
95
  pending_regen_key = widget_key(panel_key, "pending_regen")
96
 
97
- selected_persona, prompt_mode, changed = _render_persona_prompt_controls(
98
  personas,
99
  state["persona_id"],
100
  state["prompt_mode"],
101
  widget_key(panel_key, "persona"),
102
  widget_key(panel_key, "prompt_mode"),
103
  )
 
 
 
104
  if changed:
105
  reset_chat_context_state(
106
  state,
@@ -117,7 +128,7 @@ def render_compare_mode(
117
 
118
  chat_log = st.container()
119
  with chat_log:
120
- active_system_prompt = _render_system_prompt(
121
  prompt_key,
122
  prompt_mode,
123
  active_system_prompt,
@@ -220,10 +231,10 @@ def render_compare_mode(
220
  ):
221
  msg.pop("_needs_contrast", None)
222
  continue
223
- context_a = _build_chat_messages(
224
  left_prompt, left_state["messages"][:msg_idx]
225
  )
226
- context_b = _build_chat_messages(
227
  right_prompt, right_state["messages"][:msg_idx]
228
  )
229
  try:
@@ -256,7 +267,7 @@ def render_compare_mode(
256
  panel_edit_key,
257
  _,
258
  ) in panels:
259
- _render_chat_window(
260
  chat_log=panel_log,
261
  messages=panel_state["messages"],
262
  chat_state=panel_state,
@@ -267,6 +278,9 @@ def render_compare_mode(
267
  )
268
 
269
  footer = st.container()
 
 
 
270
  with footer:
271
  exp_col, rst_col, _spacer = st.columns([0.5, 0.5, 10], gap="xsmall")
272
  with exp_col:
@@ -280,7 +294,7 @@ def render_compare_mode(
280
  ("left", left_state, left_prompt, left_persona),
281
  ("right", right_state, right_prompt, right_persona),
282
  ):
283
- _save_chat_export_message(
284
  model_name=model_name,
285
  dataset_source=dataset_source,
286
  persona_id=panel_persona.id,
@@ -288,15 +302,21 @@ def render_compare_mode(
288
  prompt_mode=panel_state["prompt_mode"],
289
  system_prompt=panel_prompt,
290
  messages=panel_state["messages"],
291
- generation=_generation_dict(gen_kwargs, advanced_generation),
292
  panel_label=side,
293
  )
294
  st.toast("Exported", icon=":material/check:")
295
  with rst_col:
 
 
 
 
 
296
  with st.popover(
297
  "",
298
  icon=":material/delete_sweep:",
299
  help="Reset chat",
 
300
  ):
301
  if st.button(
302
  "Reset left",
@@ -310,6 +330,7 @@ def render_compare_mode(
310
  left_prompt_key,
311
  left_pending_key,
312
  )
 
313
  st.rerun()
314
  if st.button(
315
  "Reset right",
@@ -323,6 +344,7 @@ def render_compare_mode(
323
  right_prompt_key,
324
  right_pending_key,
325
  )
 
326
  st.rerun()
327
  if st.button(
328
  "Reset both",
@@ -345,6 +367,7 @@ def render_compare_mode(
345
  right_prompt_key,
346
  right_pending_key,
347
  )
 
348
  st.rerun()
349
 
350
  user_prompt = st.chat_input(
@@ -360,11 +383,11 @@ def render_compare_mode(
360
  for panel_state, panel_log, _panel_prompt, _p_pending, _panel_edit_key, _ in panels:
361
  panel_state["messages"].append({"role": "user", "content": user_prompt})
362
  with panel_log:
363
- _render_chat_message({"role": "user", "content": user_prompt})
364
 
365
  # Snapshot contexts before the new assistant turn is appended (needed for contrast).
366
  pre_gen_contexts = [
367
- _build_chat_messages(panel_prompt, panel_state["messages"])
368
  for panel_state, _panel_log, panel_prompt, _p_pending, _panel_edit_key, _ in panels
369
  ]
370
 
 
4
 
5
  from state import default_chat_state, reset_chat_context_state
6
  from utils.chat import ChatReply, generate_chat_reply, resolve_system_prompt
7
+ from utils.chat_export import save_chat_export
8
  from utils.contrast import compute_contrast, compute_contrast_pair
9
  from utils.helpers import persona_label, widget_key
10
  from utils.runtime import cached_model
11
 
12
  from .chat import (
13
+ build_chat_messages,
14
+ generation_dict,
15
+ render_chat_message,
16
+ render_chat_window,
17
+ render_persona_prompt_controls,
18
+ render_system_prompt,
 
19
  )
20
 
21
 
 
47
  ) -> ChatReply:
48
  return generate_chat_reply(
49
  model=model,
50
+ messages=build_chat_messages(panel_prompt, panel_state["messages"]),
51
  remote=remote,
52
  past_key_values=panel_state["past_key_values"],
53
  **gen_kwargs,
 
90
  def render_panel(side: str) -> tuple[dict, object, str | None, str, PersonaData]:
91
  panel_key = widget_key(context_key, f"cmp_{side}")
92
  state = _panel_state(panel_key)
93
+
94
+ # Carry over persona / prompt selections across model or remote switches.
95
+ persist_persona_key = f"chat:last_cmp_{side}_persona"
96
+ persist_prompt_key = f"chat:last_cmp_{side}_prompt"
97
+ if state["persona_id"] is None:
98
+ state["persona_id"] = st.session_state.get(persist_persona_key)
99
+ state["prompt_mode"] = st.session_state.get(persist_prompt_key, "templated")
100
+
101
  prompt_key = widget_key(panel_key, "custom_prompt")
102
  edit_key = widget_key(panel_key, "edit_idx")
103
  pending_regen_key = widget_key(panel_key, "pending_regen")
104
 
105
+ selected_persona, prompt_mode, changed = render_persona_prompt_controls(
106
  personas,
107
  state["persona_id"],
108
  state["prompt_mode"],
109
  widget_key(panel_key, "persona"),
110
  widget_key(panel_key, "prompt_mode"),
111
  )
112
+ st.session_state[persist_persona_key] = selected_persona.id
113
+ st.session_state[persist_prompt_key] = prompt_mode
114
+
115
  if changed:
116
  reset_chat_context_state(
117
  state,
 
128
 
129
  chat_log = st.container()
130
  with chat_log:
131
+ active_system_prompt = render_system_prompt(
132
  prompt_key,
133
  prompt_mode,
134
  active_system_prompt,
 
231
  ):
232
  msg.pop("_needs_contrast", None)
233
  continue
234
+ context_a = build_chat_messages(
235
  left_prompt, left_state["messages"][:msg_idx]
236
  )
237
+ context_b = build_chat_messages(
238
  right_prompt, right_state["messages"][:msg_idx]
239
  )
240
  try:
 
267
  panel_edit_key,
268
  _,
269
  ) in panels:
270
+ render_chat_window(
271
  chat_log=panel_log,
272
  messages=panel_state["messages"],
273
  chat_state=panel_state,
 
278
  )
279
 
280
  footer = st.container()
281
+ reset_menu_nonce_key = widget_key(context_key, "cmp_reset_menu_nonce")
282
+ if reset_menu_nonce_key not in st.session_state:
283
+ st.session_state[reset_menu_nonce_key] = 0
284
  with footer:
285
  exp_col, rst_col, _spacer = st.columns([0.5, 0.5, 10], gap="xsmall")
286
  with exp_col:
 
294
  ("left", left_state, left_prompt, left_persona),
295
  ("right", right_state, right_prompt, right_persona),
296
  ):
297
+ save_chat_export(
298
  model_name=model_name,
299
  dataset_source=dataset_source,
300
  persona_id=panel_persona.id,
 
302
  prompt_mode=panel_state["prompt_mode"],
303
  system_prompt=panel_prompt,
304
  messages=panel_state["messages"],
305
+ generation=generation_dict(gen_kwargs, advanced_generation),
306
  panel_label=side,
307
  )
308
  st.toast("Exported", icon=":material/check:")
309
  with rst_col:
310
+ popover_key = widget_key(
311
+ context_key,
312
+ "cmp_reset_menu",
313
+ str(st.session_state[reset_menu_nonce_key]),
314
+ )
315
  with st.popover(
316
  "",
317
  icon=":material/delete_sweep:",
318
  help="Reset chat",
319
+ key=popover_key,
320
  ):
321
  if st.button(
322
  "Reset left",
 
330
  left_prompt_key,
331
  left_pending_key,
332
  )
333
+ st.session_state[reset_menu_nonce_key] += 1
334
  st.rerun()
335
  if st.button(
336
  "Reset right",
 
344
  right_prompt_key,
345
  right_pending_key,
346
  )
347
+ st.session_state[reset_menu_nonce_key] += 1
348
  st.rerun()
349
  if st.button(
350
  "Reset both",
 
367
  right_prompt_key,
368
  right_pending_key,
369
  )
370
+ st.session_state[reset_menu_nonce_key] += 1
371
  st.rerun()
372
 
373
  user_prompt = st.chat_input(
 
383
  for panel_state, panel_log, _panel_prompt, _p_pending, _panel_edit_key, _ in panels:
384
  panel_state["messages"].append({"role": "user", "content": user_prompt})
385
  with panel_log:
386
+ render_chat_message({"role": "user", "content": user_prompt})
387
 
388
  # Snapshot contexts before the new assistant turn is appended (needed for contrast).
389
  pre_gen_contexts = [
390
+ build_chat_messages(panel_prompt, panel_state["messages"])
391
  for panel_state, _panel_log, panel_prompt, _p_pending, _panel_edit_key, _ in panels
392
  ]
393
 
utils/chat.py CHANGED
@@ -73,7 +73,7 @@ def _format_plain_messages(
73
  return "\n\n".join(lines)
74
 
75
 
76
- def _format_generation_prompt(
77
  messages: list[dict[str, str]], tokenizer: object
78
  ) -> tuple[str, int]:
79
  """Render messages into a single prompt string and count prompt tokens.
@@ -169,7 +169,7 @@ def generate_chat_reply(
169
  """
170
 
171
  tokenizer = model.tokenizer
172
- prompt, prompt_token_count = _format_generation_prompt(messages, tokenizer)
173
 
174
  generation_kwargs: dict[str, object] = {
175
  "max_new_tokens": max_new_tokens,
 
73
  return "\n\n".join(lines)
74
 
75
 
76
+ def format_generation_prompt(
77
  messages: list[dict[str, str]], tokenizer: object
78
  ) -> tuple[str, int]:
79
  """Render messages into a single prompt string and count prompt tokens.
 
169
  """
170
 
171
  tokenizer = model.tokenizer
172
+ prompt, prompt_token_count = format_generation_prompt(messages, tokenizer)
173
 
174
  generation_kwargs: dict[str, object] = {
175
  "max_new_tokens": max_new_tokens,
utils/contrast.py CHANGED
@@ -1,7 +1,3 @@
1
- # WARNING: This is mostly vibecoded and need reviews
2
- # - Check that the model is runned once with normally for gneration and things are beeing traced perphaps at the last step of generation with iter.last or somrething liek that from the docs
3
- # - Then the model is runned again with the entire context of the conversation from the other context on the rifht ? or on the left dependeing on which one we are doing at the moment. And this will then compute the prob diff and show them.
4
-
5
  """
6
  Contrastive token-level log-probability comparison for compare mode.
7
 
@@ -15,13 +11,16 @@ Negative (blue) → token is more characteristic of persona B.
15
  Near-zero (gray) → both personas would emit this token with similar likelihood.
16
  """
17
 
 
18
  from dataclasses import dataclass
19
  from html import escape
20
 
21
  import torch
22
  from nnterp import StandardizedTransformer
23
 
24
- from utils.chat import _format_generation_prompt
 
 
25
 
26
 
27
  @dataclass
@@ -48,6 +47,7 @@ def _normalise_diffs(diffs: torch.Tensor) -> list[float]:
48
 
49
 
50
  def _decode_ids(tokenizer: object, ids: list[int]) -> str:
 
51
  try:
52
  return tokenizer.decode(
53
  ids,
@@ -79,15 +79,18 @@ def _prepare_trace_text(
79
  response_ids: torch.Tensor,
80
  ) -> tuple[str, int, int]:
81
  """Build the trace text and return ``(full_text, n_ctx, n_resp)``."""
82
- context_prompt, _ = _format_generation_prompt(context_messages, tokenizer)
83
  context_ids = tokenizer(context_prompt, return_tensors="pt").input_ids[0]
84
  response_text = _decode_ids(tokenizer, response_ids.tolist())
85
  full_text = context_prompt + response_text
86
  full_ids = tokenizer(full_text, return_tensors="pt").input_ids[0]
87
  expected_ids = torch.cat([context_ids, response_ids.cpu()])
88
  if full_ids.tolist() != expected_ids.tolist():
89
- raise ValueError(
90
- "contrast trace text did not round-trip to the expected token ids"
 
 
 
91
  )
92
  n_ctx = len(context_ids)
93
  n_resp = len(response_ids)
 
 
 
 
 
1
  """
2
  Contrastive token-level log-probability comparison for compare mode.
3
 
 
11
  Near-zero (gray) → both personas would emit this token with similar likelihood.
12
  """
13
 
14
+ import logging
15
  from dataclasses import dataclass
16
  from html import escape
17
 
18
  import torch
19
  from nnterp import StandardizedTransformer
20
 
21
+ from utils.chat import format_generation_prompt
22
+
23
+ logger = logging.getLogger(__name__)
24
 
25
 
26
  @dataclass
 
47
 
48
 
49
  def _decode_ids(tokenizer: object, ids: list[int]) -> str:
50
+ """Decode token IDs, falling back when clean_up_tokenization_spaces is unsupported."""
51
  try:
52
  return tokenizer.decode(
53
  ids,
 
79
  response_ids: torch.Tensor,
80
  ) -> tuple[str, int, int]:
81
  """Build the trace text and return ``(full_text, n_ctx, n_resp)``."""
82
+ context_prompt, _ = format_generation_prompt(context_messages, tokenizer)
83
  context_ids = tokenizer(context_prompt, return_tensors="pt").input_ids[0]
84
  response_text = _decode_ids(tokenizer, response_ids.tolist())
85
  full_text = context_prompt + response_text
86
  full_ids = tokenizer(full_text, return_tensors="pt").input_ids[0]
87
  expected_ids = torch.cat([context_ids, response_ids.cpu()])
88
  if full_ids.tolist() != expected_ids.tolist():
89
+ logger.warning(
90
+ "contrast trace text did not round-trip to the expected token ids "
91
+ "(expected %d tokens, got %d); contrast scores may be slightly misaligned",
92
+ len(expected_ids),
93
+ len(full_ids),
94
  )
95
  n_ctx = len(context_ids)
96
  n_resp = len(response_ids)