Jac-Zac commited on
Commit
9edffb7
·
1 Parent(s): fee1567

Updated to latest probing options

Browse files

- Cleaned up repo
- Improved performance drastically updating to the latest versions of
the librarires + less reloading smarter caching and prefetchign

.env.example CHANGED
@@ -23,3 +23,5 @@ ARTIFACTS_DIR=artifacts
23
  # Keep model cache at 1 unless you have enough RAM for multiple loaded models.
24
  # PERSONA_UI_MODEL_CACHE_ENTRIES=1
25
  # PERSONA_UI_STORE_CACHE_ENTRIES=4
 
 
 
23
  # Keep model cache at 1 unless you have enough RAM for multiple loaded models.
24
  # PERSONA_UI_MODEL_CACHE_ENTRIES=1
25
  # PERSONA_UI_STORE_CACHE_ENTRIES=4
26
+ # PERSONA_UI_VECTOR_CACHE_ENTRIES=4
27
+ # PERSONA_UI_PREPARED_CACHE_ENTRIES=8
README.md CHANGED
@@ -116,6 +116,8 @@ NDIF_API_KEY=... # Required for remote (NDIF) model execution
116
  HF_HOME=... # Optional: HuggingFace cache directory
117
  ARTIFACTS_DIR=... # Optional: where persona vectors are read from (default: ./artifacts)
118
  PERSONA_VECTORS_HUB_REPO=... # Optional: default Analysis/Probing Hub dataset repo
 
 
119
  ```
120
 
121
  The app picks up this file automatically via `load_dotenv()` on startup.
@@ -148,3 +150,7 @@ the Analysis/Probing tab's Local source path) at the tree you want to load.
148
 
149
  The store classes are `PersonaVectorStore` (local) and `HFPersonaVectorStore`
150
  (Hub) — same API, both imported by `utils/analysis_sources.py`.
 
 
 
 
 
116
  HF_HOME=... # Optional: HuggingFace cache directory
117
  ARTIFACTS_DIR=... # Optional: where persona vectors are read from (default: ./artifacts)
118
  PERSONA_VECTORS_HUB_REPO=... # Optional: default Analysis/Probing Hub dataset repo
119
+ PERSONA_UI_VECTOR_CACHE_ENTRIES=4 # Optional: loaded analysis datasets kept warm
120
+ PERSONA_UI_PREPARED_CACHE_ENTRIES=8 # Optional: prepared projections / k-means groups kept warm
121
  ```
122
 
123
  The app picks up this file automatically via `load_dotenv()` on startup.
 
150
 
151
  The store classes are `PersonaVectorStore` (local) and `HFPersonaVectorStore`
152
  (Hub) — same API, both imported by `utils/analysis_sources.py`.
153
+
154
+ ## Analysis responsiveness
155
+
156
+ The Analysis tab keeps a small bounded cache of loaded vector datasets and prepared projection data. Once a projection has been computed, recoloring it by persona, attribute, or k-means group reuses the same coordinates; nearby Hub interactions also keep metadata warm instead of re-scanning after every figure. Tune `PERSONA_UI_VECTOR_CACHE_ENTRIES` if RAM is tight or you regularly switch among many selections, and `PERSONA_UI_PREPARED_CACHE_ENTRIES` if you revisit several projection configurations in one session.
pyproject.toml CHANGED
@@ -5,7 +5,7 @@ description = "Streamlit UI for persona-vectors"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
- "persona-vectors>=0.8.2",
9
  "datasets>=4.8.5",
10
  "huggingface-hub>=1.14.0",
11
  "streamlit>=1.44.0",
 
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
+ "persona-vectors>=0.8.3",
9
  "datasets>=4.8.5",
10
  "huggingface-hub>=1.14.0",
11
  "streamlit>=1.44.0",
tabs/analysis/_shared.py CHANGED
@@ -6,6 +6,19 @@ from persona_data.synth_persona import BASELINE_PERSONA_ID
6
  from persona_vectors.extraction import MaskStrategy
