Jac-Zac commited on
Commit
f4259c0
·
1 Parent(s): 5bf7fd5

Cleaned up code with the new updates

Browse files
Files changed (10) hide show
  1. README.md +1 -4
  2. app.py +1 -1
  3. tabs/chat.py +327 -281
  4. tabs/compare.py +253 -166
  5. tabs/extract.py +32 -34
  6. utils/artifacts.py +0 -244
  7. utils/chat.py +0 -1
  8. utils/chat_export.py +8 -48
  9. utils/datasets.py +5 -1
  10. utils/helpers.py +15 -9
README.md CHANGED
@@ -24,13 +24,10 @@ persona-ui/
24
  │ ├── compare.py # Activation comparison tab
25
  │ └── extract.py # Extraction tab
26
  └── utils/
27
- ├── artifacts.py # Load saved activations metadata
28
  ├── chat.py # Chat generation logic
29
  ├── chat_export.py # Export chat logs to JSON
30
  ├── datasets.py # Dataset loader wrapper
31
- ├── extraction.py # Extraction orchestration
32
  ├── helpers.py # UI labels and slug helpers
33
- ├── local_dataset.py # Local JSONL dataset parsing
34
  └── runtime.py # Model caching and NDIF queries
35
  ```
36
 
@@ -81,7 +78,7 @@ HF_HOME=... # Optional: HuggingFace cache directory
81
  ARTIFACTS_DIR=... # Optional: where activations are read from (default: ./artifacts)
82
  ```
83
 
84
- The app picks up this file automatically via `load_env()` on startup.
85
 
86
  ## Saved Artifacts
87
 
 
24
  │ ├── compare.py # Activation comparison tab
25
  │ └── extract.py # Extraction tab
26
  └── utils/
 
27
  ├── chat.py # Chat generation logic
28
  ├── chat_export.py # Export chat logs to JSON
29
  ├── datasets.py # Dataset loader wrapper
 
30
  ├── helpers.py # UI labels and slug helpers
 
31
  └── runtime.py # Model caching and NDIF queries
32
  ```
33
 
 
78
  ARTIFACTS_DIR=... # Optional: where activations are read from (default: ./artifacts)
79
  ```
80
 
81
+ The app picks up this file automatically via `load_dotenv()` on startup.
82
 
83
  ## Saved Artifacts
84
 
app.py CHANGED
@@ -26,7 +26,7 @@ def _sidebar_controls() -> tuple[bool, str, str, str]:
26
  if st.button(
27
  tab_name,
28
  key=f"sidebar__tab__{tab_name.lower()}",
29
- use_container_width=True,
30
  type="primary" if is_selected else "secondary",
31
  icon=icon,
32
  ):
 
26
  if st.button(
27
  tab_name,
28
  key=f"sidebar__tab__{tab_name.lower()}",
29
+ width="stretch",
30
  type="primary" if is_selected else "secondary",
31
  icon=icon,
32
  ):
tabs/chat.py CHANGED
@@ -1,10 +1,15 @@
1
- import threading
2
  from concurrent.futures import ThreadPoolExecutor
3
- from contextlib import nullcontext
4
 
5
  import streamlit as st
 
6
 
7
- from state import chat_session_key, get_chat_state, reset_chat_state
 
 
 
 
 
8
  from utils.chat import ChatReply, generate_chat_reply, resolve_system_prompt
9
  from utils.chat_export import save_chat_export
10
  from utils.datasets import load_dataset
@@ -12,14 +17,12 @@ from utils.helpers import (
12
  MODE_LABEL_TO_KEY,
13
  MODE_LABELS,
14
  VARIANT_LABELS,
 
15
  persona_label,
16
  widget_key,
17
  )
18
  from utils.runtime import cached_model
19
 
20
- _VISIBLE_MESSAGE_COUNT = 5
21
- _model_lock = threading.Lock()
22
-
23
 
24
  def _render_chat_message(message: dict[str, str]) -> None:
25
  if not message.get("content"):
