Jac-Zac commited on
Commit
b279884
·
1 Parent(s): 9edffb7

Big refactoring

Browse files

- Speed gains
- Improved dendogram figures
- Better information while chatting with models or loading datasets
- Faster overall ui
- Probin UI imrpovements
- Default values changed for better user experiennce
- Code structure refactoring

.env.example CHANGED
@@ -25,3 +25,8 @@ ARTIFACTS_DIR=artifacts
25
  # PERSONA_UI_STORE_CACHE_ENTRIES=4
26
  # PERSONA_UI_VECTOR_CACHE_ENTRIES=4
27
  # PERSONA_UI_PREPARED_CACHE_ENTRIES=8
 
 
 
 
 
 
25
  # PERSONA_UI_STORE_CACHE_ENTRIES=4
26
  # PERSONA_UI_VECTOR_CACHE_ENTRIES=4
27
  # PERSONA_UI_PREPARED_CACHE_ENTRIES=8
28
+ # PERSONA_UI_FIGURE_STATE_ENTRIES=2
29
+ # PERSONA_UI_PREPARED_STATE_ENTRIES=4
30
+ # PERSONA_UI_PROBE_CACHE_ENTRIES=8
31
+ # PERSONA_UI_PROBE_SWEEP_CACHE_ENTRIES=4
32
+ # PERSONA_UI_PROBE_DERIVED_CACHE_ENTRIES=12
README.md CHANGED
@@ -118,6 +118,8 @@ ARTIFACTS_DIR=... # Optional: where persona vectors are read from (default:
118
  PERSONA_VECTORS_HUB_REPO=... # Optional: default Analysis/Probing Hub dataset repo
119
  PERSONA_UI_VECTOR_CACHE_ENTRIES=4 # Optional: loaded analysis datasets kept warm
120
  PERSONA_UI_PREPARED_CACHE_ENTRIES=8 # Optional: prepared projections / k-means groups kept warm
 
 
121
  ```
122
 
123
  The app picks up this file automatically via `load_dotenv()` on startup.
@@ -153,4 +155,4 @@ The store classes are `PersonaVectorStore` (local) and `HFPersonaVectorStore`
153
 
154
  ## Analysis responsiveness
155
 
156
- The Analysis tab keeps a small bounded cache of loaded vector datasets and prepared projection data. Once a projection has been computed, recoloring it by persona, attribute, or k-means group reuses the same coordinates; nearby Hub interactions also keep metadata warm instead of re-scanning after every figure. Tune `PERSONA_UI_VECTOR_CACHE_ENTRIES` if RAM is tight or you regularly switch among many selections, and `PERSONA_UI_PREPARED_CACHE_ENTRIES` if you revisit several projection configurations in one session.
 
118
  PERSONA_VECTORS_HUB_REPO=... # Optional: default Analysis/Probing Hub dataset repo
119
  PERSONA_UI_VECTOR_CACHE_ENTRIES=4 # Optional: loaded analysis datasets kept warm
120
  PERSONA_UI_PREPARED_CACHE_ENTRIES=8 # Optional: prepared projections / k-means groups kept warm
121
+ PERSONA_UI_FIGURE_STATE_ENTRIES=2 # Optional: recent rendered Analysis figures kept in-session
122
+ PERSONA_UI_PREPARED_STATE_ENTRIES=4 # Optional: recent projection-ready markers kept in-session
123
  ```
124
 
125
  The app picks up this file automatically via `load_dotenv()` on startup.
 
155
 
156
  ## Analysis responsiveness
157
 
158
+ The Analysis tab keeps small bounded caches of loaded vector datasets, prepared projection data, and a tiny MRU window of rendered figures. Once a projection has been computed, recoloring it by persona, attribute, or k-means group reuses the same coordinates; nearby method switches can reuse the last couple of figures instead of rebuilding immediately, while the caps keep RAM bounded. Tune `PERSONA_UI_VECTOR_CACHE_ENTRIES` if RAM is tight or you regularly switch among many selections, `PERSONA_UI_PREPARED_CACHE_ENTRIES` if you revisit several projection configurations in one session, and `PERSONA_UI_FIGURE_STATE_ENTRIES` if you want more or less method-switch warmth. Probe loading, probe sweeps, and per-trace probe outputs are bounded separately via `PERSONA_UI_PROBE_CACHE_ENTRIES`, `PERSONA_UI_PROBE_SWEEP_CACHE_ENTRIES`, and `PERSONA_UI_PROBE_DERIVED_CACHE_ENTRIES`; the derived-output cache defaults to a wider MRU window because those tensors are small compared with traced activations and are cheap wins to keep warm.
app.py CHANGED
@@ -4,11 +4,7 @@ from dataclasses import dataclass
4
  import streamlit as st
5
  from dotenv import load_dotenv
6
 
7
- from utils.analysis_sources import (
8
- DEFAULT_COMPARE_MODEL,
9
- DEFAULT_HUB_REPO,
10
- SOURCE_HUB,
11
- )
12
  from utils.helpers import DATASET_SOURCES, session_key, widget_key
13
  from utils.preload import preload_once
14
  from utils.runtime import list_remote_models
