Jac-Zac commited on
Commit
a9950fb
·
1 Parent(s): e2cecb1

Updated to new chat edit and comparison

Browse files

- Cleaned up chat
- Improved flow and UI
- Added comparison
- Fixed remote to make it more robus

Files changed (10) hide show
  1. README.md +6 -2
  2. state.py +19 -13
  3. tabs/chat.py +183 -416
  4. tabs/compare_chat.py +443 -0
  5. tabs/extract.py +0 -2
  6. utils/chat.py +3 -7
  7. utils/chat_export.py +1 -1
  8. utils/contrast.py +311 -0
  9. utils/helpers.py +0 -2
  10. utils/runtime.py +39 -10
README.md CHANGED
@@ -29,13 +29,17 @@ A web app built on top of [persona-vectors](../persona-vectors) that provides th
29
  persona-ui/
30
  ├── app.py # Main entry point (Streamlit)
31
  ├── state.py # Session state management (chat history, KV cache)
 
 
32
  ├── tabs/
33
  │ ├── chat.py # Chat tab
34
  │ ├── compare.py # Activation comparison tab
 
35
  │ └── extract.py # Extraction tab
36
  └── utils/
37
  ├── chat.py # Chat generation logic
38
  ├── chat_export.py # Export chat logs to JSON
 
39
  ├── datasets.py # Dataset loader wrapper
40
  ├── helpers.py # UI labels and slug helpers
41
  └── runtime.py # Model caching and NDIF queries
@@ -121,8 +125,8 @@ artifacts/
121
  ├── activations/<model_dir>/<prompt_variant>/<persona_id>/
122
  │ ├── activations.safetensors
123
  │ └── metadata.json # used for persona names and layer counts
124
- └── chats/<model_dir>/<prompt_variant>/
125
  └── <export>.json
126
  ```
127
 
128
- `<model_dir>` is the model name with `/` replaced by `__` (e.g. `google__gemma-2-9b-it`).
 
29
  persona-ui/
30
  ├── app.py # Main entry point (Streamlit)
31
  ├── state.py # Session state management (chat history, KV cache)
32
+ ├── scripts/
33
+ │ └── oracle_probe.py # Notebook-style activation oracle script
34
  ├── tabs/
35
  │ ├── chat.py # Chat tab
36
  │ ├── compare.py # Activation comparison tab
37
+ │ ├── compare_chat.py # Side-by-side chat comparison mode
38
  │ └── extract.py # Extraction tab
39
  └── utils/
40
  ├── chat.py # Chat generation logic
41
  ├── chat_export.py # Export chat logs to JSON
42
+ ├── contrast.py # Contrastive token log-prob coloring
43
  ├── datasets.py # Dataset loader wrapper
44
  ├── helpers.py # UI labels and slug helpers
45
  └── runtime.py # Model caching and NDIF queries
 
125
  ├── activations/<model_dir>/<prompt_variant>/<persona_id>/
126
  │ ├── activations.safetensors
127
  │ └── metadata.json # used for persona names and layer counts
128
+ └── chats/<model_dir>/<persona_id>/
129
  └── <export>.json
130
  ```
131
 
132
+ `<model_dir>` is the model name with `/` replaced by `__` (e.g. `google__gemma-2-9b-it`). Chat exports still store `dataset_source` in the JSON payload.
state.py CHANGED
@@ -9,7 +9,7 @@ def chat_session_key(model_name: str, dataset_source: str) -> str:
9
  return f"{_CHAT_STATE_PREFIX}{model_name}::{dataset_source}"
10
 
11
 
12
- def _default_chat_state() -> dict[str, object]:
13
  return {
14
  "messages": [],
15
  "persona_id": None,
@@ -18,6 +18,22 @@ def _default_chat_state() -> dict[str, object]:
18
  }
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def _evict_inactive_kv_caches(active_key: str) -> None:
22
  """Drop past_key_values from every chat context except the active one."""
23
 