@@ -33,6 +36,21 @@ def _clear_chat_ui_state(*keys: str) -> None:
33
  st.session_state.pop(key, None)
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def _generation_dict(gen_kwargs: dict, advanced_generation: bool) -> dict[str, object]:
37
  return {
38
  "max_new_tokens": int(gen_kwargs["max_new_tokens"]),
@@ -46,186 +64,146 @@ def _generation_dict(gen_kwargs: dict, advanced_generation: bool) -> dict[str, o
46
  }
47
 
48
 
49
- # ── Compare mode helpers ───────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
50
 
51
-
52
- def _panel_state(panel_key: str) -> dict:
53
- """Get or initialise compare-panel chat state stored in session_state."""
54
- if panel_key not in st.session_state:
55
- st.session_state[panel_key] = {
56
- "messages": [],
57
- "persona_id": None,
58
- "prompt_mode": "templated",
59
- "past_key_values": None,
60
- }
61
- return st.session_state[panel_key]
62
-
63
-
64
- def _render_compare_panel(
65
- side: str,
66
- context_key: str,
67
- personas: list,
68
- remote: bool,
69
- model_name: str,
70
- dataset_source: str,
71
- gen_kwargs: dict,
72
- advanced_generation: bool,
73
- ) -> dict:
74
- """Render persona/prompt controls + chat log for one compare panel.
75
-
76
- Returns a dict with keys needed by the generation step:
77
- panel_key, state, active_system_prompt, selected_persona, chat_log
78
- """
79
- panel_key = widget_key(context_key, f"cmp_{side}")
80
- state = _panel_state(panel_key)
81
-
82
- # ── Per-panel selectors ──────────────────────────────────────────────────
83
- p_col, m_col = st.columns([3, 2])
84
  with p_col:
85
  selected_index = next(
86
- (i for i, p in enumerate(personas) if p.id == state["persona_id"]), 0
87
  )
88
  selected_persona = st.selectbox(
89
  "Persona",
90
  options=personas,
91
  index=selected_index,
92
  format_func=persona_label,
93
- key=widget_key(panel_key, "persona"),
94
  )
95
  with m_col:
96
- current_label = VARIANT_LABELS.get(state["prompt_mode"], "None")
97
  prompt_mode_label = st.selectbox(
98
  "Prompt",
99
  options=MODE_LABELS,
100
  index=MODE_LABELS.index(current_label),
101
- key=widget_key(panel_key, "prompt_mode"),
102
  )
103
  prompt_mode = MODE_LABEL_TO_KEY[prompt_mode_label]
104
-
105
- # Reset state when persona or mode changes.
106
  changed = (
107
- state["persona_id"] != selected_persona.id
108
- or state["prompt_mode"] != prompt_mode
109
  )
110
- if changed:
111
- state["messages"] = []
112
- state["past_key_values"] = None
113
- state["persona_id"] = selected_persona.id
114
- state["prompt_mode"] = prompt_mode
115
- _clear_chat_ui_state(
116
- widget_key(panel_key, "custom_prompt"),
117
- widget_key(panel_key, "show_all"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  )
 
119
 
120
- # ── System prompt ────────────────────────────────────────────────────────
121
- active_system_prompt = resolve_system_prompt(
122
- persona=selected_persona, mode=prompt_mode
123
- )
124
- custom_prompt_key = widget_key(panel_key, "custom_prompt")
125
- if prompt_mode != "empty":
126
- if custom_prompt_key not in st.session_state:
127
- st.session_state[custom_prompt_key] = active_system_prompt
128
- with st.expander("Edit prompt", expanded=False):
129
- active_system_prompt = (
130
- st.text_area(
131
- "prompt",
132
- key=custom_prompt_key,
133
- height=150,
134
- label_visibility="collapsed",
135
- )
136
- or None
137
- )
138
 
139
- export_success_message: str | None = None
140
- action_col1, action_col2 = st.columns(2)
141
- with action_col1:
142
- if st.button(
143
- "Export chat",
144
- key=widget_key(panel_key, "export_chat"),
145
- use_container_width=True,
146
- ):
147
- export_path = save_chat_export(
148
- model_name=model_name,
149
- dataset_source=dataset_source,
150
- persona_id=selected_persona.id,
151
- persona_name=getattr(selected_persona, "name", None),
152
- panel_label=side,
153
- prompt_mode=prompt_mode,
154
- system_prompt=active_system_prompt,
155
- messages=state["messages"],
156
- generation=_generation_dict(gen_kwargs, advanced_generation),
157
- )
158
- export_success_message = f"Saved chat export to {export_path}"
159
- with action_col2:
160
- if st.button(
161
- "Reset chat",
162
- key=widget_key(panel_key, "reset"),
163
- use_container_width=True,
164
- type="secondary",
165
- ):
166
- state["messages"] = []
167
- state["past_key_values"] = None
168
- _clear_chat_ui_state(
169
- widget_key(panel_key, "custom_prompt"),
170
- widget_key(panel_key, "show_all"),
171
- )
172
- st.rerun()
173
 
174
- if export_success_message:
175
- st.success(export_success_message)
176
-
177
- # ── Message history ──────────────────────────────────────────────────────
178
- show_all_key = widget_key(panel_key, "show_all")
179
- messages = state["messages"]
180
- if len(messages) > _VISIBLE_MESSAGE_COUNT and not st.session_state.get(
181
- show_all_key, False
182
- ):
183
- hidden_count = len(messages) - _VISIBLE_MESSAGE_COUNT
184
- if st.button(
185
- f"Show earlier ({hidden_count} hidden)",
186
- key=widget_key(panel_key, "show_all_btn"),
187
  ):
188
- st.session_state[show_all_key] = True
189
- st.rerun()
190
- visible = messages[-_VISIBLE_MESSAGE_COUNT:]
191
- else:
192
- visible = messages
 
 
 
 
 
193
 
194
- chat_log = st.container()
195
- with chat_log:
196
- for msg in visible:
197
- _render_chat_message(msg)
198
 
199
- return {
200
- "panel_key": panel_key,
201
- "state": state,
202
- "active_system_prompt": active_system_prompt,
203
- "selected_persona": selected_persona,
204
- "chat_log": chat_log,
205
- }
206
 
207
 
208
- def _generate_for_panel(
209
- panel: dict,
210
- model,
211
- remote: bool,
212
- gen_kwargs: dict,
213
- ) -> ChatReply:
214
- """Run generate_chat_reply for one compare panel. Thread-safe."""
215
- messages = []
216
- if panel["active_system_prompt"]:
217
- messages.append({"role": "system", "content": panel["active_system_prompt"]})
218
- messages.extend(panel["state"]["messages"])
219
-
220
- ctx = nullcontext() if remote else _model_lock
221
- with ctx:
222
- return generate_chat_reply(
223
- model=model,
224
- messages=messages,
225
- remote=remote,
226
- past_key_values=panel["state"]["past_key_values"],
227
- **gen_kwargs,
228
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
 
231
  def _render_compare_mode(
@@ -233,35 +211,90 @@ def _render_compare_mode(
233
  model_name: str,
234
  context_key: str,
235
  dataset_source: str,
236
- personas: list,
237
  gen_kwargs: dict,
238
  advanced_generation: bool,
239
  ) -> None:
240
  """Render the full side-by-side comparison UI."""
241
  left_col, right_col = st.columns(2)
242
 
243
- with left_col:
244
- left = _render_compare_panel(
245
- "left",
246
- context_key,
 
 
 
 
 
 
247
  personas,
248
- remote,
249
- model_name,
250
- dataset_source,
251
- gen_kwargs,
252
- advanced_generation,
253
  )
254
- with right_col:
255
- right = _render_compare_panel(
256
- "right",
257
- context_key,
258
- personas,
259
- remote,
260
- model_name,
261
- dataset_source,
262
- gen_kwargs,
263
- advanced_generation,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  )
 
 
 
 
 
 
265
 
266
  user_prompt = st.chat_input(
267
  "Ask both...",
@@ -271,43 +304,73 @@ def _render_compare_mode(
271
  return
272
 
273
  model = cached_model(model_name=model_name, remote=remote)
274
- panels = [(left, left_col), (right, right_col)]
 
 
 
275
 
276
- for panel, col in panels:
277
- panel["state"]["messages"].append({"role": "user", "content": user_prompt})
278
- with col:
279
- with panel["chat_log"]:
280
- _render_chat_message({"role": "user", "content": user_prompt})
281
 
282
- # Generate both responses in parallel (remote: truly concurrent; local: serialised via lock).
283
  with st.spinner("Generating..."):
284
- with ThreadPoolExecutor(max_workers=2) as executor:
285
- futures = [
286
- executor.submit(_generate_for_panel, panel, model, remote, gen_kwargs)
287
- for panel, col in panels
288
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  results = []
290
- for future in futures:
291
  try:
292
- results.append(future.result())
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  except Exception as exc:
294
  results.append(exc)
295
 
296
- for (panel, col), result in zip(panels, results):
297
  if isinstance(result, Exception):
298
- with col:
299
- with panel["chat_log"]:
300
- st.error(f"Generation failed: {result}")
301
- panel["state"]["messages"].pop()
302
  continue
303
 
304
- panel["state"]["messages"].append({"role": "assistant", "content": result.text})
305
- panel["state"]["past_key_values"] = (
306
- result.past_key_values if not remote else None
307
- )
308
- with col:
309
- with panel["chat_log"]:
310
- _render_chat_message({"role": "assistant", "content": result.text})
311
 
312
 
313
  # ── Main tab entry point ───────────────────────────────────────────────────────
@@ -465,6 +528,12 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
465
  # ── Single-chat mode ──────────────────────────────────────────────────────
466
  persona_select_key = widget_key(context_key, "persona_select")
467
  prompt_mode_select_key = widget_key(context_key, "system_prompt_select")
 
 
 
 
 
 
468
 
469
  col1, col2 = st.columns([2, 1])
470
  with col1:
@@ -481,66 +550,35 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
481
  )
482
  with col2:
483
  current_mode_label = VARIANT_LABELS.get(chat_state["prompt_mode"], "None")
484
- prompt_mode_label = st.selectbox(
485
  "Prompt",
486
  options=MODE_LABELS,
487
  index=MODE_LABELS.index(current_mode_label),
488
  key=prompt_mode_select_key,
489
  )
490
- prompt_mode = MODE_LABEL_TO_KEY[prompt_mode_label]
491
 
492
  active_system_prompt = resolve_system_prompt(
493
  persona=selected_persona,
494
  mode=prompt_mode,
495
  )
496
 
497
- chat_input_key = widget_key(context_key, "chat_input")
498
- show_all_key = widget_key(context_key, "show_all_messages")
499
- custom_prompt_key = widget_key(context_key, "custom_system_prompt")
500
- pending_key = widget_key(context_key, "pending_prompt")
501
- export_success_message: str | None = None
502
-
503
- action_col1, action_col2 = st.columns(2)
504
- with action_col1:
505
- if st.button("Reset chat", use_container_width=True, type="secondary"):
506
- reset_chat_state(model_name, remote, dataset_source)
507
- _clear_chat_ui_state(
508
- chat_input_key,
509
- show_all_key,
510
- custom_prompt_key,
511
- pending_key,
512
- )
513
- st.rerun()
514
- with action_col2:
515
- if st.button("Export chat", use_container_width=True):
516
- export_path = save_chat_export(
517
- model_name=model_name,
518
- dataset_source=dataset_source,
519
- persona_id=selected_persona.id,
520
- persona_name=getattr(selected_persona, "name", None),
521
- prompt_mode=prompt_mode,
522
- system_prompt=active_system_prompt,
523
- messages=chat_state["messages"],
524
- generation=_generation_dict(gen_kwargs, advanced_generation),
525
- )
526
- export_success_message = f"Saved chat export to {export_path}"
527
-
528
- if export_success_message:
529
- st.success(export_success_message)
530
-
531
  changed_context = (
532
  chat_state["persona_id"] != selected_persona.id
533
  or chat_state["prompt_mode"] != prompt_mode
534
  )
535
  if changed_context:
536
  had_history = bool(chat_state["messages"])
537
- chat_state["persona_id"] = selected_persona.id
538
- chat_state["prompt_mode"] = prompt_mode
539
- reset_chat_state(model_name, remote, dataset_source)
540
- _clear_chat_ui_state(
 
 
 
541
  chat_input_key,
542
  show_all_key,
543
- custom_prompt_key,
544
  pending_key,
545
  )
546
  if had_history:
@@ -548,40 +586,51 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
548
 
549
  chat_log = st.container()
550
 
551
- with chat_log:
552
- # System prompt as first item in conversation — collapsed by default, editable.
553
- if prompt_mode != "empty":
554
- if custom_prompt_key not in st.session_state:
555
- st.session_state[custom_prompt_key] = active_system_prompt
556
- with st.expander("Edit prompt", expanded=False):
557
- active_system_prompt = (
558
- st.text_area(
559
- "Prompt",
560
- key=custom_prompt_key,
561
- height=200,
562
- label_visibility="collapsed",
563
- )
564
- or None
565
- )
566
 
567
- # Collapse older messages, show only the most recent ones.
568
- messages = chat_state["messages"]
569
- if len(messages) > _VISIBLE_MESSAGE_COUNT and not st.session_state.get(
570
- show_all_key, False
571
- ):
572
- hidden_count = len(messages) - _VISIBLE_MESSAGE_COUNT
573
- if st.button(
574
- f"Show earlier messages ({hidden_count} hidden)",
575
- key=widget_key(context_key, "show_all_btn"),
576
- ):
577
- st.session_state[show_all_key] = True
578
- st.rerun()
579
- visible_messages = messages[-_VISIBLE_MESSAGE_COUNT:]
580
- else:
581
- visible_messages = messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
 
583
- for message in visible_messages:
584
- _render_chat_message(message)
 
 
 
 
 
585
 
586
  user_prompt = st.chat_input(
587
  "Ask something...",
@@ -598,10 +647,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
598
  if not st.session_state.pop(pending_key, False):
599
  return
600
 
601
- messages = []
602
- if active_system_prompt:
603
- messages.append({"role": "system", "content": active_system_prompt})
604
- messages.extend(chat_state["messages"])
605
 
606
  with st.spinner("Generating reply..."):
607
  model = cached_model(model_name=model_name, remote=remote)
 
 
1
  from concurrent.futures import ThreadPoolExecutor
2
+ from typing import Any
3
 
4
  import streamlit as st
5
+ from persona_data.synth_persona import PersonaData
6
 
7
+ from state import (
8
+ _default_chat_state,
9
+ chat_session_key,
10
+ get_chat_state,
11
+ reset_chat_state,
12
+ )
13
  from utils.chat import ChatReply, generate_chat_reply, resolve_system_prompt
14
  from utils.chat_export import save_chat_export
15
  from utils.datasets import load_dataset
 
17
  MODE_LABEL_TO_KEY,
18
  MODE_LABELS,
19
  VARIANT_LABELS,
20
+ VISIBLE_MESSAGE_COUNT,
21
  persona_label,
22
  widget_key,
23
  )
24
  from utils.runtime import cached_model
25
 
 
 
 
26
 
27
  def _render_chat_message(message: dict[str, str]) -> None:
28
  if not message.get("content"):
 
36
  st.session_state.pop(key, None)
37
 
38
 
39
+ def _reset_single_chat_context(
40
+ model_name: str,
41
+ remote: bool,
42
+ dataset_source: str,
43
+ chat_state: dict[str, object],
44
+ persona_id: str,
45
+ prompt_mode: str,
46
+ *ui_keys: str,
47
+ ) -> None:
48
+ reset_chat_state(model_name, remote, dataset_source)
49
+ chat_state["persona_id"] = persona_id
50
+ chat_state["prompt_mode"] = prompt_mode
51
+ _clear_chat_ui_state(*ui_keys)
52
+
53
+
54
  def _generation_dict(gen_kwargs: dict, advanced_generation: bool) -> dict[str, object]:
55
  return {
56
  "max_new_tokens": int(gen_kwargs["max_new_tokens"]),
 
64
  }
65
 
66
 
67
+ def _render_persona_prompt_controls(
68
+ personas: list[PersonaData],
69
+ current_persona_id: str | None,
70
+ current_prompt_mode: str,
71
+ persona_key: str,
72
+ prompt_key: str,
73
+ column_widths: tuple[int, int] = (3, 2),
74
+ ) -> tuple[PersonaData, str, bool]:
75
+ """Render persona and prompt selectors, returning the selected values."""
76
 
77
+ p_col, m_col = st.columns(list(column_widths))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  with p_col:
79
  selected_index = next(
80
+ (i for i, p in enumerate(personas) if p.id == current_persona_id), 0
81
  )
82
  selected_persona = st.selectbox(
83
  "Persona",
84
  options=personas,
85
  index=selected_index,
86
  format_func=persona_label,
87
+ key=persona_key,
88
  )
89
  with m_col:
90
+ current_label = VARIANT_LABELS.get(current_prompt_mode, "None")
91
  prompt_mode_label = st.selectbox(
92
  "Prompt",
93
  options=MODE_LABELS,
94
  index=MODE_LABELS.index(current_label),
95
+ key=prompt_key,
96
  )
97
  prompt_mode = MODE_LABEL_TO_KEY[prompt_mode_label]
 
 
98
  changed = (
99
+ current_persona_id != selected_persona.id or current_prompt_mode != prompt_mode
 
100
  )
101
+ return selected_persona, prompt_mode, changed
102
+
103
+
104
+ def _render_system_prompt_editor(
105
+ prompt_key: str,
106
+ prompt_mode: str,
107
+ active_system_prompt: str | None,
108
+ *,
109
+ height: int,
110
+ label: str = "Prompt",
111
+ ) -> str | None:
112
+ """Render the editable system prompt area for a chat panel."""
113
+
114
+ if prompt_mode == "empty":
115
+ return active_system_prompt
116
+
117
+ if prompt_key not in st.session_state:
118
+ st.session_state[prompt_key] = active_system_prompt or ""
119
+
120
+ with st.expander("Edit prompt", expanded=False):
121
+ edited_prompt = (
122
+ st.text_area(
123
+ label,
124
+ key=prompt_key,
125
+ height=height,
126
+ label_visibility="collapsed",
127
+ )
128
+ or None
129
  )
130
+ return edited_prompt
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ def _render_chat_window(
134
+ *,
135
+ chat_log: Any,
136
+ messages: list[dict[str, str]],
137
+ show_all_key: str,
138
+ show_all_btn_key: str,
139
+ show_earlier_label: str,
140
+ ) -> Any:
141
+ """Render the visible chat history inside one container."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ with chat_log:
144
+ if len(messages) > VISIBLE_MESSAGE_COUNT and not st.session_state.get(
145
+ show_all_key, False
 
 
 
 
 
 
 
 
 
 
146
  ):
147
+ hidden_count = len(messages) - VISIBLE_MESSAGE_COUNT
148
+ if st.button(
149
+ f"{show_earlier_label} ({hidden_count} hidden)",
150
+ key=show_all_btn_key,
151
+ ):
152
+ st.session_state[show_all_key] = True
153
+ st.rerun()
154
+ visible_messages = messages[-VISIBLE_MESSAGE_COUNT:]
155
+ else:
156
+ visible_messages = messages
157
 
158
+ for message in visible_messages:
159
+ _render_chat_message(message)
 
 
160
 
161
+ return chat_log
 
 
 
 
 
 
162
 
163
 
164
+ def _build_chat_messages(
165
+ system_prompt: str | None,
166
+ messages: list[dict[str, str]],
167
+ ) -> list[dict[str, str]]:
168
+ return (
169
+ [{"role": "system", "content": system_prompt}] if system_prompt else []
170
+ ) + messages
171
+
172
+
173
+ def _save_chat_export_message(
174
+ *,
175
+ model_name: str,
176
+ dataset_source: str,
177
+ persona_id: str,
178
+ persona_name: str | None,
179
+ prompt_mode: str,
180
+ system_prompt: str | None,
181
+ messages: list[dict[str, str]],
182
+ generation: dict[str, object],
183
+ panel_label: str | None = None,
184
+ ) -> str:
185
+ export_path = save_chat_export(
186
+ model_name=model_name,
187
+ dataset_source=dataset_source,
188
+ persona_id=persona_id,
189
+ persona_name=persona_name,
190
+ panel_label=panel_label,
191
+ prompt_mode=prompt_mode,
192
+ system_prompt=system_prompt,
193
+ messages=messages,
194
+ generation=generation,
195
+ )
196
+ return f"Saved chat export to {export_path}"
197
+
198
+
199
+ # ── Compare mode helpers ───────────────────────────────────────────────────────
200
+
201
+
202
+ def _panel_state(panel_key: str) -> dict:
203
+ """Get or initialise compare-panel chat state stored in session_state."""
204
+ if panel_key not in st.session_state:
205
+ st.session_state[panel_key] = _default_chat_state()
206
+ return st.session_state[panel_key]
207
 
208
 
209
  def _render_compare_mode(
 
211
  model_name: str,
212
  context_key: str,
213
  dataset_source: str,
214
+ personas: list[PersonaData],
215
  gen_kwargs: dict,
216
  advanced_generation: bool,
217
  ) -> None:
218
  """Render the full side-by-side comparison UI."""
219
  left_col, right_col = st.columns(2)
220
 
221
+ def render_panel(side: str, column) -> tuple[dict[str, object], Any, str | None]:
222
+ panel_key = widget_key(context_key, f"cmp_{side}")
223
+ state = st.session_state.get(panel_key)
224
+ if state is None:
225
+ state = _default_chat_state()
226
+ st.session_state[panel_key] = state
227
+ prompt_key = widget_key(panel_key, "custom_prompt")
228
+ show_all_key = widget_key(panel_key, "show_all")
229
+
230
+ selected_persona, prompt_mode, changed = _render_persona_prompt_controls(
231
  personas,
232
+ state["persona_id"],
233
+ state["prompt_mode"],
234
+ widget_key(panel_key, "persona"),
235
+ widget_key(panel_key, "prompt_mode"),
 
236
  )
237
+ if changed:
238
+ state["messages"] = []
239
+ state["past_key_values"] = None
240
+ state["persona_id"] = selected_persona.id
241
+ state["prompt_mode"] = prompt_mode
242
+ _clear_chat_ui_state(prompt_key, show_all_key)
243
+
244
+ active_system_prompt = resolve_system_prompt(
245
+ persona=selected_persona, mode=prompt_mode
246
+ )
247
+ active_system_prompt = _render_system_prompt_editor(
248
+ prompt_key,
249
+ prompt_mode,
250
+ active_system_prompt,
251
+ height=150,
252
+ )
253
+
254
+ btn_col1, btn_col2 = st.columns(2)
255
+ with btn_col1:
256
+ if st.button(
257
+ "Export chat", key=widget_key(panel_key, "export_chat"), width="stretch"
258
+ ):
259
+ st.success(
260
+ _save_chat_export_message(
261
+ model_name=model_name,
262
+ dataset_source=dataset_source,
263
+ persona_id=selected_persona.id,
264
+ persona_name=getattr(selected_persona, "name", None),
265
+ prompt_mode=prompt_mode,
266
+ system_prompt=active_system_prompt,
267
+ messages=state["messages"],
268
+ generation=_generation_dict(gen_kwargs, advanced_generation),
269
+ panel_label=side,
270
+ )
271
+ )
272
+ with btn_col2:
273
+ if st.button(
274
+ "Reset chat",
275
+ key=widget_key(panel_key, "reset"),
276
+ width="stretch",
277
+ type="secondary",
278
+ ):
279
+ state["messages"] = []
280
+ state["past_key_values"] = None
281
+ _clear_chat_ui_state(prompt_key, show_all_key)
282
+ st.rerun()
283
+
284
+ chat_log = st.container()
285
+ _render_chat_window(
286
+ chat_log=chat_log,
287
+ messages=state["messages"],
288
+ show_all_key=show_all_key,
289
+ show_all_btn_key=widget_key(panel_key, "show_all_btn"),
290
+ show_earlier_label="Show earlier",
291
  )
292
+ return state, chat_log, active_system_prompt
293
+
294
+ with left_col:
295
+ left_state, left_log, left_prompt = render_panel("left", left_col)
296
+ with right_col:
297
+ right_state, right_log, right_prompt = render_panel("right", right_col)
298
 
299
  user_prompt = st.chat_input(
300
  "Ask both...",
 
304
  return
305
 
306
  model = cached_model(model_name=model_name, remote=remote)
307
+ panels = [
308
+ (left_state, left_log, left_prompt),
309
+ (right_state, right_log, right_prompt),
310
+ ]
311
 
312
+ for panel_state, panel_log, _panel_prompt in panels:
313
+ panel_state["messages"].append({"role": "user", "content": user_prompt})
314
+ with panel_log:
315
+ _render_chat_message({"role": "user", "content": user_prompt})
 
316
 
 
317
  with st.spinner("Generating..."):
318
+ if remote:
319
+ with ThreadPoolExecutor(max_workers=2) as executor:
320
+ futures = [
321
+ executor.submit(
322
+ generate_chat_reply,
323
+ model=model,
324
+ messages=(
325
+ [{"role": "system", "content": panel_prompt}]
326
+ if panel_prompt
327
+ else []
328
+ )
329
+ + panel_state["messages"],
330
+ remote=remote,
331
+ past_key_values=panel_state["past_key_values"],
332
+ **gen_kwargs,
333
+ )
334
+ for panel_state, _panel_log, panel_prompt in panels
335
+ ]
336
+ results: list[ChatReply | Exception] = []
337
+ for future in futures:
338
+ try:
339
+ results.append(future.result())
340
+ except Exception as exc:
341
+ results.append(exc)
342
+ else:
343
  results = []
344
+ for panel_state, _panel_log, panel_prompt in panels:
345
  try:
346
+ results.append(
347
+ generate_chat_reply(
348
+ model=model,
349
+ messages=(
350
+ [{"role": "system", "content": panel_prompt}]
351
+ if panel_prompt
352
+ else []
353
+ )
354
+ + panel_state["messages"],
355
+ remote=remote,
356
+ past_key_values=panel_state["past_key_values"],
357
+ **gen_kwargs,
358
+ )
359
+ )
360
  except Exception as exc:
361
  results.append(exc)
362
 
363
+ for (panel_state, panel_log, _panel_prompt), result in zip(panels, results):
364
  if isinstance(result, Exception):
365
+ with panel_log:
366
+ st.error(f"Generation failed: {result}")
367
+ panel_state["messages"].pop()
 
368
  continue
369
 
370
+ panel_state["messages"].append({"role": "assistant", "content": result.text})
371
+ panel_state["past_key_values"] = result.past_key_values if not remote else None
372
+ with panel_log:
373
+ _render_chat_message({"role": "assistant", "content": result.text})
 
 
 
374
 
375
 
376
  # ── Main tab entry point ───────────────────────────────────────────────────────
 
528
  # ── Single-chat mode ──────────────────────────────────────────────────────
529
  persona_select_key = widget_key(context_key, "persona_select")
530
  prompt_mode_select_key = widget_key(context_key, "system_prompt_select")
531
+ prompt_key = widget_key(context_key, "custom_system_prompt")
532
+ show_all_key = widget_key(context_key, "show_all_messages")
533
+ chat_input_key = widget_key(context_key, "chat_input")
534
+ pending_key = widget_key(context_key, "pending_prompt")
535
+ export_key = widget_key(context_key, "export_chat")
536
+ reset_key = widget_key(context_key, "reset")
537
 
538
  col1, col2 = st.columns([2, 1])
539
  with col1:
 
550
  )
551
  with col2:
552
  current_mode_label = VARIANT_LABELS.get(chat_state["prompt_mode"], "None")
553
+ st.selectbox(
554
  "Prompt",
555
  options=MODE_LABELS,
556
  index=MODE_LABELS.index(current_mode_label),
557
  key=prompt_mode_select_key,
558
  )
559
+ prompt_mode = MODE_LABEL_TO_KEY[st.session_state[prompt_mode_select_key]]
560
 
561
  active_system_prompt = resolve_system_prompt(
562
  persona=selected_persona,
563
  mode=prompt_mode,
564
  )
565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  changed_context = (
567
  chat_state["persona_id"] != selected_persona.id
568
  or chat_state["prompt_mode"] != prompt_mode
569
  )
570
  if changed_context:
571
  had_history = bool(chat_state["messages"])
572
+ _reset_single_chat_context(
573
+ model_name,
574
+ remote,
575
+ dataset_source,
576
+ chat_state,
577
+ selected_persona.id,
578
+ prompt_mode,
579
  chat_input_key,
580
  show_all_key,
581
+ prompt_key,
582
  pending_key,
583
  )
584
  if had_history:
 
586
 
587
  chat_log = st.container()
588
 
589
+ active_system_prompt = _render_system_prompt_editor(
590
+ prompt_key,
591
+ prompt_mode,
592
+ active_system_prompt,
593
+ height=200,
594
+ )
 
 
 
 
 
 
 
 
 
595
 
596
+ action_col1, action_col2 = st.columns(2)
597
+ with action_col1:
598
+ if st.button("Export chat", key=export_key, width="stretch"):
599
+ st.success(
600
+ _save_chat_export_message(
601
+ model_name=model_name,
602
+ dataset_source=dataset_source,
603
+ persona_id=selected_persona.id,
604
+ persona_name=getattr(selected_persona, "name", None),
605
+ prompt_mode=prompt_mode,
606
+ system_prompt=active_system_prompt,
607
+ messages=chat_state["messages"],
608
+ generation=_generation_dict(gen_kwargs, advanced_generation),
609
+ )
610
+ )
611
+ with action_col2:
612
+ if st.button("Reset chat", key=reset_key, width="stretch", type="secondary"):
613
+ _reset_single_chat_context(
614
+ model_name,
615
+ remote,
616
+ dataset_source,
617
+ chat_state,
618
+ selected_persona.id,
619
+ prompt_mode,
620
+ chat_input_key,
621
+ show_all_key,
622
+ prompt_key,
623
+ pending_key,
624
+ )
625
+ st.rerun()
626
 
627
+ _render_chat_window(
628
+ chat_log=chat_log,
629
+ messages=chat_state["messages"],
630
+ show_all_key=show_all_key,
631
+ show_all_btn_key=widget_key(context_key, "show_all_btn"),
632
+ show_earlier_label="Show earlier messages",
633
+ )
634
 
635
  user_prompt = st.chat_input(
636
  "Ask something...",
 
647
  if not st.session_state.pop(pending_key, False):
648
  return
649
 
650
+ messages = _build_chat_messages(active_system_prompt, chat_state["messages"])
 
 
 
651
 
652
  with st.spinner("Generating reply..."):
653
  model = cached_model(model_name=model_name, remote=remote)
tabs/compare.py CHANGED
@@ -1,21 +1,18 @@
 
 
 
1
  import streamlit as st
 
2
  from persona_data.environment import get_artifacts_dir
3
  from persona_vectors.analysis import build_embedding_figure, project_pca, project_umap
4
- from persona_vectors.plots import (
5
- plot_multiple_layer_similarities,
6
- save_plot_html,
7
- save_plot_png,
8
- )
9
 
10
- from utils.artifacts import (
11
- artifact_persona_options,
12
- list_available_layers,
13
- load_cosine_traces,
14
- load_embedding_samples,
15
- )
16
  from utils.helpers import (
17
  ANALYSIS_HELP_TEXT,
18
- ANALYSIS_LABELS,
19
  ANALYSIS_MODES,
20
  PROMPT_VARIANTS,
21
  persona_display_label,
@@ -29,15 +26,151 @@ def _filename(*parts: str) -> str:
29
  return "__".join(slugify(part) for part in parts if part)
30
 
31
 
32
- def _select_artifact_personas(
33
- artifacts_root: str,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  model_name: str,
35
  variants: list[str],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  ) -> tuple[list[str], dict[str, str]]:
37
- persona_options, persona_names = artifact_persona_options(
38
- artifacts_root,
39
- model_name,
40
- variants,
 
41
  )
42
  if not persona_options:
43
  if len(variants) > 1:
@@ -55,15 +188,81 @@ def _select_artifact_personas(
55
  format_func=lambda persona_id: persona_display_label(
56
  persona_id, persona_names.get(persona_id)
57
  ),
58
- key=widget_key("load", "personas", model_name, *variants),
59
  )
60
  return persona_ids, persona_names
61
 
62
 
63
- def _render_cosine_similarity(
64
- artifacts_root: str,
65
- model_name: str,
 
66
  ) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  col1, col2 = st.columns(2)
68
  with col1:
69
  variant_a = st.selectbox(
@@ -86,24 +285,16 @@ def _render_cosine_similarity(
86
  st.warning("Choose two different variants to compare.")
87
  return
88
 
89
- persona_ids, _ = _select_artifact_personas(
90
- artifacts_root,
91
- model_name,
92
- [variant_a, variant_b],
93
- )
94
  if not persona_ids:
95
  return
96
 
97
- cosine_fig_key = widget_key("load", "cosine_fig_state", model_name)
98
- filename = _filename("compare", "cosine", model_name, variant_a, variant_b)
99
 
100
  if st.button("Compare vectors", type="primary"):
101
- traces, loaded_names, errors = load_cosine_traces(
102
- artifacts_root,
103
- model_name,
104
- persona_ids,
105
- variant_a,
106
- variant_b,
107
  )
108
 
109
  if errors:
@@ -125,7 +316,7 @@ def _render_cosine_similarity(
125
  )
126
  for persona_id, short, long in traces
127
  ]
128
- fig = plot_multiple_layer_similarities(
129
  display_traces,
130
  title=f"{prompt_variant_label(variant_a)} vs {prompt_variant_label(variant_b)}",
131
  show=False,
@@ -134,82 +325,27 @@ def _render_cosine_similarity(
134
 
135
  if cosine_fig_key in st.session_state:
136
  fig, n_traces = st.session_state[cosine_fig_key]
137
- st.plotly_chart(fig, use_container_width=True)
138
- save_col1, save_col2 = st.columns(2)
139
- with save_col1:
140
- if st.button("Save HTML", key=widget_key("load", "save_cosine_html")):
141
- output_path = save_plot_html(fig, filename)
142
- st.success(f"Saved HTML to `{output_path}`")
143
- with save_col2:
144
- if st.button("Save PNG", key=widget_key("load", "save_cosine_png")):
145
- try:
146
- output_path = save_plot_png(fig, filename)
147
- st.success(f"Saved PNG to `{output_path}`")
148
- except Exception as exc:
149
- st.error(f"Could not save PNG: {exc}")
150
  st.success(f"Loaded {n_traces} personas for cosine comparison.")
151
 
152
 
153
- def _render_embedding_analysis(
154
- artifacts_root: str,
155
- model_name: str,
156
- analysis_mode: str,
157
- ) -> None:
158
- selected_variant = st.selectbox(
159
- "Variant",
160
- options=PROMPT_VARIANTS,
161
- format_func=prompt_variant_label,
162
- key=widget_key("load", "variant"),
163
- )
164
-
165
- persona_ids, persona_names = _select_artifact_personas(
166
- artifacts_root,
167
- model_name,
168
- [selected_variant],
169
- )
170
- if not persona_ids:
171
- return
172
-
173
- layer_options = list_available_layers(
174
- artifacts_root,
175
- model_name,
176
- [selected_variant],
177
- persona_ids,
178
- )
179
- if not layer_options:
180
- st.info(
181
- "No shared layers are available for the selected personas. Try fewer personas or a different variant."
182
- )
183
  return
184
-
185
  persona_key = "_".join(sorted(persona_ids))
186
- layer_key = widget_key("load", "layers", model_name, selected_variant, persona_key)
187
- default_layers = [
188
- layer
189
- for layer in st.session_state.get(layer_key, layer_options[:3])
190
- if layer in layer_options
191
- ] or layer_options[:3]
192
- selected_layers = st.multiselect(
193
- "Layers",
194
- options=layer_options,
195
- default=default_layers,
196
- key=layer_key,
197
- )
198
- if not selected_layers:
199
- st.info("Select at least one layer.")
200
  return
201
 
202
- button_label = (
203
- "Generate PCA projection"
204
- if analysis_mode == "PCA"
205
- else "Generate UMAP projection"
206
- )
207
-
208
  embedding_fig_key = widget_key(
209
- "load", "embedding_fig_state", model_name, analysis_mode
210
  )
211
 
212
- if st.button(button_label, type="primary"):
213
  progress = st.progress(0, text="Preparing projections...")
214
 
215
  def update_progress(current: int, total: int, loaded: int) -> None:
@@ -219,15 +355,13 @@ def _render_embedding_analysis(
219
  text=f"Processing layer {current}/{total} ({loaded} plot(s) ready)",
220
  )
221
 
222
- project_fn = project_pca if analysis_mode == "PCA" else project_umap
223
  try:
224
- plots, errors = load_embedding_samples(
225
- artifacts_root,
226
- model_name,
227
  persona_ids,
228
  selected_variant,
229
  selected_layers,
230
- project_fn,
231
  persona_names,
232
  progress_fn=update_progress,
233
  )
@@ -248,18 +382,7 @@ def _render_embedding_analysis(
248
  st.info("Try fewer personas, fewer layers, or a different variant.")
249
  st.session_state.pop(embedding_fig_key, None)
250
  else:
251
- title_prefix, x_label, y_label = ANALYSIS_LABELS[analysis_mode]
252
- rendered_figures: list[tuple[int, object]] = []
253
- for layer_idx, coords, labels, hover_text in plots:
254
- fig = build_embedding_figure(
255
- coords=coords,
256
- labels=labels,
257
- title=f"{title_prefix}, layer {layer_idx}",
258
- x_label=x_label,
259
- y_label=y_label,
260
- hover_text=hover_text,
261
- )
262
- rendered_figures.append((layer_idx, fig))
263
  total_samples = sum(coords.shape[0] for _, coords, _, _ in plots)
264
  st.session_state[embedding_fig_key] = (
265
  rendered_figures,
@@ -274,52 +397,14 @@ def _render_embedding_analysis(
274
  rendered_figures, saved_persona_key, saved_variant, total_samples = (
275
  st.session_state[embedding_fig_key]
276
  )
277
- cols = st.columns(2)
278
- for idx, (layer_idx, fig) in enumerate(rendered_figures):
279
- with cols[idx % 2]:
280
- st.plotly_chart(fig, use_container_width=True)
281
- st.success(
282
- f"Loaded {total_samples} samples across {len(rendered_figures)} layers."
 
283
  )
284
- filenames = [
285
- _filename(
286
- "compare",
287
- analysis_mode,
288
- model_name,
289
- saved_variant,
290
- saved_persona_key,
291
- str(layer_idx),
292
- )
293
- for layer_idx, _ in rendered_figures
294
- ]
295
- save_col1, save_col2 = st.columns(2)
296
- with save_col1:
297
- if st.button(
298
- "Save HTML",
299
- key=widget_key("load", "save_embedding_html", analysis_mode),
300
- ):
301
- saved_paths = [
302
- save_plot_html(fig, fn)
303
- for (_, fig), fn in zip(rendered_figures, filenames)
304
- ]
305
- st.success(
306
- f"Saved {len(saved_paths)} HTML plot(s) to `artifacts/plots`."
307
- )
308
- with save_col2:
309
- if st.button(
310
- "Save PNG",
311
- key=widget_key("load", "save_embedding_png", analysis_mode),
312
- ):
313
- try:
314
- saved_paths = [
315
- save_plot_png(fig, fn)
316
- for (_, fig), fn in zip(rendered_figures, filenames)
317
- ]
318
- st.success(
319
- f"Saved {len(saved_paths)} PNG plot(s) to `artifacts/plots`."
320
- )
321
- except Exception as exc:
322
- st.error(f"Could not save PNGs: {exc}")
323
 
324
 
325
  def render_compare_tab(model_name: str) -> None:
@@ -336,6 +421,8 @@ def render_compare_tab(model_name: str) -> None:
336
  value=str(get_artifacts_dir() / "activations"),
337
  )
338
 
 
 
339
  analysis_mode = st.segmented_control(
340
  "Analysis mode",
341
  options=ANALYSIS_MODES,
@@ -348,7 +435,7 @@ def render_compare_tab(model_name: str) -> None:
348
  st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
349
 
350
  if analysis_mode == "Cosine similarity":
351
- _render_cosine_similarity(artifacts_root, model_name)
352
  return
353
 
354
- _render_embedding_analysis(artifacts_root, model_name, analysis_mode)
 
1
+ from collections.abc import Callable
2
+ from dataclasses import dataclass
3
+
4
  import streamlit as st
5
+ import torch
6
  from persona_data.environment import get_artifacts_dir
7
  from persona_vectors.analysis import build_embedding_figure, project_pca, project_umap
8
+ from persona_vectors.artifacts import ActivationStore
9
+ from persona_vectors.artifacts import list_layers as list_available_layers
10
+ from persona_vectors.artifacts import list_personas as list_available_personas
11
+ from persona_vectors.artifacts import load_mean_activations, load_persona_names
12
+ from persona_vectors.plots import plot_layer_similarity, save_plot_html, save_plot_png
13
 
 
 
 
 
 
 
14
  from utils.helpers import (
15
  ANALYSIS_HELP_TEXT,
 
16
  ANALYSIS_MODES,
17
  PROMPT_VARIANTS,
18
  persona_display_label,
 
26
  return "__".join(slugify(part) for part in parts if part)
27
 
28
 
29
+ @dataclass(frozen=True)
30
+ class ProjectionConfig:
31
+ title_prefix: str
32
+ x_label: str
33
+ y_label: str
34
+ project_fn: Callable[[torch.Tensor], torch.Tensor]
35
+
36
+
37
+ _PROJECTION_CONFIGS: dict[str, ProjectionConfig] = {
38
+ "PCA": ProjectionConfig("PCA", "PC1", "PC2", project_pca),
39
+ "UMAP": ProjectionConfig("UMAP", "UMAP 1", "UMAP 2", project_umap),
40
+ }
41
+
42
+
43
+ @st.cache_data(show_spinner=False)
44
+ def _list_layers(
45
+ root_dir: str,
46
  model_name: str,
47
  variants: list[str],
48
+ persona_ids: list[str],
49
+ ) -> list[int]:
50
+ return list_available_layers(root_dir, model_name, variants, persona_ids)
51
+
52
+
53
+ def _load_embedding_samples(
54
+ store: ActivationStore,
55
+ persona_ids: list[str],
56
+ variant: str,
57
+ selected_layers: list[int],
58
+ project_fn: Callable[[torch.Tensor], torch.Tensor],
59
+ persona_names: dict[str, str],
60
+ progress_fn: Callable[[int, int, int], None] | None = None,
61
+ ) -> tuple[list[tuple[int, torch.Tensor, list[str], list[str]]], list[str]]:
62
+ """Load samples for 2D projections without re-reading each layer from disk."""
63
+
64
+ plots: list[tuple[int, torch.Tensor, list[str], list[str]]] = []
65
+ errors: list[str] = []
66
+ vectors_by_persona: dict[str, torch.Tensor] = {}
67
+
68
+ for persona_id in persona_ids:
69
+ try:
70
+ vectors, _ = store.load(variant, persona_id)
71
+ except (FileNotFoundError, KeyError, OSError, ValueError) as exc:
72
+ errors.append(f"{persona_id} / {variant}: {exc}")
73
+ continue
74
+
75
+ vectors_by_persona[persona_id] = vectors
76
+
77
+ total_layers = len(selected_layers)
78
+ for idx, layer_idx in enumerate(selected_layers, start=1):
79
+ samples: list[torch.Tensor] = []
80
+ labels: list[str] = []
81
+ hover_text: list[str] = []
82
+
83
+ for persona_id, vectors in vectors_by_persona.items():
84
+ if layer_idx >= vectors.shape[1]:
85
+ errors.append(f"{persona_id} / {variant}: missing layer {layer_idx}")
86
+ continue
87
+
88
+ layer_vectors = vectors[:, layer_idx, :]
89
+ samples.append(layer_vectors)
90
+ labels.extend([persona_id] * layer_vectors.shape[0])
91
+ display_name = persona_names.get(persona_id) or persona_id
92
+ hover_text.extend(
93
+ [f"<b>{display_name}</b><br>{variant}"] * layer_vectors.shape[0]
94
+ )
95
+
96
+ if not samples:
97
+ errors.append(f"Layer {layer_idx}: no selected personas have this layer")
98
+ else:
99
+ all_samples = torch.cat(samples, dim=0)
100
+ if all_samples.shape[0] < 2:
101
+ errors.append(
102
+ f"Layer {layer_idx}: need at least 2 samples after filtering selected personas"
103
+ )
104
+ else:
105
+ try:
106
+ coords = project_fn(all_samples)
107
+ plots.append((layer_idx, coords, labels, hover_text))
108
+ except Exception as exc:
109
+ errors.append(f"Layer {layer_idx}: {exc}")
110
+
111
+ if progress_fn is not None:
112
+ progress_fn(idx, total_layers, len(plots))
113
+
114
+ return plots, errors
115
+
116
+
117
+ def _build_embedding_figures(
118
+ plots: list[tuple[int, torch.Tensor, list[str], list[str]]],
119
+ config: ProjectionConfig,
120
+ ) -> list[tuple[int, object]]:
121
+ return [
122
+ (
123
+ layer_idx,
124
+ build_embedding_figure(
125
+ coords=coords,
126
+ labels=labels,
127
+ title=f"{config.title_prefix}, layer {layer_idx}",
128
+ x_label=config.x_label,
129
+ y_label=config.y_label,
130
+ hover_text=hover_text,
131
+ ),
132
+ )
133
+ for layer_idx, coords, labels, hover_text in plots
134
+ ]
135
+
136
+
137
+ def _render_embedding_results(
138
+ store: ActivationStore,
139
+ analysis_mode: str,
140
+ rendered_figures: list[tuple[int, object]],
141
+ saved_variant: str,
142
+ saved_persona_key: str,
143
+ total_samples: int,
144
+ ) -> None:
145
+ cols = st.columns(2)
146
+ for idx, (_, fig) in enumerate(rendered_figures):
147
+ with cols[idx % 2]:
148
+ st.plotly_chart(fig, width="stretch")
149
+
150
+ st.success(f"Loaded {total_samples} samples across {len(rendered_figures)} layers.")
151
+ filenames = [
152
+ _filename(
153
+ "compare",
154
+ analysis_mode,
155
+ store.model_name,
156
+ saved_variant,
157
+ saved_persona_key,
158
+ str(layer_idx),
159
+ )
160
+ for layer_idx, _ in rendered_figures
161
+ ]
162
+ _render_save_buttons([fig for _, fig in rendered_figures], filenames, analysis_mode)
163
+
164
+
165
+ def _select_artifact_personas(
166
+ store: ActivationStore,
167
+ variants: list[str],
168
  ) -> tuple[list[str], dict[str, str]]:
169
+ persona_options = list_available_personas(
170
+ store.root_dir, store.model_name, variants
171
+ )
172
+ persona_names = load_persona_names(
173
+ store.root_dir, store.model_name, variants, persona_options
174
  )
175
  if not persona_options:
176
  if len(variants) > 1:
 
188
  format_func=lambda persona_id: persona_display_label(
189
  persona_id, persona_names.get(persona_id)
190
  ),
191
+ key=widget_key("load", "personas", store.model_name, *variants),
192
  )
193
  return persona_ids, persona_names
194
 
195
 
196
+ def _render_save_buttons(
197
+ figs: list[object],
198
+ filenames: list[str],
199
+ key_suffix: str,
200
  ) -> None:
201
+ """Render Save HTML / Save PNG column buttons for one or more figures."""
202
+ col1, col2 = st.columns(2)
203
+ with col1:
204
+ if st.button("Save HTML", key=widget_key("load", "save_html", key_suffix)):
205
+ paths = [save_plot_html(fig, fn) for fig, fn in zip(figs, filenames)]
206
+ st.success(f"Saved {len(paths)} HTML file(s) to `artifacts/plots`.")
207
+ with col2:
208
+ if st.button("Save PNG", key=widget_key("load", "save_png", key_suffix)):
209
+ try:
210
+ paths = [save_plot_png(fig, fn) for fig, fn in zip(figs, filenames)]
211
+ st.success(f"Saved {len(paths)} PNG file(s) to `artifacts/plots`.")
212
+ except Exception as exc:
213
+ st.error(f"Could not save PNG: {exc}")
214
+
215
+
216
+ def _select_embedding_config(
217
+ store: ActivationStore,
218
+ ) -> tuple[str, list[str], dict[str, str], list[int]] | None:
219
+ """Render variant / persona / layer selectors and return the selection, or None on early exit."""
220
+ selected_variant = st.selectbox(
221
+ "Variant",
222
+ options=PROMPT_VARIANTS,
223
+ format_func=prompt_variant_label,
224
+ key=widget_key("load", "variant"),
225
+ )
226
+
227
+ persona_ids, persona_names = _select_artifact_personas(store, [selected_variant])
228
+ if not persona_ids:
229
+ return None
230
+
231
+ layer_options = _list_layers(
232
+ str(store.root_dir),
233
+ store.model_name,
234
+ [selected_variant],
235
+ persona_ids,
236
+ )
237
+ if not layer_options:
238
+ st.info(
239
+ "No shared layers are available for the selected personas. Try fewer personas or a different variant."
240
+ )
241
+ return None
242
+
243
+ persona_key = "_".join(sorted(persona_ids))
244
+ layer_key = widget_key(
245
+ "load", "layers", store.model_name, selected_variant, persona_key
246
+ )
247
+ default_layers = [
248
+ layer
249
+ for layer in st.session_state.get(layer_key, layer_options[:3])
250
+ if layer in layer_options
251
+ ] or layer_options[:3]
252
+ selected_layers = st.multiselect(
253
+ "Layers",
254
+ options=layer_options,
255
+ default=default_layers,
256
+ key=layer_key,
257
+ )
258
+ if not selected_layers:
259
+ st.info("Select at least one layer.")
260
+ return None
261
+
262
+ return selected_variant, persona_ids, persona_names, selected_layers
263
+
264
+
265
+ def _render_cosine_similarity(store: ActivationStore) -> None:
266
  col1, col2 = st.columns(2)
267
  with col1:
268
  variant_a = st.selectbox(
 
285
  st.warning("Choose two different variants to compare.")
286
  return
287
 
288
+ persona_ids, _ = _select_artifact_personas(store, [variant_a, variant_b])
 
 
 
 
289
  if not persona_ids:
290
  return
291
 
292
+ cosine_fig_key = widget_key("load", "cosine_fig_state", store.model_name)
293
+ filename = _filename("compare", "cosine", store.model_name, variant_a, variant_b)
294
 
295
  if st.button("Compare vectors", type="primary"):
296
+ traces, loaded_names, errors = load_mean_activations(
297
+ store.root_dir, store.model_name, persona_ids, variant_a, variant_b
 
 
 
 
298
  )
299
 
300
  if errors:
 
316
  )
317
  for persona_id, short, long in traces
318
  ]
319
+ fig = plot_layer_similarity(
320
  display_traces,
321
  title=f"{prompt_variant_label(variant_a)} vs {prompt_variant_label(variant_b)}",
322
  show=False,
 
325
 
326
  if cosine_fig_key in st.session_state:
327
  fig, n_traces = st.session_state[cosine_fig_key]
328
+ st.plotly_chart(fig, width="stretch")
329
+ _render_save_buttons([fig], [filename], "cosine")
 
 
 
 
 
 
 
 
 
 
 
330
  st.success(f"Loaded {n_traces} personas for cosine comparison.")
331
 
332
 
333
+ def _render_embedding_analysis(store: ActivationStore, analysis_mode: str) -> None:
334
+ config = _select_embedding_config(store)
335
+ if config is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  return
337
+ selected_variant, persona_ids, persona_names, selected_layers = config
338
  persona_key = "_".join(sorted(persona_ids))
339
+ projection_config = _PROJECTION_CONFIGS.get(analysis_mode)
340
+ if projection_config is None:
341
+ st.error(f"Unsupported analysis mode: {analysis_mode}")
 
 
 
 
 
 
 
 
 
 
 
342
  return
343
 
 
 
 
 
 
 
344
  embedding_fig_key = widget_key(
345
+ "load", "embedding_fig_state", store.model_name, analysis_mode
346
  )
347
 
348
+ if st.button(f"Generate {analysis_mode} projection", type="primary"):
349
  progress = st.progress(0, text="Preparing projections...")
350
 
351
  def update_progress(current: int, total: int, loaded: int) -> None:
 
355
  text=f"Processing layer {current}/{total} ({loaded} plot(s) ready)",
356
  )
357
 
 
358
  try:
359
+ plots, errors = _load_embedding_samples(
360
+ store,
 
361
  persona_ids,
362
  selected_variant,
363
  selected_layers,
364
+ projection_config.project_fn,
365
  persona_names,
366
  progress_fn=update_progress,
367
  )
 
382
  st.info("Try fewer personas, fewer layers, or a different variant.")
383
  st.session_state.pop(embedding_fig_key, None)
384
  else:
385
+ rendered_figures = _build_embedding_figures(plots, projection_config)
 
 
 
 
 
 
 
 
 
 
 
386
  total_samples = sum(coords.shape[0] for _, coords, _, _ in plots)
387
  st.session_state[embedding_fig_key] = (
388
  rendered_figures,
 
397
  rendered_figures, saved_persona_key, saved_variant, total_samples = (
398
  st.session_state[embedding_fig_key]
399
  )
400
+ _render_embedding_results(
401
+ store,
402
+ analysis_mode,
403
+ rendered_figures,
404
+ saved_variant,
405
+ saved_persona_key,
406
+ total_samples,
407
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
 
409
 
410
  def render_compare_tab(model_name: str) -> None:
 
421
  value=str(get_artifacts_dir() / "activations"),
422
  )
423
 
424
+ store = ActivationStore(model_name, artifacts_root)
425
+
426
  analysis_mode = st.segmented_control(
427
  "Analysis mode",
428
  options=ANALYSIS_MODES,
 
435
  st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
436
 
437
  if analysis_mode == "Cosine similarity":
438
+ _render_cosine_similarity(store)
439
  return
440
 
441
+ _render_embedding_analysis(store, analysis_mode)
tabs/extract.py CHANGED
@@ -3,6 +3,7 @@ from persona_vectors.extraction import run_extraction
3
 
4
  from utils.datasets import load_dataset
5
  from utils.helpers import (
 
6
  PROMPT_VARIANTS,
7
  persona_label,
8
  prompt_variant_label,
@@ -84,8 +85,8 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
84
  st.info("Select at least one persona.")
85
  return
86
 
87
- qa_filter_type: str | None
88
- qa_filter_difficulty: list[int] | None
89
 
90
  with st.expander("Advanced", expanded=False):
91
  st.caption("Filters")
@@ -114,35 +115,38 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
114
  )
115
  qa_filter_difficulty = difficulty_values if difficulty_values else None
116
 
117
- # Pre-load QA pairs for all selected personas to validate filters and set slider range.
118
- qa_by_persona = {
119
- p.id: dataset.get_qa(
120
- p.id, type=qa_filter_type, difficulty=qa_filter_difficulty
 
 
121
  )
122
- for p in selected_personas
123
- }
124
- personas_without_qa = [p for p in selected_personas if not qa_by_persona[p.id]]
125
- if personas_without_qa:
126
- names = ", ".join(p.name for p in personas_without_qa)
 
127
  st.warning(f"No QA pairs match filters for: {names}. They will be skipped.")
128
 
129
- personas_to_run = [p for p in selected_personas if qa_by_persona[p.id]]
130
- if not personas_to_run:
131
  st.info("No personas have matching QA pairs. Widen the filters.")
132
  return
133
 
134
- min_qa_count = min(len(qa_by_persona[p.id]) for p in personas_to_run)
 
 
 
 
 
 
 
 
 
135
 
136
- with col3:
137
- max_questions = st.slider(
138
- "Max questions",
139
- min_value=1,
140
- max_value=min_qa_count,
141
- value=min_qa_count,
142
- key=_extract_widget_key(
143
- model_name, remote, dataset_source, "max_questions"
144
- ),
145
- )
146
 
147
  run_clicked = st.button("Run extraction", type="primary")
148
  if not run_clicked:
@@ -153,25 +157,19 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
153
  progress = st.progress(0, text="Preparing extraction...")
154
  ndif_status_box = st.empty() # shows live NDIF job status when remote=True
155
 
156
- _STATUS_ICONS = {
157
- "RECEIVED": "◉", "QUEUED": "◎", "DISPATCHED": "◈",
158
- "RUNNING": "●", "COMPLETED": "✓", "ERROR": "✗",
159
- }
160
-
161
  def _on_ndif_status(job_id: str, status_name: str, description: str) -> None:
162
- icon = _STATUS_ICONS.get(status_name, "•")
163
  ndif_status_box.caption(f"{icon} `{job_id}` **{status_name}** — {description}")
164
 
165
  with st.spinner("Loading model..."):
166
  model = cached_model(model_name=model_name, remote=remote)
167
 
168
  try:
169
- total_steps = len(personas_to_run) * len(selected_variants)
170
  step = 0
171
  results = []
172
 
173
- for persona in personas_to_run:
174
- qa_pairs = qa_by_persona[persona.id][:max_questions]
175
  for variant in selected_variants:
176
  progress.progress(
177
  step / total_steps if total_steps else 1.0,
@@ -181,7 +179,7 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
181
  model=model,
182
  model_name=model_name,
183
  persona=persona,
184
- qa_pairs=qa_pairs,
185
  variants=[variant],
186
  remote=remote,
187
  on_status=_on_ndif_status if remote else None,
 
3
 
4
  from utils.datasets import load_dataset
5
  from utils.helpers import (
6
+ NDIF_STATUS_ICONS,
7
  PROMPT_VARIANTS,
8
  persona_label,
9
  prompt_variant_label,
 
85
  st.info("Select at least one persona.")
86
  return
87
 
88
+ runs = None
89
+ max_questions = 0
90
 
91
  with st.expander("Advanced", expanded=False):
92
  st.caption("Filters")
 
115
  )
116
  qa_filter_difficulty = difficulty_values if difficulty_values else None
117
 
118
+ runs, skipped = [], []
119
+ for persona in selected_personas:
120
+ qa = list(
121
+ dataset.get_qa(
122
+ persona.id, type=qa_filter_type, difficulty=qa_filter_difficulty
123
+ )
124
  )
125
+ if qa:
126
+ runs.append((persona, qa))
127
+ else:
128
+ skipped.append(persona)
129
+ if skipped:
130
+ names = ", ".join(p.name for p in skipped)
131
  st.warning(f"No QA pairs match filters for: {names}. They will be skipped.")
132
 
133
+ if not runs:
 
134
  st.info("No personas have matching QA pairs. Widen the filters.")
135
  return
136
 
137
+ max_q = min(len(qa_pairs) for _, qa_pairs in runs)
138
+ max_questions = st.slider(
139
+ "Max questions",
140
+ min_value=1,
141
+ max_value=max_q,
142
+ value=max_q,
143
+ key=_extract_widget_key(
144
+ model_name, remote, dataset_source, "max_questions"
145
+ ),
146
+ )
147
 
148
+ if runs is None:
149
+ return
 
 
 
 
 
 
 
 
150
 
151
  run_clicked = st.button("Run extraction", type="primary")
152
  if not run_clicked:
 
157
  progress = st.progress(0, text="Preparing extraction...")
158
  ndif_status_box = st.empty() # shows live NDIF job status when remote=True
159
 
 
 
 
 
 
160
  def _on_ndif_status(job_id: str, status_name: str, description: str) -> None:
161
+ icon = NDIF_STATUS_ICONS.get(status_name, "•")
162
  ndif_status_box.caption(f"{icon} `{job_id}` **{status_name}** — {description}")
163
 
164
  with st.spinner("Loading model..."):
165
  model = cached_model(model_name=model_name, remote=remote)
166
 
167
  try:
168
+ total_steps = len(runs) * len(selected_variants)
169
  step = 0
170
  results = []
171
 
172
+ for persona, qa_pairs in runs:
 
173
  for variant in selected_variants:
174
  progress.progress(
175
  step / total_steps if total_steps else 1.0,
 
179
  model=model,
180
  model_name=model_name,
181
  persona=persona,
182
+ qa_pairs=qa_pairs[:max_questions],
183
  variants=[variant],
184
  remote=remote,
185
  on_status=_on_ndif_status if remote else None,
utils/artifacts.py DELETED
@@ -1,244 +0,0 @@
1
- import logging
2
- from collections.abc import Callable
3
- from pathlib import Path
4
-
5
- import streamlit as st
6
- import torch
7
- from persona_vectors.activation_io import (
8
- load_activation_metadata,
9
- load_per_question_vectors,
10
- model_dir_name,
11
- )
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
- def list_available_personas(
17
- artifacts_root: str | Path,
18
- model_name: str,
19
- variants: list[str],
20
- ) -> list[str]:
21
- """List persona ids available for every requested variant."""
22
-
23
- shared_personas: set[str] | None = None
24
- root = Path(artifacts_root)
25
- for variant in variants:
26
- model_dir = root / model_dir_name(model_name) / variant
27
- if not model_dir.exists():
28
- return []
29
-
30
- variant_personas = {d.name for d in model_dir.iterdir() if d.is_dir()}
31
- if shared_personas is None:
32
- shared_personas = variant_personas
33
- else:
34
- shared_personas &= variant_personas
35
-
36
- if not shared_personas:
37
- return []
38
-
39
- return sorted(shared_personas or set())
40
-
41
-
42
- def load_persona_names(
43
- artifacts_root: str | Path,
44
- model_name: str,
45
- variants: list[str],
46
- persona_ids: list[str],
47
- ) -> dict[str, str]:
48
- """Load display names from saved activation metadata."""
49
-
50
- names: dict[str, str] = {}
51
- for persona_id in persona_ids:
52
- for variant in variants:
53
- try:
54
- metadata = load_activation_metadata(
55
- root_dir=artifacts_root,
56
- model_name=model_name,
57
- prompt_variant=variant,
58
- persona_id=persona_id,
59
- )
60
- except Exception:
61
- logger.debug(
62
- "Failed to load metadata for persona %s variant %s",
63
- persona_id,
64
- variant,
65
- exc_info=True,
66
- )
67
- continue
68
-
69
- persona_name = metadata.get("persona_name")
70
- if isinstance(persona_name, str) and persona_name:
71
- names[persona_id] = persona_name
72
- break
73
-
74
- return names
75
-
76
-
77
- def artifact_persona_options(
78
- artifacts_root: str | Path,
79
- model_name: str,
80
- variants: list[str],
81
- ) -> tuple[list[str], dict[str, str]]:
82
- """Return persona ids and names for the selected artifacts."""
83
-
84
- persona_options = list_available_personas(artifacts_root, model_name, variants)
85
- persona_names = load_persona_names(
86
- artifacts_root,
87
- model_name,
88
- variants,
89
- persona_options,
90
- )
91
- return persona_options, persona_names
92
-
93
-
94
- @st.cache_data(show_spinner=False)
95
- def list_available_layers(
96
- artifacts_root: str,
97
- model_name: str,
98
- variants: list[str],
99
- persona_ids: list[str],
100
- ) -> list[int]:
101
- """List layer indices shared by all matching saved activation files."""
102
-
103
- shared_layers: set[int] | None = None
104
- for variant in variants:
105
- for persona_id in persona_ids:
106
- try:
107
- vectors, _ = load_per_question_vectors(
108
- root_dir=artifacts_root,
109
- model_name=model_name,
110
- prompt_variant=variant,
111
- persona_id=persona_id,
112
- )
113
- except Exception:
114
- logger.debug(
115
- "Failed to load vectors for persona %s variant %s",
116
- persona_id,
117
- variant,
118
- exc_info=True,
119
- )
120
- continue
121
-
122
- layers = set(range(vectors.shape[1]))
123
- if shared_layers is None:
124
- shared_layers = layers
125
- else:
126
- shared_layers &= layers
127
-
128
- return sorted(shared_layers or set())
129
-
130
-
131
- def load_cosine_traces(
132
- artifacts_root: str | Path,
133
- model_name: str,
134
- persona_ids: list[str],
135
- variant_a: str,
136
- variant_b: str,
137
- ) -> tuple[list[tuple[str, torch.Tensor, torch.Tensor]], dict[str, str], list[str]]:
138
- """Load mean activation traces for pairwise cosine-similarity plots."""
139
-
140
- persona_names = load_persona_names(
141
- artifacts_root,
142
- model_name,
143
- [variant_a, variant_b],
144
- persona_ids,
145
- )
146
- traces: list[tuple[str, torch.Tensor, torch.Tensor]] = []
147
- errors: list[str] = []
148
-
149
- for persona_id in persona_ids:
150
- try:
151
- vectors_a, _ = load_per_question_vectors(
152
- root_dir=artifacts_root,
153
- model_name=model_name,
154
- prompt_variant=variant_a,
155
- persona_id=persona_id,
156
- )
157
- vectors_b, _ = load_per_question_vectors(
158
- root_dir=artifacts_root,
159
- model_name=model_name,
160
- prompt_variant=variant_b,
161
- persona_id=persona_id,
162
- )
163
- except Exception as exc:
164
- errors.append(f"{persona_id}: {exc}")
165
- continue
166
-
167
- traces.append(
168
- (persona_id, vectors_a.float().mean(dim=0), vectors_b.float().mean(dim=0))
169
- )
170
-
171
- return traces, persona_names, errors
172
-
173
-
174
- def load_embedding_samples(
175
- artifacts_root: str | Path,
176
- model_name: str,
177
- persona_ids: list[str],
178
- variant: str,
179
- selected_layers: list[int],
180
- project_fn: Callable[[torch.Tensor], torch.Tensor],
181
- persona_names: dict[str, str],
182
- progress_fn: Callable[[int, int, int], None] | None = None,
183
- ) -> tuple[list[tuple[int, torch.Tensor, list[str], list[str]]], list[str]]:
184
- """Load samples for 2D projections without re-reading each layer from disk."""
185
-
186
- plots: list[tuple[int, torch.Tensor, list[str], list[str]]] = []
187
- errors: list[str] = []
188
- vectors_by_persona: dict[str, torch.Tensor] = {}
189
-
190
- for persona_id in persona_ids:
191
- try:
192
- vectors, _ = load_per_question_vectors(
193
- root_dir=artifacts_root,
194
- model_name=model_name,
195
- prompt_variant=variant,
196
- persona_id=persona_id,
197
- )
198
- except Exception as exc:
199
- errors.append(f"{persona_id} / {variant}: {exc}")
200
- continue
201
-
202
- vectors_by_persona[persona_id] = vectors
203
-
204
- total_layers = len(selected_layers)
205
- for idx, layer_idx in enumerate(selected_layers, start=1):
206
- samples: list[torch.Tensor] = []
207
- labels: list[str] = []
208
- hover_text: list[str] = []
209
-
210
- for persona_id, vectors in vectors_by_persona.items():
211
- if layer_idx >= vectors.shape[1]:
212
- errors.append(f"{persona_id} / {variant}: missing layer {layer_idx}")
213
- continue
214
-
215
- layer_vectors = vectors[:, layer_idx, :]
216
- samples.append(layer_vectors)
217
- labels.extend([persona_id] * layer_vectors.shape[0])
218
- display_name = persona_names.get(persona_id) or persona_id
219
- hover_text.extend(
220
- [
221
- f"<b>{display_name}</b><br>{variant}",
222
- ]
223
- * layer_vectors.shape[0]
224
- )
225
-
226
- if not samples:
227
- errors.append(f"Layer {layer_idx}: no selected personas have this layer")
228
- else:
229
- all_samples = torch.cat(samples, dim=0)
230
- if all_samples.shape[0] < 2:
231
- errors.append(
232
- f"Layer {layer_idx}: need at least 2 samples after filtering selected personas"
233
- )
234
- else:
235
- try:
236
- coords = project_fn(all_samples)
237
- plots.append((layer_idx, coords, labels, hover_text))
238
- except Exception as exc:
239
- errors.append(f"Layer {layer_idx}: {exc}")
240
-
241
- if progress_fn is not None:
242
- progress_fn(idx, total_layers, len(plots))
243
-
244
- return plots, errors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/chat.py CHANGED
@@ -52,7 +52,6 @@ def resolve_system_prompt(
52
  return format_biography_prompt(persona.biography_md)
53
  if mode == "custom":
54
  return format_empty_persona_prompt()
55
- return ""
56
 
57
 
58
  def _format_plain_messages(
 
52
  return format_biography_prompt(persona.biography_md)
53
  if mode == "custom":
54
  return format_empty_persona_prompt()
 
55
 
56
 
57
  def _format_plain_messages(
utils/chat_export.py CHANGED
@@ -3,24 +3,23 @@ from datetime import datetime, timezone
3
  from pathlib import Path
4
 
5
  from persona_data.environment import get_artifacts_dir
6
- from persona_vectors.activation_io import model_dir_name
7
 
8
  from utils.helpers import slugify
9
 
10
 
11
- def build_chat_export_payload(
12
  *,
13
  model_name: str,
14
  dataset_source: str,
15
  persona_id: str,
16
  persona_name: str | None,
17
- panel_label: str | None,
18
  prompt_mode: str,
19
  system_prompt: str | None,
20
  messages: list[dict[str, str]],
21
  generation: dict[str, object],
22
- ) -> dict[str, object]:
23
- """Build a JSON-serializable snapshot of the current chat session.
 
24
 
25
  Args:
26
  model_name: Model identifier used for the chat.
@@ -28,14 +27,15 @@ def build_chat_export_payload(
28
  persona_id: Selected persona id.
29
  persona_name: Selected persona display name, if available.
30
  prompt_mode: Active system prompt mode.
 
31
  messages: Conversation messages without the system prompt.
32
  generation: Generation settings used for the chat.
33
 
34
  Returns:
35
- A JSON-serializable dictionary.
36
  """
37
 
38
- return {
39
  "model_name": model_name,
40
  "dataset_source": dataset_source,
41
  "persona": {
@@ -51,50 +51,10 @@ def build_chat_export_payload(
51
  + messages,
52
  }
53
 
54
-
55
- def save_chat_export(
56
- *,
57
- model_name: str,
58
- dataset_source: str,
59
- persona_id: str,
60
- persona_name: str | None,
61
- prompt_mode: str,
62
- system_prompt: str | None,
63
- messages: list[dict[str, str]],
64
- generation: dict[str, object],
65
- panel_label: str | None = None,
66
- ) -> Path:
67
- """Save the current chat session to ``artifacts/chats`` as JSON.
68
-
69
- Args:
70
- model_name: Model identifier used for the chat.
71
- dataset_source: Human-readable dataset source label.
72
- persona_id: Selected persona id.
73
- persona_name: Selected persona display name, if available.
74
- prompt_mode: Active system prompt mode.
75
- system_prompt: Current system prompt text, if any.
76
- messages: Conversation messages without the system prompt.
77
- generation: Generation settings used for the chat.
78
-
79
- Returns:
80
- The path the export was written to.
81
- """
82
-
83
- payload = build_chat_export_payload(
84
- model_name=model_name,
85
- dataset_source=dataset_source,
86
- persona_id=persona_id,
87
- persona_name=persona_name,
88
- panel_label=panel_label,
89
- prompt_mode=prompt_mode,
90
- system_prompt=system_prompt,
91
- messages=messages,
92
- generation=generation,
93
- )
94
  export_dir = (
95
  get_artifacts_dir()
96
  / "chats"
97
- / model_dir_name(model_name)
98
  / slugify(dataset_source)
99
  / slugify(persona_id)
100
  )
 
3
  from pathlib import Path
4
 
5
  from persona_data.environment import get_artifacts_dir
 
6
 
7
  from utils.helpers import slugify
8
 
9
 
10
+ def save_chat_export(
11
  *,
12
  model_name: str,
13
  dataset_source: str,
14
  persona_id: str,
15
  persona_name: str | None,
 
16
  prompt_mode: str,
17
  system_prompt: str | None,
18
  messages: list[dict[str, str]],
19
  generation: dict[str, object],
20
+ panel_label: str | None = None,
21
+ ) -> Path:
22
+ """Save the current chat session to ``artifacts/chats`` as JSON.
23
 
24
  Args:
25
  model_name: Model identifier used for the chat.
 
27
  persona_id: Selected persona id.
28
  persona_name: Selected persona display name, if available.
29
  prompt_mode: Active system prompt mode.
30
+ system_prompt: Current system prompt text, if any.
31
  messages: Conversation messages without the system prompt.
32
  generation: Generation settings used for the chat.
33
 
34
  Returns:
35
+ The path the export was written to.
36
  """
37
 
38
+ payload = {
39
  "model_name": model_name,
40
  "dataset_source": dataset_source,
41
  "persona": {
 
51
  + messages,
52
  }
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  export_dir = (
55
  get_artifacts_dir()
56
  / "chats"
57
+ / model_name.replace("/", "__")
58
  / slugify(dataset_source)
59
  / slugify(persona_id)
60
  )
utils/datasets.py CHANGED
@@ -1,4 +1,5 @@
1
  import atexit
 
2
  import shutil
3
  from pathlib import Path
4
  from tempfile import mkdtemp
@@ -31,10 +32,13 @@ def _upload_cache_dir() -> Path:
31
  def _uploaded_file_to_temp_path(uploaded_file: Any, stem: str) -> Path:
32
  suffix = Path(uploaded_file.name).suffix or ".jsonl"
33
  temp_path = _upload_cache_dir() / f"{stem}{suffix}"
 
34
  data = uploaded_file.getvalue()
35
- if temp_path.exists() and temp_path.stat().st_size == len(data):
 
36
  return temp_path
37
  temp_path.write_bytes(data)
 
38
  return temp_path
39
 
40
 
 
1
  import atexit
2
+ import hashlib
3
  import shutil
4
  from pathlib import Path
5
  from tempfile import mkdtemp
 
32
  def _uploaded_file_to_temp_path(uploaded_file: Any, stem: str) -> Path:
33
  suffix = Path(uploaded_file.name).suffix or ".jsonl"
34
  temp_path = _upload_cache_dir() / f"{stem}{suffix}"
35
+ hash_path = temp_path.with_suffix(temp_path.suffix + ".sha256")
36
  data = uploaded_file.getvalue()
37
+ digest = hashlib.sha256(data).hexdigest()
38
+ if temp_path.exists() and hash_path.exists() and hash_path.read_text() == digest:
39
  return temp_path
40
  temp_path.write_bytes(data)
41
+ hash_path.write_text(digest)
42
  return temp_path
43
 
44
 
utils/helpers.py CHANGED
@@ -1,5 +1,7 @@
 
 
1
  from persona_data.synth_persona import PersonaData
2
- from persona_vectors.extraction import SUPPORTED_VARIANTS
3
 
4
  # Variant key -> human-readable label mapping
5
  VARIANT_LABELS = {
@@ -18,25 +20,29 @@ MODE_LABELS = list(VARIANT_LABELS.values())
18
  # Reverse lookup: label -> key
19
  MODE_LABEL_TO_KEY = {v: k for k, v in VARIANT_LABELS.items()}
20
 
 
 
21
  DATASET_SOURCES = ["HuggingFace: synth-persona", "Local JSONL upload"]
22
  ANALYSIS_MODES = ["Cosine similarity", "PCA", "UMAP"]
23
 
24
- ANALYSIS_LABELS = {
25
- "PCA": ("PCA", "PC1", "PC2"),
26
- "UMAP": ("UMAP", "UMAP 1", "UMAP 2"),
27
- }
28
-
29
  ANALYSIS_HELP_TEXT = {
30
  "Cosine similarity": "Compare layer-wise alignment between variants.",
31
  "PCA": "Project the selected layers into a global 2D view.",
32
  "UMAP": "Project the selected layers into a local-neighborhood 2D view.",
33
  }
34
 
 
 
 
 
 
 
 
 
35
 
36
- def slugify(value: str) -> str:
37
- """Convert a string to a slug safe for filenames and URLs."""
38
 
39
- import re
 
40
 
41
  return re.sub(r"[^a-z0-9]+", "_", value.lower()).strip("_") or "unknown"
42
 
 
1
+ import re
2
+
3
  from persona_data.synth_persona import PersonaData
4
+ from persona_vectors.artifacts import SUPPORTED_VARIANTS
5
 
6
  # Variant key -> human-readable label mapping
7
  VARIANT_LABELS = {
 
20
  # Reverse lookup: label -> key
21
  MODE_LABEL_TO_KEY = {v: k for k, v in VARIANT_LABELS.items()}
22
 
23
+ VISIBLE_MESSAGE_COUNT = 5
24
+
25
  DATASET_SOURCES = ["HuggingFace: synth-persona", "Local JSONL upload"]
26
  ANALYSIS_MODES = ["Cosine similarity", "PCA", "UMAP"]
27
 
 
 
 
 
 
28
  ANALYSIS_HELP_TEXT = {
29
  "Cosine similarity": "Compare layer-wise alignment between variants.",
30
  "PCA": "Project the selected layers into a global 2D view.",
31
  "UMAP": "Project the selected layers into a local-neighborhood 2D view.",
32
  }
33
 
34
+ NDIF_STATUS_ICONS = {
35
+ "RECEIVED": "◉",
36
+ "QUEUED": "◎",
37
+ "DISPATCHED": "◈",
38
+ "RUNNING": "●",
39
+ "COMPLETED": "✓",
40
+ "ERROR": "✗",
41
+ }
42
 
 
 
43
 
44
+ def slugify(value: str) -> str:
45
+ """Convert a string to a filesystem-safe slug."""
46
 
47
  return re.sub(r"[^a-z0-9]+", "_", value.lower()).strip("_") or "unknown"
48