7
  from persona_vectors.plots import save_plot_html
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from utils.analysis_sources import (
10
  Store,
11
  available_variants,
@@ -13,7 +26,6 @@ from utils.analysis_sources import (
13
  load_variant_vectors_cached,
14
  persona_names_cached,
15
  personas_cached,
16
- release_hf_store_cache,
17
  store_cache_parts,
18
  store_id,
19
  store_layers_cached,
@@ -22,20 +34,6 @@ from utils.controls import render_mask_strategy_select
22
  from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
23
  from utils.theme import active_base, style_plotly_layer_controls
24
 
25
- from tabs.analysis._state import (
26
- _DEFAULT_LAYER_FRAMES,
27
- _HIGHLIGHT_OTHER_COLOR,
28
- _HIGHLIGHT_OTHER_LABEL,
29
- _LAST_LAYER_FRAMES_KEY,
30
- _LAST_MASK_STRATEGY_KEY,
31
- PersonaOptions,
32
- _is_assistant_persona,
33
- _persona_names_state_key,
34
- _personas_empty_message,
35
- _remembered_selectbox,
36
- _sequence_to_list,
37
- )
38
-
39
 
40
  def _gray_out_unselected_personas(fig: go.Figure) -> None:
41
  def _gray_trace(trace: object) -> None:
@@ -118,8 +116,7 @@ def _load_variant_vectors(
118
  )
119
 
120
 
121
- def _release_vector_memory(store: Store, variants: list[str] | tuple[str, ...]) -> None:
122
- release_hf_store_cache(store, variants)
123
  gc.collect()
124
 
125
 
 
6
  from persona_vectors.extraction import MaskStrategy
7
  from persona_vectors.plots import save_plot_html
8
 
9
+ from tabs.analysis._state import (
10
+ _DEFAULT_LAYER_FRAMES,
11
+ _HIGHLIGHT_OTHER_COLOR,
12
+ _HIGHLIGHT_OTHER_LABEL,
13
+ _LAST_LAYER_FRAMES_KEY,
14
+ _LAST_MASK_STRATEGY_KEY,
15
+ PersonaOptions,
16
+ _is_assistant_persona,
17
+ _persona_names_state_key,
18
+ _personas_empty_message,
19
+ _remembered_selectbox,
20
+ _sequence_to_list,
21
+ )
22
  from utils.analysis_sources import (
23
  Store,
24
  available_variants,
 
26
  load_variant_vectors_cached,
27
  persona_names_cached,
28
  personas_cached,
 
29
  store_cache_parts,
30
  store_id,
31
  store_layers_cached,
 
34
  from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
35
  from utils.theme import active_base, style_plotly_layer_controls
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def _gray_out_unselected_personas(fig: go.Figure) -> None:
39
  def _gray_trace(trace: object) -> None:
 
116
  )
117
 
118
 
119
+ def _release_vector_memory() -> None:
 
120
  gc.collect()
121
 
122
 
tabs/analysis/_state.py CHANGED
@@ -45,7 +45,7 @@ _CLUSTER_MODES = {
45
  "First selected layer": "first_layer",
46
  "Per layer": "per_layer",
47
  }
48
- _PROJECTION_COLOR_MODES = ["Persona", "K-means clusters", "Persona attribute"]
49
  _MAX_ATTRIBUTE_CATEGORIES = DEFAULT_MAX_ATTRIBUTE_CATEGORIES
50
 
51
 
@@ -87,7 +87,7 @@ class ProjectionColorConfig:
87
  @dataclass(frozen=True)
88
  class LayeredFigureStateKeys:
89
  figure: str
90
- projection: str | None = None
91
 
92
 
93
  _HIGHLIGHT_OTHER_LABEL = "Other"
@@ -139,7 +139,7 @@ _TRACKED_STATE_KEYS_KEY = "analysis:_tracked_state_keys"
139
 
140
 
141
  def _clear_old_load_states(current_key: str, suffix: str) -> None:
142
- # Only one heavy figure/projection state should live at a time. We track
143
  # the keys we create per suffix so eviction is O(1) instead of scanning
144
  # all of session_state on every rerun. Every such key is passed through
145
  # this function before it is set, so the registry stays authoritative.
@@ -156,8 +156,8 @@ def _clear_old_figure_states(current_key: str) -> None:
156
  _clear_old_load_states(current_key, "_fig_state")
157
 
158
 
159
- def _clear_old_projection_states(current_key: str) -> None:
160
- _clear_old_load_states(current_key, "_projection_state")
161
 
162
 
163
  def _store_figure_state(key: str, value: object) -> None:
 
45
  "First selected layer": "first_layer",
46
  "Per layer": "per_layer",
47
  }
48
+ _PROJECTION_COLOR_MODES = ["Persona attribute", "Persona", "K-means clusters"]
49
  _MAX_ATTRIBUTE_CATEGORIES = DEFAULT_MAX_ATTRIBUTE_CATEGORIES
50
 
51
 
 
87
  @dataclass(frozen=True)
88
  class LayeredFigureStateKeys:
89
  figure: str
90
+ prepared: str | None = None
91
 
92
 
93
  _HIGHLIGHT_OTHER_LABEL = "Other"
 
139
 
140
 
141
  def _clear_old_load_states(current_key: str, suffix: str) -> None:
142
+ # Only one heavy figure state should live at a time. We track
143
  # the keys we create per suffix so eviction is O(1) instead of scanning
144
  # all of session_state on every rerun. Every such key is passed through
145
  # this function before it is set, so the registry stays authoritative.
 
156
  _clear_old_load_states(current_key, "_fig_state")
157
 
158
 
159
+ def _clear_old_prepared_states(current_key: str) -> None:
160
+ _clear_old_load_states(current_key, "_projection_ready")
161
 
162
 
163
  def _store_figure_state(key: str, value: object) -> None:
tabs/analysis/cosine.py CHANGED
@@ -78,22 +78,15 @@ def _build_cosine_figures(
78
  mask_strategy: MaskStrategy,
79
  selection: CosineSelection,
80
  ) -> tuple[object, object | None, int, int] | None:
81
- variant_sample_cache: dict[str, object] = {}
82
-
83
- def _load_variant(variant: str):
84
- if variant not in variant_sample_cache:
85
- samples = _load_variant_vectors(
86
- store,
87
- [variant],
88
- mask_strategy,
89
- persona_ids=selection.persona_ids,
90
- )
91
- variant_sample_cache[variant] = samples[variant]
92
- return variant_sample_cache[variant]
93
-
94
  try:
95
- samples_a = _load_variant(selection.variant_a)
96
- samples_b = _load_variant(selection.variant_b)
 
 
 
 
 
 
97
  except Exception as exc:
98
  st.error(f"Could not load vectors: {exc}")
99
  return None
@@ -120,8 +113,8 @@ def _build_cosine_figures(
120
  pair_errors = []
121
  for left, right in combinations(selection.variants, 2):
122
  try:
123
- left_samples = _load_variant(left)
124
- right_samples = _load_variant(right)
125
  pair_traces.append(
126
  (
127
  f"{prompt_variant_label(left)} vs {prompt_variant_label(right)}",
@@ -207,7 +200,7 @@ def _render_cosine_similarity(
207
  _store_figure_state(cosine_fig_key, figures)
208
  progress.progress(100, text="Done.")
209
  finally:
210
- _release_vector_memory(store, selection.variants)
211
  progress.empty()
212
 
213
  if cosine_fig_key in st.session_state:
 
78
  mask_strategy: MaskStrategy,
79
  selection: CosineSelection,
80
  ) -> tuple[object, object | None, int, int] | None:
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  try:
82
+ by_variant = _load_variant_vectors(
83
+ store,
84
+ selection.variants,
85
+ mask_strategy,
86
+ persona_ids=selection.persona_ids,
87
+ )
88
+ samples_a = by_variant[selection.variant_a]
89
+ samples_b = by_variant[selection.variant_b]
90
  except Exception as exc:
91
  st.error(f"Could not load vectors: {exc}")
92
  return None
 
113
  pair_errors = []
114
  for left, right in combinations(selection.variants, 2):
115
  try:
116
+ left_samples = by_variant[left]
117
+ right_samples = by_variant[right]
118
  pair_traces.append(
119
  (
120
  f"{prompt_variant_label(left)} vs {prompt_variant_label(right)}",
 
200
  _store_figure_state(cosine_fig_key, figures)
201
  progress.progress(100, text="Done.")
202
  finally:
203
+ _release_vector_memory()
204
  progress.empty()
205
 
206
  if cosine_fig_key in st.session_state:
tabs/analysis/dendrogram.py CHANGED
@@ -13,7 +13,7 @@ from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
13
 
14
  from tabs.analysis._shared import (
15
  _load_persona_options,
16
- _load_persona_vectors,
17
  _plotly_chart,
18
  _release_vector_memory,
19
  _render_layer_frame_controls,
@@ -204,13 +204,14 @@ def _render_dendrogram_analysis(
204
  ):
205
  progress = st.progress(0, text="Loading first variant vectors…")
206
  try:
207
- progress.progress(15, text="Loading first variant vectors…")
208
- samples_a = _load_persona_vectors(
209
  store,
210
- variant_a,
211
  mask_strategy,
212
  persona_ids,
213
  )
 
214
  progress.progress(40, text="Building first dendrogram…")
215
  fig_a = plot_persona_dendrogram(
216
  samples_a,
@@ -223,13 +224,8 @@ def _render_dendrogram_analysis(
223
  del samples_a
224
  fig_b = None
225
  if variant_a != variant_b:
226
- progress.progress(60, text="Loading second variant vectors…")
227
- samples_b = _load_persona_vectors(
228
- store,
229
- variant_b,
230
- mask_strategy,
231
- persona_ids,
232
- )
233
  progress.progress(75, text="Building second dendrogram…")
234
  fig_b = plot_persona_dendrogram(
235
  samples_b,
@@ -250,7 +246,7 @@ def _render_dendrogram_analysis(
250
  st.error(f"Could not build dendrogram: {exc}")
251
  st.session_state.pop(fig_key, None)
252
  finally:
253
- _release_vector_memory(store, shared_variants)
254
  progress.empty()
255
 
256
  if fig_key in st.session_state:
 
13
 
14
  from tabs.analysis._shared import (
15
  _load_persona_options,
16
+ _load_variant_vectors,
17
  _plotly_chart,
18
  _release_vector_memory,
19
  _render_layer_frame_controls,
 
204
  ):
205
  progress = st.progress(0, text="Loading first variant vectors…")
206
  try:
207
+ progress.progress(15, text="Loading variant vectors…")
208
+ by_variant = _load_variant_vectors(
209
  store,
210
+ shared_variants,
211
  mask_strategy,
212
  persona_ids,
213
  )
214
+ samples_a = by_variant[variant_a]
215
  progress.progress(40, text="Building first dendrogram…")
216
  fig_a = plot_persona_dendrogram(
217
  samples_a,
 
224
  del samples_a
225
  fig_b = None
226
  if variant_a != variant_b:
227
+ progress.progress(60, text="Building second dendrogram…")
228
+ samples_b = by_variant[variant_b]
 
 
 
 
 
229
  progress.progress(75, text="Building second dendrogram…")
230
  fig_b = plot_persona_dendrogram(
231
  samples_b,
 
246
  st.error(f"Could not build dendrogram: {exc}")
247
  st.session_state.pop(fig_key, None)
248
  finally:
249
+ _release_vector_memory()
250
  progress.empty()
251
 
252
  if fig_key in st.session_state:
tabs/analysis/layered.py CHANGED
@@ -11,14 +11,19 @@ from persona_vectors.plots import (
11
  build_layered_figure,
12
  build_pair_similarity_figure,
13
  build_similarity_figures,
14
- prepare_layered_projection_data,
15
  )
16
 
17
  from utils.analysis_metadata import (
18
  synth_persona_attribute_names,
19
  synth_persona_dataset_cached,
20
  )
21
- from utils.analysis_sources import Store, store_id
 
 
 
 
 
 
22
  from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
23
 
24
  from tabs.analysis._shared import (
@@ -48,7 +53,7 @@ from tabs.analysis._state import (
48
  LayeredFigureStateKeys,
49
  ProjectionColorConfig,
50
  _clear_old_figure_states,
51
- _clear_old_projection_states,
52
  _highlight_persona_groups,
53
  _persona_display_label,
54
  _persona_names_state_key,
@@ -116,7 +121,7 @@ def _render_projection_color_config(
116
  key=color_mode_key,
117
  remember_key=_LAST_PROJECTION_COLOR_MODE_KEY,
118
  options=_PROJECTION_COLOR_MODES,
119
- default="Persona",
120
  )
121
  if color_mode == "K-means clusters":
122
  max_clusters = min(10, len(persona_ids))
@@ -265,36 +270,34 @@ def _layered_figure_state_keys(
265
  )
266
  if figure_kind not in _PROJECTION_KINDS:
267
  return LayeredFigureStateKeys(figure=figure_key)
268
-
269
- graph_overlay = figure_kind == "isomap"
270
- projection_key = widget_key(
271
  "load",
272
- f"{scope}_projection_state",
273
  store_id(store),
274
  store.model_name,
275
  mask_strategy.value,
276
  figure_kind,
277
  str(n_components),
278
- str(graph_overlay),
279
  str(_DEFAULT_GRAPH_NEIGHBORS),
280
  variant,
281
- "persona_vector",
282
  persona_key,
283
  layer_key,
284
  )
285
- return LayeredFigureStateKeys(figure=figure_key, projection=projection_key)
286
 
287
 
288
  def _projection_build_kwargs(
289
- samples,
290
  *,
 
 
 
291
  figure_kind: str,
292
  selected_layers: list[int],
293
  n_components: int,
294
  color_config: ProjectionColorConfig,
295
  persona_ids: list[str],
296
  persona_names: dict[str, str],
297
- projection_key: str | None,
298
  ) -> dict:
299
  if figure_kind not in _PROJECTION_KINDS:
300
  return {}
@@ -305,22 +308,29 @@ def _projection_build_kwargs(
305
  "graph_overlay": graph_overlay,
306
  "graph_n_neighbors": _DEFAULT_GRAPH_NEIGHBORS,
307
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  if color_config.n_clusters is not None:
309
- build_kwargs["n_clusters"] = color_config.n_clusters
310
- build_kwargs["cluster_mode"] = color_config.cluster_mode
311
- if projection_key is not None:
312
- projection_data = st.session_state.get(projection_key)
313
- if projection_data is None:
314
- projection_data = prepare_layered_projection_data(
315
- samples,
316
- figure_kind,
317
- layers=selected_layers,
318
- n_components=n_components,
319
- graph_overlay=graph_overlay,
320
- graph_n_neighbors=_DEFAULT_GRAPH_NEIGHBORS,
321
- )
322
- st.session_state[projection_key] = projection_data
323
- build_kwargs["projection_data"] = projection_data
324
  if color_config.attribute_name is not None:
325
  build_kwargs.update(
326
  attribute_color_kwargs(
@@ -487,8 +497,6 @@ def _render_layered_figure_analysis(
487
  selected_layers=selected_layers,
488
  pair_trajectories=pair_trajectories,
489
  )
490
- if state_keys.projection is not None:
491
- _clear_old_projection_states(state_keys.projection)
492
  filename = scope
493
  _clear_old_figure_states(state_keys.figure)
494
  persona_names = st.session_state.get(
@@ -496,7 +504,13 @@ def _render_layered_figure_analysis(
496
  {},
497
  )
498
 
499
- if st.button(button_label, type="primary"):
 
 
 
 
 
 
500
  build_label = {
501
  "umap": "Computing UMAP projections…",
502
  "pca": "Computing PCA projections…",
@@ -514,14 +528,15 @@ def _render_layered_figure_analysis(
514
  )
515
  progress.progress(55, text=build_label)
516
  build_kwargs = _projection_build_kwargs(
517
- samples,
 
 
518
  figure_kind=figure_kind,
519
  selected_layers=selected_layers,
520
  n_components=n_components,
521
  color_config=color_config,
522
  persona_ids=persona_ids,
523
  persona_names=persona_names,
524
- projection_key=state_keys.projection,
525
  )
526
  main_fig, extra_fig = _build_layered_analysis_figures(
527
  samples,
@@ -541,12 +556,15 @@ def _render_layered_figure_analysis(
541
  n_samples = samples.vectors.shape[0]
542
  del samples
543
  _store_figure_state(state_keys.figure, (main_fig, extra_fig, n_samples))
 
 
 
544
  progress.progress(100, text="Done.")
545
  except Exception as exc:
546
  st.error(f"Could not build figure: {exc}")
547
  st.session_state.pop(state_keys.figure, None)
548
  finally:
549
- _release_vector_memory(store, [variant])
550
  progress.empty()
551
 
552
  if state_keys.figure in st.session_state:
 
11
  build_layered_figure,
12
  build_pair_similarity_figure,
13
  build_similarity_figures,
 
14
  )
15
 
16
  from utils.analysis_metadata import (
17
  synth_persona_attribute_names,
18
  synth_persona_dataset_cached,
19
  )
20
+ from utils.analysis_sources import (
21
+ Store,
22
+ kmeans_groups_cached,
23
+ projection_data_cached,
24
+ store_cache_parts,
25
+ store_id,
26
+ )
27
  from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
28
 
29
  from tabs.analysis._shared import (
 
53
  LayeredFigureStateKeys,
54
  ProjectionColorConfig,
55
  _clear_old_figure_states,
56
+ _clear_old_prepared_states,
57
  _highlight_persona_groups,
58
  _persona_display_label,
59
  _persona_names_state_key,
 
121
  key=color_mode_key,
122
  remember_key=_LAST_PROJECTION_COLOR_MODE_KEY,
123
  options=_PROJECTION_COLOR_MODES,
124
+ default="Persona attribute",
125
  )
126
  if color_mode == "K-means clusters":
127
  max_clusters = min(10, len(persona_ids))
 
270
  )
271
  if figure_kind not in _PROJECTION_KINDS:
272
  return LayeredFigureStateKeys(figure=figure_key)
273
+ prepared_key = widget_key(
 
 
274
  "load",
275
+ f"{scope}_projection_ready",
276
  store_id(store),
277
  store.model_name,
278
  mask_strategy.value,
279
  figure_kind,
280
  str(n_components),
281
+ str(figure_kind == "isomap"),
282
  str(_DEFAULT_GRAPH_NEIGHBORS),
283
  variant,
 
284
  persona_key,
285
  layer_key,
286
  )
287
+ return LayeredFigureStateKeys(figure=figure_key, prepared=prepared_key)
288
 
289
 
290
  def _projection_build_kwargs(
 
291
  *,
292
+ store: Store,
293
+ mask_strategy: MaskStrategy,
294
+ variant: str,
295
  figure_kind: str,
296
  selected_layers: list[int],
297
  n_components: int,
298
  color_config: ProjectionColorConfig,
299
  persona_ids: list[str],
300
  persona_names: dict[str, str],
 
301
  ) -> dict:
302
  if figure_kind not in _PROJECTION_KINDS:
303
  return {}
 
308
  "graph_overlay": graph_overlay,
309
  "graph_n_neighbors": _DEFAULT_GRAPH_NEIGHBORS,
310
  }
311
+ source, location, model_name = store_cache_parts(store)
312
+ cache_args = (
313
+ source,
314
+ location,
315
+ model_name,
316
+ mask_strategy.value,
317
+ variant,
318
+ tuple(persona_ids),
319
+ tuple(selected_layers),
320
+ )
321
+ build_kwargs["projection_data"] = projection_data_cached(
322
+ *cache_args,
323
+ figure_kind,
324
+ n_components,
325
+ graph_overlay,
326
+ _DEFAULT_GRAPH_NEIGHBORS,
327
+ )
328
  if color_config.n_clusters is not None:
329
+ build_kwargs["groups"] = kmeans_groups_cached(
330
+ *cache_args,
331
+ color_config.n_clusters,
332
+ color_config.cluster_mode or "mean_across_layers",
333
+ )
 
 
 
 
 
 
 
 
 
 
334
  if color_config.attribute_name is not None:
335
  build_kwargs.update(
336
  attribute_color_kwargs(
 
497
  selected_layers=selected_layers,
498
  pair_trajectories=pair_trajectories,
499
  )
 
 
500
  filename = scope
501
  _clear_old_figure_states(state_keys.figure)
502
  persona_names = st.session_state.get(
 
504
  {},
505
  )
506
 
507
+ build_clicked = st.button(button_label, type="primary")
508
+ recolor_from_warm_projection = (
509
+ state_keys.prepared is not None
510
+ and bool(st.session_state.get(state_keys.prepared))
511
+ and state_keys.figure not in st.session_state
512
+ )
513
+ if build_clicked or recolor_from_warm_projection:
514
  build_label = {
515
  "umap": "Computing UMAP projections…",
516
  "pca": "Computing PCA projections…",
 
528
  )
529
  progress.progress(55, text=build_label)
530
  build_kwargs = _projection_build_kwargs(
531
+ store=store,
532
+ mask_strategy=mask_strategy,
533
+ variant=variant,
534
  figure_kind=figure_kind,
535
  selected_layers=selected_layers,
536
  n_components=n_components,
537
  color_config=color_config,
538
  persona_ids=persona_ids,
539
  persona_names=persona_names,
 
540
  )
541
  main_fig, extra_fig = _build_layered_analysis_figures(
542
  samples,
 
556
  n_samples = samples.vectors.shape[0]
557
  del samples
558
  _store_figure_state(state_keys.figure, (main_fig, extra_fig, n_samples))
559
+ if state_keys.prepared is not None:
560
+ _clear_old_prepared_states(state_keys.prepared)
561
+ st.session_state[state_keys.prepared] = True
562
  progress.progress(100, text="Done.")
563
  except Exception as exc:
564
  st.error(f"Could not build figure: {exc}")
565
  st.session_state.pop(state_keys.figure, None)
566
  finally:
567
+ _release_vector_memory()
568
  progress.empty()
569
 
570
  if state_keys.figure in st.session_state:
tabs/probe.py CHANGED
@@ -23,6 +23,7 @@ from persona_vectors.plots import plot_metric_comparison, plot_metric_over_layer
23
  from persona_vectors.probes import (
24
  AttributeLabels,
25
  attribute_probe_labels,
 
26
  filter_attribute_samples_min_count,
27
  infer_probe_task,
28
  layer_matrix,
@@ -85,8 +86,9 @@ class _SweepInputs:
85
  mask_value: str
86
  variant: str
87
  persona_ids: tuple[str, ...]
88
- attribute: str
89
  task: str
 
90
  n_pca_components: int | None
91
  layers: tuple[int, ...]
92
  min_class_count: int
@@ -234,22 +236,62 @@ def _select_personas(
234
  # ---------------------------------------------------------------------------
235
 
236
 
237
- def _select_attribute() -> str:
 
238
  dataset = synth_persona_dataset_cached()
239
- options = list(synth_persona_attribute_names())
240
- if "sex" in options:
241
- default_index = options.index("sex")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  else:
243
- default_index = 0
244
- return st.selectbox(
245
- "Attribute to probe",
 
246
  options=options,
247
- index=default_index,
248
  format_func=lambda name: attribute_display_label(dataset, name),
249
- key="probe:attribute",
 
 
250
  )
251
 
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  def _select_pca_components() -> int | None:
254
  use_pca = st.toggle(
255
  "Add PCA-compressed comparison",
@@ -298,61 +340,78 @@ def _select_layers(num_layers: int) -> list[int]:
298
  @st.cache_resource(show_spinner=False)
299
  def _cached_sweep(
300
  inputs: _SweepInputs,
301
- ) -> tuple[dict[str, list[dict[str, object]]], AttributeLabels, LayeredSamples]:
 
 
 
302
  samples = load_persona_vectors_cached(
303
  inputs.source, inputs.location, inputs.model_name,
304
  inputs.mask_value, inputs.variant, inputs.persona_ids,
305
  )
306
  dataset = synth_persona_dataset_cached()
307
- labels = attribute_probe_labels(
308
- dataset, inputs.attribute, list(inputs.persona_ids), task=inputs.task, # type: ignore[arg-type]
309
- )
310
- probe_samples, labels = filter_attribute_samples_min_count(
311
- samples, labels, min_count=inputs.min_class_count
312
- )
 
 
 
 
 
 
 
 
313
 
314
- def _sweep(n_pca: int | None) -> list[dict[str, object]]:
 
315
  return sweep_attribute(
316
  probe_samples, labels,
317
  layers=list(inputs.layers),
 
318
  n_pca_components=n_pca,
319
  seed=inputs.seed,
320
  )
321
 
 
 
 
 
 
 
322
  if inputs.n_pca_components is not None:
323
  # Always overlay the compressed sweep against full activations.
324
  rows_by_label = {
325
- "full": _sweep(None),
326
- f"pca{inputs.n_pca_components}": _sweep(inputs.n_pca_components),
327
  }
328
  else:
329
- rows_by_label = {"full": _sweep(None)}
330
- return rows_by_label, labels, probe_samples
331
 
332
 
333
  def _show_sweep(
334
  rows_by_label: dict[str, list[dict[str, object]]],
335
- labels: AttributeLabels,
336
- samples: LayeredSamples,
337
- attribute: str,
338
  task: str,
339
  inputs: _SweepInputs,
340
  ) -> None:
341
  primary = _PRIMARY_METRIC[task]
342
  secondary = _SECONDARY_METRIC.get(task)
343
 
344
- # Tolerate stale session state from a previous code version (bare rows).
345
- if isinstance(rows_by_label, list):
346
- rows_by_label = {"full": rows_by_label}
347
  primary_label = (
348
  f"pca{inputs.n_pca_components}" if inputs.n_pca_components else "full"
349
  )
350
  rows = rows_by_label.get(primary_label) or next(iter(rows_by_label.values()))
351
 
352
  def _plot(metric: str):
353
- if len(rows_by_label) > 1:
354
- return plot_metric_comparison(rows_by_label, attribute, metric=metric)
355
- return plot_metric_over_layers(rows, attribute, metric=metric)
 
 
356
 
357
  st.plotly_chart(_plot(primary), width="stretch")
358
  if secondary is not None:
@@ -377,21 +436,31 @@ def _show_sweep(
377
  if best is None:
378
  return
379
 
380
- if len(rows_by_label) > 1:
 
381
  summary_rows = []
382
  for label, label_rows in rows_by_label.items():
383
- label_best = _best_row(label_rows)
384
- if label_best is None:
385
- continue
386
- summary_rows.append({
387
- "features": label,
388
- "best_layer": label_best["layer"],
389
- "probe": label_best["probe_kind"],
390
- primary: round(float(label_best[primary]), 3),
391
- f"baseline_{primary}": round(
392
- float(label_best.get(f"baseline_{primary}", float("nan"))), 3
393
- ),
394
- })
 
 
 
 
 
 
 
 
 
395
  if summary_rows:
396
  st.dataframe(summary_rows, width="stretch", hide_index=True)
397
 
@@ -399,18 +468,26 @@ def _show_sweep(
399
  f" · pca{inputs.n_pca_components}" if inputs.n_pca_components else ""
400
  )
401
 
402
- cols = st.columns([1, 1.2, 1.8])
403
- cols[0].metric("Best layer", best["layer"])
404
- cols[1].metric(
405
- f"Best {primary}",
406
- f"{best[primary]:.3f}",
407
- delta=f"baseline {best.get(f'baseline_{primary}', float('nan')):.3f}",
408
- delta_color="off",
409
- )
410
- cols[2].metric("Probe", f"{best['probe_kind']}{feature_desc}")
 
 
 
 
 
 
 
 
411
 
412
  _render_selectivity_control(best, labels, samples, task, inputs)
413
- _render_save_artifact(best, labels, samples, attribute, task, inputs)
414
 
415
 
416
  def _render_selectivity_control(
@@ -461,7 +538,6 @@ def _render_save_artifact(
461
  best: dict[str, object],
462
  labels: AttributeLabels,
463
  samples: LayeredSamples,
464
- attribute: str,
465
  task: str,
466
  inputs: _SweepInputs,
467
  ) -> None:
@@ -540,12 +616,15 @@ def render_probing_tab() -> None:
540
  if not persona_ids:
541
  return
542
 
543
- dataset = synth_persona_dataset_cached()
544
  with st.expander("Probe configuration", expanded=True):
545
- attribute = _select_attribute()
546
- task = infer_probe_task(dataset, attribute)
 
 
 
547
  st.caption(f"Inferred task: **{task}**")
548
 
 
549
  n_pca_components = _select_pca_components()
550
 
551
  source, location, model_name = store_cache_parts(store)
@@ -563,17 +642,13 @@ def render_probing_tab() -> None:
563
  num_layers = max(available_layers) + 1
564
  layers = _select_layers(num_layers)
565
  min_class_count = _MIN_CLASS_COUNT
566
- seed = st.number_input(
567
- "Seed", min_value=0, max_value=10_000, value=0, step=1,
568
- key="probe:seed",
569
- help="Seeds the probe/PCA fit. The 80/20 split itself is fixed "
570
- "(random_state=0).",
571
- )
572
 
573
  inputs = _SweepInputs(
574
  source=source, location=location, model_name=model_name,
575
  mask_value=mask_strategy.value, variant=variant,
576
- persona_ids=tuple(persona_ids), attribute=attribute, task=task,
 
577
  n_pca_components=n_pca_components,
578
  layers=tuple(layers), min_class_count=min_class_count,
579
  seed=int(seed),
@@ -584,25 +659,21 @@ def render_probing_tab() -> None:
584
  if run:
585
  with st.spinner("Evaluating probes across layers..."):
586
  try:
587
- sweep, labels, probe_samples = _cached_sweep(inputs)
588
  except Exception as exc:
589
  st.error(f"Sweep failed: {exc}")
590
  st.session_state.pop(state_key, None)
591
  return
592
- st.session_state[state_key] = (
593
- sweep,
594
- labels,
595
- probe_samples,
596
- attribute,
597
- task,
598
- inputs,
599
- )
600
 
601
  if state_key in st.session_state:
602
  saved_result = st.session_state[state_key]
603
- if len(saved_result) == 5:
604
- sweep, labels, probe_samples, last_attr, last_task = saved_result
605
- result_inputs = inputs
606
  else:
607
- sweep, labels, probe_samples, last_attr, last_task, result_inputs = saved_result
608
- _show_sweep(sweep, labels, probe_samples, last_attr, last_task, result_inputs)
 
 
 
 
23
  from persona_vectors.probes import (
24
  AttributeLabels,
25
  attribute_probe_labels,
26
+ default_probe_kinds,
27
  filter_attribute_samples_min_count,
28
  infer_probe_task,
29
  layer_matrix,
 
86
  mask_value: str
87
  variant: str
88
  persona_ids: tuple[str, ...]
89
+ attributes: tuple[str, ...]
90
  task: str
91
+ probe_kinds: tuple[str, ...]
92
  n_pca_components: int | None
93
  layers: tuple[int, ...]
94
  min_class_count: int
 
236
  # ---------------------------------------------------------------------------
237
 
238
 
239
+ @st.cache_data(show_spinner=False)
240
+ def _attribute_tasks() -> dict[str, str]:
241
  dataset = synth_persona_dataset_cached()
242
+ return {
243
+ name: infer_probe_task(dataset, name)
244
+ for name in synth_persona_attribute_names()
245
+ }
246
+
247
+
248
+ def _select_attributes() -> list[str]:
249
+ """Multi-select locked to one task type.
250
+
251
+ Picking the first attribute fixes the task; only same-task attributes stay
252
+ selectable. Clearing the selection reopens every attribute again.
253
+ """
254
+ dataset = synth_persona_dataset_cached()
255
+ tasks = _attribute_tasks()
256
+ all_names = list(synth_persona_attribute_names())
257
+
258
+ key = "probe:attributes"
259
+ if key not in st.session_state:
260
+ st.session_state[key] = ["sex"] if "sex" in all_names else all_names[:1]
261
+
262
+ selected = st.session_state[key]
263
+ if selected:
264
+ locked = tasks[selected[0]]
265
+ options = [name for name in all_names if tasks[name] == locked]
266
  else:
267
+ options = all_names
268
+
269
+ return st.multiselect(
270
+ "Attributes to probe",
271
  options=options,
 
272
  format_func=lambda name: attribute_display_label(dataset, name),
273
+ key=key,
274
+ help="Pick one or more attributes of the same task type. They are "
275
+ "overlaid in one figure. Remove all to switch to a different task type.",
276
  )
277
 
278
 
279
+ def _select_probe_kinds(task: str) -> list[str]:
280
+ """Pick which probe families to fit. Only shown when the task has >1."""
281
+ available = list(default_probe_kinds(task)) # type: ignore[arg-type]
282
+ if len(available) < 2:
283
+ return available
284
+ selected = st.multiselect(
285
+ "Probe kinds to fit",
286
+ options=available,
287
+ default=available,
288
+ key=f"probe:kinds:{task}",
289
+ help="Which probe families to fit at each layer. Defaults to all "
290
+ "available for this task.",
291
+ )
292
+ return selected or available
293
+
294
+
295
  def _select_pca_components() -> int | None:
296
  use_pca = st.toggle(
297
  "Add PCA-compressed comparison",
 
340
  @st.cache_resource(show_spinner=False)
341
  def _cached_sweep(
342
  inputs: _SweepInputs,
343
+ ) -> tuple[
344
+ dict[str, list[dict[str, object]]],
345
+ dict[str, tuple[AttributeLabels, LayeredSamples]],
346
+ ]:
347
  samples = load_persona_vectors_cached(
348
  inputs.source, inputs.location, inputs.model_name,
349
  inputs.mask_value, inputs.variant, inputs.persona_ids,
350
  )
351
  dataset = synth_persona_dataset_cached()
352
+ # The min-count filter drops personas per attribute, so each attribute keeps
353
+ # its own (labels, samples) pair for the downstream selectivity/save tools.
354
+ per_attr: dict[str, tuple[AttributeLabels, LayeredSamples]] = {}
355
+
356
+ def _labels_and_samples(attribute: str) -> tuple[AttributeLabels, LayeredSamples]:
357
+ if attribute not in per_attr:
358
+ labels = attribute_probe_labels(
359
+ dataset, attribute, list(inputs.persona_ids), task=inputs.task, # type: ignore[arg-type]
360
+ )
361
+ probe_samples, labels = filter_attribute_samples_min_count(
362
+ samples, labels, min_count=inputs.min_class_count
363
+ )
364
+ per_attr[attribute] = (labels, probe_samples)
365
+ return per_attr[attribute]
366
 
367
+ def _sweep(attribute: str, n_pca: int | None) -> list[dict[str, object]]:
368
+ labels, probe_samples = _labels_and_samples(attribute)
369
  return sweep_attribute(
370
  probe_samples, labels,
371
  layers=list(inputs.layers),
372
+ probe_kinds=list(inputs.probe_kinds), # type: ignore[arg-type]
373
  n_pca_components=n_pca,
374
  seed=inputs.seed,
375
  )
376
 
377
+ def _sweep_all(n_pca: int | None) -> list[dict[str, object]]:
378
+ rows: list[dict[str, object]] = []
379
+ for attribute in inputs.attributes:
380
+ rows.extend(_sweep(attribute, n_pca))
381
+ return rows
382
+
383
  if inputs.n_pca_components is not None:
384
  # Always overlay the compressed sweep against full activations.
385
  rows_by_label = {
386
+ "full": _sweep_all(None),
387
+ f"pca{inputs.n_pca_components}": _sweep_all(inputs.n_pca_components),
388
  }
389
  else:
390
+ rows_by_label = {"full": _sweep_all(None)}
391
+ return rows_by_label, per_attr
392
 
393
 
394
  def _show_sweep(
395
  rows_by_label: dict[str, list[dict[str, object]]],
396
+ per_attr: dict[str, tuple[AttributeLabels, LayeredSamples]],
397
+ attributes: tuple[str, ...],
 
398
  task: str,
399
  inputs: _SweepInputs,
400
  ) -> None:
401
  primary = _PRIMARY_METRIC[task]
402
  secondary = _SECONDARY_METRIC.get(task)
403
 
 
 
 
404
  primary_label = (
405
  f"pca{inputs.n_pca_components}" if inputs.n_pca_components else "full"
406
  )
407
  rows = rows_by_label.get(primary_label) or next(iter(rows_by_label.values()))
408
 
409
  def _plot(metric: str):
410
+ if len(rows_by_label) > 1 or len(attributes) > 1:
411
+ return plot_metric_comparison(
412
+ rows_by_label, list(attributes), metric=metric
413
+ )
414
+ return plot_metric_over_layers(rows, attributes[0], metric=metric)
415
 
416
  st.plotly_chart(_plot(primary), width="stretch")
417
  if secondary is not None:
 
436
  if best is None:
437
  return
438
 
439
+ multi_attr = len(attributes) > 1
440
+ if len(rows_by_label) > 1 or multi_attr:
441
  summary_rows = []
442
  for label, label_rows in rows_by_label.items():
443
+ for attribute in attributes:
444
+ attr_rows = [
445
+ row for row in label_rows
446
+ if row.get("attribute") == attribute
447
+ ]
448
+ label_best = _best_row(attr_rows)
449
+ if label_best is None:
450
+ continue
451
+ summary_row: dict[str, object] = {}
452
+ if multi_attr:
453
+ summary_row["attribute"] = attribute
454
+ summary_row.update({
455
+ "features": label,
456
+ "best_layer": label_best["layer"],
457
+ "probe": label_best["probe_kind"],
458
+ primary: round(float(label_best[primary]), 3),
459
+ f"baseline_{primary}": round(
460
+ float(label_best.get(f"baseline_{primary}", float("nan"))), 3
461
+ ),
462
+ })
463
+ summary_rows.append(summary_row)
464
  if summary_rows:
465
  st.dataframe(summary_rows, width="stretch", hide_index=True)
466
 
 
468
  f" · pca{inputs.n_pca_components}" if inputs.n_pca_components else ""
469
  )
470
 
471
+ best_attr = str(best["attribute"])
472
+ labels, samples = per_attr[best_attr]
473
+ if multi_attr:
474
+ # The per-attribute summary table above already covers every result;
475
+ # a single "best" card would only show one attribute, so skip it and
476
+ # just say which one the controls below operate on.
477
+ st.caption(f"Controls below use the best result: **{best_attr}**.")
478
+ else:
479
+ cols = st.columns([1, 1.2, 1.8])
480
+ cols[0].metric("Best layer", best["layer"])
481
+ cols[1].metric(
482
+ f"Best {primary}",
483
+ f"{best[primary]:.3f}",
484
+ delta=f"baseline {best.get(f'baseline_{primary}', float('nan')):.3f}",
485
+ delta_color="off",
486
+ )
487
+ cols[2].metric("Probe", f"{best['probe_kind']}{feature_desc}")
488
 
489
  _render_selectivity_control(best, labels, samples, task, inputs)
490
+ _render_save_artifact(best, labels, samples, task, inputs)
491
 
492
 
493
  def _render_selectivity_control(
 
538
  best: dict[str, object],
539
  labels: AttributeLabels,
540
  samples: LayeredSamples,
 
541
  task: str,
542
  inputs: _SweepInputs,
543
  ) -> None:
 
616
  if not persona_ids:
617
  return
618
 
 
619
  with st.expander("Probe configuration", expanded=True):
620
+ attributes = _select_attributes()
621
+ if not attributes:
622
+ st.info("Select at least one attribute to probe.")
623
+ return
624
+ task = _attribute_tasks()[attributes[0]]
625
  st.caption(f"Inferred task: **{task}**")
626
 
627
+ probe_kinds = _select_probe_kinds(task)
628
  n_pca_components = _select_pca_components()
629
 
630
  source, location, model_name = store_cache_parts(store)
 
642
  num_layers = max(available_layers) + 1
643
  layers = _select_layers(num_layers)
644
  min_class_count = _MIN_CLASS_COUNT
645
+ seed = 0
 
 
 
 
 
646
 
647
  inputs = _SweepInputs(
648
  source=source, location=location, model_name=model_name,
649
  mask_value=mask_strategy.value, variant=variant,
650
+ persona_ids=tuple(persona_ids), attributes=tuple(attributes), task=task,
651
+ probe_kinds=tuple(probe_kinds),
652
  n_pca_components=n_pca_components,
653
  layers=tuple(layers), min_class_count=min_class_count,
654
  seed=int(seed),
 
659
  if run:
660
  with st.spinner("Evaluating probes across layers..."):
661
  try:
662
+ sweep, per_attr = _cached_sweep(inputs)
663
  except Exception as exc:
664
  st.error(f"Sweep failed: {exc}")
665
  st.session_state.pop(state_key, None)
666
  return
667
+ st.session_state[state_key] = (sweep, per_attr, inputs)
 
 
 
 
 
 
 
668
 
669
  if state_key in st.session_state:
670
  saved_result = st.session_state[state_key]
671
+ if len(saved_result) != 3:
672
+ # Stale shape from a previous code version — drop it.
673
+ st.session_state.pop(state_key, None)
674
  else:
675
+ sweep, per_attr, result_inputs = saved_result
676
+ _show_sweep(
677
+ sweep, per_attr, result_inputs.attributes,
678
+ result_inputs.task, result_inputs,
679
+ )
tests/test_probes.py CHANGED
@@ -12,9 +12,11 @@ two correctness fixes:
12
  import pytest
13
  import torch
14
 
 
15
  from utils.probes import (
16
  LoadedProbe,
17
  _LinearProbe,
 
18
  _normalize_labels,
19
  parse_probe_filename,
20
  )
@@ -196,3 +198,33 @@ def test_run_single_output_predicts_negative_when_score_low():
196
  result = probe.run(torch.tensor([1.0, 1.0]))
197
  assert result.predicted_index == 0
198
  assert result.predicted_label == "neg"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  import pytest
13
  import torch
14
 
15
+ from persona_vectors.probes import ProbeArtifact
16
  from utils.probes import (
17
  LoadedProbe,
18
  _LinearProbe,
19
+ _loaded_probe_from_artifact,
20
  _normalize_labels,
21
  parse_probe_filename,
22
  )
 
198
  result = probe.run(torch.tensor([1.0, 1.0]))
199
  assert result.predicted_index == 0
200
  assert result.predicted_label == "neg"
201
+
202
+
203
+ # --------------------------------------------------------------------------- #
204
+ # canonical persona-vectors artifacts
205
+ # --------------------------------------------------------------------------- #
206
+
207
+
208
+ def test_loaded_probe_from_canonical_artifact():
209
+ artifact = ProbeArtifact(
210
+ metadata={
211
+ "schema_version": 2,
212
+ "input_dim": 2,
213
+ "artifact_feature_dim": 2,
214
+ "class_names": ["neg", "pos"],
215
+ "task": "binary",
216
+ "probe_kind": "logistic_regression",
217
+ "layer": 3,
218
+ },
219
+ tensors={
220
+ "weight": torch.tensor([[-1.0, 0.0], [1.0, 0.0]]),
221
+ "bias": torch.zeros(2),
222
+ },
223
+ )
224
+ probe = _loaded_probe_from_artifact(
225
+ filename="m/answer_mean/templated/sex/logistic_regression_layer3/probe.json",
226
+ artifact=artifact,
227
+ )
228
+ assert probe.labels == ["neg", "pos"]
229
+ assert probe.layer == 3
230
+ assert probe.run(torch.tensor([1.0, 0.0])).predicted_label == "pos"
utils/analysis_sources.py CHANGED
@@ -1,7 +1,11 @@
1
  import os
2
 
3
  import streamlit as st
4
- from persona_vectors.analysis import LayeredSamples, load_persona_vectors
 
 
 
 
5
  from persona_vectors.artifacts import (
6
  PersonaVectorStore,
7
  HFPersonaVectorStore,
@@ -10,6 +14,11 @@ from persona_vectors.artifacts import (
10
  )
11
  from persona_vectors.extraction import MaskStrategy
12
  from persona_vectors.hub import list_hub_vector_models
 
 
 
 
 
13
 
14
  from utils.helpers import env_int
15
 
@@ -26,7 +35,8 @@ SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
26
 
27
 
28
  _STORE_CACHE_ENTRIES = env_int("PERSONA_UI_STORE_CACHE_ENTRIES", 4)
29
- _VECTOR_CACHE_ENTRIES = env_int("PERSONA_UI_VECTOR_CACHE_ENTRIES", 2)
 
30
 
31
 
32
  @st.cache_resource(show_spinner=False, max_entries=_STORE_CACHE_ENTRIES)
@@ -137,23 +147,41 @@ def local_model_matches(left: str, right: str) -> bool:
137
 
138
 
139
  @st.cache_resource(show_spinner=False, max_entries=_VECTOR_CACHE_ENTRIES)
140
- def load_persona_vectors_cached(
141
  source: str,
142
  location: str,
143
  model_name: str,
144
  mask_strategy_value: str,
145
- variant: str,
146
  persona_ids: tuple[str, ...],
147
- ) -> LayeredSamples:
148
  store = activation_store_cached(source, location, model_name, mask_strategy_value)
149
- return load_persona_vectors(
150
  store,
151
- variant,
152
  mask_strategy=MaskStrategy(mask_strategy_value),
153
- persona_ids=list(persona_ids),
154
  )
155
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def load_variant_vectors_cached(
158
  source: str,
159
  location: str,
@@ -162,12 +190,64 @@ def load_variant_vectors_cached(
162
  variants: tuple[str, ...],
163
  persona_ids: tuple[str, ...],
164
  ) -> dict[str, LayeredSamples]:
165
- return {
166
- variant: load_persona_vectors_cached(
167
- source, location, model_name, mask_strategy_value, variant, persona_ids
168
- )
169
- for variant in variants
170
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
 
173
  def prefetch_hub_metadata(
@@ -194,13 +274,3 @@ def prefetch_hub_metadata(
194
  mask_strategy_value,
195
  (variant,),
196
  )
197
-
198
-
199
- def release_hf_store_cache(
200
- store: Store,
201
- variants: list[str] | tuple[str, ...] | None = None,
202
- ) -> None:
203
- """Drop cached HF data for ``variants`` (or all) on Hub stores."""
204
- release_cache = getattr(store, "release_cache", None)
205
- if isinstance(store, HFPersonaVectorStore) and callable(release_cache):
206
- release_cache(variants)
 
1
  import os
2
 
3
  import streamlit as st
4
+ from persona_vectors.analysis import (
5
+ AnalysisDataset,
6
+ LayeredSamples,
7
+ load_analysis_dataset,
8
+ )
9
  from persona_vectors.artifacts import (
10
  PersonaVectorStore,
11
  HFPersonaVectorStore,
 
14
  )
15
  from persona_vectors.extraction import MaskStrategy
16
  from persona_vectors.hub import list_hub_vector_models
17
+ from persona_vectors.plots import (
18
+ LayeredProjectionData,
19
+ prepare_kmeans_groups,
20
+ prepare_layered_projection_data,
21
+ )
22
 
23
  from utils.helpers import env_int
24
 
 
35
 
36
 
37
  _STORE_CACHE_ENTRIES = env_int("PERSONA_UI_STORE_CACHE_ENTRIES", 4)
38
+ _VECTOR_CACHE_ENTRIES = env_int("PERSONA_UI_VECTOR_CACHE_ENTRIES", 4)
39
+ _PREPARED_CACHE_ENTRIES = env_int("PERSONA_UI_PREPARED_CACHE_ENTRIES", 8)
40
 
41
 
42
  @st.cache_resource(show_spinner=False, max_entries=_STORE_CACHE_ENTRIES)
 
147
 
148
 
149
  @st.cache_resource(show_spinner=False, max_entries=_VECTOR_CACHE_ENTRIES)
150
+ def load_analysis_dataset_cached(
151
  source: str,
152
  location: str,
153
  model_name: str,
154
  mask_strategy_value: str,
155
+ variants: tuple[str, ...],
156
  persona_ids: tuple[str, ...],
157
+ ) -> AnalysisDataset:
158
  store = activation_store_cached(source, location, model_name, mask_strategy_value)
159
+ return load_analysis_dataset(
160
  store,
161
+ variants,
162
  mask_strategy=MaskStrategy(mask_strategy_value),
163
+ persona_ids=persona_ids,
164
  )
165
 
166
 
167
+ def load_persona_vectors_cached(
168
+ source: str,
169
+ location: str,
170
+ model_name: str,
171
+ mask_strategy_value: str,
172
+ variant: str,
173
+ persona_ids: tuple[str, ...],
174
+ ) -> LayeredSamples:
175
+ return load_analysis_dataset_cached(
176
+ source,
177
+ location,
178
+ model_name,
179
+ mask_strategy_value,
180
+ (variant,),
181
+ persona_ids,
182
+ ).samples(variant)
183
+
184
+
185
  def load_variant_vectors_cached(
186
  source: str,
187
  location: str,
 
190
  variants: tuple[str, ...],
191
  persona_ids: tuple[str, ...],
192
  ) -> dict[str, LayeredSamples]:
193
+ return load_analysis_dataset_cached(
194
+ source,
195
+ location,
196
+ model_name,
197
+ mask_strategy_value,
198
+ variants,
199
+ persona_ids,
200
+ ).samples_by_variant
201
+
202
+
203
+ @st.cache_resource(show_spinner=False, max_entries=_PREPARED_CACHE_ENTRIES)
204
+ def projection_data_cached(
205
+ source: str,
206
+ location: str,
207
+ model_name: str,
208
+ mask_strategy_value: str,
209
+ variant: str,
210
+ persona_ids: tuple[str, ...],
211
+ layers: tuple[int, ...],
212
+ kind: str,
213
+ n_components: int,
214
+ graph_overlay: bool,
215
+ graph_n_neighbors: int,
216
+ ) -> LayeredProjectionData:
217
+ samples = load_persona_vectors_cached(
218
+ source, location, model_name, mask_strategy_value, variant, persona_ids
219
+ )
220
+ return prepare_layered_projection_data(
221
+ samples,
222
+ kind,
223
+ layers=list(layers),
224
+ n_components=n_components,
225
+ graph_overlay=graph_overlay,
226
+ graph_n_neighbors=graph_n_neighbors,
227
+ )
228
+
229
+
230
+ @st.cache_resource(show_spinner=False, max_entries=_PREPARED_CACHE_ENTRIES)
231
+ def kmeans_groups_cached(
232
+ source: str,
233
+ location: str,
234
+ model_name: str,
235
+ mask_strategy_value: str,
236
+ variant: str,
237
+ persona_ids: tuple[str, ...],
238
+ layers: tuple[int, ...],
239
+ n_clusters: int,
240
+ cluster_mode: str,
241
+ ) -> list[str] | dict[int, list[str]]:
242
+ samples = load_persona_vectors_cached(
243
+ source, location, model_name, mask_strategy_value, variant, persona_ids
244
+ )
245
+ return prepare_kmeans_groups(
246
+ samples,
247
+ layers=list(layers),
248
+ n_clusters=n_clusters,
249
+ cluster_mode=cluster_mode,
250
+ )
251
 
252
 
253
  def prefetch_hub_metadata(
 
274
  mask_strategy_value,
275
  (variant,),
276
  )
 
 
 
 
 
 
 
 
 
 
utils/probes.py CHANGED
@@ -1,7 +1,6 @@
1
  from __future__ import annotations
2
 
3
  import io
4
- import json
5
  import os
6
  import re
7
  from dataclasses import dataclass
@@ -12,6 +11,7 @@ import streamlit as st
12
  import torch
13
  import torch.nn as nn
14
  import torch.nn.functional as F
 
15
 
16
  PROBE_FILENAME_RE = re.compile(
17
  r"^cognitive_map_probe_layer(?P<layer>\d+)_(?P<model_type>[a-z0-9]+)_"
@@ -457,14 +457,19 @@ def _load_persona_probe_artifact(
457
  metadata_path: Path,
458
  weights_path: Path,
459
  ) -> LoadedProbe:
460
- if not metadata_path.is_file():
461
- raise FileNotFoundError(f"Missing probe metadata file: {metadata_path}")
462
- if not weights_path.is_file():
463
- raise FileNotFoundError(f"Missing probe weights file: {weights_path}")
464
- from safetensors.torch import load_file
465
-
466
- metadata = json.loads(metadata_path.read_text())
467
- tensors = load_file(str(weights_path), device="cpu")
 
 
 
 
 
468
  payload = {
469
  **metadata,
470
  "model_type": "linear",
 
1
  from __future__ import annotations
2
 
3
  import io
 
4
  import os
5
  import re
6
  from dataclasses import dataclass
 
11
  import torch
12
  import torch.nn as nn
13
  import torch.nn.functional as F
14
+ from persona_vectors.probes import ProbeArtifact, load_probe_artifact
15
 
16
  PROBE_FILENAME_RE = re.compile(
17
  r"^cognitive_map_probe_layer(?P<layer>\d+)_(?P<model_type>[a-z0-9]+)_"
 
457
  metadata_path: Path,
458
  weights_path: Path,
459
  ) -> LoadedProbe:
460
+ if metadata_path.parent != weights_path.parent:
461
+ raise ValueError("Canonical probe files must share one artifact directory.")
462
+ artifact = load_probe_artifact(metadata_path)
463
+ return _loaded_probe_from_artifact(filename=filename, artifact=artifact)
464
+
465
+
466
+ def _loaded_probe_from_artifact(
467
+ *,
468
+ filename: str,
469
+ artifact: ProbeArtifact,
470
+ ) -> LoadedProbe:
471
+ metadata = artifact.metadata
472
+ tensors = artifact.tensors
473
  payload = {
474
  **metadata,
475
  "model_type": "linear",
uv.lock CHANGED
@@ -1608,7 +1608,7 @@ requires-dist = [
1608
  { name = "catppuccin", specifier = ">=2.5.0" },
1609
  { name = "datasets", specifier = ">=4.8.5" },
1610
  { name = "huggingface-hub", specifier = ">=1.14.0" },
1611
- { name = "persona-vectors", specifier = ">=0.8.2" },
1612
  { name = "plotly", specifier = ">=6.6.0" },
1613
  { name = "python-dotenv", specifier = ">=1.2.2" },
1614
  { name = "safetensors", specifier = ">=0.7.0" },
@@ -1620,7 +1620,7 @@ dev = [{ name = "pytest", specifier = ">=9.0.3" }]
1620
 
1621
  [[package]]
1622
  name = "persona-vectors"
1623
- version = "0.8.2"
1624
  source = { registry = "https://pypi.org/simple" }
1625
  dependencies = [
1626
  { name = "datasets" },
@@ -1639,9 +1639,9 @@ dependencies = [
1639
  { name = "transformers" },
1640
  { name = "umap-learn" },
1641
  ]
1642
- sdist = { url = "https://files.pythonhosted.org/packages/ef/f4/66d2a1e30ed814a1ea945e27e2f9241cd7374872575e4d4c9e602a92a1cc/persona_vectors-0.8.2.tar.gz", hash = "sha256:f5b0776f8adbdfd38b9ad0f097daf88abb4c5dc504b3d3620af3f392e4a4621d", size = 42138, upload-time = "2026-05-16T22:11:53.019Z" }
1643
  wheels = [
1644
- { url = "https://files.pythonhosted.org/packages/15/02/7af86ed4040c4f866705a7ec28b50ebaa570502c3b74465fd9282856b2b7/persona_vectors-0.8.2-py3-none-any.whl", hash = "sha256:6bfa374e86d5cefc009cea07a8b43cc98d710e508d8f3e3394c24483d342799b", size = 52033, upload-time = "2026-05-16T22:11:54.128Z" },
1645
  ]
1646
 
1647
  [[package]]
 
1608
  { name = "catppuccin", specifier = ">=2.5.0" },
1609
  { name = "datasets", specifier = ">=4.8.5" },
1610
  { name = "huggingface-hub", specifier = ">=1.14.0" },
1611
+ { name = "persona-vectors", specifier = ">=0.8.3" },
1612
  { name = "plotly", specifier = ">=6.6.0" },
1613
  { name = "python-dotenv", specifier = ">=1.2.2" },
1614
  { name = "safetensors", specifier = ">=0.7.0" },
 
1620
 
1621
  [[package]]
1622
  name = "persona-vectors"
1623
+ version = "0.8.3"
1624
  source = { registry = "https://pypi.org/simple" }
1625
  dependencies = [
1626
  { name = "datasets" },
 
1639
  { name = "transformers" },
1640
  { name = "umap-learn" },
1641
  ]
1642
+ sdist = { url = "https://files.pythonhosted.org/packages/c0/1d/472284f43e2a276a035e9e3de08a92654945193699598def6d6a2aa74c96/persona_vectors-0.8.3.tar.gz", hash = "sha256:f0519846b3712865bd2562cd239df05ddd006ac3d1e73e5ec5a6c860aaed5b2e", size = 43146, upload-time = "2026-05-17T12:43:13.601Z" }
1643
  wheels = [
1644
+ { url = "https://files.pythonhosted.org/packages/60/d1/a38dc354718310122cd5d3de63e3aa9060490c8db4c2eadb1d4985684796/persona_vectors-0.8.3-py3-none-any.whl", hash = "sha256:2feeaf45b071ed417d88add48a1012455c8027e4f839e99658a9808c26786b8a", size = 53129, upload-time = "2026-05-17T12:43:12.693Z" },
1645
  ]
1646
 
1647
  [[package]]