@@ -60,21 +56,34 @@ def _hub_metadata_preload_calls() -> tuple[
60
  calls: list[tuple[str, tuple[str, str, str, str | None]]] = []
61
 
62
  def add(repo: str, model: str, mask_strategy: str, variant: str | None) -> None:
63
- calls.append((
64
- "utils.analysis_sources:prefetch_hub_metadata",
65
- (repo, model, mask_strategy, variant),
66
- ))
 
 
 
 
 
 
 
67
 
68
- analysis_source = st.session_state.get("analysis:last_source", SOURCE_HUB)
69
  if analysis_source == SOURCE_HUB:
70
- repo = st.session_state.get("analysis:hub_repo", DEFAULT_HUB_REPO)
 
 
 
71
  mask_strategy = st.session_state.get(
72
  "analysis:last_mask_strategy",
73
- "answer_mean",
74
  )
75
  model = st.session_state.get(
76
  widget_key("load", "hub_model", repo, mask_strategy),
77
- st.session_state.get("analysis:hub_model_fallback", DEFAULT_COMPARE_MODEL),
 
 
 
78
  )
79
  variant = st.session_state.get(
80
  "analysis:last_projection_variant",
@@ -82,16 +91,22 @@ def _hub_metadata_preload_calls() -> tuple[
82
  )
83
  add(repo, model, mask_strategy, variant)
84
 
85
- probe_source = st.session_state.get(widget_key("probe", "source"), SOURCE_HUB)
86
  if probe_source == SOURCE_HUB:
87
- repo = st.session_state.get("probe:hub_repo", DEFAULT_HUB_REPO)
 
 
 
88
  mask_strategy = st.session_state.get(
89
  "probe:last_mask_strategy",
90
- "answer_mean",
91
  )
92
  model = st.session_state.get(
93
  widget_key("probe", "hub_model", repo, mask_strategy),
94
- st.session_state.get("probe:hub_model_fallback", DEFAULT_COMPARE_MODEL),
 
 
 
95
  )
96
  add(repo, model, mask_strategy, st.session_state.get("probe:variant"))
97
 
 
4
  import streamlit as st
5
  from dotenv import load_dotenv
6
 
7
+ from utils.analysis_sources import DEFAULT_COMPARE_MODEL, DEFAULT_HUB_REPO, SOURCE_HUB
 
 
 
 
8
  from utils.helpers import DATASET_SOURCES, session_key, widget_key
9
  from utils.preload import preload_once
10
  from utils.runtime import list_remote_models
 
56
  calls: list[tuple[str, tuple[str, str, str, str | None]]] = []
57
 
58
  def add(repo: str, model: str, mask_strategy: str, variant: str | None) -> None:
59
+ calls.append(
60
+ (
61
+ "utils.analysis_sources:prefetch_hub_metadata",
62
+ (repo, model, mask_strategy, variant),
63
+ )
64
+ )
65
+
66
+ shared_source = st.session_state.get("source:last_source", SOURCE_HUB)
67
+ shared_mask_strategy = st.session_state.get(
68
+ "source:last_mask_strategy", "answer_mean"
69
+ )
70
 
71
+ analysis_source = st.session_state.get("analysis:last_source", shared_source)
72
  if analysis_source == SOURCE_HUB:
73
+ repo = st.session_state.get(
74
+ "analysis:hub_repo",
75
+ st.session_state.get("source:hub_repo", DEFAULT_HUB_REPO),
76
+ )
77
  mask_strategy = st.session_state.get(
78
  "analysis:last_mask_strategy",
79
+ shared_mask_strategy,
80
  )
81
  model = st.session_state.get(
82
  widget_key("load", "hub_model", repo, mask_strategy),
83
+ st.session_state.get(
84
+ "analysis:hub_model_fallback",
85
+ st.session_state.get("source:hub_model", DEFAULT_COMPARE_MODEL),
86
+ ),
87
  )
88
  variant = st.session_state.get(
89
  "analysis:last_projection_variant",
 
91
  )
92
  add(repo, model, mask_strategy, variant)
93
 
94
+ probe_source = st.session_state.get(widget_key("probe", "source"), shared_source)
95
  if probe_source == SOURCE_HUB:
96
+ repo = st.session_state.get(
97
+ "probe:hub_repo",
98
+ st.session_state.get("source:hub_repo", DEFAULT_HUB_REPO),
99
+ )
100
  mask_strategy = st.session_state.get(
101
  "probe:last_mask_strategy",
102
+ shared_mask_strategy,
103
  )
104
  model = st.session_state.get(
105
  widget_key("probe", "hub_model", repo, mask_strategy),
106
+ st.session_state.get(
107
+ "probe:hub_model_fallback",
108
+ st.session_state.get("source:hub_model", DEFAULT_COMPARE_MODEL),
109
+ ),
110
  )
111
  add(repo, model, mask_strategy, st.session_state.get("probe:variant"))
112
 
pyproject.toml CHANGED
@@ -1,6 +1,6 @@
1
  [project]
2
  name = "persona-ui"
3
- version = "0.4.0"
4
  description = "Streamlit UI for persona-vectors"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
 
1
  [project]
2
  name = "persona-ui"
3
+ version = "0.5.0"
4
  description = "Streamlit UI for persona-vectors"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
state.py CHANGED
@@ -21,9 +21,19 @@ class ChatState(TypedDict):
21
 
22
 
23
  def chat_session_key(model_name: str, dataset_source: str) -> str:
24
- """Build the session-state key for a chat context."""
25
 
26
- return session_key("chat_state", model_name, dataset_source)
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  def default_chat_state() -> ChatState:
 
21
 
22
 
23
  def chat_session_key(model_name: str, dataset_source: str) -> str:
24
+ """Build the session-state key for a chat conversation.
25
 
26
+ A model/backend switch changes *how* the next turn is generated, not which
27
+ conversation the user is looking at. Keeping the model out of the key means
28
+ toggling local/remote execution (or selecting another model) no longer makes
29
+ an existing thread appear to vanish behind a fresh empty state.
30
+
31
+ ``model_name`` stays in the signature for call-site compatibility and to
32
+ make the intent explicit where chat state is requested.
33
+ """
34
+
35
+ _ = model_name
36
+ return session_key("chat_state", dataset_source)
37
 
38
 
39
  def default_chat_state() -> ChatState:
tabs/analysis/_shared.py CHANGED
@@ -261,6 +261,7 @@ def _render_persona_count_controls(
261
  *,
262
  default_count: int,
263
  include_assistant_default: bool,
 
264
  ) -> tuple[int, bool]:
265
  count_key = widget_key(
266
  "load",
@@ -280,11 +281,16 @@ def _render_persona_count_controls(
280
  )
281
 
282
  if options.regular_ids:
 
 
 
 
 
283
  persona_count = st.slider(
284
  "Personas",
285
  min_value=0 if options.assistant_id is not None else 1,
286
- max_value=len(options.regular_ids),
287
- value=default_count,
288
  key=count_key,
289
  help="Use the first N available non-assistant personas.",
290
  )
@@ -310,6 +316,7 @@ def _select_artifact_personas(
310
  remember_key: str,
311
  default_all: bool = False,
312
  default_count_limit: int | None = None,
 
313
  ) -> list[str]:
314
  empty_message = _personas_empty_message(variants)
315
  options = _load_persona_options(
@@ -336,6 +343,7 @@ def _select_artifact_personas(
336
  options,
337
  default_count=default_count,
338
  include_assistant_default=include_assistant_default,
 
339
  )
340
 
341
  persona_ids = options.regular_ids[:persona_count]
@@ -361,6 +369,48 @@ def _select_artifact_personas(
361
  return persona_ids
362
 
363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  def _render_save_buttons(
365
  figs: list[object],
366
  filenames: list[str],
@@ -398,6 +448,7 @@ def _render_mask_strategy_select(scope: str) -> MaskStrategy:
398
  return render_mask_strategy_select(
399
  key=widget_key("load", "mask_strategy", scope),
400
  last_key=_LAST_MASK_STRATEGY_KEY,
 
401
  help_text="Which extracted activation set to load.",
402
  )
403
 
@@ -410,6 +461,8 @@ def _select_single_variant_samples(
410
  remember_key: str,
411
  variant_remember_key: str,
412
  default_count_limit: int,
 
 
413
  ) -> tuple[str, list[str], str, list[int]] | None:
414
  variants = available_variants(store, mask_strategy)
415
  if not variants:
@@ -425,14 +478,41 @@ def _select_single_variant_samples(
425
  default=default_variant,
426
  format_func=prompt_variant_label,
427
  )
428
- persona_ids = _select_artifact_personas(
429
- store,
430
- [variant],
431
- mask_strategy,
432
- widget_scope=f"{scope}:{store_id(store)}",
433
- remember_key=remember_key,
434
- default_count_limit=default_count_limit,
435
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  if not persona_ids:
437
  return None
438
 
 
261
  *,
262
  default_count: int,
263
  include_assistant_default: bool,
264
+ max_count_limit: int | None = None,
265
  ) -> tuple[int, bool]:
266
  count_key = widget_key(
267
  "load",
 
281
  )
282
 
283
  if options.regular_ids:
284
+ max_count = (
285
+ min(max_count_limit, len(options.regular_ids))
286
+ if max_count_limit is not None
287
+ else len(options.regular_ids)
288
+ )
289
  persona_count = st.slider(
290
  "Personas",
291
  min_value=0 if options.assistant_id is not None else 1,
292
+ max_value=max_count,
293
+ value=min(default_count, max_count),
294
  key=count_key,
295
  help="Use the first N available non-assistant personas.",
296
  )
 
316
  remember_key: str,
317
  default_all: bool = False,
318
  default_count_limit: int | None = None,
319
+ max_count_limit: int | None = None,
320
  ) -> list[str]:
321
  empty_message = _personas_empty_message(variants)
322
  options = _load_persona_options(
 
343
  options,
344
  default_count=default_count,
345
  include_assistant_default=include_assistant_default,
346
+ max_count_limit=max_count_limit,
347
  )
348
 
349
  persona_ids = options.regular_ids[:persona_count]
 
369
  return persona_ids
370
 
371
 
372
+ def _render_persona_select_controls(
373
+ options: PersonaOptions,
374
+ widget_scope: str,
375
+ *,
376
+ max_selections: int | None = None,
377
+ ) -> list[str]:
378
+ select_key = widget_key("load", "persona_select", widget_scope)
379
+ assistant_key = widget_key("load", "persona_select_assistant", widget_scope)
380
+
381
+ label_map = {
382
+ persona_id: f"{options.persona_names.get(persona_id, persona_id)} ({persona_id})"
383
+ for persona_id in options.regular_ids
384
+ }
385
+ sorted_labels = sorted(label_map.values())
386
+ selected_labels = st.multiselect(
387
+ "Select personas",
388
+ options=sorted_labels,
389
+ key=select_key,
390
+ placeholder="Search and select personas...",
391
+ max_selections=max_selections,
392
+ )
393
+ label_to_id = {label: persona_id for persona_id, label in label_map.items()}
394
+ selected_ids = [label_to_id[label] for label in selected_labels]
395
+
396
+ if options.assistant_id is not None:
397
+ include_assistant = st.checkbox(
398
+ "Include Assistant persona",
399
+ key=assistant_key,
400
+ )
401
+ if include_assistant:
402
+ selected_ids.append(options.assistant_id)
403
+
404
+ st.session_state[_persona_names_state_key(widget_scope)] = dict(
405
+ options.persona_names
406
+ )
407
+
408
+ if not selected_ids:
409
+ st.info("Select at least one persona.")
410
+
411
+ return selected_ids
412
+
413
+
414
  def _render_save_buttons(
415
  figs: list[object],
416
  filenames: list[str],
 
448
  return render_mask_strategy_select(
449
  key=widget_key("load", "mask_strategy", scope),
450
  last_key=_LAST_MASK_STRATEGY_KEY,
451
+ remember_key="source:last_mask_strategy",
452
  help_text="Which extracted activation set to load.",
453
  )
454
 
 
461
  remember_key: str,
462
  variant_remember_key: str,
463
  default_count_limit: int,
464
+ max_count_limit: int | None = None,
465
+ allow_specific_personas: bool = False,
466
  ) -> tuple[str, list[str], str, list[int]] | None:
467
  variants = available_variants(store, mask_strategy)
468
  if not variants:
 
478
  default=default_variant,
479
  format_func=prompt_variant_label,
480
  )
481
+ widget_scope = f"{scope}:{store_id(store)}"
482
+ select_specific = False
483
+ if allow_specific_personas:
484
+ select_specific = st.toggle(
485
+ "Select specific personas",
486
+ value=False,
487
+ key=widget_key("load", "select_specific_personas", scope, store_id(store)),
488
+ help="Search and select specific personas instead of using the first N.",
489
+ )
490
+
491
+ if select_specific:
492
+ options = _load_persona_options(
493
+ store,
494
+ [variant],
495
+ mask_strategy,
496
+ empty_message=_personas_empty_message([variant]),
497
+ )
498
+ if options is None:
499
+ st.session_state.pop(_persona_names_state_key(widget_scope), None)
500
+ return None
501
+ persona_ids = _render_persona_select_controls(
502
+ options,
503
+ widget_scope,
504
+ max_selections=max_count_limit,
505
+ )
506
+ else:
507
+ persona_ids = _select_artifact_personas(
508
+ store,
509
+ [variant],
510
+ mask_strategy,
511
+ widget_scope=widget_scope,
512
+ remember_key=remember_key,
513
+ default_count_limit=default_count_limit,
514
+ max_count_limit=max_count_limit,
515
+ )
516
  if not persona_ids:
517
  return None
518
 
tabs/analysis/_state.py CHANGED
@@ -4,7 +4,7 @@ import streamlit as st
4
  from persona_data.synth_persona import BASELINE_PERSONA_ID
5
  from persona_vectors.attributes import DEFAULT_MAX_ATTRIBUTE_CATEGORIES
6
 
7
- from utils.helpers import slugify, widget_key
8
 
9
 
10
  def _filename(*parts: str) -> str:
@@ -30,11 +30,15 @@ _LAST_LAYER_FRAMES_KEY = "analysis:last_layer_frames"
30
 
31
  _DEFAULT_LAYER_FRAMES = 16
32
  _DEFAULT_PERSONA_LIMITS = {
33
- "similarity": 120,
34
  "pca": 500,
35
  "umap": 500,
36
  "isomap": 500,
37
- "dendro": 160,
 
 
 
 
38
  }
39
  _MAX_SIMILARITY_CELLS = 4_000_000
40
  _MAX_PAIR_TRAJECTORY_TRACES = 500
@@ -136,28 +140,38 @@ def _sequence_to_list(value: object) -> list[object] | None:
136
 
137
 
138
  _TRACKED_STATE_KEYS_KEY = "analysis:_tracked_state_keys"
 
 
139
 
140
 
141
- def _clear_old_load_states(current_key: str, suffix: str) -> None:
142
- # Only one heavy figure state should live at a time. We track
143
- # the keys we create per suffix so eviction is O(1) instead of scanning
144
- # all of session_state on every rerun. Every such key is passed through
145
- # this function before it is set, so the registry stays authoritative.
146
- tracked: dict[str, set[str]] = st.session_state.setdefault(
147
  _TRACKED_STATE_KEYS_KEY, {}
148
  )
149
- for key in tracked.get(suffix, ()):
150
- if key != current_key:
151
- st.session_state.pop(key, None)
152
- tracked[suffix] = {current_key}
 
153
 
154
 
155
  def _clear_old_figure_states(current_key: str) -> None:
156
- _clear_old_load_states(current_key, "_fig_state")
 
 
 
 
157
 
158
 
159
  def _clear_old_prepared_states(current_key: str) -> None:
160
- _clear_old_load_states(current_key, "_projection_ready")
 
 
 
 
161
 
162
 
163
  def _store_figure_state(key: str, value: object) -> None:
 
4
  from persona_data.synth_persona import BASELINE_PERSONA_ID
5
  from persona_vectors.attributes import DEFAULT_MAX_ATTRIBUTE_CATEGORIES
6
 
7
+ from utils.helpers import env_int, slugify, widget_key
8
 
9
 
10
  def _filename(*parts: str) -> str:
 
30
 
31
  _DEFAULT_LAYER_FRAMES = 16
32
  _DEFAULT_PERSONA_LIMITS = {
33
+ "similarity": 20,
34
  "pca": 500,
35
  "umap": 500,
36
  "isomap": 500,
37
+ "dendro": 20,
38
+ }
39
+ _MAX_PERSONA_COUNTS = {
40
+ "similarity": 100,
41
+ "dendro": 100,
42
  }
43
  _MAX_SIMILARITY_CELLS = 4_000_000
44
  _MAX_PAIR_TRAJECTORY_TRACES = 500
 
140
 
141
 
142
  _TRACKED_STATE_KEYS_KEY = "analysis:_tracked_state_keys"
143
+ _FIGURE_STATE_ENTRIES = env_int("PERSONA_UI_FIGURE_STATE_ENTRIES", 2)
144
+ _PREPARED_STATE_ENTRIES = env_int("PERSONA_UI_PREPARED_STATE_ENTRIES", 4)
145
 
146
 
147
+ def _touch_load_state(current_key: str, suffix: str, *, max_entries: int) -> None:
148
+ # Keep a tiny MRU window of heavy state instead of scanning all of
149
+ # session_state or retaining every figure forever. This makes nearby
150
+ # method-switching feel warm while still giving RAM a hard ceiling.
151
+ tracked: dict[str, list[str]] = st.session_state.setdefault(
 
152
  _TRACKED_STATE_KEYS_KEY, {}
153
  )
154
+ keys = [key for key in tracked.get(suffix, []) if key != current_key]
155
+ keys.append(current_key)
156
+ while len(keys) > max(1, max_entries):
157
+ st.session_state.pop(keys.pop(0), None)
158
+ tracked[suffix] = keys
159
 
160
 
161
  def _clear_old_figure_states(current_key: str) -> None:
162
+ _touch_load_state(
163
+ current_key,
164
+ "_fig_state",
165
+ max_entries=_FIGURE_STATE_ENTRIES,
166
+ )
167
 
168
 
169
  def _clear_old_prepared_states(current_key: str) -> None:
170
+ _touch_load_state(
171
+ current_key,
172
+ "_projection_ready",
173
+ max_entries=_PREPARED_STATE_ENTRIES,
174
+ )
175
 
176
 
177
  def _store_figure_state(key: str, value: object) -> None:
tabs/analysis/cosine.py CHANGED
@@ -4,9 +4,6 @@ import streamlit as st
4
  from persona_vectors.extraction import MaskStrategy
5
  from persona_vectors.plots import plot_layer_similarity
6
 
7
- from utils.analysis_sources import Store, available_variants, store_id
8
- from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
9
-
10
  from tabs.analysis._shared import (
11
  _load_variant_vectors,
12
  _plotly_chart,
@@ -21,6 +18,8 @@ from tabs.analysis._state import (
21
  _filename,
22
  _store_figure_state,
23
  )
 
 
24
 
25
 
26
  def _render_cosine_selection(
 
4
  from persona_vectors.extraction import MaskStrategy
5
  from persona_vectors.plots import plot_layer_similarity
6
 
 
 
 
7
  from tabs.analysis._shared import (
8
  _load_variant_vectors,
9
  _plotly_chart,
 
18
  _filename,
19
  _store_figure_state,
20
  )
21
+ from utils.analysis_sources import Store, available_variants, store_id
22
+ from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
23
 
24
 
25
  def _render_cosine_selection(
tabs/analysis/dendrogram.py CHANGED
@@ -1,15 +1,10 @@
 
 
 
1
  import streamlit as st
2
  from persona_vectors.extraction import MaskStrategy
3
  from persona_vectors.plots import plot_persona_dendrogram
4
-
5
- from utils.analysis_sources import (
6
- Store,
7
- available_variants,
8
- store_cache_parts,
9
- store_id,
10
- store_layers_cached,
11
- )
12
- from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
13
 
14
  from tabs.analysis._shared import (
15
  _load_persona_options,
@@ -17,60 +12,113 @@ from tabs.analysis._shared import (
17
  _plotly_chart,
18
  _release_vector_memory,
19
  _render_layer_frame_controls,
 
20
  _render_save_buttons,
21
  _select_artifact_personas,
22
  )
23
  from tabs.analysis._state import (
24
  _DEFAULT_PERSONA_LIMITS,
25
- PersonaOptions,
26
  _clear_old_figure_states,
27
  _filename,
28
  _persona_names_state_key,
29
  _personas_empty_message,
30
  _store_figure_state,
31
  )
 
 
 
 
 
 
 
 
32
 
33
  _LAST_DENDRO_PERSONAS_KEY = "analysis:last_personas:dendro"
34
  _DENDRO_LINKAGE_OPTIONS = ["ward", "complete", "average", "single"]
35
 
36
 
37
- def _render_persona_select_controls(
38
- options: PersonaOptions,
39
- widget_scope: str,
40
- ) -> list[str]:
41
- select_key = widget_key("load", "persona_select", widget_scope)
42
- assistant_key = widget_key("load", "persona_select_assistant", widget_scope)
43
-
44
- label_map = {
45
- pid: f"{options.persona_names.get(pid, pid)} ({pid})"
46
- for pid in options.regular_ids
47
- }
48
- sorted_labels = sorted(label_map.values())
49
- selected_labels = st.multiselect(
50
- "Select personas",
51
- options=sorted_labels,
52
- key=select_key,
53
- placeholder="Search and select personas...",
54
  )
55
- label_to_id = {v: k for k, v in label_map.items()}
56
- selected_ids = [label_to_id[lbl] for lbl in selected_labels]
 
 
57
 
58
- if options.assistant_id is not None:
59
- include_assistant = st.checkbox(
60
- "Include Assistant persona",
61
- key=assistant_key,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  )
63
- if include_assistant:
64
- selected_ids.append(options.assistant_id)
65
 
66
- st.session_state[_persona_names_state_key(widget_scope)] = dict(
67
- options.persona_names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  )
69
-
70
- if not selected_ids:
71
- st.info("Select at least one persona.")
72
-
73
- return selected_ids
 
 
 
 
 
 
 
74
 
75
 
76
  def _render_dendrogram_analysis(
@@ -132,6 +180,7 @@ def _render_dendrogram_analysis(
132
  persona_ids = _render_persona_select_controls(
133
  options,
134
  widget_scope=f"dendro:{store_id(store)}",
 
135
  )
136
  if not persona_ids:
137
  return
@@ -143,6 +192,7 @@ def _render_dendrogram_analysis(
143
  widget_scope=f"dendro:{store_id(store)}",
144
  remember_key=_LAST_DENDRO_PERSONAS_KEY,
145
  default_count_limit=_DEFAULT_PERSONA_LIMITS["dendro"],
 
146
  )
147
  if not persona_ids:
148
  return
@@ -221,7 +271,6 @@ def _render_dendrogram_analysis(
221
  title=f"Dendrogram — {prompt_variant_label(variant_a)}",
222
  )
223
  fig_a.update_layout(height=750)
224
- del samples_a
225
  fig_b = None
226
  if variant_a != variant_b:
227
  progress.progress(60, text="Building second dendrogram…")
@@ -236,10 +285,26 @@ def _render_dendrogram_analysis(
236
  )
237
  fig_b.update_layout(height=750)
238
  del samples_b
 
 
 
 
 
 
 
 
 
239
  progress.progress(90, text="Storing figure state…")
240
  _store_figure_state(
241
  fig_key,
242
- (fig_a, fig_b, len(persona_ids), variant_a, variant_b),
 
 
 
 
 
 
 
243
  )
244
  progress.progress(100, text="Done.")
245
  except Exception as exc:
@@ -250,8 +315,16 @@ def _render_dendrogram_analysis(
250
  progress.empty()
251
 
252
  if fig_key in st.session_state:
253
- fig_a, fig_b, n_personas, va, vb = st.session_state[fig_key]
254
- if fig_b is not None:
 
 
 
 
 
 
 
 
255
  col_a, col_b = st.columns(2)
256
  with col_a:
257
  st.subheader(prompt_variant_label(va))
@@ -262,14 +335,22 @@ def _render_dendrogram_analysis(
262
  else:
263
  _plotly_chart(fig_a)
264
 
265
- figs = [fig_a] + ([fig_b] if fig_b else [])
266
- filenames = [
267
- _filename("dendro", store.model_name, mask_strategy.value, va),
268
- *(
269
- [_filename("dendro", store.model_name, mask_strategy.value, vb)]
270
- if fig_b
271
- else []
272
- ),
273
- ]
 
 
 
 
 
 
 
 
274
  _render_save_buttons(figs, filenames, "dendro")
275
  st.success(f"Generated dendrogram(s) for {n_personas} persona(s).")
 
1
+ from copy import deepcopy
2
+
3
+ import plotly.graph_objects as go
4
  import streamlit as st
5
  from persona_vectors.extraction import MaskStrategy
6
  from persona_vectors.plots import plot_persona_dendrogram
7
+ from plotly.subplots import make_subplots
 
 
 
 
 
 
 
 
8
 
9
  from tabs.analysis._shared import (
10
  _load_persona_options,
 
12
  _plotly_chart,
13
  _release_vector_memory,
14
  _render_layer_frame_controls,
15
+ _render_persona_select_controls,
16
  _render_save_buttons,
17
  _select_artifact_personas,
18
  )
19
  from tabs.analysis._state import (
20
  _DEFAULT_PERSONA_LIMITS,
21
+ _MAX_PERSONA_COUNTS,
22
  _clear_old_figure_states,
23
  _filename,
24
  _persona_names_state_key,
25
  _personas_empty_message,
26
  _store_figure_state,
27
  )
28
+ from utils.analysis_sources import (
29
+ Store,
30
+ available_variants,
31
+ store_cache_parts,
32
+ store_id,
33
+ store_layers_cached,
34
+ )
35
+ from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
36
 
37
  _LAST_DENDRO_PERSONAS_KEY = "analysis:last_personas:dendro"
38
  _DENDRO_LINKAGE_OPTIONS = ["ward", "complete", "average", "single"]
39
 
40
 
41
+ def _comparison_dendrogram_figure(
42
+ fig_a: go.Figure,
43
+ fig_b: go.Figure,
44
+ *,
45
+ title_a: str,
46
+ title_b: str,
47
+ ) -> go.Figure:
48
+ """Merge two layered dendrograms so one slider drives both panels."""
49
+ combined = make_subplots(
50
+ rows=1,
51
+ cols=2,
52
+ subplot_titles=(title_a, title_b),
53
+ shared_yaxes=True,
54
+ horizontal_spacing=0.05,
 
 
 
55
  )
56
+ for trace in fig_a.data:
57
+ combined.add_trace(deepcopy(trace), row=1, col=1)
58
+ for trace in fig_b.data:
59
+ combined.add_trace(deepcopy(trace), row=1, col=2)
60
 
61
+ frames: list[go.Frame] = []
62
+ for frame_a, frame_b in zip(fig_a.frames, fig_b.frames, strict=True):
63
+ right_data = []
64
+ for trace in frame_b.data:
65
+ copied = deepcopy(trace)
66
+ copied.update(xaxis="x2", yaxis="y2")
67
+ right_data.append(copied)
68
+ frame_xaxis = frame_a.layout.xaxis.to_plotly_json()
69
+ frame_xaxis2 = frame_b.layout.xaxis.to_plotly_json()
70
+ frame_xaxis2["matches"] = None
71
+ frame_xaxis2["anchor"] = "y2"
72
+ frame_yaxis = frame_a.layout.yaxis.to_plotly_json()
73
+ frame_yaxis2 = frame_b.layout.yaxis.to_plotly_json()
74
+ frame_yaxis2["matches"] = "y"
75
+ frame_yaxis2["anchor"] = "x2"
76
+ frames.append(
77
+ go.Frame(
78
+ name=frame_a.name,
79
+ data=[*deepcopy(frame_a.data), *right_data],
80
+ layout={
81
+ "title": {"text": f"Dendrogram comparison - Layer {frame_a.name}"},
82
+ "xaxis": frame_xaxis,
83
+ "xaxis2": frame_xaxis2,
84
+ "yaxis": frame_yaxis,
85
+ "yaxis2": frame_yaxis2,
86
+ },
87
+ )
88
  )
 
 
89
 
90
+ y_ranges = [
91
+ fig_a.layout.yaxis.range,
92
+ fig_b.layout.yaxis.range,
93
+ ]
94
+ max_y = max(float(axis_range[1]) for axis_range in y_ranges if axis_range)
95
+ first_layer = fig_a.frames[0].name if fig_a.frames else ""
96
+ combined.frames = frames
97
+ combined.update_layout(
98
+ title={
99
+ "text": f"Dendrogram comparison - Layer {first_layer}",
100
+ "font": {"size": 24},
101
+ "y": 0.98,
102
+ "yanchor": "top",
103
+ },
104
+ template="plotly_white",
105
+ height=750,
106
+ margin=dict(t=140, b=260),
107
+ updatemenus=fig_a.layout.updatemenus,
108
+ sliders=fig_a.layout.sliders,
109
  )
110
+ left_xaxis = fig_a.layout.xaxis.to_plotly_json()
111
+ right_xaxis = fig_b.layout.xaxis.to_plotly_json()
112
+ right_xaxis["matches"] = None
113
+ right_xaxis["anchor"] = "y2"
114
+ combined.update_layout(xaxis=left_xaxis, xaxis2=right_xaxis)
115
+ combined.update_xaxes(tickangle=-45, automargin=True)
116
+ combined.update_yaxes(
117
+ title_text=fig_a.layout.yaxis.title.text,
118
+ range=[0.0, max_y],
119
+ automargin=True,
120
+ )
121
+ return combined
122
 
123
 
124
  def _render_dendrogram_analysis(
 
180
  persona_ids = _render_persona_select_controls(
181
  options,
182
  widget_scope=f"dendro:{store_id(store)}",
183
+ max_selections=_MAX_PERSONA_COUNTS["dendro"],
184
  )
185
  if not persona_ids:
186
  return
 
192
  widget_scope=f"dendro:{store_id(store)}",
193
  remember_key=_LAST_DENDRO_PERSONAS_KEY,
194
  default_count_limit=_DEFAULT_PERSONA_LIMITS["dendro"],
195
+ max_count_limit=_MAX_PERSONA_COUNTS["dendro"],
196
  )
197
  if not persona_ids:
198
  return
 
271
  title=f"Dendrogram — {prompt_variant_label(variant_a)}",
272
  )
273
  fig_a.update_layout(height=750)
 
274
  fig_b = None
275
  if variant_a != variant_b:
276
  progress.progress(60, text="Building second dendrogram…")
 
285
  )
286
  fig_b.update_layout(height=750)
287
  del samples_b
288
+ del samples_a
289
+ comparison_fig = None
290
+ if fig_b is not None and layered_mode:
291
+ comparison_fig = _comparison_dendrogram_figure(
292
+ fig_a,
293
+ fig_b,
294
+ title_a=prompt_variant_label(variant_a),
295
+ title_b=prompt_variant_label(variant_b),
296
+ )
297
  progress.progress(90, text="Storing figure state…")
298
  _store_figure_state(
299
  fig_key,
300
+ (
301
+ None if comparison_fig is not None else fig_a,
302
+ None if comparison_fig is not None else fig_b,
303
+ comparison_fig,
304
+ len(persona_ids),
305
+ variant_a,
306
+ variant_b,
307
+ ),
308
  )
309
  progress.progress(100, text="Done.")
310
  except Exception as exc:
 
315
  progress.empty()
316
 
317
  if fig_key in st.session_state:
318
+ saved = st.session_state[fig_key]
319
+ if len(saved) == 5:
320
+ # Drop pre-refactor state so hot-reloaded sessions do not unpack the
321
+ # old two-figure payload shape.
322
+ st.session_state.pop(fig_key, None)
323
+ return
324
+ fig_a, fig_b, comparison_fig, n_personas, va, vb = saved
325
+ if comparison_fig is not None:
326
+ _plotly_chart(comparison_fig)
327
+ elif fig_b is not None:
328
  col_a, col_b = st.columns(2)
329
  with col_a:
330
  st.subheader(prompt_variant_label(va))
 
335
  else:
336
  _plotly_chart(fig_a)
337
 
338
+ figs = (
339
+ [comparison_fig]
340
+ if comparison_fig is not None
341
+ else [fig_a] + ([fig_b] if fig_b else [])
342
+ )
343
+ filenames = (
344
+ [_filename("dendro_compare", store.model_name, mask_strategy.value, va, vb)]
345
+ if comparison_fig is not None
346
+ else [
347
+ _filename("dendro", store.model_name, mask_strategy.value, va),
348
+ *(
349
+ [_filename("dendro", store.model_name, mask_strategy.value, vb)]
350
+ if fig_b
351
+ else []
352
+ ),
353
+ ]
354
+ )
355
  _render_save_buttons(figs, filenames, "dendro")
356
  st.success(f"Generated dendrogram(s) for {n_personas} persona(s).")
tabs/analysis/layered.py CHANGED
@@ -2,10 +2,7 @@ from collections.abc import Callable
2
 
3
  import plotly.graph_objects as go
4
  import streamlit as st
5
- from persona_vectors.attributes import (
6
- attribute_color_kwargs,
7
- attribute_display_label,
8
- )
9
  from persona_vectors.extraction import MaskStrategy
10
  from persona_vectors.plots import (
11
  build_layered_figure,
@@ -13,19 +10,6 @@ from persona_vectors.plots import (
13
  build_similarity_figures,
14
  )
15
 
16
- from utils.analysis_metadata import (
17
- synth_persona_attribute_names,
18
- synth_persona_dataset_cached,
19
- )
20
- from utils.analysis_sources import (
21
- Store,
22
- kmeans_groups_cached,
23
- projection_data_cached,
24
- store_cache_parts,
25
- store_id,
26
- )
27
- from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
28
-
29
  from tabs.analysis._shared import (
30
  _gray_out_unselected_personas,
31
  _load_persona_vectors,
@@ -61,6 +45,18 @@ from tabs.analysis._state import (
61
  _remembered_selectbox,
62
  _store_figure_state,
63
  )
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
 
66
  def _render_pair_trajectory_control(
@@ -446,6 +442,8 @@ def _render_layered_figure_analysis(
446
  n_components: int = 2,
447
  remember_key: str = _LAST_PROJECTION_PERSONAS_KEY,
448
  default_count_limit: int = 500,
 
 
449
  ) -> None:
450
  """Render a single-variant layered analysis: select → button → figure(s).
451
 
@@ -463,6 +461,8 @@ def _render_layered_figure_analysis(
463
  else _LAST_SIMILARITY_VARIANT_KEY
464
  ),
465
  default_count_limit=default_count_limit,
 
 
466
  )
467
  if selected is None:
468
  return
 
2
 
3
  import plotly.graph_objects as go
4
  import streamlit as st
5
+ from persona_vectors.attributes import attribute_color_kwargs, attribute_display_label
 
 
 
6
  from persona_vectors.extraction import MaskStrategy
7
  from persona_vectors.plots import (
8
  build_layered_figure,
 
10
  build_similarity_figures,
11
  )
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  from tabs.analysis._shared import (
14
  _gray_out_unselected_personas,
15
  _load_persona_vectors,
 
45
  _remembered_selectbox,
46
  _store_figure_state,
47
  )
48
+ from utils.analysis_metadata import (
49
+ synth_persona_attribute_names,
50
+ synth_persona_dataset_cached,
51
+ )
52
+ from utils.analysis_sources import (
53
+ Store,
54
+ kmeans_groups_cached,
55
+ projection_data_cached,
56
+ store_cache_parts,
57
+ store_id,
58
+ )
59
+ from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
60
 
61
 
62
  def _render_pair_trajectory_control(
 
442
  n_components: int = 2,
443
  remember_key: str = _LAST_PROJECTION_PERSONAS_KEY,
444
  default_count_limit: int = 500,
445
+ max_count_limit: int | None = None,
446
+ allow_specific_personas: bool = False,
447
  ) -> None:
448
  """Render a single-variant layered analysis: select → button → figure(s).
449
 
 
461
  else _LAST_SIMILARITY_VARIANT_KEY
462
  ),
463
  default_count_limit=default_count_limit,
464
+ max_count_limit=max_count_limit,
465
+ allow_specific_personas=allow_specific_personas,
466
  )
467
  if selected is None:
468
  return
tabs/analysis_core.py CHANGED
@@ -1,27 +1,4 @@
1
- from pathlib import Path
2
-
3
  import streamlit as st
4
- from persona_data.environment import get_artifacts_dir
5
- from persona_vectors.extraction import MaskStrategy
6
-
7
- from utils.analysis_sources import (
8
- DEFAULT_COMPARE_MODEL,
9
- DEFAULT_HUB_REPO,
10
- SOURCE_HUB,
11
- SOURCE_LOCAL,
12
- SOURCES,
13
- Store,
14
- activation_store_cached,
15
- hub_models_by_mask_strategy,
16
- local_model_matches,
17
- local_model_options_cached,
18
- )
19
- from utils.helpers import (
20
- ANALYSIS_HELP_TEXT,
21
- ANALYSIS_MODES,
22
- prompt_variant_label,
23
- widget_key,
24
- )
25
 
26
  from tabs.analysis._shared import _render_mask_strategy_select
27
  from tabs.analysis._state import (
@@ -29,153 +6,18 @@ from tabs.analysis._state import (
29
  _LAST_PROJECTION_DIMS_KEY,
30
  _LAST_SIMILARITY_PERSONAS_KEY,
31
  _LAST_SOURCE_KEY,
 
32
  )
33
  from tabs.analysis.cosine import _render_cosine_similarity
34
  from tabs.analysis.dendrogram import _render_dendrogram_analysis
35
  from tabs.analysis.layered import _render_layered_figure_analysis
36
-
37
-
38
- def _render_source_select() -> str:
39
- last_source = st.session_state.get(_LAST_SOURCE_KEY, SOURCE_HUB)
40
- source = st.segmented_control(
41
- "Source",
42
- options=SOURCES,
43
- default=last_source if last_source in SOURCES else SOURCE_HUB,
44
- key=widget_key("load", "source"),
45
- label_visibility="collapsed",
46
- )
47
- if source is None:
48
- source = SOURCE_HUB
49
- st.session_state[_LAST_SOURCE_KEY] = source
50
- return source
51
-
52
-
53
- def _render_hub_model_select(
54
- repo_id: str,
55
- mask_strategy: MaskStrategy,
56
- ) -> str:
57
- fallback_model = st.session_state.get(
58
- "analysis:hub_model_fallback",
59
- DEFAULT_COMPARE_MODEL,
60
- )
61
- try:
62
- models_by_strategy = hub_models_by_mask_strategy(repo_id)
63
- except Exception as exc:
64
- st.warning(f"Could not load Hub configs for `{repo_id}`: {exc}")
65
- return st.text_input(
66
- "Hub model",
67
- value=fallback_model,
68
- key="analysis:hub_model_fallback",
69
- help="Analysis-only model id to use if Hub config discovery is unavailable.",
70
- )
71
-
72
- model_options = models_by_strategy.get(mask_strategy, [])
73
- if not model_options:
74
- st.warning(
75
- f"No Hub vector configs found for `{mask_strategy.value}` in `{repo_id}`."
76
- )
77
- return st.text_input(
78
- "Hub model",
79
- value=fallback_model,
80
- key="analysis:hub_model_fallback",
81
- help="Analysis-only model id to use for this Hub repo.",
82
- )
83
-
84
- previous_model = st.session_state.get(
85
- widget_key("load", "hub_model", repo_id, mask_strategy.value),
86
- fallback_model,
87
- )
88
- default_model = (
89
- previous_model if previous_model in model_options else model_options[0]
90
- )
91
-
92
- return st.selectbox(
93
- "Hub model",
94
- options=model_options,
95
- index=model_options.index(default_model),
96
- key=widget_key("load", "hub_model", repo_id, mask_strategy.value),
97
- help="Models with vectors in the selected Hub repo and mask strategy.",
98
- )
99
-
100
-
101
- def _render_local_model_select(
102
- artifacts_root: str,
103
- mask_strategy: MaskStrategy,
104
- ) -> str:
105
- fallback_model = st.session_state.get("analysis:local_model", DEFAULT_COMPARE_MODEL)
106
- model_options = local_model_options_cached(artifacts_root, mask_strategy.value)
107
- if not model_options:
108
- return st.text_input(
109
- "Local model",
110
- value=fallback_model,
111
- key="analysis:local_model",
112
- help="Analysis-only local model id or path.",
113
- )
114
-
115
- custom = st.toggle(
116
- "Custom local model",
117
- value=False,
118
- key="analysis:local_model_custom_enabled",
119
- help="Enter a model id/path manually instead of choosing from activation directories.",
120
- )
121
- if custom:
122
- return st.text_input(
123
- "Local model",
124
- value=fallback_model,
125
- key="analysis:local_model",
126
- help="Analysis-only local model id or path.",
127
- )
128
-
129
- previous_model = st.session_state.get("analysis:local_model_select", fallback_model)
130
- if not any(local_model_matches(previous_model, option) for option in model_options):
131
- previous_model = fallback_model
132
- default_model = next(
133
- (
134
- option
135
- for option in model_options
136
- if local_model_matches(option, previous_model)
137
- ),
138
- model_options[0],
139
- )
140
- selected = st.selectbox(
141
- "Local model",
142
- options=model_options,
143
- index=model_options.index(default_model),
144
- key="analysis:local_model_select",
145
- help="Models discovered under the selected artifacts root.",
146
- )
147
- st.session_state["analysis:local_model"] = selected
148
- return selected
149
-
150
-
151
- def _build_store(source: str, mask_strategy: MaskStrategy) -> Store:
152
- if source == SOURCE_HUB:
153
- repo = st.text_input(
154
- "Hub repo",
155
- value=st.session_state.get("analysis:hub_repo", DEFAULT_HUB_REPO),
156
- key="analysis:hub_repo",
157
- help="Hugging Face dataset published by `scripts/push_to_hf.py`.",
158
- )
159
- hub_model_name = _render_hub_model_select(repo, mask_strategy)
160
- return activation_store_cached(
161
- SOURCE_HUB,
162
- repo,
163
- hub_model_name,
164
- mask_strategy.value,
165
- )
166
- artifacts_root = st.text_input(
167
- "Artifacts root",
168
- value=str(get_artifacts_dir() / "activations"),
169
- key="analysis:artifacts_root",
170
- )
171
- artifacts_root = str(Path(artifacts_root).expanduser())
172
- local_model_name = _render_local_model_select(artifacts_root, mask_strategy)
173
- return activation_store_cached(
174
- SOURCE_LOCAL,
175
- artifacts_root,
176
- local_model_name,
177
- mask_strategy.value,
178
- )
179
 
180
 
181
  def render_analysis_tab() -> None:
@@ -186,7 +28,7 @@ def render_analysis_tab() -> None:
186
  "Analyse persona vectors by cosine similarity, PCA, UMAP, Isomap, or hierarchical clustering."
187
  )
188
 
189
- source = _render_source_select()
190
 
191
  analysis_mode = st.segmented_control(
192
  "Analysis mode",
@@ -201,7 +43,18 @@ def render_analysis_tab() -> None:
201
 
202
  with st.expander("Source settings", expanded=True):
203
  mask_strategy = _render_mask_strategy_select(analysis_mode)
204
- store = _build_store(source, mask_strategy)
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  if analysis_mode == "Cosine similarity":
207
  _render_cosine_similarity(store, mask_strategy)
@@ -219,6 +72,8 @@ def render_analysis_tab() -> None:
219
  include_pair_trajectories=True,
220
  remember_key=_LAST_SIMILARITY_PERSONAS_KEY,
221
  default_count_limit=_DEFAULT_PERSONA_LIMITS["similarity"],
 
 
222
  )
223
  return
224
 
 
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  from tabs.analysis._shared import _render_mask_strategy_select
4
  from tabs.analysis._state import (
 
6
  _LAST_PROJECTION_DIMS_KEY,
7
  _LAST_SIMILARITY_PERSONAS_KEY,
8
  _LAST_SOURCE_KEY,
9
+ _MAX_PERSONA_COUNTS,
10
  )
11
  from tabs.analysis.cosine import _render_cosine_similarity
12
  from tabs.analysis.dendrogram import _render_dendrogram_analysis
13
  from tabs.analysis.layered import _render_layered_figure_analysis
14
+ from utils.helpers import (
15
+ ANALYSIS_HELP_TEXT,
16
+ ANALYSIS_MODES,
17
+ prompt_variant_label,
18
+ widget_key,
19
+ )
20
+ from utils.source_controls import render_source_select, render_store_select
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  def render_analysis_tab() -> None:
 
28
  "Analyse persona vectors by cosine similarity, PCA, UMAP, Isomap, or hierarchical clustering."
29
  )
30
 
31
+ source = render_source_select(widget_scope="load", last_source_key=_LAST_SOURCE_KEY)
32
 
33
  analysis_mode = st.segmented_control(
34
  "Analysis mode",
 
43
 
44
  with st.expander("Source settings", expanded=True):
45
  mask_strategy = _render_mask_strategy_select(analysis_mode)
46
+ store = render_store_select(
47
+ source,
48
+ mask_strategy,
49
+ state_prefix="analysis",
50
+ widget_scope="load",
51
+ artifacts_root_key="analysis:artifacts_root",
52
+ model_label="Hub model",
53
+ local_model_label="Local model",
54
+ allow_custom_local_model=True,
55
+ repo_help="Hugging Face dataset published by `scripts/push_to_hf.py`.",
56
+ fallback_help="Analysis-only model id to use if Hub config discovery is unavailable.",
57
+ )
58
 
59
  if analysis_mode == "Cosine similarity":
60
  _render_cosine_similarity(store, mask_strategy)
 
72
  include_pair_trajectories=True,
73
  remember_key=_LAST_SIMILARITY_PERSONAS_KEY,
74
  default_count_limit=_DEFAULT_PERSONA_LIMITS["similarity"],
75
+ max_count_limit=_MAX_PERSONA_COUNTS["similarity"],
76
+ allow_specific_personas=True,
77
  )
78
  return
79
 
tabs/chat.py CHANGED
@@ -15,6 +15,8 @@ from tabs.chat_shared import (
15
  generate_chat_reply_result,
16
  hydrate_chat_state,
17
  load_chat_personas,
 
 
18
  render_chat_selection,
19
  )
20
  from tabs.chat_ui import (
@@ -25,7 +27,7 @@ from tabs.chat_ui import (
25
  )
26
  from utils.chat import build_chat_messages, resolve_system_prompt
27
  from utils.chat_export import save_chat_export
28
- from utils.helpers import session_key, widget_key
29
  from utils.runtime import cached_model
30
 
31
  if TYPE_CHECKING:
@@ -94,9 +96,26 @@ def _handle_single_chat_generation(
94
  chat_log,
95
  ) -> None:
96
  messages = build_chat_messages(active_system_prompt, chat_state["messages"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  with st.spinner("Generating reply..."):
 
99
  model = cached_model(model_name=model_name)
 
 
100
 
101
  def _show_error(exc: Exception) -> None:
102
  with chat_log:
@@ -108,15 +127,19 @@ def _handle_single_chat_generation(
108
  messages=messages,
109
  remote=remote,
110
  generation=generation,
 
111
  on_error=_show_error,
112
  )
113
  if error is not None:
 
114
  if pending_action == "new_user_prompt" and chat_state["messages"]:
115
  chat_state["messages"].pop()
116
  return
117
  if reply is None:
 
118
  return
119
 
 
120
  chat_state["messages"].append({"role": "assistant", "content": reply.text})
121
  st.rerun()
122
 
 
15
  generate_chat_reply_result,
16
  hydrate_chat_state,
17
  load_chat_personas,
18
+ mark_model_loaded,
19
+ model_load_status,
20
  render_chat_selection,
21
  )
22
  from tabs.chat_ui import (
 
27
  )
28
  from utils.chat import build_chat_messages, resolve_system_prompt
29
  from utils.chat_export import save_chat_export
30
+ from utils.helpers import format_ndif_status, session_key, widget_key
31
  from utils.runtime import cached_model
32
 
33
  if TYPE_CHECKING:
 
96
  chat_log,
97
  ) -> None:
98
  messages = build_chat_messages(active_system_prompt, chat_state["messages"])
99
+ status_box = st.empty()
100
+
101
+ def _show_phase(text: str) -> None:
102
+ status_box.caption(text)
103
+
104
+ def _show_ndif_status(job_id: str, status_name: str, description: str) -> None:
105
+ status_box.caption(
106
+ format_ndif_status(
107
+ job_id,
108
+ status_name,
109
+ description,
110
+ completed_detail="Downloading result...",
111
+ )
112
+ )
113
 
114
  with st.spinner("Generating reply..."):
115
+ _show_phase(model_load_status(model_name))
116
  model = cached_model(model_name=model_name)
117
+ mark_model_loaded(model_name)
118
+ _show_phase("Submitting to NDIF..." if remote else "Generating locally...")
119
 
120
  def _show_error(exc: Exception) -> None:
121
  with chat_log:
 
127
  messages=messages,
128
  remote=remote,
129
  generation=generation,
130
+ on_status=_show_ndif_status if remote else None,
131
  on_error=_show_error,
132
  )
133
  if error is not None:
134
+ status_box.empty()
135
  if pending_action == "new_user_prompt" and chat_state["messages"]:
136
  chat_state["messages"].pop()
137
  return
138
  if reply is None:
139
+ status_box.empty()
140
  return
141
 
142
+ status_box.empty()
143
  chat_state["messages"].append({"role": "assistant", "content": reply.text})
144
  st.rerun()
145
 
tabs/chat_shared.py CHANGED
@@ -23,6 +23,9 @@ class ChatSelection:
23
  changed: bool
24
 
25
 
 
 
 
26
  def load_chat_personas(dataset_source: str) -> list[PersonaData] | None:
27
  personas_file_key = session_key("extract", "personas_file")
28
  qa_file_key = session_key("extract", "qa_file")
@@ -84,12 +87,27 @@ def render_chat_selection(
84
  return ChatSelection(selected_persona, prompt_mode, changed)
85
 
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def generate_chat_reply_result(
88
  *,
89
  model: object,
90
  messages: list[dict[str, str]],
91
  remote: bool,
92
  generation: GenerationConfig,
 
93
  on_error: Callable[[Exception], None] | None = None,
94
  ) -> tuple[ChatReply | None, Exception | None]:
95
  try:
@@ -98,6 +116,7 @@ def generate_chat_reply_result(
98
  model=model,
99
  messages=messages,
100
  remote=remote,
 
101
  **generation.to_generate_kwargs(),
102
  ),
103
  None,
 
23
  changed: bool
24
 
25
 
26
+ _LOADED_MODEL_NAMES_KEY = session_key("chat", "loaded_model_names")
27
+
28
+
29
  def load_chat_personas(dataset_source: str) -> list[PersonaData] | None:
30
  personas_file_key = session_key("extract", "personas_file")
31
  qa_file_key = session_key("extract", "qa_file")
 
87
  return ChatSelection(selected_persona, prompt_mode, changed)
88
 
89
 
90
+ def model_load_status(model_name: str) -> str:
91
+ """Return an honest coarse-grained loading label for the current session."""
92
+
93
+ loaded_names = st.session_state.setdefault(_LOADED_MODEL_NAMES_KEY, set())
94
+ return "Using cached model..." if model_name in loaded_names else "Loading model..."
95
+
96
+
97
+ def mark_model_loaded(model_name: str) -> None:
98
+ """Remember that this session has already requested a model once."""
99
+
100
+ loaded_names = st.session_state.setdefault(_LOADED_MODEL_NAMES_KEY, set())
101
+ loaded_names.add(model_name)
102
+
103
+
104
  def generate_chat_reply_result(
105
  *,
106
  model: object,
107
  messages: list[dict[str, str]],
108
  remote: bool,
109
  generation: GenerationConfig,
110
+ on_status: Callable[[str, str, str], None] | None = None,
111
  on_error: Callable[[Exception], None] | None = None,
112
  ) -> tuple[ChatReply | None, Exception | None]:
113
  try:
 
116
  model=model,
117
  messages=messages,
118
  remote=remote,
119
+ on_status=on_status,
120
  **generation.to_generate_kwargs(),
121
  ),
122
  None,
tabs/chat_ui.py CHANGED
@@ -16,6 +16,7 @@ from utils.helpers import (
16
 
17
  if TYPE_CHECKING:
18
  from persona_data.synth_persona import PersonaData
 
19
  from utils.contrast import TokenContrast
20
 
21
  GENERATION_DEFAULTS = {
 
16
 
17
  if TYPE_CHECKING:
18
  from persona_data.synth_persona import PersonaData
19
+
20
  from utils.contrast import TokenContrast
21
 
22
  GENERATION_DEFAULTS = {
tabs/compare_chat.py CHANGED
@@ -14,7 +14,7 @@ from tabs.chat_shared import (
14
  from utils.chat import ChatReply, build_chat_messages, resolve_system_prompt
15
  from utils.chat_export import save_chat_export
16
  from utils.contrast import compute_contrast, compute_contrast_pair
17
- from utils.helpers import persona_label, session_key, widget_key
18
  from utils.runtime import cached_model
19
 
20
  from .chat_ui import (
@@ -142,15 +142,40 @@ def _generate_panels(
142
  spinner_label: str,
143
  ) -> list[ChatReply | Exception]:
144
  results: list[ChatReply | Exception] = []
 
145
  with st.spinner(spinner_label):
146
  for panel in panels:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  reply, error = generate_chat_reply_result(
148
  model=model,
149
  messages=build_chat_messages(panel.prompt, panel.state["messages"]),
150
  remote=remote,
151
  generation=generation,
 
152
  )
153
  results.append(reply if error is None else error)
 
154
  return results
155
 
156
 
 
14
  from utils.chat import ChatReply, build_chat_messages, resolve_system_prompt
15
  from utils.chat_export import save_chat_export
16
  from utils.contrast import compute_contrast, compute_contrast_pair
17
+ from utils.helpers import format_ndif_status, persona_label, session_key, widget_key
18
  from utils.runtime import cached_model
19
 
20
  from .chat_ui import (
 
142
  spinner_label: str,
143
  ) -> list[ChatReply | Exception]:
144
  results: list[ChatReply | Exception] = []
145
+ status_box = st.empty()
146
  with st.spinner(spinner_label):
147
  for panel in panels:
148
+ panel_label = panel.side.title()
149
+ status_box.caption(
150
+ f"{panel_label}: {'Submitting to NDIF...' if remote else 'Generating locally...'}"
151
+ )
152
+
153
+ def _show_ndif_status(
154
+ job_id: str,
155
+ status_name: str,
156
+ description: str,
157
+ *,
158
+ label: str = panel_label,
159
+ ) -> None:
160
+ status_box.caption(
161
+ format_ndif_status(
162
+ job_id,
163
+ status_name,
164
+ description,
165
+ prefix=label,
166
+ completed_detail="Downloading result...",
167
+ )
168
+ )
169
+
170
  reply, error = generate_chat_reply_result(
171
  model=model,
172
  messages=build_chat_messages(panel.prompt, panel.state["messages"]),
173
  remote=remote,
174
  generation=generation,
175
+ on_status=_show_ndif_status if remote else None,
176
  )
177
  results.append(reply if error is None else error)
178
+ status_box.empty()
179
  return results
180
 
181
 
tabs/extract.py CHANGED
@@ -20,7 +20,7 @@ from utils.datasets import (
20
  warm_qa_in_background,
21
  )
22
  from utils.helpers import (
23
- NDIF_STATUS_ICONS,
24
  persona_label,
25
  prompt_variant_label,
26
  session_key,
@@ -353,8 +353,7 @@ def _run_extraction_plan(
353
  ndif_status_box = st.empty()
354
 
355
  def _on_ndif_status(job_id: str, status_name: str, description: str) -> None:
356
- icon = NDIF_STATUS_ICONS.get(status_name, "•")
357
- ndif_status_box.caption(f"{icon} `{job_id}` **{status_name}** — {description}")
358
 
359
  with st.spinner("Loading model..."):
360
  model = cached_model(model_name=model_name)
 
20
  warm_qa_in_background,
21
  )
22
  from utils.helpers import (
23
+ format_ndif_status,
24
  persona_label,
25
  prompt_variant_label,
26
  session_key,
 
353
  ndif_status_box = st.empty()
354
 
355
  def _on_ndif_status(job_id: str, status_name: str, description: str) -> None:
356
+ ndif_status_box.caption(format_ndif_status(job_id, status_name, description))
 
357
 
358
  with st.spinner("Loading model..."):
359
  model = cached_model(model_name=model_name)
tabs/probe.py CHANGED
@@ -11,43 +11,28 @@ is a thin Streamlit wrapper around them.
11
 
12
  from __future__ import annotations
13
 
14
- from dataclasses import dataclass
15
- from pathlib import Path
16
-
17
  import streamlit as st
18
- from persona_data.environment import get_artifacts_dir
19
  from persona_vectors.analysis import LayeredSamples
20
  from persona_vectors.attributes import attribute_display_label
21
  from persona_vectors.extraction import MaskStrategy
22
  from persona_vectors.plots import plot_metric_comparison, plot_metric_over_layers
23
  from persona_vectors.probes import (
24
  AttributeLabels,
25
- attribute_probe_labels,
26
  default_probe_kinds,
27
- filter_attribute_samples_min_count,
28
  infer_probe_task,
29
  layer_matrix,
30
  save_probe_artifact,
31
  shuffle_label_baseline,
32
- sweep_attribute,
33
  )
34
 
 
35
  from utils.analysis_metadata import (
36
  synth_persona_attribute_names,
37
  synth_persona_dataset_cached,
38
  )
39
  from utils.analysis_sources import (
40
- DEFAULT_COMPARE_MODEL,
41
- DEFAULT_HUB_REPO,
42
- SOURCE_HUB,
43
- SOURCE_LOCAL,
44
- SOURCES,
45
  Store,
46
- activation_store_cached,
47
  available_variants,
48
- hub_models_by_mask_strategy,
49
- load_persona_vectors_cached,
50
- local_model_options_cached,
51
  persona_names_cached,
52
  personas_cached,
53
  store_cache_parts,
@@ -55,6 +40,7 @@ from utils.analysis_sources import (
55
  )
56
  from utils.controls import render_mask_strategy_select
57
  from utils.helpers import widget_key
 
58
 
59
  # ---------------------------------------------------------------------------
60
  # Constants and config
@@ -78,94 +64,6 @@ _SECONDARY_METRIC = {
78
  }
79
 
80
 
81
- @dataclass(frozen=True)
82
- class _SweepInputs:
83
- source: str
84
- location: str
85
- model_name: str
86
- mask_value: str
87
- variant: str
88
- persona_ids: tuple[str, ...]
89
- attributes: tuple[str, ...]
90
- task: str
91
- probe_kinds: tuple[str, ...]
92
- n_pca_components: int | None
93
- layers: tuple[int, ...]
94
- min_class_count: int
95
- seed: int
96
-
97
-
98
- # ---------------------------------------------------------------------------
99
- # Source / store selection (slim mirror of the analysis tab pattern)
100
- # ---------------------------------------------------------------------------
101
-
102
-
103
- def _select_source() -> str:
104
- key = widget_key("probe", "source")
105
- source = st.segmented_control(
106
- "Source",
107
- options=SOURCES,
108
- default=st.session_state.get(key, SOURCE_HUB),
109
- key=key,
110
- label_visibility="collapsed",
111
- )
112
- return source or SOURCE_HUB
113
-
114
-
115
- def _select_store(source: str, mask_strategy: MaskStrategy) -> Store:
116
- if source == SOURCE_HUB:
117
- repo = st.text_input(
118
- "Hub repo",
119
- value=st.session_state.get("probe:hub_repo", DEFAULT_HUB_REPO),
120
- key="probe:hub_repo",
121
- )
122
- models = hub_models_by_mask_strategy(repo).get(mask_strategy, [])
123
- if not models:
124
- st.warning(
125
- f"No Hub vector configs for `{mask_strategy.value}` in `{repo}`."
126
- )
127
- model_name = st.text_input(
128
- "Model",
129
- value=st.session_state.get("probe:hub_model_fallback", DEFAULT_COMPARE_MODEL),
130
- key="probe:hub_model_fallback",
131
- )
132
- else:
133
- previous = st.session_state.get(
134
- widget_key("probe", "hub_model", repo, mask_strategy.value),
135
- models[0],
136
- )
137
- model_name = st.selectbox(
138
- "Model",
139
- options=models,
140
- index=models.index(previous) if previous in models else 0,
141
- key=widget_key("probe", "hub_model", repo, mask_strategy.value),
142
- )
143
- return activation_store_cached(SOURCE_HUB, repo, model_name, mask_strategy.value)
144
-
145
- root = st.text_input(
146
- "Artifacts root",
147
- value=str(get_artifacts_dir() / "activations"),
148
- key="probe:local_root",
149
- )
150
- root = str(Path(root).expanduser())
151
- models = local_model_options_cached(root, mask_strategy.value)
152
- if models:
153
- previous = st.session_state.get("probe:local_model", models[0])
154
- model_name = st.selectbox(
155
- "Model",
156
- options=models,
157
- index=models.index(previous) if previous in models else 0,
158
- key="probe:local_model",
159
- )
160
- else:
161
- model_name = st.text_input(
162
- "Model",
163
- value=st.session_state.get("probe:local_model_fallback", DEFAULT_COMPARE_MODEL),
164
- key="probe:local_model_fallback",
165
- )
166
- return activation_store_cached(SOURCE_LOCAL, root, model_name, mask_strategy.value)
167
-
168
-
169
  def _select_variant(store: Store, mask_strategy: MaskStrategy) -> str | None:
170
  variants = available_variants(store, mask_strategy)
171
  if not variants:
@@ -184,7 +82,9 @@ def _select_personas(
184
  store: Store, variant: str, mask_strategy: MaskStrategy
185
  ) -> list[str]:
186
  source, location, model_name = store_cache_parts(store)
187
- all_ids = personas_cached(source, location, model_name, mask_strategy.value, (variant,))
 
 
188
  if not all_ids:
189
  st.info("No personas found for this variant.")
190
  return []
@@ -225,7 +125,12 @@ def _select_personas(
225
  st.session_state["probe:persona_count"] = count
226
  persona_ids = regular[:count]
227
  persona_names_cached(
228
- source, location, model_name, mask_strategy.value, (variant,), tuple(persona_ids)
 
 
 
 
 
229
  )
230
  st.caption(f"Probing {len(persona_ids)} of {len(regular)} non-assistant personas.")
231
  return persona_ids
@@ -323,13 +228,15 @@ def _select_layers(num_layers: int) -> list[int]:
323
  )
324
  if not fast:
325
  return list(range(num_layers))
326
- return sorted({
327
- 0,
328
- num_layers // 4,
329
- num_layers // 2,
330
- (3 * num_layers) // 4,
331
- num_layers - 1,
332
- })
 
 
333
 
334
 
335
  # ---------------------------------------------------------------------------
@@ -337,66 +244,12 @@ def _select_layers(num_layers: int) -> list[int]:
337
  # ---------------------------------------------------------------------------
338
 
339
 
340
- @st.cache_resource(show_spinner=False)
341
- def _cached_sweep(
342
- inputs: _SweepInputs,
343
- ) -> tuple[
344
- dict[str, list[dict[str, object]]],
345
- dict[str, tuple[AttributeLabels, LayeredSamples]],
346
- ]:
347
- samples = load_persona_vectors_cached(
348
- inputs.source, inputs.location, inputs.model_name,
349
- inputs.mask_value, inputs.variant, inputs.persona_ids,
350
- )
351
- dataset = synth_persona_dataset_cached()
352
- # The min-count filter drops personas per attribute, so each attribute keeps
353
- # its own (labels, samples) pair for the downstream selectivity/save tools.
354
- per_attr: dict[str, tuple[AttributeLabels, LayeredSamples]] = {}
355
-
356
- def _labels_and_samples(attribute: str) -> tuple[AttributeLabels, LayeredSamples]:
357
- if attribute not in per_attr:
358
- labels = attribute_probe_labels(
359
- dataset, attribute, list(inputs.persona_ids), task=inputs.task, # type: ignore[arg-type]
360
- )
361
- probe_samples, labels = filter_attribute_samples_min_count(
362
- samples, labels, min_count=inputs.min_class_count
363
- )
364
- per_attr[attribute] = (labels, probe_samples)
365
- return per_attr[attribute]
366
-
367
- def _sweep(attribute: str, n_pca: int | None) -> list[dict[str, object]]:
368
- labels, probe_samples = _labels_and_samples(attribute)
369
- return sweep_attribute(
370
- probe_samples, labels,
371
- layers=list(inputs.layers),
372
- probe_kinds=list(inputs.probe_kinds), # type: ignore[arg-type]
373
- n_pca_components=n_pca,
374
- seed=inputs.seed,
375
- )
376
-
377
- def _sweep_all(n_pca: int | None) -> list[dict[str, object]]:
378
- rows: list[dict[str, object]] = []
379
- for attribute in inputs.attributes:
380
- rows.extend(_sweep(attribute, n_pca))
381
- return rows
382
-
383
- if inputs.n_pca_components is not None:
384
- # Always overlay the compressed sweep against full activations.
385
- rows_by_label = {
386
- "full": _sweep_all(None),
387
- f"pca{inputs.n_pca_components}": _sweep_all(inputs.n_pca_components),
388
- }
389
- else:
390
- rows_by_label = {"full": _sweep_all(None)}
391
- return rows_by_label, per_attr
392
-
393
-
394
  def _show_sweep(
395
  rows_by_label: dict[str, list[dict[str, object]]],
396
  per_attr: dict[str, tuple[AttributeLabels, LayeredSamples]],
397
  attributes: tuple[str, ...],
398
  task: str,
399
- inputs: _SweepInputs,
400
  ) -> None:
401
  primary = _PRIMARY_METRIC[task]
402
  secondary = _SECONDARY_METRIC.get(task)
@@ -442,8 +295,7 @@ def _show_sweep(
442
  for label, label_rows in rows_by_label.items():
443
  for attribute in attributes:
444
  attr_rows = [
445
- row for row in label_rows
446
- if row.get("attribute") == attribute
447
  ]
448
  label_best = _best_row(attr_rows)
449
  if label_best is None:
@@ -451,22 +303,23 @@ def _show_sweep(
451
  summary_row: dict[str, object] = {}
452
  if multi_attr:
453
  summary_row["attribute"] = attribute
454
- summary_row.update({
455
- "features": label,
456
- "best_layer": label_best["layer"],
457
- "probe": label_best["probe_kind"],
458
- primary: round(float(label_best[primary]), 3),
459
- f"baseline_{primary}": round(
460
- float(label_best.get(f"baseline_{primary}", float("nan"))), 3
461
- ),
462
- })
 
 
 
463
  summary_rows.append(summary_row)
464
  if summary_rows:
465
  st.dataframe(summary_rows, width="stretch", hide_index=True)
466
 
467
- feature_desc = (
468
- f" · pca{inputs.n_pca_components}" if inputs.n_pca_components else ""
469
- )
470
 
471
  best_attr = str(best["attribute"])
472
  labels, samples = per_attr[best_attr]
@@ -495,7 +348,7 @@ def _render_selectivity_control(
495
  labels: AttributeLabels,
496
  samples: LayeredSamples,
497
  task: str,
498
- inputs: _SweepInputs,
499
  ) -> None:
500
  if task == "numeric":
501
  return # selectivity control is classification-only
@@ -507,14 +360,18 @@ def _render_selectivity_control(
507
  "dataset artifacts, not the property."
508
  )
509
  n_repeats = st.slider(
510
- "Shuffle repeats", min_value=3, max_value=15, value=5,
 
 
 
511
  key="probe:shuffle_repeats",
512
  )
513
  if st.button("Run selectivity control", key="probe:run_shuffle"):
514
  with st.spinner("Running shuffled-label control..."):
515
  X = layer_matrix(samples, int(best["layer"]))
516
  shuffled = shuffle_label_baseline(
517
- X, labels.y,
 
518
  task=task, # type: ignore[arg-type]
519
  layer=int(best["layer"]),
520
  probe_kind=best["probe_kind"], # type: ignore[arg-type]
@@ -539,7 +396,7 @@ def _render_save_artifact(
539
  labels: AttributeLabels,
540
  samples: LayeredSamples,
541
  task: str,
542
- inputs: _SweepInputs,
543
  ) -> None:
544
  def synced_default(key: str, default: str) -> str:
545
  default_key = f"{key}:default"
@@ -575,7 +432,9 @@ def _render_save_artifact(
575
  if st.button("Save", key="probe:save_artifact"):
576
  X = layer_matrix(samples, int(best["layer"]))
577
  directory = save_probe_artifact(
578
- X=X, y=labels.y, labels=labels,
 
 
579
  task=task, # type: ignore[arg-type]
580
  probe_kind=best["probe_kind"], # type: ignore[arg-type]
581
  n_pca_components=inputs.n_pca_components,
@@ -601,14 +460,21 @@ def _render_save_artifact(
601
  def render_probing_tab() -> None:
602
  st.title("Probing")
603
 
604
- source = _select_source()
605
  with st.expander("Source", expanded=True):
606
  mask_strategy = render_mask_strategy_select(
607
  key=widget_key("probe", "mask_strategy"),
608
  last_key="probe:last_mask_strategy",
 
609
  help_text="Which extracted activation set to load.",
610
  )
611
- store = _select_store(source, mask_strategy)
 
 
 
 
 
 
612
  variant = _select_variant(store, mask_strategy)
613
  if variant is None:
614
  return
@@ -644,13 +510,19 @@ def render_probing_tab() -> None:
644
  min_class_count = _MIN_CLASS_COUNT
645
  seed = 0
646
 
647
- inputs = _SweepInputs(
648
- source=source, location=location, model_name=model_name,
649
- mask_value=mask_strategy.value, variant=variant,
650
- persona_ids=tuple(persona_ids), attributes=tuple(attributes), task=task,
 
 
 
 
 
651
  probe_kinds=tuple(probe_kinds),
652
  n_pca_components=n_pca_components,
653
- layers=tuple(layers), min_class_count=min_class_count,
 
654
  seed=int(seed),
655
  )
656
 
@@ -659,7 +531,7 @@ def render_probing_tab() -> None:
659
  if run:
660
  with st.spinner("Evaluating probes across layers..."):
661
  try:
662
- sweep, per_attr = _cached_sweep(inputs)
663
  except Exception as exc:
664
  st.error(f"Sweep failed: {exc}")
665
  st.session_state.pop(state_key, None)
@@ -674,6 +546,9 @@ def render_probing_tab() -> None:
674
  else:
675
  sweep, per_attr, result_inputs = saved_result
676
  _show_sweep(
677
- sweep, per_attr, result_inputs.attributes,
678
- result_inputs.task, result_inputs,
 
 
 
679
  )
 
11
 
12
  from __future__ import annotations
13
 
 
 
 
14
  import streamlit as st
 
15
  from persona_vectors.analysis import LayeredSamples
16
  from persona_vectors.attributes import attribute_display_label
17
  from persona_vectors.extraction import MaskStrategy
18
  from persona_vectors.plots import plot_metric_comparison, plot_metric_over_layers
19
  from persona_vectors.probes import (
20
  AttributeLabels,
 
21
  default_probe_kinds,
 
22
  infer_probe_task,
23
  layer_matrix,
24
  save_probe_artifact,
25
  shuffle_label_baseline,
 
26
  )
27
 
28
+ from tabs.probe_sweep import SweepInputs, cached_sweep
29
  from utils.analysis_metadata import (
30
  synth_persona_attribute_names,
31
  synth_persona_dataset_cached,
32
  )
33
  from utils.analysis_sources import (
 
 
 
 
 
34
  Store,
 
35
  available_variants,
 
 
 
36
  persona_names_cached,
37
  personas_cached,
38
  store_cache_parts,
 
40
  )
41
  from utils.controls import render_mask_strategy_select
42
  from utils.helpers import widget_key
43
+ from utils.source_controls import render_source_select, render_store_select
44
 
45
  # ---------------------------------------------------------------------------
46
  # Constants and config
 
64
  }
65
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def _select_variant(store: Store, mask_strategy: MaskStrategy) -> str | None:
68
  variants = available_variants(store, mask_strategy)
69
  if not variants:
 
82
  store: Store, variant: str, mask_strategy: MaskStrategy
83
  ) -> list[str]:
84
  source, location, model_name = store_cache_parts(store)
85
+ all_ids = personas_cached(
86
+ source, location, model_name, mask_strategy.value, (variant,)
87
+ )
88
  if not all_ids:
89
  st.info("No personas found for this variant.")
90
  return []
 
125
  st.session_state["probe:persona_count"] = count
126
  persona_ids = regular[:count]
127
  persona_names_cached(
128
+ source,
129
+ location,
130
+ model_name,
131
+ mask_strategy.value,
132
+ (variant,),
133
+ tuple(persona_ids),
134
  )
135
  st.caption(f"Probing {len(persona_ids)} of {len(regular)} non-assistant personas.")
136
  return persona_ids
 
228
  )
229
  if not fast:
230
  return list(range(num_layers))
231
+ return sorted(
232
+ {
233
+ 0,
234
+ num_layers // 4,
235
+ num_layers // 2,
236
+ (3 * num_layers) // 4,
237
+ num_layers - 1,
238
+ }
239
+ )
240
 
241
 
242
  # ---------------------------------------------------------------------------
 
244
  # ---------------------------------------------------------------------------
245
 
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  def _show_sweep(
248
  rows_by_label: dict[str, list[dict[str, object]]],
249
  per_attr: dict[str, tuple[AttributeLabels, LayeredSamples]],
250
  attributes: tuple[str, ...],
251
  task: str,
252
+ inputs: SweepInputs,
253
  ) -> None:
254
  primary = _PRIMARY_METRIC[task]
255
  secondary = _SECONDARY_METRIC.get(task)
 
295
  for label, label_rows in rows_by_label.items():
296
  for attribute in attributes:
297
  attr_rows = [
298
+ row for row in label_rows if row.get("attribute") == attribute
 
299
  ]
300
  label_best = _best_row(attr_rows)
301
  if label_best is None:
 
303
  summary_row: dict[str, object] = {}
304
  if multi_attr:
305
  summary_row["attribute"] = attribute
306
+ summary_row.update(
307
+ {
308
+ "features": label,
309
+ "best_layer": label_best["layer"],
310
+ "probe": label_best["probe_kind"],
311
+ primary: round(float(label_best[primary]), 3),
312
+ f"baseline_{primary}": round(
313
+ float(label_best.get(f"baseline_{primary}", float("nan"))),
314
+ 3,
315
+ ),
316
+ }
317
+ )
318
  summary_rows.append(summary_row)
319
  if summary_rows:
320
  st.dataframe(summary_rows, width="stretch", hide_index=True)
321
 
322
+ feature_desc = f" · pca{inputs.n_pca_components}" if inputs.n_pca_components else ""
 
 
323
 
324
  best_attr = str(best["attribute"])
325
  labels, samples = per_attr[best_attr]
 
348
  labels: AttributeLabels,
349
  samples: LayeredSamples,
350
  task: str,
351
+ inputs: SweepInputs,
352
  ) -> None:
353
  if task == "numeric":
354
  return # selectivity control is classification-only
 
360
  "dataset artifacts, not the property."
361
  )
362
  n_repeats = st.slider(
363
+ "Shuffle repeats",
364
+ min_value=3,
365
+ max_value=15,
366
+ value=5,
367
  key="probe:shuffle_repeats",
368
  )
369
  if st.button("Run selectivity control", key="probe:run_shuffle"):
370
  with st.spinner("Running shuffled-label control..."):
371
  X = layer_matrix(samples, int(best["layer"]))
372
  shuffled = shuffle_label_baseline(
373
+ X,
374
+ labels.y,
375
  task=task, # type: ignore[arg-type]
376
  layer=int(best["layer"]),
377
  probe_kind=best["probe_kind"], # type: ignore[arg-type]
 
396
  labels: AttributeLabels,
397
  samples: LayeredSamples,
398
  task: str,
399
+ inputs: SweepInputs,
400
  ) -> None:
401
  def synced_default(key: str, default: str) -> str:
402
  default_key = f"{key}:default"
 
432
  if st.button("Save", key="probe:save_artifact"):
433
  X = layer_matrix(samples, int(best["layer"]))
434
  directory = save_probe_artifact(
435
+ X=X,
436
+ y=labels.y,
437
+ labels=labels,
438
  task=task, # type: ignore[arg-type]
439
  probe_kind=best["probe_kind"], # type: ignore[arg-type]
440
  n_pca_components=inputs.n_pca_components,
 
460
  def render_probing_tab() -> None:
461
  st.title("Probing")
462
 
463
+ source = render_source_select(widget_scope="probe")
464
  with st.expander("Source", expanded=True):
465
  mask_strategy = render_mask_strategy_select(
466
  key=widget_key("probe", "mask_strategy"),
467
  last_key="probe:last_mask_strategy",
468
+ remember_key="source:last_mask_strategy",
469
  help_text="Which extracted activation set to load.",
470
  )
471
+ store = render_store_select(
472
+ source,
473
+ mask_strategy,
474
+ state_prefix="probe",
475
+ widget_scope="probe",
476
+ artifacts_root_key="probe:local_root",
477
+ )
478
  variant = _select_variant(store, mask_strategy)
479
  if variant is None:
480
  return
 
510
  min_class_count = _MIN_CLASS_COUNT
511
  seed = 0
512
 
513
+ inputs = SweepInputs(
514
+ source=source,
515
+ location=location,
516
+ model_name=model_name,
517
+ mask_value=mask_strategy.value,
518
+ variant=variant,
519
+ persona_ids=tuple(persona_ids),
520
+ attributes=tuple(attributes),
521
+ task=task,
522
  probe_kinds=tuple(probe_kinds),
523
  n_pca_components=n_pca_components,
524
+ layers=tuple(layers),
525
+ min_class_count=min_class_count,
526
  seed=int(seed),
527
  )
528
 
 
531
  if run:
532
  with st.spinner("Evaluating probes across layers..."):
533
  try:
534
+ sweep, per_attr = cached_sweep(inputs)
535
  except Exception as exc:
536
  st.error(f"Sweep failed: {exc}")
537
  st.session_state.pop(state_key, None)
 
546
  else:
547
  sweep, per_attr, result_inputs = saved_result
548
  _show_sweep(
549
+ sweep,
550
+ per_attr,
551
+ result_inputs.attributes,
552
+ result_inputs.task,
553
+ result_inputs,
554
  )
tabs/probe_sweep.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import streamlit as st
6
+ from persona_vectors.analysis import LayeredSamples
7
+ from persona_vectors.probes import (
8
+ AttributeLabels,
9
+ attribute_probe_labels,
10
+ filter_attribute_samples_min_count,
11
+ sweep_attribute,
12
+ )
13
+
14
+ from utils.analysis_metadata import synth_persona_dataset_cached
15
+ from utils.analysis_sources import load_persona_vectors_cached
16
+ from utils.helpers import env_int
17
+
18
+ _SWEEP_CACHE_ENTRIES = env_int("PERSONA_UI_PROBE_SWEEP_CACHE_ENTRIES", 4)
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class SweepInputs:
23
+ source: str
24
+ location: str
25
+ model_name: str
26
+ mask_value: str
27
+ variant: str
28
+ persona_ids: tuple[str, ...]
29
+ attributes: tuple[str, ...]
30
+ task: str
31
+ probe_kinds: tuple[str, ...]
32
+ n_pca_components: int | None
33
+ layers: tuple[int, ...]
34
+ min_class_count: int
35
+ seed: int
36
+
37
+
38
+ @st.cache_resource(show_spinner=False, max_entries=_SWEEP_CACHE_ENTRIES)
39
+ def cached_sweep(
40
+ inputs: SweepInputs,
41
+ ) -> tuple[
42
+ dict[str, list[dict[str, object]]],
43
+ dict[str, tuple[AttributeLabels, LayeredSamples]],
44
+ ]:
45
+ samples = load_persona_vectors_cached(
46
+ inputs.source,
47
+ inputs.location,
48
+ inputs.model_name,
49
+ inputs.mask_value,
50
+ inputs.variant,
51
+ inputs.persona_ids,
52
+ )
53
+ dataset = synth_persona_dataset_cached()
54
+ per_attr: dict[str, tuple[AttributeLabels, LayeredSamples]] = {}
55
+
56
+ def labels_and_samples(attribute: str) -> tuple[AttributeLabels, LayeredSamples]:
57
+ if attribute not in per_attr:
58
+ labels = attribute_probe_labels(
59
+ dataset,
60
+ attribute,
61
+ list(inputs.persona_ids),
62
+ task=inputs.task, # type: ignore[arg-type]
63
+ )
64
+ probe_samples, labels = filter_attribute_samples_min_count(
65
+ samples,
66
+ labels,
67
+ min_count=inputs.min_class_count,
68
+ )
69
+ per_attr[attribute] = (labels, probe_samples)
70
+ return per_attr[attribute]
71
+
72
+ def sweep_one(attribute: str, n_pca: int | None) -> list[dict[str, object]]:
73
+ labels, probe_samples = labels_and_samples(attribute)
74
+ return sweep_attribute(
75
+ probe_samples,
76
+ labels,
77
+ layers=list(inputs.layers),
78
+ probe_kinds=list(inputs.probe_kinds), # type: ignore[arg-type]
79
+ n_pca_components=n_pca,
80
+ seed=inputs.seed,
81
+ )
82
+
83
+ def sweep_all(n_pca: int | None) -> list[dict[str, object]]:
84
+ rows: list[dict[str, object]] = []
85
+ for attribute in inputs.attributes:
86
+ rows.extend(sweep_one(attribute, n_pca))
87
+ return rows
88
+
89
+ rows_by_label = {"full": sweep_all(None)}
90
+ if inputs.n_pca_components is not None:
91
+ rows_by_label[f"pca{inputs.n_pca_components}"] = sweep_all(
92
+ inputs.n_pca_components
93
+ )
94
+ return rows_by_label, per_attr
tabs/probe_ui.py CHANGED
@@ -6,7 +6,15 @@ import streamlit as st
6
  import torch
7
 
8
  from utils.chat import build_chat_messages
9
- from utils.helpers import session_key, widget_key
 
 
 
 
 
 
 
 
10
  from utils.probe_overlay import (
11
  attach_overlays,
12
  build_classification_overlays,
@@ -15,24 +23,25 @@ from utils.probe_overlay import (
15
  )
16
  from utils.probe_trace import ConversationTrace, trace_conversation
17
  from utils.probes import (
18
- DEFAULT_LOCAL_PROBE_DIR,
19
- DEFAULT_PROBE_REPO,
20
  LoadedProbe,
21
- list_local_probe_files,
22
- list_probe_files,
23
  load_local_probe,
24
  load_probe,
25
  load_probe_from_bytes,
26
- model_probe_dir_name,
27
- parse_probe_filename,
28
  )
29
  from utils.runtime import cached_model
 
30
 
31
  _LAST_SOURCE_KEY = session_key("probe", "last_source")
32
  _LAST_LOCAL_FILE_KEY = session_key("probe", "last_local_file")
33
  _LAST_HUB_FILE_KEY = session_key("probe", "last_hub_file")
34
 
35
  _PROBE_SOURCES = ("Local artifact", "Hugging Face repo", "Upload .pt")
 
 
 
 
 
 
36
 
37
 
38
  # ---------------------------------------------------------------------------
@@ -62,23 +71,16 @@ def _default_file(files: list[str], remembered: str | None) -> str:
62
  return files[0]
63
 
64
 
65
- def _render_probe_selector(
66
- *, context_key: str, model_name: str
67
- ) -> LoadedProbe | None:
68
  """Inline source + file selector. Returns the loaded probe or None."""
69
- source_key = widget_key(context_key, "probe_source")
70
- if source_key not in st.session_state:
71
- st.session_state[source_key] = st.session_state.get(
72
- _LAST_SOURCE_KEY, _PROBE_SOURCES[0]
73
- )
74
- source = st.segmented_control(
75
  "Probe source",
76
  options=_PROBE_SOURCES,
77
- key=source_key,
 
 
78
  label_visibility="collapsed",
79
  )
80
- source = source or _PROBE_SOURCES[0]
81
- st.session_state[_LAST_SOURCE_KEY] = source
82
 
83
  if source == "Local artifact":
84
  return _render_local_probe(context_key=context_key, model_name=model_name)
@@ -87,9 +89,7 @@ def _render_probe_selector(
87
  return _render_upload_probe(context_key=context_key)
88
 
89
 
90
- def _render_local_probe(
91
- *, context_key: str, model_name: str
92
- ) -> LoadedProbe | None:
93
  root_dir = st.text_input(
94
  "Probe directory",
95
  value=st.session_state.get(
@@ -118,9 +118,7 @@ def _render_local_probe(
118
  return None
119
 
120
 
121
- def _render_hub_probe(
122
- *, context_key: str, model_name: str
123
- ) -> LoadedProbe | None:
124
  repo_id = st.text_input(
125
  "Probe repo",
126
  value=st.session_state.get(
@@ -249,15 +247,43 @@ def _validate(
249
  # ---------------------------------------------------------------------------
250
 
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  def _classification_predictions(
253
  probe: LoadedProbe, activations: torch.Tensor, cache_key: str
254
  ) -> tuple[torch.Tensor, torch.Tensor]:
255
  full_key = widget_key("probe_predictions", cache_key, str(id(probe)))
256
- cached = st.session_state.get(full_key)
257
  if cached is not None:
258
  return cached
259
  _, probs, predicted = probe.run_batch(activations)
260
- st.session_state[full_key] = (probs, predicted)
261
  return probs, predicted
262
 
263
 
@@ -265,11 +291,11 @@ def _regression_values(
265
  probe: LoadedProbe, activations: torch.Tensor, cache_key: str
266
  ) -> torch.Tensor:
267
  full_key = widget_key("probe_values", cache_key, str(id(probe)))
268
- cached = st.session_state.get(full_key)
269
  if cached is not None:
270
  return cached
271
  values = probe.predict_batch(activations)
272
- st.session_state[full_key] = values
273
  return values
274
 
275
 
@@ -297,9 +323,7 @@ def _apply_overlays(
297
  probs, predicted = _classification_predictions(
298
  probe, trace.activations, trace.cache_key
299
  )
300
- binary = probs.shape[1] == 1 or (
301
- probs.shape[1] == 2 and len(probe.labels) == 2
302
- )
303
  overlays = build_classification_overlays(
304
  trace=trace,
305
  probs=probs,
@@ -332,9 +356,7 @@ def render_probe_inspector(
332
  def _conversation_sig() -> int:
333
  return hash(
334
  tuple(
335
- (m.get("role"), m.get("content"))
336
- for m in messages
337
- if m.get("content")
338
  )
339
  )
340
 
@@ -349,9 +371,7 @@ def render_probe_inspector(
349
  st.caption("Probe overlay shows up after the first assistant reply.")
350
  return
351
 
352
- probe = _render_probe_selector(
353
- context_key=context_key, model_name=model_name
354
- )
355
  if probe is None:
356
  _reset()
357
  return
 
6
  import torch
7
 
8
  from utils.chat import build_chat_messages
9
+ from utils.helpers import env_int, session_key, widget_key
10
+ from utils.probe_files import (
11
+ DEFAULT_LOCAL_PROBE_DIR,
12
+ DEFAULT_PROBE_REPO,
13
+ list_local_probe_files,
14
+ list_probe_files,
15
+ model_probe_dir_name,
16
+ parse_probe_filename,
17
+ )
18
  from utils.probe_overlay import (
19
  attach_overlays,
20
  build_classification_overlays,
 
23
  )
24
  from utils.probe_trace import ConversationTrace, trace_conversation
25
  from utils.probes import (
 
 
26
  LoadedProbe,
 
 
27
  load_local_probe,
28
  load_probe,
29
  load_probe_from_bytes,
 
 
30
  )
31
  from utils.runtime import cached_model
32
+ from utils.selection_controls import remembered_segmented_control
33
 
34
  _LAST_SOURCE_KEY = session_key("probe", "last_source")
35
  _LAST_LOCAL_FILE_KEY = session_key("probe", "last_local_file")
36
  _LAST_HUB_FILE_KEY = session_key("probe", "last_hub_file")
37
 
38
  _PROBE_SOURCES = ("Local artifact", "Hugging Face repo", "Upload .pt")
39
+ _DERIVED_CACHE_TRACKER_KEY = session_key("probe", "derived_cache_keys")
40
+ # Keep enough room for the three retained traces plus a few recently explored
41
+ # probes per trace. Derived outputs are much smaller than the trace activations
42
+ # themselves, so this avoids needless recomputation without reopening
43
+ # unbounded growth.
44
+ _DERIVED_CACHE_ENTRIES = env_int("PERSONA_UI_PROBE_DERIVED_CACHE_ENTRIES", 12)
45
 
46
 
47
  # ---------------------------------------------------------------------------
 
71
  return files[0]
72
 
73
 
74
+ def _render_probe_selector(*, context_key: str, model_name: str) -> LoadedProbe | None:
 
 
75
  """Inline source + file selector. Returns the loaded probe or None."""
76
+ source = remembered_segmented_control(
 
 
 
 
 
77
  "Probe source",
78
  options=_PROBE_SOURCES,
79
+ key=widget_key(context_key, "probe_source"),
80
+ remember_key=_LAST_SOURCE_KEY,
81
+ default=_PROBE_SOURCES[0],
82
  label_visibility="collapsed",
83
  )
 
 
84
 
85
  if source == "Local artifact":
86
  return _render_local_probe(context_key=context_key, model_name=model_name)
 
89
  return _render_upload_probe(context_key=context_key)
90
 
91
 
92
+ def _render_local_probe(*, context_key: str, model_name: str) -> LoadedProbe | None:
 
 
93
  root_dir = st.text_input(
94
  "Probe directory",
95
  value=st.session_state.get(
 
118
  return None
119
 
120
 
121
+ def _render_hub_probe(*, context_key: str, model_name: str) -> LoadedProbe | None:
 
 
122
  repo_id = st.text_input(
123
  "Probe repo",
124
  value=st.session_state.get(
 
247
  # ---------------------------------------------------------------------------
248
 
249
 
250
+ def _store_derived_cache(key: str, value: object) -> None:
251
+ """Store one derived probe result while keeping a small MRU window."""
252
+
253
+ tracked = st.session_state.setdefault(_DERIVED_CACHE_TRACKER_KEY, [])
254
+ if not isinstance(tracked, list):
255
+ tracked = []
256
+ tracked = [existing for existing in tracked if existing != key]
257
+ tracked.append(key)
258
+ while len(tracked) > _DERIVED_CACHE_ENTRIES:
259
+ st.session_state.pop(tracked.pop(0), None)
260
+ st.session_state[_DERIVED_CACHE_TRACKER_KEY] = tracked
261
+ st.session_state[key] = value
262
+
263
+
264
+ def _get_derived_cache(key: str) -> object | None:
265
+ """Return a derived probe result and refresh its MRU position."""
266
+
267
+ cached = st.session_state.get(key)
268
+ if cached is None:
269
+ return None
270
+ tracked = st.session_state.get(_DERIVED_CACHE_TRACKER_KEY)
271
+ if isinstance(tracked, list) and key in tracked:
272
+ tracked = [existing for existing in tracked if existing != key]
273
+ tracked.append(key)
274
+ st.session_state[_DERIVED_CACHE_TRACKER_KEY] = tracked
275
+ return cached
276
+
277
+
278
  def _classification_predictions(
279
  probe: LoadedProbe, activations: torch.Tensor, cache_key: str
280
  ) -> tuple[torch.Tensor, torch.Tensor]:
281
  full_key = widget_key("probe_predictions", cache_key, str(id(probe)))
282
+ cached = _get_derived_cache(full_key)
283
  if cached is not None:
284
  return cached
285
  _, probs, predicted = probe.run_batch(activations)
286
+ _store_derived_cache(full_key, (probs, predicted))
287
  return probs, predicted
288
 
289
 
 
291
  probe: LoadedProbe, activations: torch.Tensor, cache_key: str
292
  ) -> torch.Tensor:
293
  full_key = widget_key("probe_values", cache_key, str(id(probe)))
294
+ cached = _get_derived_cache(full_key)
295
  if cached is not None:
296
  return cached
297
  values = probe.predict_batch(activations)
298
+ _store_derived_cache(full_key, values)
299
  return values
300
 
301
 
 
323
  probs, predicted = _classification_predictions(
324
  probe, trace.activations, trace.cache_key
325
  )
326
+ binary = probs.shape[1] == 1 or (probs.shape[1] == 2 and len(probe.labels) == 2)
 
 
327
  overlays = build_classification_overlays(
328
  trace=trace,
329
  probs=probs,
 
356
  def _conversation_sig() -> int:
357
  return hash(
358
  tuple(
359
+ (m.get("role"), m.get("content")) for m in messages if m.get("content")
 
 
360
  )
361
  )
362
 
 
371
  st.caption("Probe overlay shows up after the first assistant reply.")
372
  return
373
 
374
+ probe = _render_probe_selector(context_key=context_key, model_name=model_name)
 
 
375
  if probe is None:
376
  _reset()
377
  return
tests/test_datasets.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from utils import datasets
4
+
5
+
6
+ class _Progress:
7
+ def __init__(self) -> None:
8
+ self.updates: list[tuple[float, str | None]] = []
9
+
10
+ def progress(self, value: float, *, text: str | None = None) -> None:
11
+ self.updates.append((value, text))
12
+
13
+
14
+ def test_download_missing_startup_files_only_fetches_uncached_files(monkeypatch):
15
+ warnings: list[str] = []
16
+ progress = _Progress()
17
+ downloads: list[tuple[str, str, str]] = []
18
+
19
+ monkeypatch.setattr(
20
+ datasets,
21
+ "_is_cached",
22
+ lambda _repo, filename: filename == "already.jsonl",
23
+ )
24
+ monkeypatch.setattr(datasets.st, "warning", warnings.append)
25
+ monkeypatch.setattr(
26
+ datasets.st,
27
+ "progress",
28
+ lambda value, *, text=None: progress,
29
+ )
30
+ monkeypatch.setattr(
31
+ datasets,
32
+ "hf_hub_download",
33
+ lambda repo, filename, *, repo_type: downloads.append(
34
+ (repo, filename, repo_type)
35
+ ),
36
+ )
37
+
38
+ datasets._download_missing_startup_files_if_needed(
39
+ "org/repo",
40
+ ("already.jsonl", "missing.jsonl"),
41
+ "Example",
42
+ )
43
+
44
+ assert warnings and "First-time setup for Example" in warnings[0]
45
+ assert downloads == [("org/repo", "missing.jsonl", "dataset")]
46
+ assert progress.updates[-1] == (1.0, "Downloaded missing.jsonl (1/1)")
47
+
48
+
49
+ def test_download_missing_startup_files_stays_quiet_when_cached(monkeypatch):
50
+ monkeypatch.setattr(datasets, "_is_cached", lambda *_args: True)
51
+
52
+ def unexpected(*_args, **_kwargs):
53
+ raise AssertionError("cold-download UI should not render for warm cache")
54
+
55
+ monkeypatch.setattr(datasets.st, "warning", unexpected)
56
+ monkeypatch.setattr(datasets.st, "progress", unexpected)
57
+ monkeypatch.setattr(datasets, "hf_hub_download", unexpected)
58
+
59
+ datasets._download_missing_startup_files_if_needed(
60
+ "org/repo",
61
+ ("cached.jsonl",),
62
+ "Example",
63
+ )
64
+
65
+
66
+ def test_prepare_nemotron_prefetches_first_parquet_shard(monkeypatch):
67
+ calls: list[tuple[str, tuple[str, ...], str]] = []
68
+ monkeypatch.setattr(
69
+ datasets,
70
+ "list_repo_files",
71
+ lambda *_args, **_kwargs: (
72
+ "README.md",
73
+ "data/train-00001-of-00002.parquet",
74
+ "data/train-00000-of-00002.parquet",
75
+ ),
76
+ )
77
+ monkeypatch.setattr(
78
+ datasets,
79
+ "_download_missing_startup_files_if_needed",
80
+ lambda repo, filenames, label: calls.append((repo, filenames, label)),
81
+ )
82
+
83
+ datasets._prepare_nemotron_startup_download(
84
+ datasets.DatasetSource.NEMOTRON_USA.value,
85
+ "Nemotron USA",
86
+ )
87
+
88
+ assert calls == [
89
+ (
90
+ "nvidia/Nemotron-Personas-USA",
91
+ ("data/train-00000-of-00002.parquet",),
92
+ "Nemotron USA",
93
+ )
94
+ ]
95
+
96
+
97
+ def test_warm_qa_makes_synth_qa_download_visible_before_thread(monkeypatch):
98
+ calls: list[tuple[str, tuple[str, ...], str]] = []
99
+ started: list[bool] = []
100
+
101
+ class DummySynth:
102
+ def prefetch_qa(self) -> None:
103
+ pass
104
+
105
+ class DummyThread:
106
+ def __init__(self, *args, **kwargs) -> None:
107
+ pass
108
+
109
+ def start(self) -> None:
110
+ started.append(True)
111
+
112
+ monkeypatch.setattr(datasets, "SynthPersonaDataset", DummySynth)
113
+ monkeypatch.setattr(
114
+ datasets,
115
+ "_download_missing_startup_files_if_needed",
116
+ lambda repo, filenames, label: calls.append((repo, filenames, label)),
117
+ )
118
+ monkeypatch.setattr(datasets.threading, "Thread", DummyThread)
119
+
120
+ datasets.warm_qa_in_background(DummySynth())
121
+
122
+ assert calls == [
123
+ (
124
+ "implicit-personalization/synth-persona",
125
+ ("dataset_qa.jsonl",),
126
+ "SynthPersona QA",
127
+ )
128
+ ]
129
+ assert started == [True]
tests/test_probe_cache_bounds.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+
5
+ from tabs import probe_ui
6
+ from utils import probe_trace
7
+
8
+
9
+ def test_store_derived_cache_evicts_oldest(monkeypatch):
10
+ session_state: dict[str, object] = {}
11
+ monkeypatch.setattr(probe_ui.st, "session_state", session_state)
12
+ monkeypatch.setattr(probe_ui, "_DERIVED_CACHE_ENTRIES", 2)
13
+
14
+ probe_ui._store_derived_cache("k1", 1)
15
+ probe_ui._store_derived_cache("k2", 2)
16
+ probe_ui._store_derived_cache("k3", 3)
17
+
18
+ assert "k1" not in session_state
19
+ assert session_state["k2"] == 2
20
+ assert session_state["k3"] == 3
21
+ assert session_state[probe_ui._DERIVED_CACHE_TRACKER_KEY] == ["k2", "k3"]
22
+
23
+
24
+ def test_get_derived_cache_refreshes_recently_used_entry(monkeypatch):
25
+ session_state: dict[str, object] = {}
26
+ monkeypatch.setattr(probe_ui.st, "session_state", session_state)
27
+ monkeypatch.setattr(probe_ui, "_DERIVED_CACHE_ENTRIES", 2)
28
+
29
+ probe_ui._store_derived_cache("k1", 1)
30
+ probe_ui._store_derived_cache("k2", 2)
31
+
32
+ assert probe_ui._get_derived_cache("k1") == 1
33
+ probe_ui._store_derived_cache("k3", 3)
34
+
35
+ assert "k1" in session_state
36
+ assert "k2" not in session_state
37
+ assert session_state[probe_ui._DERIVED_CACHE_TRACKER_KEY] == ["k1", "k3"]
38
+
39
+
40
+ def test_trace_eviction_drops_derived_results(monkeypatch):
41
+ session_state: dict[str, object] = {}
42
+ monkeypatch.setattr(probe_trace.st, "session_state", session_state)
43
+ monkeypatch.setattr(probe_trace, "_MAX_CACHED_TRACES", 1)
44
+
45
+ trace = probe_trace.ConversationTrace(
46
+ cache_key="old",
47
+ model_name="m",
48
+ remote=False,
49
+ prompt_text="p",
50
+ prompt_hash="h",
51
+ layer=0,
52
+ location="post_reasoning",
53
+ input_ids=torch.tensor([1]),
54
+ activations=torch.zeros((1, 1)),
55
+ tokens=["x"],
56
+ assistant_spans=[],
57
+ is_special=torch.tensor([False]),
58
+ )
59
+ old_prediction_key = "probe_predictions::old::probe"
60
+ kept_prediction_key = "probe_predictions::new::probe"
61
+ session_state[probe_trace._DERIVED_CACHE_TRACKER_KEY] = [
62
+ old_prediction_key,
63
+ kept_prediction_key,
64
+ ]
65
+ session_state[old_prediction_key] = object()
66
+ session_state[kept_prediction_key] = object()
67
+
68
+ probe_trace._store_cached_trace("old", trace)
69
+ probe_trace._store_cached_trace(
70
+ "new",
71
+ probe_trace.ConversationTrace(
72
+ **{**trace.__dict__, "cache_key": "new"},
73
+ ),
74
+ )
75
+
76
+ assert old_prediction_key not in session_state
77
+ assert kept_prediction_key in session_state
78
+ assert session_state[probe_trace._DERIVED_CACHE_TRACKER_KEY] == [
79
+ kept_prediction_key
80
+ ]
tests/test_probe_sweep.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from types import SimpleNamespace
4
+
5
+ import torch
6
+ from persona_vectors.analysis import LayeredSamples
7
+ from persona_vectors.probes import AttributeLabels
8
+
9
+ from tabs import probe_sweep
10
+
11
+
12
+ def test_cached_sweep_keeps_per_attribute_samples_and_full_plus_pca(monkeypatch):
13
+ samples = LayeredSamples(
14
+ vectors=torch.zeros((3, 2, 4)),
15
+ labels=["p0", "p1", "p2"],
16
+ hover_text=["p0", "p1", "p2"],
17
+ )
18
+ sweep_calls: list[tuple[str, int | None]] = []
19
+
20
+ monkeypatch.setattr(
21
+ probe_sweep,
22
+ "load_persona_vectors_cached",
23
+ lambda *args: samples,
24
+ )
25
+ monkeypatch.setattr(
26
+ probe_sweep,
27
+ "synth_persona_dataset_cached",
28
+ lambda: SimpleNamespace(),
29
+ )
30
+
31
+ def labels_for(_dataset, attribute, _persona_ids, *, task):
32
+ return AttributeLabels(
33
+ attribute_name=attribute,
34
+ task=task,
35
+ y=torch.tensor([0, 1, 0]).numpy(),
36
+ labels=["a", "b", "a"],
37
+ class_names=["a", "b"],
38
+ )
39
+
40
+ monkeypatch.setattr(probe_sweep, "attribute_probe_labels", labels_for)
41
+
42
+ def filtered(input_samples, labels, *, min_count):
43
+ assert min_count == 2
44
+ return input_samples, labels
45
+
46
+ monkeypatch.setattr(
47
+ probe_sweep,
48
+ "filter_attribute_samples_min_count",
49
+ filtered,
50
+ )
51
+
52
+ def sweep(input_samples, labels, *, layers, probe_kinds, n_pca_components, seed):
53
+ assert input_samples is samples
54
+ assert layers == [0, 1]
55
+ assert probe_kinds == ["logistic_regression"]
56
+ assert seed == 0
57
+ sweep_calls.append((labels.attribute_name, n_pca_components))
58
+ return [
59
+ {
60
+ "attribute": labels.attribute_name,
61
+ "layer": 0,
62
+ "probe_kind": probe_kinds[0],
63
+ "balanced_accuracy": 0.5,
64
+ }
65
+ ]
66
+
67
+ monkeypatch.setattr(probe_sweep, "sweep_attribute", sweep)
68
+
69
+ inputs = probe_sweep.SweepInputs(
70
+ source="src",
71
+ location="loc",
72
+ model_name="model",
73
+ mask_value="answer_mean",
74
+ variant="templated",
75
+ persona_ids=("p0", "p1", "p2"),
76
+ attributes=("sex", "gender"),
77
+ task="binary",
78
+ probe_kinds=("logistic_regression",),
79
+ n_pca_components=2,
80
+ layers=(0, 1),
81
+ min_class_count=2,
82
+ seed=0,
83
+ )
84
+
85
+ rows_by_label, per_attr = probe_sweep.cached_sweep.__wrapped__(inputs)
86
+
87
+ assert list(rows_by_label) == ["full", "pca2"]
88
+ assert [row["attribute"] for row in rows_by_label["full"]] == ["sex", "gender"]
89
+ assert set(per_attr) == {"sex", "gender"}
90
+ assert sweep_calls == [
91
+ ("sex", None),
92
+ ("gender", None),
93
+ ("sex", 2),
94
+ ("gender", 2),
95
+ ]
tests/test_probes.py CHANGED
@@ -11,17 +11,16 @@ two correctness fixes:
11
 
12
  import pytest
13
  import torch
14
-
15
  from persona_vectors.probes import ProbeArtifact
 
 
16
  from utils.probes import (
17
  LoadedProbe,
18
  _LinearProbe,
19
  _loaded_probe_from_artifact,
20
  _normalize_labels,
21
- parse_probe_filename,
22
  )
23
 
24
-
25
  # --------------------------------------------------------------------------- #
26
  # parse_probe_filename
27
  # --------------------------------------------------------------------------- #
@@ -123,9 +122,7 @@ def test_normalize_batch_pca_only_applies_pca():
123
  probe = _probe(
124
  2,
125
  pca_mean=torch.ones(3),
126
- pca_components=torch.tensor(
127
- [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]
128
- ),
129
  )
130
  batch = torch.tensor([[2.0, 4.0, 9.0]])
131
  out = probe._normalize_batch(batch)
 
11
 
12
  import pytest
13
  import torch
 
14
  from persona_vectors.probes import ProbeArtifact
15
+
16
+ from utils.probe_files import parse_probe_filename
17
  from utils.probes import (
18
  LoadedProbe,
19
  _LinearProbe,
20
  _loaded_probe_from_artifact,
21
  _normalize_labels,
 
22
  )
23
 
 
24
  # --------------------------------------------------------------------------- #
25
  # parse_probe_filename
26
  # --------------------------------------------------------------------------- #
 
122
  probe = _probe(
123
  2,
124
  pca_mean=torch.ones(3),
125
+ pca_components=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]),
 
 
126
  )
127
  batch = torch.tensor([[2.0, 4.0, 9.0]])
128
  out = probe._normalize_batch(batch)
tests/test_state.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from state import chat_session_key
2
+
3
+
4
+ def test_chat_session_key_is_stable_across_model_switches() -> None:
5
+ dataset = "HuggingFace: synth-persona"
6
+
7
+ assert chat_session_key("google/gemma-2-2b-it", dataset) == chat_session_key(
8
+ "google/gemma-2-9b-it",
9
+ dataset,
10
+ )
11
+
12
+
13
+ def test_chat_session_key_still_separates_datasets() -> None:
14
+ model = "google/gemma-2-2b-it"
15
+
16
+ assert chat_session_key(model, "dataset-a") != chat_session_key(model, "dataset-b")
utils/analysis_sources.py CHANGED
@@ -7,8 +7,8 @@ from persona_vectors.analysis import (
7
  load_analysis_dataset,
8
  )
9
  from persona_vectors.artifacts import (
10
- PersonaVectorStore,
11
  HFPersonaVectorStore,
 
12
  discover_activation_models,
13
  model_dir_name,
14
  )
 
7
  load_analysis_dataset,
8
  )
9
  from persona_vectors.artifacts import (
 
10
  HFPersonaVectorStore,
11
+ PersonaVectorStore,
12
  discover_activation_models,
13
  model_dir_name,
14
  )
utils/chat.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
  import logging
 
4
  from contextlib import contextmanager, nullcontext
5
  from dataclasses import dataclass
6
  from typing import TYPE_CHECKING, Any, Literal
@@ -185,6 +186,7 @@ def generate_chat_reply(
185
  top_k: int = 50,
186
  repetition_penalty: float = 1.0,
187
  seed: int | None = None,
 
188
  ) -> ChatReply:
189
  """Generate one assistant reply from a full chat history.
190
 
@@ -228,9 +230,16 @@ def generate_chat_reply(
228
  generation_kwargs["repetition_penalty"] = repetition_penalty
229
  # `remote` is captured by nnsight's RemoteableMixin.trace() and is NOT
230
  # forwarded to the underlying model's generate
 
 
231
  with (
232
  _seeded_rng(seed if do_sample and not remote else None),
233
- model.generate(prompt, remote=remote, **generation_kwargs) as tracer,
 
 
 
 
 
234
  ):
235
  generated = tracer.result.save()
236
 
@@ -247,3 +256,34 @@ def generate_chat_reply(
247
  text=text,
248
  generated_ids=generated_ids.detach().cpu(),
249
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import logging
4
+ from collections.abc import Callable
5
  from contextlib import contextmanager, nullcontext
6
  from dataclasses import dataclass
7
  from typing import TYPE_CHECKING, Any, Literal
 
186
  top_k: int = 50,
187
  repetition_penalty: float = 1.0,
188
  seed: int | None = None,
189
+ on_status: Callable[[str, str, str], None] | None = None,
190
  ) -> ChatReply:
191
  """Generate one assistant reply from a full chat history.
192
 
 
230
  generation_kwargs["repetition_penalty"] = repetition_penalty
231
  # `remote` is captured by nnsight's RemoteableMixin.trace() and is NOT
232
  # forwarded to the underlying model's generate
233
+ backend = _build_remote_backend(model, on_status) if remote else None
234
+
235
  with (
236
  _seeded_rng(seed if do_sample and not remote else None),
237
+ model.generate(
238
+ prompt,
239
+ remote=remote,
240
+ backend=backend,
241
+ **generation_kwargs,
242
+ ) as tracer,
243
  ):
244
  generated = tracer.result.save()
245
 
 
256
  text=text,
257
  generated_ids=generated_ids.detach().cpu(),
258
  )
259
+
260
+
261
+ def _build_remote_backend(
262
+ model: StandardizedTransformer,
263
+ on_status: Callable[[str, str, str], None] | None,
264
+ ):
265
+ """Build an NDIF backend that can surface lifecycle updates to callers."""
266
+
267
+ if on_status is None:
268
+ return None
269
+
270
+ from nnsight.intervention.backends.remote import JobStatusDisplay, RemoteBackend
271
+
272
+ class _CallbackJobStatusDisplay(JobStatusDisplay):
273
+ def update(
274
+ self,
275
+ job_id: str = "",
276
+ status_name: str = "",
277
+ description: str = "",
278
+ ):
279
+ super().update(job_id, status_name, description)
280
+ if status_name:
281
+ on_status(job_id, status_name, description)
282
+
283
+ backend = RemoteBackend(model.to_model_key())
284
+ backend.CONNECT_TIMEOUT = 300.0
285
+ backend.status_display = _CallbackJobStatusDisplay(
286
+ enabled=True,
287
+ verbose=backend.verbose,
288
+ )
289
+ return backend
utils/contrast.py CHANGED
@@ -247,9 +247,7 @@ def render_contrast_html(result: TokenContrast) -> str:
247
  # those render as blank lines before the first word. Drop leading
248
  # whitespace-only tokens (and left-trim the first visible one) so the
249
  # contrast starts at real content. Display-only — weights stay aligned.
250
- items = list(
251
- zip(result.tokens, result.weights, result.raw_diffs, strict=True)
252
- )
253
  start = 0
254
  while start < len(items) and not items[start][0].strip():
255
  start += 1
 
247
  # those render as blank lines before the first word. Drop leading
248
  # whitespace-only tokens (and left-trim the first visible one) so the
249
  # contrast starts at real content. Display-only — weights stay aligned.
250
+ items = list(zip(result.tokens, result.weights, result.raw_diffs, strict=True))
 
 
251
  start = 0
252
  while start < len(items) and not items[start][0].strip():
253
  start += 1
utils/controls.py CHANGED
@@ -7,8 +7,12 @@ def render_mask_strategy_select(
7
  key: str,
8
  last_key: str,
9
  help_text: str,
 
10
  ) -> MaskStrategy:
11
- last_strategy = st.session_state.get(last_key, MaskStrategy.ANSWER_MEAN.value)
 
 
 
12
  strategies = list(MaskStrategy)
13
  selected = st.selectbox(
14
  "Mask strategy",
@@ -26,4 +30,6 @@ def render_mask_strategy_select(
26
  help=help_text,
27
  )
28
  st.session_state[last_key] = selected.value
 
 
29
  return selected
 
7
  key: str,
8
  last_key: str,
9
  help_text: str,
10
+ remember_key: str | None = None,
11
  ) -> MaskStrategy:
12
+ last_strategy = st.session_state.get(
13
+ remember_key,
14
+ st.session_state.get(last_key, MaskStrategy.ANSWER_MEAN.value),
15
+ )
16
  strategies = list(MaskStrategy)
17
  selected = st.selectbox(
18
  "Mask strategy",
 
30
  help=help_text,
31
  )
32
  st.session_state[last_key] = selected.value
33
+ if remember_key is not None:
34
+ st.session_state[remember_key] = selected.value
35
  return selected
utils/datasets.py CHANGED
@@ -7,6 +7,7 @@ from tempfile import mkdtemp
7
  from typing import Any
8
 
9
  import streamlit as st
 
10
  from persona_data.nemotron_personas import (
11
  NemotronPersonasFranceDataset,
12
  NemotronPersonasUSADataset,
@@ -16,6 +17,17 @@ from persona_data.synth_persona import SynthPersonaDataset
16
 
17
  from .helpers import DatasetSource
18
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  @st.cache_resource(show_spinner=False)
21
  def _cached_dataset(cls: type) -> Any:
@@ -39,13 +51,19 @@ def warm_qa_in_background(dataset: Any) -> None:
39
  warm = getattr(dataset, "prefetch_qa", None)
40
  if warm is None:
41
  return # persona-only dataset (e.g. Nemotron) has no QA
 
 
 
 
 
 
 
 
42
  with _qa_warm_lock:
43
  if getattr(dataset, "_qa_warm_started", False):
44
  return
45
  dataset._qa_warm_started = True
46
- threading.Thread(
47
- target=warm, name="persona-ui-warm-qa", daemon=True
48
- ).start()
49
 
50
 
51
  @st.cache_resource(show_spinner=False)
@@ -118,12 +136,19 @@ def load_dataset(
118
  """Load the selected dataset source for the UI."""
119
 
120
  if dataset_source == DatasetSource.SYNTH_PERSONA.value:
 
 
 
 
 
121
  return _cached_dataset(SynthPersonaDataset), "SynthPersona"
122
 
123
  if dataset_source == DatasetSource.NEMOTRON_FRANCE.value:
 
124
  return _cached_dataset(NemotronPersonasFranceDataset), "Nemotron France"
125
 
126
  if dataset_source == DatasetSource.NEMOTRON_USA.value:
 
127
  return _cached_dataset(NemotronPersonasUSADataset), "Nemotron USA"
128
 
129
  if personas_file is None or qa_file is None:
@@ -132,3 +157,60 @@ def load_dataset(
132
  personas_path = _uploaded_file_to_temp_path(personas_file, stem="personas")
133
  qa_path = _uploaded_file_to_temp_path(qa_file, stem="qa")
134
  return _cached_local_dataset(str(personas_path), str(qa_path)), "Local upload"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from typing import Any
8
 
9
  import streamlit as st
10
+ from huggingface_hub import hf_hub_download, list_repo_files, try_to_load_from_cache
11
  from persona_data.nemotron_personas import (
12
  NemotronPersonasFranceDataset,
13
  NemotronPersonasUSADataset,
 
17
 
18
  from .helpers import DatasetSource
19
 
20
+ _SYNTH_PERSONA_REPO = "implicit-personalization/synth-persona"
21
+ _SYNTH_PERSONA_STARTUP_FILES = (
22
+ "implicit_shared_mc_bank.json",
23
+ "dataset_personas.jsonl",
24
+ )
25
+ _SYNTH_PERSONA_QA_FILE = "dataset_qa.jsonl"
26
+ _NEMOTRON_REPOS = {
27
+ DatasetSource.NEMOTRON_FRANCE.value: "nvidia/Nemotron-Personas-France",
28
+ DatasetSource.NEMOTRON_USA.value: "nvidia/Nemotron-Personas-USA",
29
+ }
30
+
31
 
32
  @st.cache_resource(show_spinner=False)
33
  def _cached_dataset(cls: type) -> Any:
 
51
  warm = getattr(dataset, "prefetch_qa", None)
52
  if warm is None:
53
  return # persona-only dataset (e.g. Nemotron) has no QA
54
+ if isinstance(dataset, SynthPersonaDataset):
55
+ # Extract will need QA soon. Make the one-time large transfer explicit,
56
+ # then leave the CPU-heavy parse on the existing background thread.
57
+ _download_missing_startup_files_if_needed(
58
+ _SYNTH_PERSONA_REPO,
59
+ (_SYNTH_PERSONA_QA_FILE,),
60
+ "SynthPersona QA",
61
+ )
62
  with _qa_warm_lock:
63
  if getattr(dataset, "_qa_warm_started", False):
64
  return
65
  dataset._qa_warm_started = True
66
+ threading.Thread(target=warm, name="persona-ui-warm-qa", daemon=True).start()
 
 
67
 
68
 
69
  @st.cache_resource(show_spinner=False)
 
136
  """Load the selected dataset source for the UI."""
137
 
138
  if dataset_source == DatasetSource.SYNTH_PERSONA.value:
139
+ _download_missing_startup_files_if_needed(
140
+ _SYNTH_PERSONA_REPO,
141
+ _SYNTH_PERSONA_STARTUP_FILES,
142
+ "SynthPersona",
143
+ )
144
  return _cached_dataset(SynthPersonaDataset), "SynthPersona"
145
 
146
  if dataset_source == DatasetSource.NEMOTRON_FRANCE.value:
147
+ _prepare_nemotron_startup_download(dataset_source, "Nemotron France")
148
  return _cached_dataset(NemotronPersonasFranceDataset), "Nemotron France"
149
 
150
  if dataset_source == DatasetSource.NEMOTRON_USA.value:
151
+ _prepare_nemotron_startup_download(dataset_source, "Nemotron USA")
152
  return _cached_dataset(NemotronPersonasUSADataset), "Nemotron USA"
153
 
154
  if personas_file is None or qa_file is None:
 
157
  personas_path = _uploaded_file_to_temp_path(personas_file, stem="personas")
158
  qa_path = _uploaded_file_to_temp_path(qa_file, stem="qa")
159
  return _cached_local_dataset(str(personas_path), str(qa_path)), "Local upload"
160
+
161
+
162
+ def _is_cached(repo_id: str, filename: str) -> bool:
163
+ """Return whether a Hub dataset file already exists in the local HF cache."""
164
+
165
+ cached = try_to_load_from_cache(repo_id, filename, repo_type="dataset")
166
+ return isinstance(cached, str)
167
+
168
+
169
+ def _download_missing_startup_files_if_needed(
170
+ repo_id: str,
171
+ filenames: tuple[str, ...],
172
+ label: str,
173
+ ) -> None:
174
+ """Make first-time Hub downloads visible before dataset construction blocks.
175
+
176
+ Hugging Face handles byte-level transfer internally. We expose file-level
177
+ progress here, which is the useful unit this UI can know in advance.
178
+ """
179
+
180
+ missing = tuple(
181
+ filename for filename in filenames if not _is_cached(repo_id, filename)
182
+ )
183
+ if not missing:
184
+ return
185
+
186
+ st.warning(
187
+ f"First-time setup for {label}: downloading dataset files from Hugging Face. "
188
+ "Later loads should use the local cache."
189
+ )
190
+ progress = st.progress(0.0, text=f"Preparing {label} download…")
191
+ total = len(missing)
192
+ for index, filename in enumerate(missing, start=1):
193
+ progress.progress(
194
+ (index - 1) / total,
195
+ text=f"Downloading {filename} ({index}/{total})",
196
+ )
197
+ hf_hub_download(repo_id, filename, repo_type="dataset")
198
+ progress.progress(
199
+ index / total,
200
+ text=f"Downloaded {filename} ({index}/{total})",
201
+ )
202
+
203
+
204
+ def _prepare_nemotron_startup_download(dataset_source: str, label: str) -> None:
205
+ """Prefetch the first parquet shard used by the default Nemotron sample."""
206
+
207
+ repo_id = _NEMOTRON_REPOS[dataset_source]
208
+ parquet_files = tuple(
209
+ sorted(
210
+ filename
211
+ for filename in list_repo_files(repo_id, repo_type="dataset")
212
+ if filename.startswith("data/train-") and filename.endswith(".parquet")
213
+ )
214
+ )
215
+ if parquet_files:
216
+ _download_missing_startup_files_if_needed(repo_id, (parquet_files[0],), label)
utils/helpers.py CHANGED
@@ -64,6 +64,26 @@ NDIF_STATUS_ICONS = {
64
  }
65
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def slugify(value: str) -> str:
68
  """Convert a string to a filesystem-safe slug."""
69
 
 
64
  }
65
 
66
 
67
+ def format_ndif_status(
68
+ job_id: str,
69
+ status_name: str,
70
+ description: str,
71
+ *,
72
+ prefix: str | None = None,
73
+ completed_detail: str | None = None,
74
+ ) -> str:
75
+ """Build the shared one-line NDIF status label used across the UI."""
76
+
77
+ icon = NDIF_STATUS_ICONS.get(status_name, "•")
78
+ detail = (
79
+ completed_detail
80
+ if completed_detail is not None and status_name == "COMPLETED"
81
+ else description
82
+ )
83
+ label = f"{icon} `{job_id}` **{status_name}** — {detail}"
84
+ return f"{prefix}: {label}" if prefix else label
85
+
86
+
87
  def slugify(value: str) -> str:
88
  """Convert a string to a filesystem-safe slug."""
89
 
utils/probe_files.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import re
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+
8
+ import streamlit as st
9
+
10
+ PROBE_FILENAME_RE = re.compile(
11
+ r"^cognitive_map_probe_layer(?P<layer>\d+)_(?P<model_type>[a-z0-9]+)_"
12
+ r"(?P<location>pre_reasoning|post_reasoning)_all_(?P<scope>general|size\d+)\.pt$"
13
+ )
14
+
15
+ PERSONA_PROBE_DIR_RE = re.compile(
16
+ r"^(?P<probe_kind>[a-z_]+?)(?:_pca(?P<pca>\d+))?_layer(?P<layer>\d+)$"
17
+ )
18
+
19
+ DEFAULT_PROBE_REPO = "project-telos/cognitive_map_probes"
20
+ DEFAULT_LOCAL_PROBE_DIR = os.environ.get("PERSONA_PROBES_DIR", "artifacts/probes")
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class ProbeFileMetadata:
25
+ filename: str
26
+ layer: int | None
27
+ model_type: str
28
+ location: str | None
29
+ scope: str | None
30
+ label: str
31
+ model_name: str | None = None
32
+ attribute_name: str | None = None
33
+
34
+
35
+ def model_probe_dir_name(model_name: str) -> str:
36
+ return model_name.replace("/", "__")
37
+
38
+
39
+ def parse_probe_filename(filename: str) -> ProbeFileMetadata:
40
+ path = Path(filename)
41
+ match = PROBE_FILENAME_RE.match(path.name)
42
+ if match:
43
+ layer = int(match.group("layer"))
44
+ model_type = match.group("model_type")
45
+ location = match.group("location")
46
+ scope = match.group("scope")
47
+ scope_label = scope.replace("size", "size ")
48
+ return ProbeFileMetadata(
49
+ filename=filename,
50
+ layer=layer,
51
+ model_type=model_type,
52
+ location=location,
53
+ scope=scope,
54
+ label=(
55
+ f"Layer {layer} - {model_type.upper()} - "
56
+ f"{location.replace('_', ' ')} - {scope_label}"
57
+ ),
58
+ )
59
+
60
+ parent_match = PERSONA_PROBE_DIR_RE.match(path.parent.name)
61
+ if parent_match and path.name in {"probe.json", "weights.safetensors"}:
62
+ layer = int(parent_match.group("layer"))
63
+ probe_kind = parent_match.group("probe_kind")
64
+ pca = parent_match.group("pca")
65
+ scope = f"pca{pca}" if pca else None
66
+ attribute = path.parent.parent.name or "attribute"
67
+ model_name = path.parts[0].replace("__", "/") if len(path.parts) >= 5 else None
68
+ label = f"{attribute} - layer {layer} - {probe_kind}"
69
+ if pca:
70
+ label += f" (pca{pca})"
71
+ return ProbeFileMetadata(
72
+ filename=filename,
73
+ layer=layer,
74
+ model_type=probe_kind,
75
+ location=None,
76
+ scope=scope,
77
+ label=label,
78
+ model_name=model_name,
79
+ attribute_name=attribute,
80
+ )
81
+
82
+ return ProbeFileMetadata(
83
+ filename=filename,
84
+ layer=None,
85
+ model_type="unknown",
86
+ location=None,
87
+ scope=None,
88
+ label=path.stem.replace("_", " "),
89
+ )
90
+
91
+
92
+ @st.cache_data(show_spinner=False, ttl=300)
93
+ def list_probe_files(repo_id: str) -> list[str]:
94
+ from huggingface_hub import list_repo_files
95
+
96
+ return _dedupe_probe_entries(list_repo_files(repo_id, repo_type="model"))
97
+
98
+
99
+ @st.cache_data(show_spinner=False, ttl=30)
100
+ def list_local_probe_files(root_dir: str) -> list[str]:
101
+ root = Path(root_dir).expanduser()
102
+ if not root.is_dir():
103
+ return []
104
+ files = _dedupe_probe_entries(
105
+ [
106
+ str(path.relative_to(root))
107
+ for path in root.rglob("*")
108
+ if path.is_file()
109
+ and path.name in {"probe.pt", "probe.json", "weights.safetensors"}
110
+ ]
111
+ )
112
+ return sorted(files, key=_probe_sort_key)
113
+
114
+
115
+ @st.cache_data(show_spinner=False, ttl=300)
116
+ def download_probe_file(repo_id: str, filename: str) -> str:
117
+ from huggingface_hub import hf_hub_download
118
+
119
+ return hf_hub_download(repo_id, filename, repo_type="model")
120
+
121
+
122
+ @st.cache_data(show_spinner=False, ttl=300)
123
+ def download_probe_json_and_weights(repo_id: str, filename: str) -> tuple[str, str]:
124
+ from huggingface_hub import hf_hub_download
125
+
126
+ metadata_path = hf_hub_download(repo_id, filename, repo_type="model")
127
+ weights_name = str(Path(filename).with_name("weights.safetensors"))
128
+ weights_path = hf_hub_download(repo_id, weights_name, repo_type="model")
129
+ return metadata_path, weights_path
130
+
131
+
132
+ def _probe_sort_key(filename: str) -> tuple[str, str, int, str]:
133
+ metadata = parse_probe_filename(filename)
134
+ return (
135
+ metadata.model_name or "",
136
+ metadata.attribute_name or "",
137
+ metadata.layer if metadata.layer is not None else 10**9,
138
+ filename,
139
+ )
140
+
141
+
142
+ def _dedupe_probe_entries(files: list[str]) -> list[str]:
143
+ by_dir: dict[str, set[str]] = {}
144
+ standalone: list[str] = []
145
+ for filename in files:
146
+ path = Path(filename)
147
+ if path.name in {"probe.pt", "probe.json", "weights.safetensors"}:
148
+ by_dir.setdefault(str(path.parent), set()).add(path.name)
149
+ elif filename.endswith(".pt"):
150
+ standalone.append(filename)
151
+
152
+ entries = list(standalone)
153
+ for directory, names in by_dir.items():
154
+ selected = (
155
+ "probe.json"
156
+ if "probe.json" in names
157
+ else "probe.pt"
158
+ if "probe.pt" in names
159
+ else "weights.safetensors"
160
+ )
161
+ entries.append(str(Path(directory) / selected))
162
+ return sorted(entries, key=_probe_sort_key)
utils/probe_overlay.py CHANGED
@@ -124,18 +124,14 @@ def build_regression_overlays(
124
  return overlays
125
 
126
 
127
- def attach_overlays(
128
- messages: list[dict], overlays: list[ProbeOverlay]
129
- ) -> None:
130
  """Attach one overlay to each assistant message, in order.
131
 
132
  Requires a 1:1 match. If the counts don't line up (e.g. the chat template
133
  doesn't mark assistant tokens), clear overlays so the caller can show a
134
  clear status instead of painting the wrong message.
135
  """
136
- assistant_idxs = [
137
- i for i, m in enumerate(messages) if m.get("role") == "assistant"
138
- ]
139
  clear_overlays(messages)
140
  if not assistant_idxs or len(overlays) != len(assistant_idxs):
141
  return
@@ -189,8 +185,7 @@ def _tooltip(probs_row: list[float], labels: list[str | None]) -> str:
189
  # Single-output sigmoid: synthesize the complementary class so the
190
  # hover shows both label probabilities, not just one.
191
  return escape(
192
- f"{positive_label} {positive:.2f} · "
193
- f"not {positive_label} {1 - positive:.2f}"
194
  )
195
  ranked = sorted(enumerate(probs_row), key=lambda item: item[1], reverse=True)
196
  parts = [f"{_label_for(labels, idx)} {prob:.2f}" for idx, prob in ranked]
 
124
  return overlays
125
 
126
 
127
+ def attach_overlays(messages: list[dict], overlays: list[ProbeOverlay]) -> None:
 
 
128
  """Attach one overlay to each assistant message, in order.
129
 
130
  Requires a 1:1 match. If the counts don't line up (e.g. the chat template
131
  doesn't mark assistant tokens), clear overlays so the caller can show a
132
  clear status instead of painting the wrong message.
133
  """
134
+ assistant_idxs = [i for i, m in enumerate(messages) if m.get("role") == "assistant"]
 
 
135
  clear_overlays(messages)
136
  if not assistant_idxs or len(overlays) != len(assistant_idxs):
137
  return
 
185
  # Single-output sigmoid: synthesize the complementary class so the
186
  # hover shows both label probabilities, not just one.
187
  return escape(
188
+ f"{positive_label} {positive:.2f} · not {positive_label} {1 - positive:.2f}"
 
189
  )
190
  ranked = sorted(enumerate(probs_row), key=lambda item: item[1], reverse=True)
191
  parts = [f"{_label_for(labels, idx)} {prob:.2f}" for idx, prob in ranked]
utils/probe_trace.py CHANGED
@@ -11,6 +11,7 @@ from persona_data.prompts import normalize_messages, supports_system_role
11
  from utils.chat import decode_token, format_generation_prompt, resolve_saved_tensor
12
 
13
  _TRACE_CACHE_KEY = "probe:trace_cache"
 
14
  _MAX_CACHED_TRACES = 3
15
 
16
 
@@ -92,9 +93,7 @@ def trace_conversation(
92
 
93
  n_tokens = int(input_ids.shape[0])
94
  assistant_spans = _clip_spans(
95
- _assistant_spans_from_offsets(
96
- model.tokenizer, prompt_text, messages, n_tokens
97
- ),
98
  n_tokens,
99
  )
100
  if not assistant_spans and assistant_mask_seq is not None:
@@ -182,6 +181,30 @@ def _store_cached_trace(cache_key: str, trace: ConversationTrace) -> None:
182
  while len(cache) > _MAX_CACHED_TRACES:
183
  oldest_key = next(iter(cache))
184
  cache.pop(oldest_key, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
 
187
  def _compute_assistant_mask(
@@ -302,9 +325,7 @@ def _assistant_spans_from_prefixes(
302
  for i, message in enumerate(messages):
303
  if message.get("role") != "assistant":
304
  continue
305
- prefix_ids = apply(
306
- messages[:i], tokenize=True, add_generation_prompt=True
307
- )
308
  through_ids = apply(
309
  messages[: i + 1], tokenize=True, add_generation_prompt=False
310
  )
@@ -332,9 +353,7 @@ def _flatten_ids(value: object) -> list[int] | None:
332
  return None
333
 
334
 
335
- def _clip_spans(
336
- spans: list[tuple[int, int]], n_tokens: int
337
- ) -> list[tuple[int, int]]:
338
  clipped: list[tuple[int, int]] = []
339
  for start, end in spans:
340
  s = max(0, min(start, n_tokens))
 
11
  from utils.chat import decode_token, format_generation_prompt, resolve_saved_tensor
12
 
13
  _TRACE_CACHE_KEY = "probe:trace_cache"
14
+ _DERIVED_CACHE_TRACKER_KEY = "probe:derived_cache_keys"
15
  _MAX_CACHED_TRACES = 3
16
 
17
 
 
93
 
94
  n_tokens = int(input_ids.shape[0])
95
  assistant_spans = _clip_spans(
96
+ _assistant_spans_from_offsets(model.tokenizer, prompt_text, messages, n_tokens),
 
 
97
  n_tokens,
98
  )
99
  if not assistant_spans and assistant_mask_seq is not None:
 
181
  while len(cache) > _MAX_CACHED_TRACES:
182
  oldest_key = next(iter(cache))
183
  cache.pop(oldest_key, None)
184
+ _drop_derived_results_for_trace(oldest_key)
185
+
186
+
187
+ def _drop_derived_results_for_trace(cache_key: str) -> None:
188
+ """Remove probe predictions tied to a trace that just aged out."""
189
+
190
+ prefixes = (
191
+ f"probe_predictions::{cache_key}::",
192
+ f"probe_values::{cache_key}::",
193
+ )
194
+ tracked = st.session_state.get(_DERIVED_CACHE_TRACKER_KEY)
195
+ if isinstance(tracked, list):
196
+ kept: list[str] = []
197
+ for key in tracked:
198
+ if isinstance(key, str) and key.startswith(prefixes):
199
+ st.session_state.pop(key, None)
200
+ else:
201
+ kept.append(key)
202
+ st.session_state[_DERIVED_CACHE_TRACKER_KEY] = kept
203
+ return
204
+
205
+ for key in list(st.session_state):
206
+ if isinstance(key, str) and key.startswith(prefixes):
207
+ st.session_state.pop(key, None)
208
 
209
 
210
  def _compute_assistant_mask(
 
325
  for i, message in enumerate(messages):
326
  if message.get("role") != "assistant":
327
  continue
328
+ prefix_ids = apply(messages[:i], tokenize=True, add_generation_prompt=True)
 
 
329
  through_ids = apply(
330
  messages[: i + 1], tokenize=True, add_generation_prompt=False
331
  )
 
353
  return None
354
 
355
 
356
+ def _clip_spans(spans: list[tuple[int, int]], n_tokens: int) -> list[tuple[int, int]]:
 
 
357
  clipped: list[tuple[int, int]] = []
358
  for start, end in spans:
359
  s = max(0, min(start, n_tokens))
utils/probes.py CHANGED
@@ -1,8 +1,6 @@
1
  from __future__ import annotations
2
 
3
  import io
4
- import os
5
- import re
6
  from dataclasses import dataclass
7
  from pathlib import Path
8
  from typing import Any
@@ -13,33 +11,14 @@ import torch.nn as nn
13
  import torch.nn.functional as F
14
  from persona_vectors.probes import ProbeArtifact, load_probe_artifact
15
 
16
- PROBE_FILENAME_RE = re.compile(
17
- r"^cognitive_map_probe_layer(?P<layer>\d+)_(?P<model_type>[a-z0-9]+)_"
18
- r"(?P<location>pre_reasoning|post_reasoning)_all_(?P<scope>general|size\d+)\.pt$"
 
 
19
  )
20
 
21
- # persona-vectors layout: .../{model}/{mask}/{variant}/{attribute}/{probe_kind}[_pca{K}]_layer{N}/weights.safetensors
22
- PERSONA_PROBE_DIR_RE = re.compile(
23
- r"^(?P<probe_kind>[a-z_]+?)(?:_pca(?P<pca>\d+))?_layer(?P<layer>\d+)$"
24
- )
25
-
26
- DEFAULT_PROBE_REPO = "project-telos/cognitive_map_probes"
27
- DEFAULT_LOCAL_PROBE_DIR = os.environ.get(
28
- "PERSONA_PROBES_DIR",
29
- "artifacts/probes",
30
- )
31
-
32
-
33
- @dataclass(frozen=True)
34
- class ProbeFileMetadata:
35
- filename: str
36
- layer: int | None
37
- model_type: str
38
- location: str | None
39
- scope: str | None
40
- label: str
41
- model_name: str | None = None
42
- attribute_name: str | None = None
43
 
44
 
45
  @dataclass(frozen=True)
@@ -195,9 +174,7 @@ class LoadedProbe:
195
  predicted = torch.argmax(probs, dim=-1)
196
  return logits, probs, predicted
197
 
198
- def _forward_batch(
199
- self, batch: torch.Tensor
200
- ) -> tuple[torch.Tensor, torch.Tensor]:
201
  normalized = self._normalize_batch(batch)
202
  with torch.no_grad():
203
  logits = self.model(normalized).detach().cpu()
@@ -233,104 +210,7 @@ class LoadedProbe:
233
  return batch
234
 
235
 
236
- def model_probe_dir_name(model_name: str) -> str:
237
- return model_name.replace("/", "__")
238
-
239
-
240
- def parse_probe_filename(filename: str) -> ProbeFileMetadata:
241
- path = Path(filename)
242
- match = PROBE_FILENAME_RE.match(path.name)
243
- if match:
244
- layer = int(match.group("layer"))
245
- model_type = match.group("model_type")
246
- location = match.group("location")
247
- scope = match.group("scope")
248
- scope_label = scope.replace("size", "size ")
249
- return ProbeFileMetadata(
250
- filename=filename,
251
- layer=layer,
252
- model_type=model_type,
253
- location=location,
254
- scope=scope,
255
- label=(
256
- f"Layer {layer} - {model_type.upper()} - "
257
- f"{location.replace('_', ' ')} - {scope_label}"
258
- ),
259
- )
260
-
261
- # persona-vectors layout: parent dir holds probe_kind[_pca{K}]_layer{N},
262
- # and the dir above that is the attribute name.
263
- parent_match = PERSONA_PROBE_DIR_RE.match(path.parent.name)
264
- if parent_match and path.name in {"probe.json", "weights.safetensors"}:
265
- layer = int(parent_match.group("layer"))
266
- probe_kind = parent_match.group("probe_kind")
267
- pca = parent_match.group("pca")
268
- scope = f"pca{pca}" if pca else None
269
- attribute = path.parent.parent.name or "attribute"
270
- model_name = path.parts[0].replace("__", "/") if len(path.parts) >= 5 else None
271
- label = f"{attribute} - layer {layer} - {probe_kind}"
272
- if pca:
273
- label += f" (pca{pca})"
274
- return ProbeFileMetadata(
275
- filename=filename,
276
- layer=layer,
277
- model_type=probe_kind,
278
- location=None,
279
- scope=scope,
280
- label=label,
281
- model_name=model_name,
282
- attribute_name=attribute,
283
- )
284
-
285
- return ProbeFileMetadata(
286
- filename=filename,
287
- layer=None,
288
- model_type="unknown",
289
- location=None,
290
- scope=None,
291
- label=path.stem.replace("_", " "),
292
- )
293
-
294
-
295
- @st.cache_data(show_spinner=False, ttl=300)
296
- def list_probe_files(repo_id: str) -> list[str]:
297
- from huggingface_hub import list_repo_files
298
-
299
- files = list_repo_files(repo_id, repo_type="model")
300
- return _dedupe_probe_entries(files)
301
-
302
-
303
- @st.cache_data(show_spinner=False, ttl=30)
304
- def list_local_probe_files(root_dir: str) -> list[str]:
305
- root = Path(root_dir).expanduser()
306
- if not root.is_dir():
307
- return []
308
- files = _dedupe_probe_entries([
309
- str(path.relative_to(root))
310
- for path in root.rglob("*")
311
- if path.is_file() and path.name in {"probe.pt", "probe.json", "weights.safetensors"}
312
- ])
313
- return sorted(files, key=_probe_sort_key)
314
-
315
-
316
- @st.cache_data(show_spinner=False, ttl=300)
317
- def download_probe_file(repo_id: str, filename: str) -> str:
318
- from huggingface_hub import hf_hub_download
319
-
320
- return hf_hub_download(repo_id, filename, repo_type="model")
321
-
322
-
323
- @st.cache_data(show_spinner=False, ttl=300)
324
- def download_probe_json_and_weights(repo_id: str, filename: str) -> tuple[str, str]:
325
- from huggingface_hub import hf_hub_download
326
-
327
- metadata_path = hf_hub_download(repo_id, filename, repo_type="model")
328
- weights_name = str(Path(filename).with_name("weights.safetensors"))
329
- weights_path = hf_hub_download(repo_id, weights_name, repo_type="model")
330
- return metadata_path, weights_path
331
-
332
-
333
- @st.cache_resource(show_spinner=False)
334
  def load_probe(repo_id: str, filename: str) -> LoadedProbe:
335
  if filename.endswith("probe.json"):
336
  metadata_path, weights_path = download_probe_json_and_weights(repo_id, filename)
@@ -346,7 +226,7 @@ def load_probe(repo_id: str, filename: str) -> LoadedProbe:
346
  )
347
 
348
 
349
- @st.cache_resource(show_spinner=False)
350
  def load_local_probe(root_dir: str, filename: str) -> LoadedProbe:
351
  root = Path(root_dir).expanduser()
352
  path = (root / filename).resolve()
@@ -370,7 +250,7 @@ def load_local_probe(root_dir: str, filename: str) -> LoadedProbe:
370
  )
371
 
372
 
373
- @st.cache_resource(show_spinner=False)
374
  def load_probe_from_bytes(filename: str, data: bytes) -> LoadedProbe:
375
  return _load_probe_payload(
376
  filename=filename,
@@ -432,16 +312,20 @@ def _load_probe_payload(
432
  _optional_str(payload.get("attribute_name")) or metadata.attribute_name
433
  ),
434
  feature_space=(
435
- (f"pca{payload['n_pca_components']}"
436
- if payload.get("n_pca_components")
437
- else None)
 
 
438
  or _optional_str(payload.get("feature_space"))
439
  or metadata.scope
440
  ),
441
  task=_optional_str(payload.get("task")),
442
  probe_kind=_optional_str(payload.get("probe_kind")),
443
  scaler_mean=_as_cpu_tensor(payload.get("scaler_mean")),
444
- scaler_std=_as_cpu_tensor(_first_present(payload, "scaler_std", "scaler_scale")),
 
 
445
  pca_mean=_as_cpu_tensor(payload.get("pca_mean")),
446
  pca_components=_as_cpu_tensor(payload.get("pca_components")),
447
  )
@@ -617,39 +501,6 @@ def _first_present(payload: dict[str, Any], *keys: str) -> Any:
617
  return None
618
 
619
 
620
- def _probe_sort_key(filename: str) -> tuple[str, str, int, str]:
621
- metadata = parse_probe_filename(filename)
622
- return (
623
- metadata.model_name or "",
624
- metadata.attribute_name or "",
625
- metadata.layer if metadata.layer is not None else 10**9,
626
- filename,
627
- )
628
-
629
-
630
- def _dedupe_probe_entries(files: list[str]) -> list[str]:
631
- by_dir: dict[str, set[str]] = {}
632
- standalone: list[str] = []
633
- for filename in files:
634
- path = Path(filename)
635
- if path.name in {"probe.pt", "probe.json", "weights.safetensors"}:
636
- by_dir.setdefault(str(path.parent), set()).add(path.name)
637
- elif filename.endswith(".pt"):
638
- standalone.append(filename)
639
-
640
- entries = list(standalone)
641
- for directory, names in by_dir.items():
642
- selected = (
643
- "probe.json"
644
- if "probe.json" in names
645
- else "probe.pt"
646
- if "probe.pt" in names
647
- else "weights.safetensors"
648
- )
649
- entries.append(str(Path(directory) / selected))
650
- return sorted(entries, key=_probe_sort_key)
651
-
652
-
653
  def _normalize_labels(raw_labels: Any, num_classes: int) -> list[str | None]:
654
  if isinstance(raw_labels, (list, tuple)):
655
  labels = [str(label) for label in raw_labels[:num_classes]]
 
1
  from __future__ import annotations
2
 
3
  import io
 
 
4
  from dataclasses import dataclass
5
  from pathlib import Path
6
  from typing import Any
 
11
  import torch.nn.functional as F
12
  from persona_vectors.probes import ProbeArtifact, load_probe_artifact
13
 
14
+ from utils.helpers import env_int
15
+ from utils.probe_files import (
16
+ download_probe_file,
17
+ download_probe_json_and_weights,
18
+ parse_probe_filename,
19
  )
20
 
21
+ _PROBE_CACHE_ENTRIES = env_int("PERSONA_UI_PROBE_CACHE_ENTRIES", 8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
  @dataclass(frozen=True)
 
174
  predicted = torch.argmax(probs, dim=-1)
175
  return logits, probs, predicted
176
 
177
+ def _forward_batch(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
 
 
178
  normalized = self._normalize_batch(batch)
179
  with torch.no_grad():
180
  logits = self.model(normalized).detach().cpu()
 
210
  return batch
211
 
212
 
213
+ @st.cache_resource(show_spinner=False, max_entries=_PROBE_CACHE_ENTRIES)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  def load_probe(repo_id: str, filename: str) -> LoadedProbe:
215
  if filename.endswith("probe.json"):
216
  metadata_path, weights_path = download_probe_json_and_weights(repo_id, filename)
 
226
  )
227
 
228
 
229
+ @st.cache_resource(show_spinner=False, max_entries=_PROBE_CACHE_ENTRIES)
230
  def load_local_probe(root_dir: str, filename: str) -> LoadedProbe:
231
  root = Path(root_dir).expanduser()
232
  path = (root / filename).resolve()
 
250
  )
251
 
252
 
253
+ @st.cache_resource(show_spinner=False, max_entries=_PROBE_CACHE_ENTRIES)
254
  def load_probe_from_bytes(filename: str, data: bytes) -> LoadedProbe:
255
  return _load_probe_payload(
256
  filename=filename,
 
312
  _optional_str(payload.get("attribute_name")) or metadata.attribute_name
313
  ),
314
  feature_space=(
315
+ (
316
+ f"pca{payload['n_pca_components']}"
317
+ if payload.get("n_pca_components")
318
+ else None
319
+ )
320
  or _optional_str(payload.get("feature_space"))
321
  or metadata.scope
322
  ),
323
  task=_optional_str(payload.get("task")),
324
  probe_kind=_optional_str(payload.get("probe_kind")),
325
  scaler_mean=_as_cpu_tensor(payload.get("scaler_mean")),
326
+ scaler_std=_as_cpu_tensor(
327
+ _first_present(payload, "scaler_std", "scaler_scale")
328
+ ),
329
  pca_mean=_as_cpu_tensor(payload.get("pca_mean")),
330
  pca_components=_as_cpu_tensor(payload.get("pca_components")),
331
  )
 
501
  return None
502
 
503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  def _normalize_labels(raw_labels: Any, num_classes: int) -> list[str | None]:
505
  if isinstance(raw_labels, (list, tuple)):
506
  labels = [str(label) for label in raw_labels[:num_classes]]
utils/selection_controls.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Sequence
4
+
5
+ import streamlit as st
6
+
7
+
8
+ def remembered_segmented_control(
9
+ label: str,
10
+ *,
11
+ options: Sequence[str],
12
+ key: str,
13
+ remember_key: str | None = None,
14
+ default: str | None = None,
15
+ label_visibility: str = "visible",
16
+ ) -> str:
17
+ """Render a segmented control with one small, reusable memory pattern."""
18
+ fallback = default or options[0]
19
+ remembered = st.session_state.get(
20
+ remember_key,
21
+ st.session_state.get(key, fallback),
22
+ )
23
+ selected = (
24
+ st.segmented_control(
25
+ label,
26
+ options=options,
27
+ default=remembered if remembered in options else fallback,
28
+ key=key,
29
+ label_visibility=label_visibility,
30
+ )
31
+ or fallback
32
+ )
33
+ if remember_key is not None:
34
+ st.session_state[remember_key] = selected
35
+ return selected
utils/source_controls.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ import streamlit as st
6
+ from persona_data.environment import get_artifacts_dir
7
+ from persona_vectors.extraction import MaskStrategy
8
+
9
+ from utils.analysis_sources import (
10
+ DEFAULT_COMPARE_MODEL,
11
+ DEFAULT_HUB_REPO,
12
+ SOURCE_HUB,
13
+ SOURCE_LOCAL,
14
+ SOURCES,
15
+ Store,
16
+ activation_store_cached,
17
+ hub_models_by_mask_strategy,
18
+ local_model_matches,
19
+ local_model_options_cached,
20
+ )
21
+ from utils.helpers import widget_key
22
+ from utils.selection_controls import remembered_segmented_control
23
+
24
+ _SHARED_SOURCE_KEY = "source:last_source"
25
+ _SHARED_HUB_REPO_KEY = "source:hub_repo"
26
+ _SHARED_HUB_MODEL_KEY = "source:hub_model"
27
+ _SHARED_LOCAL_ROOT_KEY = "source:local_root"
28
+ _SHARED_LOCAL_MODEL_KEY = "source:local_model"
29
+
30
+
31
+ def render_source_select(
32
+ *,
33
+ widget_scope: str,
34
+ last_source_key: str | None = None,
35
+ ) -> str:
36
+ key = widget_key(widget_scope, "source")
37
+ if last_source_key is not None and last_source_key not in st.session_state:
38
+ shared_source = st.session_state.get(_SHARED_SOURCE_KEY)
39
+ if shared_source is not None:
40
+ st.session_state[last_source_key] = shared_source
41
+ selected = remembered_segmented_control(
42
+ "Source",
43
+ options=SOURCES,
44
+ key=key,
45
+ remember_key=last_source_key or _SHARED_SOURCE_KEY,
46
+ default=SOURCE_HUB,
47
+ label_visibility="collapsed",
48
+ )
49
+ st.session_state[_SHARED_SOURCE_KEY] = selected
50
+ if last_source_key is not None:
51
+ st.session_state[last_source_key] = selected
52
+ return selected
53
+
54
+
55
+ def _render_hub_model_select(
56
+ *,
57
+ state_prefix: str,
58
+ widget_scope: str,
59
+ repo_id: str,
60
+ mask_strategy: MaskStrategy,
61
+ model_label: str,
62
+ fallback_help: str,
63
+ selection_help: str,
64
+ ) -> str:
65
+ fallback_key = f"{state_prefix}:hub_model_fallback"
66
+ fallback_model = st.session_state.get(
67
+ fallback_key,
68
+ st.session_state.get(_SHARED_HUB_MODEL_KEY, DEFAULT_COMPARE_MODEL),
69
+ )
70
+ try:
71
+ models_by_strategy = hub_models_by_mask_strategy(repo_id)
72
+ except Exception as exc:
73
+ st.warning(f"Could not load Hub configs for `{repo_id}`: {exc}")
74
+ model = st.text_input(
75
+ model_label,
76
+ value=fallback_model,
77
+ key=fallback_key,
78
+ help=fallback_help,
79
+ )
80
+ st.session_state[_SHARED_HUB_MODEL_KEY] = model
81
+ return model
82
+
83
+ model_options = models_by_strategy.get(mask_strategy, [])
84
+ if not model_options:
85
+ st.warning(
86
+ f"No Hub vector configs found for `{mask_strategy.value}` in `{repo_id}`."
87
+ )
88
+ model = st.text_input(
89
+ model_label,
90
+ value=fallback_model,
91
+ key=fallback_key,
92
+ help=fallback_help,
93
+ )
94
+ st.session_state[_SHARED_HUB_MODEL_KEY] = model
95
+ return model
96
+
97
+ select_key = widget_key(widget_scope, "hub_model", repo_id, mask_strategy.value)
98
+ previous_model = st.session_state.get(
99
+ select_key,
100
+ st.session_state.get(_SHARED_HUB_MODEL_KEY, fallback_model),
101
+ )
102
+ default_model = (
103
+ previous_model if previous_model in model_options else model_options[0]
104
+ )
105
+ selected = st.selectbox(
106
+ model_label,
107
+ options=model_options,
108
+ index=model_options.index(default_model),
109
+ key=select_key,
110
+ help=selection_help,
111
+ )
112
+ st.session_state[fallback_key] = selected
113
+ st.session_state[_SHARED_HUB_MODEL_KEY] = selected
114
+ return selected
115
+
116
+
117
+ def _render_local_model_select(
118
+ *,
119
+ state_prefix: str,
120
+ artifacts_root: str,
121
+ mask_strategy: MaskStrategy,
122
+ allow_custom_toggle: bool,
123
+ model_label: str,
124
+ ) -> str:
125
+ fallback_key = f"{state_prefix}:local_model"
126
+ fallback_model = st.session_state.get(
127
+ fallback_key,
128
+ st.session_state.get(_SHARED_LOCAL_MODEL_KEY, DEFAULT_COMPARE_MODEL),
129
+ )
130
+ model_options = local_model_options_cached(artifacts_root, mask_strategy.value)
131
+ if not model_options:
132
+ model = st.text_input(model_label, value=fallback_model, key=fallback_key)
133
+ st.session_state[_SHARED_LOCAL_MODEL_KEY] = model
134
+ return model
135
+
136
+ if allow_custom_toggle:
137
+ custom = st.toggle(
138
+ "Custom local model",
139
+ value=False,
140
+ key=f"{state_prefix}:local_model_custom_enabled",
141
+ help="Enter a model id/path manually instead of choosing from activation directories.",
142
+ )
143
+ if custom:
144
+ model = st.text_input("Local model", value=fallback_model, key=fallback_key)
145
+ st.session_state[_SHARED_LOCAL_MODEL_KEY] = model
146
+ return model
147
+
148
+ select_key = f"{state_prefix}:local_model_select"
149
+ previous_model = st.session_state.get(
150
+ select_key,
151
+ st.session_state.get(_SHARED_LOCAL_MODEL_KEY, fallback_model),
152
+ )
153
+ if not any(local_model_matches(previous_model, option) for option in model_options):
154
+ previous_model = fallback_model
155
+ default_model = next(
156
+ (
157
+ option
158
+ for option in model_options
159
+ if local_model_matches(option, previous_model)
160
+ ),
161
+ model_options[0],
162
+ )
163
+ selected = st.selectbox(
164
+ model_label,
165
+ options=model_options,
166
+ index=model_options.index(default_model),
167
+ key=select_key,
168
+ help="Models discovered under the selected artifacts root.",
169
+ )
170
+ st.session_state[fallback_key] = selected
171
+ st.session_state[_SHARED_LOCAL_MODEL_KEY] = selected
172
+ return selected
173
+
174
+
175
+ def render_store_select(
176
+ source: str,
177
+ mask_strategy: MaskStrategy,
178
+ *,
179
+ state_prefix: str,
180
+ widget_scope: str,
181
+ artifacts_root_key: str,
182
+ model_label: str = "Model",
183
+ local_model_label: str = "Model",
184
+ allow_custom_local_model: bool = False,
185
+ repo_help: str | None = None,
186
+ fallback_help: str = "Model id to use if Hub config discovery is unavailable.",
187
+ ) -> Store:
188
+ if source == SOURCE_HUB:
189
+ repo_key = f"{state_prefix}:hub_repo"
190
+ repo = st.text_input(
191
+ "Hub repo",
192
+ value=st.session_state.get(
193
+ repo_key,
194
+ st.session_state.get(_SHARED_HUB_REPO_KEY, DEFAULT_HUB_REPO),
195
+ ),
196
+ key=repo_key,
197
+ help=repo_help,
198
+ )
199
+ st.session_state[_SHARED_HUB_REPO_KEY] = repo
200
+ model_name = _render_hub_model_select(
201
+ state_prefix=state_prefix,
202
+ widget_scope=widget_scope,
203
+ repo_id=repo,
204
+ mask_strategy=mask_strategy,
205
+ model_label=model_label,
206
+ fallback_help=fallback_help,
207
+ selection_help="Models with vectors in the selected Hub repo and mask strategy.",
208
+ )
209
+ return activation_store_cached(
210
+ SOURCE_HUB, repo, model_name, mask_strategy.value
211
+ )
212
+
213
+ root = st.text_input(
214
+ "Artifacts root",
215
+ value=st.session_state.get(
216
+ _SHARED_LOCAL_ROOT_KEY,
217
+ str(get_artifacts_dir() / "activations"),
218
+ ),
219
+ key=artifacts_root_key,
220
+ )
221
+ root = str(Path(root).expanduser())
222
+ st.session_state[_SHARED_LOCAL_ROOT_KEY] = root
223
+ model_name = _render_local_model_select(
224
+ state_prefix=state_prefix,
225
+ artifacts_root=root,
226
+ mask_strategy=mask_strategy,
227
+ allow_custom_toggle=allow_custom_local_model,
228
+ model_label=local_model_label,
229
+ )
230
+ return activation_store_cached(SOURCE_LOCAL, root, model_name, mask_strategy.value)
uv.lock CHANGED
@@ -464,11 +464,11 @@ wheels = [
464
 
465
  [[package]]
466
  name = "decorator"
467
- version = "5.3.0"
468
  source = { registry = "https://pypi.org/simple" }
469
- sdist = { url = "https://files.pythonhosted.org/packages/5c/50/a39dd7ab407e93978dfa07d109b7d633e37958c89f30cbcec061b77b3ebc/decorator-5.3.0.tar.gz", hash = "sha256:95fda3122972c847cf0ff7e0ce2829bf25136f2526b627b3da85b60ca5f485c0", size = 58431, upload-time = "2026-05-17T06:59:57.258Z" }
470
  wheels = [
471
- { url = "https://files.pythonhosted.org/packages/d5/6f/f8d0bba4dc2a69817d74f640d504650241ebf2f9f7263426f1b953b344d4/decorator-5.3.0-py3-none-any.whl", hash = "sha256:f8c2d71ede92f073144ddd7f3e9fbbc3bd0f2f29522c9d75ee648d66553834f4", size = 11104, upload-time = "2026-05-17T06:59:54.676Z" },
472
  ]
473
 
474
  [[package]]
@@ -1585,7 +1585,7 @@ wheels = [
1585
 
1586
  [[package]]
1587
  name = "persona-ui"
1588
- version = "0.4.0"
1589
  source = { virtual = "." }
1590
  dependencies = [
1591
  { name = "catppuccin" },
@@ -2145,11 +2145,11 @@ wheels = [
2145
 
2146
  [[package]]
2147
  name = "python-multipart"
2148
- version = "0.0.28"
2149
  source = { registry = "https://pypi.org/simple" }
2150
- sdist = { url = "https://files.pythonhosted.org/packages/82/54/a85eb421fbdd5007bc5af39d0f4ed9fa609e0fedbfdc2adcf0b34526870e/python_multipart-0.0.28.tar.gz", hash = "sha256:8550da197eac0f7ab748961fc9509b999fa2662ea25cef857f05249f6893c0f8", size = 45314, upload-time = "2026-05-10T11:05:16.596Z" }
2151
  wheels = [
2152
- { url = "https://files.pythonhosted.org/packages/f3/a2/43bbc5860b5034e2af4ef99a0e04d726ff329c43e192ef3abaa8d7ecfce5/python_multipart-0.0.28-py3-none-any.whl", hash = "sha256:10faac07eb966c3f48dc415f9dee46c04cb10d58d30a35677db8027c825ed9b6", size = 29438, upload-time = "2026-05-10T11:05:15.052Z" },
2153
  ]
2154
 
2155
  [[package]]
 
464
 
465
  [[package]]
466
  name = "decorator"
467
+ version = "5.3.1"
468
  source = { registry = "https://pypi.org/simple" }
469
+ sdist = { url = "https://files.pythonhosted.org/packages/60/8b/32f9823da46cde7df2087faa08cd98d01b908f8dcab982cdba9c84e85355/decorator-5.3.1.tar.gz", hash = "sha256:4cbcdd55a6efadb9dbea26b858f4fb3264567b52d69ca0d25b721b553f60ea82", size = 58084, upload-time = "2026-05-18T06:03:28.057Z" }
470
  wheels = [
471
+ { url = "https://files.pythonhosted.org/packages/05/7f/798705f5296a58ca505d600456748d1be48078eac8a7050d8a98bc9edb89/decorator-5.3.1-py3-none-any.whl", hash = "sha256:f47fe6fdbd2edd623ecfe36875d37aba411624e2670dd395dddae1358689bb3c", size = 10365, upload-time = "2026-05-18T06:03:26.517Z" },
472
  ]
473
 
474
  [[package]]
 
1585
 
1586
  [[package]]
1587
  name = "persona-ui"
1588
+ version = "0.5.0"
1589
  source = { virtual = "." }
1590
  dependencies = [
1591
  { name = "catppuccin" },
 
2145
 
2146
  [[package]]
2147
  name = "python-multipart"
2148
+ version = "0.0.29"
2149
  source = { registry = "https://pypi.org/simple" }
2150
+ sdist = { url = "https://files.pythonhosted.org/packages/4e/fe/70bd71a6738b09a0bdf6480ca6436b167469ca4578b2a0efbe390b4b0e70/python_multipart-0.0.29.tar.gz", hash = "sha256:643e93849196645e2dbdd81a0f8829a23123ad7f797a84a364c6fb3563f18904", size = 45678, upload-time = "2026-05-17T17:29:47.654Z" }
2151
  wheels = [
2152
+ { url = "https://files.pythonhosted.org/packages/8f/cb/769cfc37177252872a45a71f3fbdde9d51b471a3f3c14bfe95dde3407386/python_multipart-0.0.29-py3-none-any.whl", hash = "sha256:2ddcc971cef266225f54f552d8fa10bcfbb1f14446caec199060daac59ff2d69", size = 29640, upload-time = "2026-05-17T17:29:45.69Z" },
2153
  ]
2154
 
2155
  [[package]]