@@ -40,22 +56,12 @@ def get_chat_state(
40
  key = chat_session_key(model_name, dataset_source)
41
  state = st.session_state.get(key)
42
  if state is None:
43
- state = _default_chat_state()
44
  st.session_state[key] = state
45
  else:
46
- for default_key, default_value in _default_chat_state().items():
47
  state.setdefault(default_key, default_value)
48
  _evict_inactive_kv_caches(key)
49
  if remote and state.get("past_key_values") is not None:
50
  state["past_key_values"] = None
51
  return state
52
-
53
-
54
- def reset_chat_state(model_name: str, dataset_source: str) -> None:
55
- """Reset chat history and cache for the active context."""
56
-
57
- key = chat_session_key(model_name, dataset_source)
58
- if key in st.session_state:
59
- state = st.session_state[key]
60
- state["messages"] = []
61
- state["past_key_values"] = None
 
9
  return f"{_CHAT_STATE_PREFIX}{model_name}::{dataset_source}"
10
 
11
 
12
+ def default_chat_state() -> dict[str, object]:
13
  return {
14
  "messages": [],
15
  "persona_id": None,
 
18
  }
19
 
20
 
21
+ def reset_chat_context_state(
22
+ state: dict[str, object],
23
+ persona_id: str,
24
+ prompt_mode: str,
25
+ *ui_keys: str,
26
+ ) -> None:
27
+ """Reset one chat context and clear any related widget state."""
28
+
29
+ state["messages"] = []
30
+ state["past_key_values"] = None
31
+ state["persona_id"] = persona_id
32
+ state["prompt_mode"] = prompt_mode
33
+ for key in ui_keys:
34
+ st.session_state.pop(key, None)
35
+
36
+
37
  def _evict_inactive_kv_caches(active_key: str) -> None:
38
  """Drop past_key_values from every chat context except the active one."""
39
 
 
56
  key = chat_session_key(model_name, dataset_source)
57
  state = st.session_state.get(key)
58
  if state is None:
59
+ state = default_chat_state()
60
  st.session_state[key] = state
61
  else:
62
+ for default_key, default_value in default_chat_state().items():
63
  state.setdefault(default_key, default_value)
64
  _evict_inactive_kv_caches(key)
65
  if remote and state.get("past_key_values") is not None:
66
  state["past_key_values"] = None
67
  return state
 
 
 
 
 
 
 
 
 
 
tabs/chat.py CHANGED
@@ -1,72 +1,109 @@
1
- from concurrent.futures import ThreadPoolExecutor
2
  from typing import Any
3
 
4
  import streamlit as st
5
  from persona_data.synth_persona import PersonaData
6
 
7
- from state import (
8
- _default_chat_state,
9
- chat_session_key,
10
- get_chat_state,
11
- reset_chat_state,
12
- )
13
  from utils.chat import ChatReply, generate_chat_reply, resolve_system_prompt
14
  from utils.chat_export import save_chat_export
 
15
  from utils.datasets import load_dataset
16
  from utils.helpers import (
17
  MODE_LABEL_TO_KEY,
18
  MODE_LABELS,
19
  VARIANT_LABELS,
20
- VISIBLE_MESSAGE_COUNT,
21
  persona_label,
22
  widget_key,
23
  )
24
  from utils.runtime import cached_model
25
 
26
- COLLAPSED_MESSAGE_CHAR_LIMIT = 500
27
-
28
 
29
  def _render_collapsible_markdown(content: str) -> None:
30
- if len(content) <= COLLAPSED_MESSAGE_CHAR_LIMIT:
31
- st.markdown(content)
32
- return
33
 
34
- with st.expander(f"Show full text ({len(content)} chars)", expanded=False):
35
- st.markdown(content)
36
 
 
37
 
38
- def _render_chat_message(message: dict[str, str]) -> None:
39
- if not message.get("content"):
40
- return
41
- with st.container(border=True):
42
- st.caption(message["role"])
43
- _render_collapsible_markdown(message["content"])
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- def _render_inline_system_prompt(
47
- prompt_key: str,
48
- prompt_mode: str,
49
- active_system_prompt: str | None,
50
- height: int = 200,
51
- ) -> str | None:
52
- """Render the system prompt as an always-editable text area at the top of the chat."""
53
- if prompt_mode == "empty":
54
- return active_system_prompt
55
 
56
- if prompt_key not in st.session_state:
57
- st.session_state[prompt_key] = active_system_prompt or ""
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- with st.container(border=True):
60
- st.caption("System prompt")
61
- st.text_area(
62
- "system_prompt_edit",
63
- value=st.session_state[prompt_key],
64
- height=height,
65
- label_visibility="collapsed",
66
- key=prompt_key,
67
- )
68
 
69
- return st.session_state.get(prompt_key) or None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
 
72
  def _render_editable_message(
@@ -76,63 +113,50 @@ def _render_editable_message(
76
  chat_state: dict[str, object],
77
  edit_key: str,
78
  pending_key: str,
 
 
79
  ) -> None:
80
- """Render a single message with an inline edit button."""
81
  if not message.get("content"):
82
  return
 
 
83
 
84
- is_editing = st.session_state.get(edit_key) == msg_index
85
-
86
- with st.container(border=True):
87
- st.caption(message["role"])
88
- if is_editing:
89
- new_content = st.text_area(
90
- "Edit",
91
- value=message["content"],
92
- height=100,
93
- label_visibility="collapsed",
94
- key=f"{edit_key}_msg_{msg_index}",
 
 
 
 
 
 
 
95
  )
96
- c1, c2 = st.columns(2)
97
- with c1:
98
- if st.button(
99
- "Save", key=f"{edit_key}_msg_save_{msg_index}", type="primary"
100
- ):
101
- messages[msg_index]["content"] = new_content
102
- del messages[msg_index + 1 :]
103
- chat_state["past_key_values"] = None
104
- st.session_state[edit_key] = None
105
- if message["role"] == "user":
106
- st.session_state[pending_key] = True
107
- st.rerun()
108
- with c2:
109
- if st.button("Cancel", key=f"{edit_key}_msg_cancel_{msg_index}"):
110
- st.session_state[edit_key] = None
111
- st.rerun()
112
- else:
113
- st.markdown(message["content"])
114
- if st.button("Edit", key=f"{edit_key}_msg_edit_{msg_index}"):
115
- st.session_state[edit_key] = msg_index
116
- st.rerun()
117
-
118
 
119
- def _clear_chat_ui_state(*keys: str) -> None:
120
- for key in keys:
121
- st.session_state.pop(key, None)
122
 
123
-
124
- def _reset_single_chat_context(
125
- model_name: str,
126
- dataset_source: str,
127
- chat_state: dict[str, object],
128
- persona_id: str,
129
  prompt_mode: str,
130
- *ui_keys: str,
131
- ) -> None:
132
- reset_chat_state(model_name, dataset_source)
133
- chat_state["persona_id"] = persona_id
134
- chat_state["prompt_mode"] = prompt_mode
135
- _clear_chat_ui_state(*ui_keys)
 
 
 
 
 
 
136
 
137
 
138
  def _generation_dict(gen_kwargs: dict, advanced_generation: bool) -> dict[str, object]:
@@ -189,42 +213,27 @@ def _render_chat_window(
189
  *,
190
  chat_log: Any,
191
  messages: list[dict[str, str]],
192
- show_all_key: str,
193
- show_all_btn_key: str,
194
- show_earlier_label: str,
195
  chat_state: dict[str, object] | None = None,
196
  edit_key: str | None = None,
197
  pending_key: str | None = None,
198
- ) -> Any:
199
- """Render the visible chat history inside one container."""
200
-
201
  with chat_log:
202
- if len(messages) > VISIBLE_MESSAGE_COUNT and not st.session_state.get(
203
- show_all_key, False
204
- ):
205
- hidden_count = len(messages) - VISIBLE_MESSAGE_COUNT
206
- if st.button(
207
- f"{show_earlier_label} ({hidden_count} hidden)",
208
- key=show_all_btn_key,
209
- ):
210
- st.session_state[show_all_key] = True
211
- st.rerun()
212
- visible_messages = messages[-VISIBLE_MESSAGE_COUNT:]
213
- index_offset = len(messages) - VISIBLE_MESSAGE_COUNT
214
- else:
215
- visible_messages = messages
216
- index_offset = 0
217
-
218
- for i, message in enumerate(visible_messages):
219
- actual_index = index_offset + i
220
  if edit_key and pending_key:
221
  _render_editable_message(
222
- message, actual_index, messages, chat_state, edit_key, pending_key
 
 
 
 
 
 
 
223
  )
224
  else:
225
- _render_chat_message(message)
226
-
227
- return chat_log
228
 
229
 
230
  def _build_chat_messages(
@@ -247,8 +256,8 @@ def _save_chat_export_message(
247
  messages: list[dict[str, str]],
248
  generation: dict[str, object],
249
  panel_label: str | None = None,
250
- ) -> str:
251
- export_path = save_chat_export(
252
  model_name=model_name,
253
  dataset_source=dataset_source,
254
  persona_id=persona_id,
@@ -259,230 +268,12 @@ def _save_chat_export_message(
259
  messages=messages,
260
  generation=generation,
261
  )
262
- return f"Saved chat export to {export_path}"
263
-
264
-
265
- # ── Compare mode helpers ───────────────────────────────────────────────────────
266
-
267
-
268
- def _panel_state(panel_key: str) -> dict:
269
- """Get or initialise compare-panel chat state stored in session_state."""
270
- if panel_key not in st.session_state:
271
- st.session_state[panel_key] = _default_chat_state()
272
- return st.session_state[panel_key]
273
-
274
-
275
- def _render_compare_mode(
276
- remote: bool,
277
- model_name: str,
278
- context_key: str,
279
- dataset_source: str,
280
- personas: list[PersonaData],
281
- gen_kwargs: dict,
282
- advanced_generation: bool,
283
- ) -> None:
284
- """Render the full side-by-side comparison UI."""
285
- left_col, right_col = st.columns(2)
286
-
287
- def render_panel(side: str) -> tuple[dict[str, object], Any, str | None, str]:
288
- panel_key = widget_key(context_key, f"cmp_{side}")
289
- state = _panel_state(panel_key)
290
- prompt_key = widget_key(panel_key, "custom_prompt")
291
- show_all_key = widget_key(panel_key, "show_all")
292
- edit_key = widget_key(panel_key, "edit_idx")
293
- pending_regen_key = widget_key(panel_key, "pending_regen")
294
-
295
- selected_persona, prompt_mode, changed = _render_persona_prompt_controls(
296
- personas,
297
- state["persona_id"],
298
- state["prompt_mode"],
299
- widget_key(panel_key, "persona"),
300
- widget_key(panel_key, "prompt_mode"),
301
- )
302
- if changed:
303
- state["messages"] = []
304
- state["past_key_values"] = None
305
- state["persona_id"] = selected_persona.id
306
- state["prompt_mode"] = prompt_mode
307
- _clear_chat_ui_state(prompt_key, show_all_key)
308
- st.session_state.pop(edit_key, None)
309
-
310
- active_system_prompt = resolve_system_prompt(
311
- persona=selected_persona, mode=prompt_mode
312
- )
313
-
314
- btn_col1, btn_col2 = st.columns(2)
315
- with btn_col1:
316
- if st.button(
317
- "Export chat", key=widget_key(panel_key, "export_chat"), width="stretch"
318
- ):
319
- st.success(
320
- _save_chat_export_message(
321
- model_name=model_name,
322
- dataset_source=dataset_source,
323
- persona_id=selected_persona.id,
324
- persona_name=getattr(selected_persona, "name", None),
325
- prompt_mode=prompt_mode,
326
- system_prompt=active_system_prompt,
327
- messages=state["messages"],
328
- generation=_generation_dict(gen_kwargs, advanced_generation),
329
- panel_label=side,
330
- )
331
- )
332
- with btn_col2:
333
- if st.button(
334
- "Reset chat",
335
- key=widget_key(panel_key, "reset"),
336
- width="stretch",
337
- type="secondary",
338
- ):
339
- state["messages"] = []
340
- state["past_key_values"] = None
341
- _clear_chat_ui_state(prompt_key, show_all_key)
342
- st.session_state.pop(edit_key, None)
343
- st.rerun()
344
-
345
- chat_log = st.container()
346
- with chat_log:
347
- active_system_prompt = _render_inline_system_prompt(
348
- prompt_key,
349
- prompt_mode,
350
- active_system_prompt,
351
- height=150,
352
- )
353
- _render_chat_window(
354
- chat_log=chat_log,
355
- messages=state["messages"],
356
- show_all_key=show_all_key,
357
- show_all_btn_key=widget_key(panel_key, "show_all_btn"),
358
- show_earlier_label="Show earlier",
359
- chat_state=state,
360
- edit_key=edit_key,
361
- pending_key=pending_regen_key,
362
- )
363
- return state, chat_log, active_system_prompt, pending_regen_key
364
-
365
- with left_col:
366
- left_state, left_log, left_prompt, left_pending = render_panel("left")
367
- with right_col:
368
- right_state, right_log, right_prompt, right_pending = render_panel("right")
369
-
370
- panels = [
371
- (left_state, left_log, left_prompt, left_pending),
372
- (right_state, right_log, right_prompt, right_pending),
373
- ]
374
-
375
- # Handle per-panel regeneration triggered by message edits
376
- any_regen = any(st.session_state.get(p_pending) for _, _, _, p_pending in panels)
377
- if any_regen:
378
- model = cached_model(model_name=model_name, remote=remote)
379
- for panel_state, panel_log, panel_prompt, p_pending in panels:
380
- if not st.session_state.pop(p_pending, False):
381
- continue
382
- regen_messages = _build_chat_messages(panel_prompt, panel_state["messages"])
383
- with st.spinner("Regenerating..."):
384
- try:
385
- result = generate_chat_reply(
386
- model=model,
387
- messages=regen_messages,
388
- remote=remote,
389
- past_key_values=panel_state["past_key_values"],
390
- **gen_kwargs,
391
- )
392
- except Exception as exc:
393
- with panel_log:
394
- st.error(f"Generation failed: {exc}")
395
- panel_state["messages"].pop()
396
- continue
397
- panel_state["messages"].append(
398
- {"role": "assistant", "content": result.text}
399
- )
400
- panel_state["past_key_values"] = (
401
- result.past_key_values if not remote else None
402
- )
403
- with panel_log:
404
- _render_chat_message({"role": "assistant", "content": result.text})
405
- st.rerun()
406
-
407
- user_prompt = st.chat_input(
408
- "Ask both...",
409
- key=widget_key(context_key, "cmp_input"),
410
- )
411
- if not user_prompt:
412
- return
413
-
414
- model = cached_model(model_name=model_name, remote=remote)
415
-
416
- for panel_state, panel_log, _panel_prompt, _p_pending in panels:
417
- panel_state["messages"].append({"role": "user", "content": user_prompt})
418
- with panel_log:
419
- _render_chat_message({"role": "user", "content": user_prompt})
420
-
421
- with st.spinner("Generating..."):
422
- if remote:
423
- with ThreadPoolExecutor(max_workers=2) as executor:
424
- futures = [
425
- executor.submit(
426
- generate_chat_reply,
427
- model=model,
428
- messages=_build_chat_messages(
429
- panel_prompt, panel_state["messages"]
430
- ),
431
- remote=remote,
432
- past_key_values=panel_state["past_key_values"],
433
- **gen_kwargs,
434
- )
435
- for panel_state, _panel_log, panel_prompt, _p_pending in panels
436
- ]
437
- results: list[ChatReply | Exception] = []
438
- for future in futures:
439
- try:
440
- results.append(future.result())
441
- except Exception as exc:
442
- results.append(exc)
443
- else:
444
- results = []
445
- for panel_state, _panel_log, panel_prompt, _p_pending in panels:
446
- try:
447
- results.append(
448
- generate_chat_reply(
449
- model=model,
450
- messages=_build_chat_messages(
451
- panel_prompt, panel_state["messages"]
452
- ),
453
- remote=remote,
454
- past_key_values=panel_state["past_key_values"],
455
- **gen_kwargs,
456
- )
457
- )
458
- except Exception as exc:
459
- results.append(exc)
460
-
461
- for (panel_state, panel_log, _panel_prompt, _p_pending), result in zip(
462
- panels, results
463
- ):
464
- if isinstance(result, Exception):
465
- with panel_log:
466
- st.error(f"Generation failed: {result}")
467
- panel_state["messages"].pop()
468
- continue
469
-
470
- panel_state["messages"].append({"role": "assistant", "content": result.text})
471
- panel_state["past_key_values"] = result.past_key_values if not remote else None
472
- with panel_log:
473
- _render_chat_message({"role": "assistant", "content": result.text})
474
-
475
- # Rerun so the newly appended turns are redrawn through the editable history
476
- # renderer instead of only appearing in the one-off generation pass.
477
- st.rerun()
478
 
479
 
480
  # ── Main tab entry point ───────────────────────────────────────────────────────
481
 
482
 
483
- def _render_generation_settings(
484
- context_key: str, remote: bool
485
- ) -> tuple[dict, bool]:
486
  """Render the Advanced generation settings expander.
487
 
488
  Returns ``(gen_kwargs, advanced_generation)`` where ``advanced_generation``
@@ -633,7 +424,9 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
633
  )
634
 
635
  if compare_mode:
636
- _render_compare_mode(
 
 
637
  remote,
638
  model_name,
639
  context_key,
@@ -648,76 +441,70 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
648
  persona_select_key = widget_key(context_key, "persona_select")
649
  prompt_mode_select_key = widget_key(context_key, "system_prompt_select")
650
  prompt_key = widget_key(context_key, "custom_system_prompt")
651
- show_all_key = widget_key(context_key, "show_all_messages")
652
  chat_input_key = widget_key(context_key, "chat_input")
653
  pending_key = widget_key(context_key, "pending_prompt")
654
  export_key = widget_key(context_key, "export_chat")
655
  reset_key = widget_key(context_key, "reset")
656
  edit_key = widget_key(context_key, "edit_idx")
657
 
658
- col1, col2 = st.columns([2, 1])
659
- with col1:
660
- selected_index = next(
661
- (i for i, p in enumerate(personas) if p.id == chat_state["persona_id"]),
662
- 0,
663
- )
664
- selected_persona = st.selectbox(
665
- "Persona",
666
- options=personas,
667
- index=selected_index,
668
- format_func=persona_label,
669
- key=persona_select_key,
670
- )
671
- with col2:
672
- current_mode_label = VARIANT_LABELS.get(chat_state["prompt_mode"], "None")
673
- st.selectbox(
674
- "Prompt",
675
- options=MODE_LABELS,
676
- index=MODE_LABELS.index(current_mode_label),
677
- key=prompt_mode_select_key,
678
  )
679
- prompt_mode = MODE_LABEL_TO_KEY[st.session_state[prompt_mode_select_key]]
 
 
 
 
 
 
 
 
 
680
 
681
  active_system_prompt = resolve_system_prompt(
682
  persona=selected_persona,
683
  mode=prompt_mode,
684
  )
685
 
686
- changed_context = (
687
- chat_state["persona_id"] != selected_persona.id
688
- or chat_state["prompt_mode"] != prompt_mode
689
- )
690
  if changed_context:
691
  had_history = bool(chat_state["messages"])
692
- _reset_single_chat_context(
693
- model_name,
694
- dataset_source,
695
- chat_state,
696
- selected_persona.id,
697
- prompt_mode,
698
- chat_input_key,
699
- show_all_key,
700
- prompt_key,
701
- pending_key,
702
- )
703
- st.session_state.pop(edit_key, None)
704
  if had_history:
705
  st.info("Chat history reset because the persona or system prompt changed.")
706
 
707
  chat_log = st.container()
708
 
709
  with chat_log:
710
- active_system_prompt = _render_inline_system_prompt(
711
  prompt_key,
712
  prompt_mode,
713
  active_system_prompt,
714
- height=200,
715
  )
716
 
717
- action_col1, action_col2 = st.columns(2)
718
- with action_col1:
719
- if st.button("Export chat", key=export_key, width="stretch"):
720
- st.success(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721
  _save_chat_export_message(
722
  model_name=model_name,
723
  dataset_source=dataset_source,
@@ -728,38 +515,18 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
728
  messages=chat_state["messages"],
729
  generation=_generation_dict(gen_kwargs, advanced_generation),
730
  )
731
- )
732
- with action_col2:
733
- if st.button("Reset chat", key=reset_key, width="stretch", type="secondary"):
734
- _reset_single_chat_context(
735
- model_name,
736
- dataset_source,
737
- chat_state,
738
- selected_persona.id,
739
- prompt_mode,
740
- chat_input_key,
741
- show_all_key,
742
- prompt_key,
743
- pending_key,
744
- )
745
- st.session_state.pop(edit_key, None)
746
- st.rerun()
747
-
748
- _render_chat_window(
749
- chat_log=chat_log,
750
- messages=chat_state["messages"],
751
- show_all_key=show_all_key,
752
- show_all_btn_key=widget_key(context_key, "show_all_btn"),
753
- show_earlier_label="Show earlier messages",
754
- chat_state=chat_state,
755
- edit_key=edit_key,
756
- pending_key=pending_key,
757
- )
758
 
759
- user_prompt = st.chat_input(
760
- "Ask something...",
761
- key=chat_input_key,
762
- )
763
 
764
  # Pass 1: user submitted — append message and rerun so it renders before generation.
765
  if user_prompt:
 
 
1
  from typing import Any
2
 
3
  import streamlit as st
4
  from persona_data.synth_persona import PersonaData
5
 
6
+ from state import chat_session_key, get_chat_state, reset_chat_context_state
 
 
 
 
 
7
  from utils.chat import ChatReply, generate_chat_reply, resolve_system_prompt
8
  from utils.chat_export import save_chat_export
9
+ from utils.contrast import TokenContrast, render_contrast_html
10
  from utils.datasets import load_dataset
11
  from utils.helpers import (
12
  MODE_LABEL_TO_KEY,
13
  MODE_LABELS,
14
  VARIANT_LABELS,
 
15
  persona_label,
16
  widget_key,
17
  )
18
  from utils.runtime import cached_model
19
 
 
 
20
 
21
  def _render_collapsible_markdown(content: str) -> None:
22
+ st.markdown(content)
 
 
23
 
 
 
24
 
25
+ # ── Dialogs ───────────────────────────────────────────────────────────────────
26
 
 
 
 
 
 
 
27
 
28
+ @st.dialog("Edit", width="medium")
29
+ def _open_edit_dialog(
30
+ *,
31
+ msg_index: int,
32
+ messages: list[dict[str, str]],
33
+ chat_state: dict[str, object],
34
+ pending_key: str,
35
+ ) -> None:
36
+ message = messages[msg_index]
37
+ role = message["role"]
38
+
39
+ n_after = len(messages) - msg_index - 1
40
+ st.caption(
41
+ f"**{role}**"
42
+ + (
43
+ f" — {n_after} subsequent {'message' if n_after == 1 else 'messages'} will be cleared"
44
+ if n_after > 0
45
+ else ""
46
+ )
47
+ )
48
 
49
+ new_content = st.text_area(
50
+ "Content",
51
+ value=message["content"],
52
+ height=320,
53
+ label_visibility="collapsed",
54
+ )
 
 
 
55
 
56
+ save_col, cancel_col = st.columns(2)
57
+ with save_col:
58
+ if st.button("Save", type="primary", use_container_width=True):
59
+ messages[msg_index]["content"] = new_content
60
+ messages[msg_index].pop("_contrast", None)
61
+ if role == "assistant":
62
+ messages[msg_index]["_needs_contrast"] = True
63
+ del messages[msg_index + 1 :]
64
+ chat_state["past_key_values"] = None
65
+ if role == "user":
66
+ st.session_state[pending_key] = True
67
+ st.rerun()
68
+ with cancel_col:
69
+ if st.button("Cancel", use_container_width=True):
70
+ st.rerun()
71
 
 
 
 
 
 
 
 
 
 
72
 
73
+ @st.dialog("Edit system prompt", width="large")
74
+ def _open_system_prompt_dialog(*, prompt_key: str, current_value: str) -> None:
75
+ new_value = st.text_area(
76
+ "System prompt",
77
+ value=current_value,
78
+ height=320,
79
+ label_visibility="collapsed",
80
+ )
81
+ save_col, cancel_col = st.columns(2)
82
+ with save_col:
83
+ if st.button("Save", type="primary", use_container_width=True):
84
+ st.session_state[prompt_key] = new_value
85
+ st.rerun()
86
+ with cancel_col:
87
+ if st.button("Cancel", use_container_width=True):
88
+ st.rerun()
89
+
90
+
91
+ # ── Message renderers ─────────────────────────────────────────────────────────
92
+
93
+
94
+ def _render_chat_message(
95
+ message: dict[str, str],
96
+ show_contrast: bool = False,
97
+ ) -> None:
98
+ if not message.get("content"):
99
+ return
100
+ role = message["role"]
101
+ tc: TokenContrast | None = message.get("_contrast") if show_contrast else None
102
+ with st.chat_message(role):
103
+ if tc is not None:
104
+ st.html(render_contrast_html(tc))
105
+ else:
106
+ _render_collapsible_markdown(message["content"])
107
 
108
 
109
  def _render_editable_message(
 
113
  chat_state: dict[str, object],
114
  edit_key: str,
115
  pending_key: str,
116
+ show_contrast: bool = False,
117
+ column_ratio: tuple[int, int] = (25, 1),
118
  ) -> None:
 
119
  if not message.get("content"):
120
  return
121
+ role = message["role"]
122
+ tc: TokenContrast | None = message.get("_contrast") if show_contrast else None
123
 
124
+ msg_col, edit_col = st.columns(
125
+ list(column_ratio), gap="xsmall", vertical_alignment="center"
126
+ )
127
+ with msg_col:
128
+ with st.chat_message(role):
129
+ if tc is not None:
130
+ st.html(render_contrast_html(tc))
131
+ else:
132
+ _render_collapsible_markdown(message["content"])
133
+ with edit_col:
134
+ if st.button(
135
+ "", icon=":material/edit:", key=f"{edit_key}_edit_{msg_index}", help="Edit"
136
+ ):
137
+ _open_edit_dialog(
138
+ msg_index=msg_index,
139
+ messages=messages,
140
+ chat_state=chat_state,
141
+ pending_key=pending_key,
142
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
 
 
 
144
 
145
+ def _render_system_prompt(
146
+ prompt_key: str,
 
 
 
 
147
  prompt_mode: str,
148
+ active_system_prompt: str | None,
149
+ ) -> str | None:
150
+ if prompt_key not in st.session_state:
151
+ st.session_state[prompt_key] = active_system_prompt or ""
152
+ current = st.session_state.get(prompt_key) or ""
153
+ with st.expander("System prompt"):
154
+ st.markdown(current or "*empty*")
155
+ if prompt_mode != "empty" and st.button(
156
+ "Edit", icon=":material/edit:", key=f"{prompt_key}_edit"
157
+ ):
158
+ _open_system_prompt_dialog(prompt_key=prompt_key, current_value=current)
159
+ return st.session_state.get(prompt_key) or None
160
 
161
 
162
  def _generation_dict(gen_kwargs: dict, advanced_generation: bool) -> dict[str, object]:
 
213
  *,
214
  chat_log: Any,
215
  messages: list[dict[str, str]],
 
 
 
216
  chat_state: dict[str, object] | None = None,
217
  edit_key: str | None = None,
218
  pending_key: str | None = None,
219
+ show_contrast: bool = False,
220
+ edit_column_ratio: tuple[int, int] = (25, 1),
221
+ ) -> None:
222
  with chat_log:
223
+ for i, message in enumerate(messages):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  if edit_key and pending_key:
225
  _render_editable_message(
226
+ message,
227
+ i,
228
+ messages,
229
+ chat_state,
230
+ edit_key,
231
+ pending_key,
232
+ show_contrast=show_contrast,
233
+ column_ratio=edit_column_ratio,
234
  )
235
  else:
236
+ _render_chat_message(message, show_contrast=show_contrast)
 
 
237
 
238
 
239
  def _build_chat_messages(
 
256
  messages: list[dict[str, str]],
257
  generation: dict[str, object],
258
  panel_label: str | None = None,
259
+ ) -> None:
260
+ save_chat_export(
261
  model_name=model_name,
262
  dataset_source=dataset_source,
263
  persona_id=persona_id,
 
268
  messages=messages,
269
  generation=generation,
270
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
 
273
  # ── Main tab entry point ───────────────────────────────────────────────────────
274
 
275
 
276
+ def _render_generation_settings(context_key: str, remote: bool) -> tuple[dict, bool]:
 
 
277
  """Render the Advanced generation settings expander.
278
 
279
  Returns ``(gen_kwargs, advanced_generation)`` where ``advanced_generation``
 
424
  )
425
 
426
  if compare_mode:
427
+ from tabs.compare_chat import render_compare_mode
428
+
429
+ render_compare_mode(
430
  remote,
431
  model_name,
432
  context_key,
 
441
  persona_select_key = widget_key(context_key, "persona_select")
442
  prompt_mode_select_key = widget_key(context_key, "system_prompt_select")
443
  prompt_key = widget_key(context_key, "custom_system_prompt")
 
444
  chat_input_key = widget_key(context_key, "chat_input")
445
  pending_key = widget_key(context_key, "pending_prompt")
446
  export_key = widget_key(context_key, "export_chat")
447
  reset_key = widget_key(context_key, "reset")
448
  edit_key = widget_key(context_key, "edit_idx")
449
 
450
+ def _reset_active_chat_context() -> None:
451
+ reset_chat_context_state(
452
+ chat_state,
453
+ selected_persona.id,
454
+ prompt_mode,
455
+ chat_input_key,
456
+ prompt_key,
457
+ pending_key,
 
 
 
 
 
 
 
 
 
 
 
 
458
  )
459
+ st.session_state.pop(edit_key, None)
460
+
461
+ selected_persona, prompt_mode, changed_context = _render_persona_prompt_controls(
462
+ personas,
463
+ chat_state["persona_id"],
464
+ chat_state["prompt_mode"],
465
+ persona_select_key,
466
+ prompt_mode_select_key,
467
+ column_widths=(2, 1),
468
+ )
469
 
470
  active_system_prompt = resolve_system_prompt(
471
  persona=selected_persona,
472
  mode=prompt_mode,
473
  )
474
 
 
 
 
 
475
  if changed_context:
476
  had_history = bool(chat_state["messages"])
477
+ _reset_active_chat_context()
 
 
 
 
 
 
 
 
 
 
 
478
  if had_history:
479
  st.info("Chat history reset because the persona or system prompt changed.")
480
 
481
  chat_log = st.container()
482
 
483
  with chat_log:
484
+ active_system_prompt = _render_system_prompt(
485
  prompt_key,
486
  prompt_mode,
487
  active_system_prompt,
 
488
  )
489
 
490
+ _render_chat_window(
491
+ chat_log=chat_log,
492
+ messages=chat_state["messages"],
493
+ chat_state=chat_state,
494
+ edit_key=edit_key,
495
+ pending_key=pending_key,
496
+ )
497
+
498
+ footer = st.container()
499
+ with footer:
500
+ exp_col, rst_col, _spacer = st.columns([0.5, 0.5, 10], gap="xsmall")
501
+ with exp_col:
502
+ if st.button(
503
+ "",
504
+ icon=":material/download:",
505
+ key=export_key,
506
+ help="Export chat",
507
+ ):
508
  _save_chat_export_message(
509
  model_name=model_name,
510
  dataset_source=dataset_source,
 
515
  messages=chat_state["messages"],
516
  generation=_generation_dict(gen_kwargs, advanced_generation),
517
  )
518
+ st.toast("Exported", icon=":material/check:")
519
+ with rst_col:
520
+ if st.button(
521
+ "",
522
+ icon=":material/delete_sweep:",
523
+ key=reset_key,
524
+ help="Reset chat",
525
+ ):
526
+ _reset_active_chat_context()
527
+ st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
529
+ user_prompt = st.chat_input("Ask something...", key=chat_input_key)
 
 
 
530
 
531
  # Pass 1: user submitted — append message and rerun so it renders before generation.
532
  if user_prompt:
tabs/compare_chat.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from nnterp import StandardizedTransformer
3
+ from persona_data.synth_persona import PersonaData
4
+
5
+ from state import default_chat_state, reset_chat_context_state
6
+ from utils.chat import ChatReply, generate_chat_reply, resolve_system_prompt
7
+ from utils.contrast import compute_contrast, compute_contrast_pair
8
+ from utils.helpers import persona_label, widget_key
9
+ from utils.runtime import cached_model
10
+
11
+ from .chat import (
12
+ _build_chat_messages,
13
+ _generation_dict,
14
+ _render_chat_message,
15
+ _render_chat_window,
16
+ _render_persona_prompt_controls,
17
+ _render_system_prompt,
18
+ _save_chat_export_message,
19
+ )
20
+
21
+
22
+ def _panel_state(panel_key: str) -> dict[str, object]:
23
+ """Get or initialise compare-panel chat state stored in session_state."""
24
+ if panel_key not in st.session_state:
25
+ st.session_state[panel_key] = default_chat_state()
26
+ return st.session_state[panel_key]
27
+
28
+
29
+ def _reset_compare_panel(
30
+ panel_state: dict,
31
+ edit_key: str,
32
+ persona_id: str,
33
+ prompt_mode: str,
34
+ *ui_keys: str,
35
+ ) -> None:
36
+ reset_chat_context_state(panel_state, persona_id, prompt_mode, *ui_keys)
37
+ st.session_state.pop(edit_key, None)
38
+
39
+
40
+ def _generate_panel_reply(
41
+ *,
42
+ model: StandardizedTransformer,
43
+ remote: bool,
44
+ panel_state: dict[str, object],
45
+ panel_prompt: str | None,
46
+ gen_kwargs: dict,
47
+ ) -> ChatReply:
48
+ return generate_chat_reply(
49
+ model=model,
50
+ messages=_build_chat_messages(panel_prompt, panel_state["messages"]),
51
+ remote=remote,
52
+ past_key_values=panel_state["past_key_values"],
53
+ **gen_kwargs,
54
+ )
55
+
56
+
57
+ def render_compare_mode(
58
+ remote: bool,
59
+ model_name: str,
60
+ context_key: str,
61
+ dataset_source: str,
62
+ personas: list[PersonaData],
63
+ gen_kwargs: dict,
64
+ advanced_generation: bool,
65
+ ) -> None:
66
+ """Render the full side-by-side comparison UI."""
67
+ contrast_key = widget_key(context_key, "token_contrast")
68
+ contrast_enabled = st.toggle(
69
+ "Token contrast",
70
+ value=False,
71
+ key=contrast_key,
72
+ help=(
73
+ "Color each generated token by how characteristic it is of each persona. "
74
+ "Red = more likely under the left persona, blue = more likely under the right. "
75
+ "Requires four extra forward passes after each turn (batched into one "
76
+ "remote session when running on NDIF)."
77
+ ),
78
+ )
79
+
80
+ left_col, right_col = st.columns(2)
81
+ left_panel_key = widget_key(context_key, "cmp_left")
82
+ right_panel_key = widget_key(context_key, "cmp_right")
83
+ left_prompt_key = widget_key(left_panel_key, "custom_prompt")
84
+ right_prompt_key = widget_key(right_panel_key, "custom_prompt")
85
+ left_edit_key = widget_key(left_panel_key, "edit_idx")
86
+ right_edit_key = widget_key(right_panel_key, "edit_idx")
87
+ left_pending_key = widget_key(left_panel_key, "pending_regen")
88
+ right_pending_key = widget_key(right_panel_key, "pending_regen")
89
+
90
+ def render_panel(side: str) -> tuple[dict, object, str | None, str, PersonaData]:
91
+ panel_key = widget_key(context_key, f"cmp_{side}")
92
+ state = _panel_state(panel_key)
93
+ prompt_key = widget_key(panel_key, "custom_prompt")
94
+ edit_key = widget_key(panel_key, "edit_idx")
95
+ pending_regen_key = widget_key(panel_key, "pending_regen")
96
+
97
+ selected_persona, prompt_mode, changed = _render_persona_prompt_controls(
98
+ personas,
99
+ state["persona_id"],
100
+ state["prompt_mode"],
101
+ widget_key(panel_key, "persona"),
102
+ widget_key(panel_key, "prompt_mode"),
103
+ )
104
+ if changed:
105
+ reset_chat_context_state(
106
+ state,
107
+ selected_persona.id,
108
+ prompt_mode,
109
+ prompt_key,
110
+ pending_regen_key,
111
+ )
112
+ st.session_state.pop(edit_key, None)
113
+
114
+ active_system_prompt = resolve_system_prompt(
115
+ persona=selected_persona, mode=prompt_mode
116
+ )
117
+
118
+ chat_log = st.container()
119
+ with chat_log:
120
+ active_system_prompt = _render_system_prompt(
121
+ prompt_key,
122
+ prompt_mode,
123
+ active_system_prompt,
124
+ )
125
+ return (
126
+ state,
127
+ chat_log,
128
+ active_system_prompt,
129
+ pending_regen_key,
130
+ selected_persona,
131
+ )
132
+
133
+ with left_col:
134
+ left_state, left_log, left_prompt, left_pending, left_persona = render_panel(
135
+ "left"
136
+ )
137
+ with right_col:
138
+ right_state, right_log, right_prompt, right_pending, right_persona = (
139
+ render_panel("right")
140
+ )
141
+
142
+ panels = [
143
+ (
144
+ left_state,
145
+ left_log,
146
+ left_prompt,
147
+ left_pending,
148
+ left_edit_key,
149
+ left_persona,
150
+ ),
151
+ (
152
+ right_state,
153
+ right_log,
154
+ right_prompt,
155
+ right_pending,
156
+ right_edit_key,
157
+ right_persona,
158
+ ),
159
+ ]
160
+
161
+ # Handle per-panel regeneration triggered by message edits
162
+ regen_panels = [
163
+ (panel_state, panel_log, panel_prompt)
164
+ for panel_state, panel_log, panel_prompt, p_pending, _panel_edit_key, _ in panels
165
+ if st.session_state.pop(p_pending, False)
166
+ ]
167
+ if regen_panels:
168
+ model = cached_model(model_name=model_name, remote=remote)
169
+
170
+ results: list[ChatReply | Exception] = []
171
+ with st.spinner("Regenerating..."):
172
+ for panel_state, _panel_log, panel_prompt in regen_panels:
173
+ try:
174
+ results.append(
175
+ _generate_panel_reply(
176
+ model=model,
177
+ remote=remote,
178
+ panel_state=panel_state,
179
+ panel_prompt=panel_prompt,
180
+ gen_kwargs=gen_kwargs,
181
+ )
182
+ )
183
+ except Exception as exc:
184
+ results.append(exc)
185
+
186
+ for (panel_state, panel_log, _panel_prompt), result in zip(
187
+ regen_panels, results
188
+ ):
189
+ if isinstance(result, Exception):
190
+ with panel_log:
191
+ st.error(f"Generation failed: {result}")
192
+ panel_state["messages"].pop()
193
+ continue
194
+ panel_state["messages"].append(
195
+ {"role": "assistant", "content": result.text}
196
+ )
197
+ panel_state["past_key_values"] = (
198
+ result.past_key_values if not remote else None
199
+ )
200
+ st.rerun()
201
+
202
+ # Recompute contrast for assistant messages that were edited in place.
203
+ if contrast_enabled:
204
+ pending_edits: list[tuple[int, int]] = [
205
+ (panel_idx, msg_idx)
206
+ for panel_idx, (panel_state, *_rest) in enumerate(panels)
207
+ for msg_idx, msg in enumerate(panel_state["messages"])
208
+ if msg.get("_needs_contrast") and msg.get("role") == "assistant"
209
+ ]
210
+ if pending_edits:
211
+ model = cached_model(model_name=model_name, remote=remote)
212
+ label_a = persona_label(left_persona)
213
+ label_b = persona_label(right_persona)
214
+ with st.spinner("Recomputing token contrast…"):
215
+ for panel_idx, msg_idx in pending_edits:
216
+ panel_state = panels[panel_idx][0]
217
+ msg = panel_state["messages"][msg_idx]
218
+ if msg_idx >= len(left_state["messages"]) or msg_idx >= len(
219
+ right_state["messages"]
220
+ ):
221
+ msg.pop("_needs_contrast", None)
222
+ continue
223
+ context_a = _build_chat_messages(
224
+ left_prompt, left_state["messages"][:msg_idx]
225
+ )
226
+ context_b = _build_chat_messages(
227
+ right_prompt, right_state["messages"][:msg_idx]
228
+ )
229
+ try:
230
+ response_ids = model.tokenizer(
231
+ msg["content"],
232
+ add_special_tokens=False,
233
+ return_tensors="pt",
234
+ ).input_ids[0]
235
+ tc = compute_contrast(
236
+ model=model,
237
+ context_a=context_a,
238
+ context_b=context_b,
239
+ response_ids=response_ids,
240
+ label_a=label_a,
241
+ label_b=label_b,
242
+ remote=remote,
243
+ )
244
+ if tc is not None:
245
+ msg["_contrast"] = tc
246
+ except Exception as exc:
247
+ st.warning(f"Token contrast recompute failed: {exc}")
248
+ msg.pop("_needs_contrast", None)
249
+ st.rerun()
250
+
251
+ for (
252
+ panel_state,
253
+ panel_log,
254
+ _panel_prompt,
255
+ panel_pending,
256
+ panel_edit_key,
257
+ _,
258
+ ) in panels:
259
+ _render_chat_window(
260
+ chat_log=panel_log,
261
+ messages=panel_state["messages"],
262
+ chat_state=panel_state,
263
+ edit_key=panel_edit_key,
264
+ pending_key=panel_pending,
265
+ show_contrast=contrast_enabled,
266
+ edit_column_ratio=(10, 1),
267
+ )
268
+
269
+ footer = st.container()
270
+ with footer:
271
+ exp_col, rst_col, _spacer = st.columns([0.5, 0.5, 10], gap="xsmall")
272
+ with exp_col:
273
+ if st.button(
274
+ "",
275
+ icon=":material/download:",
276
+ key=widget_key(context_key, "cmp_export"),
277
+ help="Export both chats",
278
+ ):
279
+ for side, panel_state, panel_prompt, panel_persona in (
280
+ ("left", left_state, left_prompt, left_persona),
281
+ ("right", right_state, right_prompt, right_persona),
282
+ ):
283
+ _save_chat_export_message(
284
+ model_name=model_name,
285
+ dataset_source=dataset_source,
286
+ persona_id=panel_persona.id,
287
+ persona_name=getattr(panel_persona, "name", None),
288
+ prompt_mode=panel_state["prompt_mode"],
289
+ system_prompt=panel_prompt,
290
+ messages=panel_state["messages"],
291
+ generation=_generation_dict(gen_kwargs, advanced_generation),
292
+ panel_label=side,
293
+ )
294
+ st.toast("Exported", icon=":material/check:")
295
+ with rst_col:
296
+ with st.popover(
297
+ "",
298
+ icon=":material/delete_sweep:",
299
+ help="Reset chat",
300
+ ):
301
+ if st.button(
302
+ "Reset left",
303
+ key=widget_key(context_key, "cmp_reset_left"),
304
+ ):
305
+ _reset_compare_panel(
306
+ left_state,
307
+ left_edit_key,
308
+ left_persona.id,
309
+ left_state["prompt_mode"],
310
+ left_prompt_key,
311
+ left_pending_key,
312
+ )
313
+ st.rerun()
314
+ if st.button(
315
+ "Reset right",
316
+ key=widget_key(context_key, "cmp_reset_right"),
317
+ ):
318
+ _reset_compare_panel(
319
+ right_state,
320
+ right_edit_key,
321
+ right_persona.id,
322
+ right_state["prompt_mode"],
323
+ right_prompt_key,
324
+ right_pending_key,
325
+ )
326
+ st.rerun()
327
+ if st.button(
328
+ "Reset both",
329
+ key=widget_key(context_key, "cmp_reset_both"),
330
+ type="primary",
331
+ ):
332
+ _reset_compare_panel(
333
+ left_state,
334
+ left_edit_key,
335
+ left_persona.id,
336
+ left_state["prompt_mode"],
337
+ left_prompt_key,
338
+ left_pending_key,
339
+ )
340
+ _reset_compare_panel(
341
+ right_state,
342
+ right_edit_key,
343
+ right_persona.id,
344
+ right_state["prompt_mode"],
345
+ right_prompt_key,
346
+ right_pending_key,
347
+ )
348
+ st.rerun()
349
+
350
+ user_prompt = st.chat_input(
351
+ "Ask both...",
352
+ key=widget_key(context_key, "cmp_input"),
353
+ )
354
+
355
+ if not user_prompt:
356
+ return
357
+
358
+ model = cached_model(model_name=model_name, remote=remote)
359
+
360
+ for panel_state, panel_log, _panel_prompt, _p_pending, _panel_edit_key, _ in panels:
361
+ panel_state["messages"].append({"role": "user", "content": user_prompt})
362
+ with panel_log:
363
+ _render_chat_message({"role": "user", "content": user_prompt})
364
+
365
+ # Snapshot contexts before the new assistant turn is appended (needed for contrast).
366
+ pre_gen_contexts = [
367
+ _build_chat_messages(panel_prompt, panel_state["messages"])
368
+ for panel_state, _panel_log, panel_prompt, _p_pending, _panel_edit_key, _ in panels
369
+ ]
370
+
371
+ results: list[ChatReply | Exception] = []
372
+ with st.spinner("Generating..."):
373
+ # Keep compare-mode generation sequential so both panels use the same
374
+ # model/session state safely.
375
+ for (
376
+ panel_state,
377
+ _panel_log,
378
+ panel_prompt,
379
+ _p_pending,
380
+ _panel_edit_key,
381
+ _,
382
+ ) in panels:
383
+ try:
384
+ results.append(
385
+ _generate_panel_reply(
386
+ model=model,
387
+ remote=remote,
388
+ panel_state=panel_state,
389
+ panel_prompt=panel_prompt,
390
+ gen_kwargs=gen_kwargs,
391
+ )
392
+ )
393
+ except Exception as exc:
394
+ results.append(exc)
395
+
396
+ valid_results: list[ChatReply | None] = []
397
+ for (
398
+ panel_state,
399
+ panel_log,
400
+ _panel_prompt,
401
+ _p_pending,
402
+ _panel_edit_key,
403
+ _,
404
+ ), result in zip(panels, results):
405
+ if isinstance(result, Exception):
406
+ with panel_log:
407
+ st.error(f"Generation failed: {result}")
408
+ panel_state["messages"].pop()
409
+ valid_results.append(None)
410
+ continue
411
+
412
+ panel_state["messages"].append({"role": "assistant", "content": result.text})
413
+ panel_state["past_key_values"] = result.past_key_values if not remote else None
414
+ valid_results.append(result)
415
+
416
+ # Compute contrastive token coloring when both panels succeeded.
417
+ if (
418
+ contrast_enabled
419
+ and len(valid_results) == 2
420
+ and all(r is not None and r.generated_ids is not None for r in valid_results)
421
+ ):
422
+ with st.spinner("Computing token contrast…"):
423
+ try:
424
+ tc_a, tc_b = compute_contrast_pair(
425
+ model=model,
426
+ context_a=pre_gen_contexts[0],
427
+ context_b=pre_gen_contexts[1],
428
+ response_ids_a=valid_results[0].generated_ids,
429
+ response_ids_b=valid_results[1].generated_ids,
430
+ label_a=persona_label(left_persona),
431
+ label_b=persona_label(right_persona),
432
+ remote=remote,
433
+ )
434
+ if tc_a is not None:
435
+ left_state["messages"][-1]["_contrast"] = tc_a
436
+ if tc_b is not None:
437
+ right_state["messages"][-1]["_contrast"] = tc_b
438
+ except Exception as exc:
439
+ st.warning(f"Token contrast failed: {exc}")
440
+
441
+ # Rerun so the newly appended turns are redrawn through the editable history
442
+ # renderer instead of only appearing in the one-off generation pass.
443
+ st.rerun()
tabs/extract.py CHANGED
@@ -111,8 +111,6 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
111
  st.info("Select at least one persona.")
112
  return
113
 
114
- max_questions = 0
115
-
116
  with st.expander("Advanced", expanded=False):
117
  st.caption("Filters")
118
 
 
111
  st.info("Select at least one persona.")
112
  return
113
 
 
 
114
  with st.expander("Advanced", expanded=False):
115
  st.caption("Filters")
116
 
utils/chat.py CHANGED
@@ -1,5 +1,5 @@
1
  import logging
2
- from contextlib import contextmanager
3
  from dataclasses import dataclass
4
  from typing import Literal
5
 
@@ -15,9 +15,8 @@ SystemPromptMode = Literal["empty", "templated", "biography", "custom"]
15
  @dataclass
16
  class ChatReply:
17
  text: str
18
- prompt_tokens: int
19
- output_tokens: int
20
  past_key_values: object | None
 
21
 
22
 
23
  def resolve_system_prompt(
@@ -204,13 +203,10 @@ def generate_chat_reply(
204
 
205
  generated_ids = sequences[0, prompt_token_count:]
206
  text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
207
- output_tokens = int(sequences.shape[1] - prompt_token_count)
208
-
209
  return ChatReply(
210
  text=text,
211
- prompt_tokens=prompt_token_count,
212
- output_tokens=max(0, output_tokens),
213
  past_key_values=(
214
  getattr(generated, "past_key_values", None) if not remote else None
215
  ),
 
216
  )
 
1
  import logging
2
+ from contextlib import contextmanager, nullcontext
3
  from dataclasses import dataclass
4
  from typing import Literal
5
 
 
15
  @dataclass
16
  class ChatReply:
17
  text: str
 
 
18
  past_key_values: object | None
19
+ generated_ids: torch.Tensor | None = None
20
 
21
 
22
  def resolve_system_prompt(
 
203
 
204
  generated_ids = sequences[0, prompt_token_count:]
205
  text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
 
 
206
  return ChatReply(
207
  text=text,
 
 
208
  past_key_values=(
209
  getattr(generated, "past_key_values", None) if not remote else None
210
  ),
211
+ generated_ids=generated_ids.detach().cpu(),
212
  )
utils/chat_export.py CHANGED
@@ -30,6 +30,7 @@ def save_chat_export(
30
  system_prompt: Current system prompt text, if any.
31
  messages: Conversation messages without the system prompt.
32
  generation: Generation settings used for the chat.
 
33
 
34
  Returns:
35
  The path the export was written to.
@@ -55,7 +56,6 @@ def save_chat_export(
55
  get_artifacts_dir()
56
  / "chats"
57
  / "__".join(slugify(part) for part in model_name.split("/"))
58
- / slugify(dataset_source)
59
  / slugify(persona_id)
60
  )
61
  export_dir.mkdir(parents=True, exist_ok=True)
 
30
  system_prompt: Current system prompt text, if any.
31
  messages: Conversation messages without the system prompt.
32
  generation: Generation settings used for the chat.
33
+ panel_label: Optional side label (e.g. "left"/"right") for compare-mode exports.
34
 
35
  Returns:
36
  The path the export was written to.
 
56
  get_artifacts_dir()
57
  / "chats"
58
  / "__".join(slugify(part) for part in model_name.split("/"))
 
59
  / slugify(persona_id)
60
  )
61
  export_dir.mkdir(parents=True, exist_ok=True)
utils/contrast.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # WARNING: This is mostly vibecoded and need reviews
2
+ # - Check that the model is runned once with normally for gneration and things are beeing traced perphaps at the last step of generation with iter.last or somrething liek that from the docs
3
+ # - Then the model is runned again with the entire context of the conversation from the other context on the rifht ? or on the left dependeing on which one we are doing at the moment. And this will then compute the prob diff and show them.
4
+
5
+ """
6
+ Contrastive token-level log-probability comparison for compare mode.
7
+
8
+ For a pair of responses generated under different persona contexts, each token
9
+ gets a weight:
10
+
11
+ w(token) = log P(token | context_A) − log P(token | context_B)
12
+
13
+ Positive (red) → token is more characteristic of persona A.
14
+ Negative (blue) → token is more characteristic of persona B.
15
+ Near-zero (gray) → both personas would emit this token with similar likelihood.
16
+ """
17
+
18
+ from dataclasses import dataclass
19
+ from html import escape
20
+
21
+ import torch
22
+ from nnterp import StandardizedTransformer
23
+
24
+ from utils.chat import _format_generation_prompt
25
+
26
+
27
+ @dataclass
28
+ class TokenContrast:
29
+ tokens: list[str]
30
+ weights: list[float] # normalised to [-1, 1], used for coloring
31
+ raw_diffs: list[float] # unclipped log P(A) - log P(B) per token
32
+ label_a: str
33
+ label_b: str
34
+
35
+
36
+ # ── Weight computation ────────────────────────────────────────────────────────
37
+
38
+
39
+ def _normalise_diffs(diffs: torch.Tensor) -> list[float]:
40
+ """
41
+ Clip at the 95th percentile of |diff| and scale to [-1, 1] so a few
42
+ high-magnitude tokens don't wash out everything else.
43
+ """
44
+ if len(diffs) < 2:
45
+ return diffs.tolist()
46
+ clip_val = max(torch.quantile(diffs.abs(), 0.95).item(), 0.3)
47
+ return (diffs.float().clamp(-clip_val, clip_val) / clip_val).tolist()
48
+
49
+
50
+ def _decode_ids(tokenizer: object, ids: list[int]) -> str:
51
+ try:
52
+ return tokenizer.decode(
53
+ ids,
54
+ skip_special_tokens=False,
55
+ clean_up_tokenization_spaces=False,
56
+ )
57
+ except TypeError:
58
+ return tokenizer.decode(ids, skip_special_tokens=False)
59
+
60
+
61
+ def _strip_special_ids(
62
+ ids: torch.Tensor,
63
+ tokenizer: object,
64
+ ) -> tuple[torch.Tensor, torch.Tensor]:
65
+ """Return display ids and a mask that excludes special tokens."""
66
+ ids = ids.cpu()
67
+ special_ids = set(getattr(tokenizer, "all_special_ids", []) or [])
68
+ if not special_ids or ids.numel() == 0:
69
+ return ids, torch.ones(ids.shape[0], dtype=torch.bool)
70
+ keep = torch.tensor(
71
+ [tid.item() not in special_ids for tid in ids], dtype=torch.bool
72
+ )
73
+ return ids[keep], keep
74
+
75
+
76
+ def _prepare_trace_text(
77
+ tokenizer: object,
78
+ context_messages: list[dict[str, str]],
79
+ response_ids: torch.Tensor,
80
+ ) -> tuple[str, int, int]:
81
+ """Build the trace text and return ``(full_text, n_ctx, n_resp)``."""
82
+ context_prompt, _ = _format_generation_prompt(context_messages, tokenizer)
83
+ context_ids = tokenizer(context_prompt, return_tensors="pt").input_ids[0]
84
+ response_text = _decode_ids(tokenizer, response_ids.tolist())
85
+ full_text = context_prompt + response_text
86
+ full_ids = tokenizer(full_text, return_tensors="pt").input_ids[0]
87
+ expected_ids = torch.cat([context_ids, response_ids.cpu()])
88
+ if full_ids.tolist() != expected_ids.tolist():
89
+ raise ValueError(
90
+ "contrast trace text did not round-trip to the expected token ids"
91
+ )
92
+ n_ctx = len(context_ids)
93
+ n_resp = len(response_ids)
94
+ return full_text, n_ctx, n_resp
95
+
96
+
97
+ def _build_contrast(
98
+ tokenizer: object,
99
+ response_ids: torch.Tensor,
100
+ lp_a: torch.Tensor,
101
+ lp_b: torch.Tensor,
102
+ label_a: str,
103
+ label_b: str,
104
+ ) -> TokenContrast:
105
+ diffs = (lp_a - lp_b).cpu()
106
+ display_ids, keep_mask = _strip_special_ids(response_ids, tokenizer)
107
+ display_diffs = diffs[keep_mask]
108
+ return TokenContrast(
109
+ tokens=[_token_display(tokenizer, tid.item()) for tid in display_ids],
110
+ weights=_normalise_diffs(display_diffs),
111
+ raw_diffs=display_diffs.float().tolist(),
112
+ label_a=label_a,
113
+ label_b=label_b,
114
+ )
115
+
116
+
117
+ def _token_display(tokenizer: object, token_id: int) -> str:
118
+ """Render a single token id as normal decoded text."""
119
+ return _decode_ids(tokenizer, [token_id])
120
+
121
+
122
+ # Each spec: (key, full_text, n_ctx, n_resp, target_ids).
123
+ PassSpec = tuple[str, str, int, int, torch.Tensor]
124
+
125
+
126
+ def _score_passes(
127
+ model: StandardizedTransformer,
128
+ specs: list[PassSpec],
129
+ remote: bool,
130
+ ) -> dict[str, torch.Tensor]:
131
+ """
132
+ Run one forward pass per spec and return reduced per-token logprobs.
133
+
134
+ The log-softmax and target-pick happen *inside* the trace, so only the
135
+ reduced ``[n_resp]`` logprob vector per pass is shipped back — not the full
136
+ ``[1, seq, vocab]`` logits (which would be hundreds of MB per pass on NDIF).
137
+ """
138
+
139
+ def _score_pass(
140
+ full_text: str,
141
+ n_ctx: int,
142
+ n_resp: int,
143
+ target_ids: torch.Tensor,
144
+ ) -> torch.Tensor:
145
+ with torch.no_grad(), model.trace(full_text, remote=remote):
146
+ # logit at position i predicts token i+1, so response token j
147
+ # (at full-text position n_ctx+j) uses logit at n_ctx+j-1.
148
+ resp_logits = model.logits[0, n_ctx - 1 : n_ctx - 1 + n_resp].float()
149
+ log_probs = torch.log_softmax(resp_logits, dim=-1)
150
+ targets = target_ids.to(log_probs.device).view(-1, 1)
151
+ picked = log_probs.gather(1, targets).view(-1)
152
+ out = picked.detach().cpu().save()
153
+
154
+ if hasattr(out, "value") and getattr(out, "value") is not None:
155
+ out = out.value
156
+ if not isinstance(out, torch.Tensor):
157
+ raise TypeError(
158
+ f"contrast score did not resolve to a tensor: {type(out)!r}"
159
+ )
160
+ return out.detach().cpu()
161
+
162
+ saved = [
163
+ _score_pass(full_text, n_ctx, n_resp, target_ids)
164
+ for _key, full_text, n_ctx, n_resp, target_ids in specs
165
+ ]
166
+
167
+ if len(saved) != len(specs):
168
+ raise RuntimeError(
169
+ f"contrast scoring returned {len(saved)} result(s) for {len(specs)} spec(s)"
170
+ )
171
+
172
+ return {spec[0]: tensor for spec, tensor in zip(specs, saved)}
173
+
174
+
175
+ def _specs_for_response(
176
+ tokenizer: object,
177
+ response_ids: torch.Tensor,
178
+ context_a: list[dict[str, str]],
179
+ context_b: list[dict[str, str]],
180
+ prefix: str,
181
+ ) -> list[PassSpec]:
182
+ """Build the (under_a, under_b) pass specs for a single response."""
183
+ text_a, n_ctx_a, n_resp = _prepare_trace_text(tokenizer, context_a, response_ids)
184
+ text_b, n_ctx_b, _ = _prepare_trace_text(tokenizer, context_b, response_ids)
185
+ return [
186
+ (f"{prefix}_under_a", text_a, n_ctx_a, n_resp, response_ids),
187
+ (f"{prefix}_under_b", text_b, n_ctx_b, n_resp, response_ids),
188
+ ]
189
+
190
+
191
+ def compute_contrast(
192
+ model: StandardizedTransformer,
193
+ context_a: list[dict[str, str]],
194
+ context_b: list[dict[str, str]],
195
+ response_ids: torch.Tensor,
196
+ label_a: str,
197
+ label_b: str,
198
+ remote: bool = False,
199
+ ) -> "TokenContrast | None":
200
+ """Compute per-token contrast weights for a single response (2 forward passes)."""
201
+ tokenizer = model.tokenizer
202
+ if response_ids.numel() == 0:
203
+ return None
204
+
205
+ specs = _specs_for_response(tokenizer, response_ids, context_a, context_b, "r")
206
+ out = _score_passes(model, specs, remote)
207
+ return _build_contrast(
208
+ tokenizer, response_ids, out["r_under_a"], out["r_under_b"], label_a, label_b
209
+ )
210
+
211
+
212
+ def compute_contrast_pair(
213
+ model: StandardizedTransformer,
214
+ context_a: list[dict[str, str]],
215
+ context_b: list[dict[str, str]],
216
+ response_ids_a: torch.Tensor,
217
+ response_ids_b: torch.Tensor,
218
+ label_a: str,
219
+ label_b: str,
220
+ remote: bool = False,
221
+ ) -> tuple["TokenContrast | None", "TokenContrast | None"]:
222
+ """
223
+ Compute contrast weights for both panel responses (up to 4 remote passes).
224
+ """
225
+ tokenizer = model.tokenizer
226
+ if response_ids_a.numel() == 0 and response_ids_b.numel() == 0:
227
+ return None, None
228
+
229
+ specs: list[PassSpec] = []
230
+ if response_ids_a.numel() > 0:
231
+ specs += _specs_for_response(
232
+ tokenizer, response_ids_a, context_a, context_b, "a"
233
+ )
234
+ if response_ids_b.numel() > 0:
235
+ specs += _specs_for_response(
236
+ tokenizer, response_ids_b, context_a, context_b, "b"
237
+ )
238
+
239
+ out = _score_passes(model, specs, remote)
240
+
241
+ def _build(resp_ids: torch.Tensor, prefix: str) -> "TokenContrast | None":
242
+ k_a, k_b = f"{prefix}_under_a", f"{prefix}_under_b"
243
+ if resp_ids.numel() == 0 or k_a not in out or k_b not in out:
244
+ return None
245
+ return _build_contrast(
246
+ tokenizer, resp_ids, out[k_a], out[k_b], label_a, label_b
247
+ )
248
+
249
+ return _build(response_ids_a, "a"), _build(response_ids_b, "b")
250
+
251
+
252
+ # ── HTML rendering ────────────────────────────────────────────────────────────
253
+
254
+
255
+ def _weight_to_bg(w: float) -> str:
256
+ """Map a normalised weight in [-1, 1] to a CSS rgba background color."""
257
+ w = max(-1.0, min(1.0, w))
258
+ alpha = abs(w) * 0.5 # cap at 0.5 opacity so text stays readable
259
+ if w > 0.05:
260
+ return f"rgba(210,60,60,{alpha:.3f})"
261
+ if w < -0.05:
262
+ return f"rgba(50,110,210,{alpha:.3f})"
263
+ return "rgba(0,0,0,0)"
264
+
265
+
266
+ _CONTRAST_CSS = (
267
+ "<style>"
268
+ ".contrast-tok{position:relative;border-radius:2px;padding:0 1px;"
269
+ "cursor:default;white-space:pre;}"
270
+ ".contrast-tok>.contrast-tip{display:none;position:absolute;bottom:100%;"
271
+ "left:50%;transform:translateX(-50%);margin-bottom:4px;padding:2px 6px;"
272
+ "border-radius:3px;background:#222;color:#eee;font-size:0.72em;"
273
+ "font-family:ui-monospace,monospace;white-space:nowrap;pointer-events:none;"
274
+ "z-index:10;box-shadow:0 2px 6px rgba(0,0,0,0.3);}"
275
+ ".contrast-tok:hover>.contrast-tip{display:block;}"
276
+ "</style>"
277
+ )
278
+
279
+
280
+ def render_contrast_html(result: TokenContrast) -> str:
281
+ """
282
+ Render each token with a colored background reflecting how A- or B-specific
283
+ it is, with a hover tooltip showing the raw Δlog P, plus a legend.
284
+ """
285
+ spans: list[str] = []
286
+ for token, weight, raw in zip(result.tokens, result.weights, result.raw_diffs):
287
+ bg = _weight_to_bg(weight)
288
+ tip = escape(f"Δlog P(A−B): {raw:+.3f}")
289
+ text = escape(token)
290
+ spans.append(
291
+ f'<span class="contrast-tok" style="background:{bg};">'
292
+ f'{text}<span class="contrast-tip">{tip}</span></span>'
293
+ )
294
+
295
+ la = escape(result.label_a)
296
+ lb = escape(result.label_b)
297
+
298
+ return (
299
+ _CONTRAST_CSS + '<div style="font-family:inherit;line-height:1.75;'
300
+ 'white-space:pre-wrap;word-break:break-word;padding:2px 0 6px 0;">'
301
+ + "".join(spans)
302
+ + '<div style="margin-top:10px;font-size:0.72em;color:#888;'
303
+ + 'display:flex;gap:12px;flex-wrap:wrap;">'
304
+ + f'<span><span style="background:rgba(210,60,60,0.45);'
305
+ + f'padding:1px 6px;border-radius:2px;">&thinsp;</span>&nbsp;{la}</span>'
306
+ + f'<span><span style="background:rgba(50,110,210,0.45);'
307
+ + f'padding:1px 6px;border-radius:2px;">&thinsp;</span>&nbsp;{lb}</span>'
308
+ + '<span style="color:#aaa;">gray = shared by both</span>'
309
+ + "</div>"
310
+ + "</div>"
311
+ )
utils/helpers.py CHANGED
@@ -16,8 +16,6 @@ MODE_LABELS = list(VARIANT_LABELS.values())
16
  # Reverse lookup: label -> key
17
  MODE_LABEL_TO_KEY = {v: k for k, v in VARIANT_LABELS.items()}
18
 
19
- VISIBLE_MESSAGE_COUNT = 5
20
-
21
  DATASET_SOURCES = ["HuggingFace: synth-persona", "Local JSONL upload"]
22
  ANALYSIS_MODES = ["Cosine similarity", "PCA", "UMAP"]
23
 
 
16
  # Reverse lookup: label -> key
17
  MODE_LABEL_TO_KEY = {v: k for k, v in VARIANT_LABELS.items()}
18
 
 
 
19
  DATASET_SOURCES = ["HuggingFace: synth-persona", "Local JSONL upload"]
20
  ANALYSIS_MODES = ["Cosine similarity", "PCA", "UMAP"]
21
 
utils/runtime.py CHANGED
@@ -7,33 +7,62 @@ logger = logging.getLogger(__name__)
7
 
8
  @st.cache_data(show_spinner=False, ttl=30)
9
  def list_remote_models() -> list[str]:
10
- """Return the NDIF language models that are currently running."""
 
 
 
 
 
 
 
 
 
11
 
12
  import nnsight
13
 
14
  try:
15
- status = nnsight.ndif_status()
16
  except Exception:
17
  logger.warning("Failed to fetch NDIF status", exc_info=True)
18
  return []
19
 
20
  model_names: list[str] = []
 
21
 
22
- for entry in status.values():
23
- if not isinstance(entry, dict):
24
  continue
25
- if entry.get("model_class") not in {"LanguageModel", "StandardizedTransformer"}:
 
 
 
26
  continue
27
 
28
- state = entry.get("state")
29
- state_name = getattr(state, "name", None) or getattr(state, "value", None)
30
- if state_name != "RUNNING":
31
- continue
 
 
32
 
33
- repo_id = entry.get("repo_id")
 
 
 
 
 
 
 
34
  if isinstance(repo_id, str):
35
  model_names.append(repo_id)
36
 
 
 
 
 
 
 
 
37
  return sorted(set(model_names))
38
 
39
 
 
7
 
8
  @st.cache_data(show_spinner=False, ttl=30)
9
  def list_remote_models() -> list[str]:
10
+ """Return the NDIF language models that are currently running.
11
+
12
+ Parses the raw NDIF response directly instead of going through
13
+ ``nnsight.ndif_status()`` because that call crashes whenever NDIF reports
14
+ any deployment with an ``application_state`` that isn't in nnsight's
15
+ ``ModelStatus`` enum (e.g. ``UNHEALTHY``) — one bad deployment poisons
16
+ the whole response. See nnsight 0.6.3 ``ndif.py::status``.
17
+ """
18
+
19
+ import json
20
 
21
  import nnsight
22
 
23
  try:
24
+ raw = nnsight.ndif_status(raw=True)
25
  except Exception:
26
  logger.warning("Failed to fetch NDIF status", exc_info=True)
27
  return []
28
 
29
  model_names: list[str] = []
30
+ bad_states: list[tuple[str, str]] = [] # (repo_id_or_key, application_state)
31
 
32
+ for value in (raw or {}).get("deployments", {}).values():
33
+ if not isinstance(value, dict):
34
  continue
35
+ if (
36
+ value.get("deployment_level") not in {"HOT", "WARM"}
37
+ and "schedule" not in value
38
+ ):
39
  continue
40
 
41
+ model_key = value.get("model_key", "")
42
+ model_class = model_key.split(":", 1)[0].split(".")[-1]
43
+ try:
44
+ repo_id = json.loads(model_key.split(":", 1)[-1]).get("repo_id")
45
+ except Exception:
46
+ repo_id = model_key
47
 
48
+ state = value.get("application_state", "NOT DEPLOYED")
49
+ if state not in {"RUNNING", "NOT DEPLOYED", "DEPLOYING", "DELETING"}:
50
+ bad_states.append((repo_id or model_key, state))
51
+
52
+ if model_class not in {"LanguageModel", "StandardizedTransformer"}:
53
+ continue
54
+ if state != "RUNNING":
55
+ continue
56
  if isinstance(repo_id, str):
57
  model_names.append(repo_id)
58
 
59
+ if bad_states:
60
+ logger.warning(
61
+ "NDIF reported deployments with unexpected application_state values "
62
+ "(nnsight's ModelStatus enum may not know about these): %s",
63
+ bad_states,
64
+ )
65
+
66
  return sorted(set(model_names))
67
 
68