Jac-Zac commited on
Commit
4df7d97
·
1 Parent(s): 1b16c40

Performance speedup

Browse files
Files changed (2) hide show
  1. tabs/compare.py +206 -31
  2. utils/compare_sources.py +181 -0
tabs/compare.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from collections.abc import Callable
2
  from dataclasses import dataclass
3
  from itertools import combinations
@@ -7,7 +8,6 @@ import plotly.graph_objects as go
7
  import streamlit as st
8
  from persona_data.environment import get_artifacts_dir
9
  from persona_data.synth_persona import BASELINE_PERSONA_ID
10
- from persona_vectors.analysis import load_persona_vectors, load_variant_vectors
11
  from persona_vectors.extraction import MaskStrategy
12
  from persona_vectors.plots import (
13
  build_layered_figure,
@@ -28,10 +28,13 @@ from utils.compare_sources import (
28
  activation_store_cached,
29
  available_variants,
30
  hub_models_by_mask_strategy,
 
 
31
  local_model_matches,
32
  local_model_options_cached,
33
  persona_names_cached,
34
  personas_cached,
 
35
  store_cache_parts,
36
  store_id,
37
  store_layers_cached,
@@ -56,9 +59,20 @@ def _filename(*parts: str) -> str:
56
  # overwrite cosine similarity defaults.
57
  _LAST_COSINE_PERSONAS_KEY = "compare:last_personas:cosine"
58
  _LAST_PROJECTION_PERSONAS_KEY = "compare:last_personas:projection"
 
59
  _LAST_MASK_STRATEGY_KEY = "compare:last_mask_strategy"
60
  _LAST_SOURCE_KEY = "compare:last_source"
61
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def _is_assistant_persona(persona_id: str, persona_name: str | None = None) -> bool:
64
  persona_id_normalized = persona_id.strip().lower()
@@ -101,6 +115,92 @@ def _layers_for_variant(
101
  )
102
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def _load_persona_options(
105
  store: Store,
106
  variants: list[str],
@@ -156,6 +256,7 @@ def _seed_persona_memory(
156
  options: PersonaOptions,
157
  *,
158
  default_all: bool,
 
159
  ) -> tuple[int, bool]:
160
  remembered_count_key = f"{remember_key}:count"
161
  remembered_assistant_key = f"{remember_key}:include_assistant"
@@ -170,9 +271,12 @@ def _seed_persona_memory(
170
  options.assistant_id in legacy_ids,
171
  )
172
 
173
- default_count = (
174
- len(options.regular_ids) if default_all else min(1, len(options.regular_ids))
175
- )
 
 
 
176
  remembered_count = int(st.session_state.get(remembered_count_key, default_count))
177
  persona_count = min(max(remembered_count, 0), len(options.regular_ids))
178
  include_assistant = bool(st.session_state.get(remembered_assistant_key, False))
@@ -236,6 +340,7 @@ def _select_artifact_personas(
236
  widget_scope: str,
237
  remember_key: str,
238
  default_all: bool = False,
 
239
  ) -> list[str]:
240
  empty_message = (
241
  "No personas have vectors for all selected variants. "
@@ -256,6 +361,7 @@ def _select_artifact_personas(
256
  remember_key,
257
  options,
258
  default_all=default_all,
 
259
  )
260
  persona_count, include_assistant = _render_persona_count_controls(
261
  store,
@@ -376,15 +482,17 @@ def _render_cosine_selection(
376
 
377
  def _build_cosine_figures(
378
  store: Store,
 
379
  selection: CosineSelection,
380
  ) -> tuple[object, object | None, int, int] | None:
381
  variant_sample_cache: dict[str, object] = {}
382
 
383
  def _load_variant(variant: str):
384
  if variant not in variant_sample_cache:
385
- samples = load_variant_vectors(
386
  store,
387
  [variant],
 
388
  persona_ids=selection.persona_ids,
389
  )
390
  variant_sample_cache[variant] = samples[variant]
@@ -479,6 +587,7 @@ def _render_cosine_similarity(
479
  mask_strategy.value,
480
  "_".join(selection.variants),
481
  )
 
482
 
483
  if st.button(
484
  "Compare vectors",
@@ -497,14 +606,15 @@ def _render_cosine_similarity(
497
  progress = st.progress(0, text="Loading activation vectors…")
498
  try:
499
  progress.progress(15, text="Loading activation vectors…")
500
- figures = _build_cosine_figures(store, selection)
501
  if figures is None:
502
  st.session_state.pop(cosine_fig_key, None)
503
  return
504
  progress.progress(90, text="Storing figure state…")
505
- st.session_state[cosine_fig_key] = figures
506
  progress.progress(100, text="Done.")
507
  finally:
 
508
  progress.empty()
509
 
510
  if cosine_fig_key in st.session_state:
@@ -527,6 +637,9 @@ def _select_single_variant_samples(
527
  store: Store,
528
  mask_strategy: MaskStrategy,
529
  scope: str,
 
 
 
530
  ) -> tuple[str, list[str], str, list[int]] | None:
531
  variants = available_variants(store, mask_strategy)
532
  if not variants:
@@ -544,8 +657,8 @@ def _select_single_variant_samples(
544
  [variant],
545
  mask_strategy,
546
  widget_scope=f"{scope}:{store_id(store)}",
547
- remember_key=_LAST_PROJECTION_PERSONAS_KEY,
548
- default_all=True,
549
  )
550
  if not persona_ids:
551
  return None
@@ -556,8 +669,8 @@ def _select_single_variant_samples(
556
  st.info("No shared layers are available for the selected personas.")
557
  return None
558
 
559
- st.caption(f"Using all {len(layer_options)} available layer(s).")
560
- return variant, persona_ids, persona_key, layer_options
561
 
562
 
563
  def _render_layered_figure_analysis(
@@ -570,17 +683,50 @@ def _render_layered_figure_analysis(
570
  title_fn: Callable[[str], str],
571
  include_pair_trajectories: bool = False,
572
  n_components: int = 2,
 
 
573
  ) -> None:
574
  """Render a single-variant layered analysis: select → button → figure(s).
575
 
576
  Used for similarity matrix, PCA, and UMAP. Set ``include_pair_trajectories``
577
  to add the pair-similarity-trajectory figure (similarity matrix only).
578
  """
579
- selected = _select_single_variant_samples(store, mask_strategy, scope)
 
 
 
 
 
 
580
  if selected is None:
581
  return
582
  variant, persona_ids, persona_key, selected_layers = selected
583
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
  n_clusters = None
585
  if figure_kind in {"pca", "umap"}:
586
  use_kmeans = st.toggle(
@@ -610,8 +756,11 @@ def _render_layered_figure_analysis(
610
  variant,
611
  "persona_vector",
612
  persona_key,
 
 
613
  )
614
  filename = scope
 
615
 
616
  if st.button(button_label, type="primary"):
617
  build_label = {
@@ -622,11 +771,11 @@ def _render_layered_figure_analysis(
622
  progress = st.progress(0, text="Loading activation vectors…")
623
  try:
624
  progress.progress(15, text="Loading activation vectors…")
625
- samples = load_persona_vectors(
626
  store,
627
  variant,
628
- mask_strategy=mask_strategy,
629
- persona_ids=persona_ids,
630
  )
631
  progress.progress(55, text=build_label)
632
  build_kwargs = {}
@@ -634,7 +783,7 @@ def _render_layered_figure_analysis(
634
  build_kwargs["n_components"] = n_components
635
  if n_clusters is not None:
636
  build_kwargs["n_clusters"] = n_clusters
637
- if figure_kind == "similarity" and include_pair_trajectories:
638
  main_fig, extra_fig = build_similarity_figures(
639
  samples,
640
  layers=selected_layers,
@@ -663,16 +812,19 @@ def _render_layered_figure_analysis(
663
  f"{prompt_variant_label(variant)} - persona vectors"
664
  ),
665
  )
666
- if include_pair_trajectories
667
  else None
668
  )
669
  progress.progress(90, text="Storing figure state…")
670
- st.session_state[fig_key] = (main_fig, extra_fig, samples.vectors.shape[0])
 
 
671
  progress.progress(100, text="Done.")
672
  except Exception as exc:
673
  st.error(f"Could not build figure: {exc}")
674
  st.session_state.pop(fig_key, None)
675
  finally:
 
676
  progress.empty()
677
 
678
  if fig_key in st.session_state:
@@ -734,7 +886,7 @@ def _render_dendrogram_analysis(
734
  mask_strategy,
735
  widget_scope=f"dendro:{store_id(store)}",
736
  remember_key=_LAST_DENDRO_PERSONAS_KEY,
737
- default_all=True,
738
  )
739
  if not persona_ids:
740
  return
@@ -755,6 +907,22 @@ def _render_dendrogram_analysis(
755
  key=widget_key("load", "dendro_linkage", store_id(store)),
756
  )
757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758
  persona_key = personas_fingerprint(persona_ids)
759
  fig_key = widget_key(
760
  "load",
@@ -767,7 +935,9 @@ def _render_dendrogram_analysis(
767
  persona_key,
768
  str(layered_mode),
769
  linkage,
 
770
  )
 
771
 
772
  if st.button(
773
  "Generate dendrograms",
@@ -779,50 +949,52 @@ def _render_dendrogram_analysis(
779
  progress = st.progress(0, text="Loading first variant vectors…")
780
  try:
781
  progress.progress(15, text="Loading first variant vectors…")
782
- samples_a = load_persona_vectors(
783
  store,
784
  variant_a,
785
- mask_strategy=mask_strategy,
786
- persona_ids=persona_ids,
787
  )
788
  progress.progress(40, text="Building first dendrogram…")
789
  fig_a = plot_persona_dendrogram(
790
  samples_a,
791
  layered=layered_mode,
 
792
  linkage=linkage,
793
  title=f"Dendrogram — {prompt_variant_label(variant_a)}",
794
  )
795
  fig_a.update_layout(height=750)
 
796
  fig_b = None
797
  if variant_a != variant_b:
798
  progress.progress(60, text="Loading second variant vectors…")
799
- samples_b = load_persona_vectors(
800
  store,
801
  variant_b,
802
- mask_strategy=mask_strategy,
803
- persona_ids=persona_ids,
804
  )
805
  progress.progress(75, text="Building second dendrogram…")
806
  fig_b = plot_persona_dendrogram(
807
  samples_b,
808
  layered=layered_mode,
 
809
  linkage=linkage,
810
  title=f"Dendrogram — {prompt_variant_label(variant_b)}",
811
  )
812
  fig_b.update_layout(height=750)
 
813
  progress.progress(90, text="Storing figure state…")
814
- st.session_state[fig_key] = (
815
- fig_a,
816
- fig_b,
817
- len(persona_ids),
818
- variant_a,
819
- variant_b,
820
  )
821
  progress.progress(100, text="Done.")
822
  except Exception as exc:
823
  st.error(f"Could not build dendrogram: {exc}")
824
  st.session_state.pop(fig_key, None)
825
  finally:
 
826
  progress.empty()
827
 
828
  if fig_key in st.session_state:
@@ -1033,6 +1205,8 @@ def render_compare_tab() -> None:
1033
  f"Centered similarity - {prompt_variant_label(v)} - persona vectors"
1034
  ),
1035
  include_pair_trajectories=True,
 
 
1036
  )
1037
  return
1038
 
@@ -1059,4 +1233,5 @@ def render_compare_tab() -> None:
1059
  f"{analysis_mode}{dim_suffix} - {prompt_variant_label(v)} - persona vectors"
1060
  ),
1061
  n_components=n_components,
 
1062
  )
 
1
+ import gc
2
  from collections.abc import Callable
3
  from dataclasses import dataclass
4
  from itertools import combinations
 
8
  import streamlit as st
9
  from persona_data.environment import get_artifacts_dir
10
  from persona_data.synth_persona import BASELINE_PERSONA_ID
 
11
  from persona_vectors.extraction import MaskStrategy
12
  from persona_vectors.plots import (
13
  build_layered_figure,
 
28
  activation_store_cached,
29
  available_variants,
30
  hub_models_by_mask_strategy,
31
+ load_persona_vectors_lean,
32
+ load_variant_vectors_lean,
33
  local_model_matches,
34
  local_model_options_cached,
35
  persona_names_cached,
36
  personas_cached,
37
+ release_store_cache,
38
  store_cache_parts,
39
  store_id,
40
  store_layers_cached,
 
59
  # overwrite cosine similarity defaults.
60
  _LAST_COSINE_PERSONAS_KEY = "compare:last_personas:cosine"
61
  _LAST_PROJECTION_PERSONAS_KEY = "compare:last_personas:projection"
62
+ _LAST_SIMILARITY_PERSONAS_KEY = "compare:last_personas:similarity"
63
  _LAST_MASK_STRATEGY_KEY = "compare:last_mask_strategy"
64
  _LAST_SOURCE_KEY = "compare:last_source"
65
 
66
+ _DEFAULT_LAYER_FRAMES = 16
67
+ _DEFAULT_PERSONA_LIMITS = {
68
+ "similarity": 120,
69
+ "pca": 500,
70
+ "umap": 500,
71
+ "dendro": 160,
72
+ }
73
+ _MAX_SIMILARITY_CELLS = 4_000_000
74
+ _MAX_PAIR_TRAJECTORY_TRACES = 500
75
+
76
 
77
  def _is_assistant_persona(persona_id: str, persona_name: str | None = None) -> bool:
78
  persona_id_normalized = persona_id.strip().lower()
 
115
  )
116
 
117
 
118
+ def _load_persona_vectors(
119
+ store: Store,
120
+ variant: str,
121
+ mask_strategy: MaskStrategy,
122
+ persona_ids: list[str],
123
+ ):
124
+ source, location, model_name = store_cache_parts(store)
125
+ return load_persona_vectors_lean(
126
+ source,
127
+ location,
128
+ model_name,
129
+ mask_strategy.value,
130
+ variant,
131
+ tuple(persona_ids),
132
+ )
133
+
134
+
135
+ def _load_variant_vectors(
136
+ store: Store,
137
+ variants: list[str] | tuple[str, ...],
138
+ mask_strategy: MaskStrategy,
139
+ persona_ids: list[str],
140
+ ):
141
+ source, location, model_name = store_cache_parts(store)
142
+ return load_variant_vectors_lean(
143
+ source,
144
+ location,
145
+ model_name,
146
+ mask_strategy.value,
147
+ tuple(variants),
148
+ tuple(persona_ids),
149
+ )
150
+
151
+
152
+ def _clear_old_figure_states(current_key: str) -> None:
153
+ for key in list(st.session_state):
154
+ if key == current_key or not isinstance(key, str):
155
+ continue
156
+ parts = key.split("::", 2)
157
+ if len(parts) >= 2 and parts[0] == "load" and parts[1].endswith("_fig_state"):
158
+ st.session_state.pop(key, None)
159
+
160
+
161
+ def _store_figure_state(key: str, value: object) -> None:
162
+ _clear_old_figure_states(key)
163
+ st.session_state[key] = value
164
+
165
+
166
+ def _release_vector_memory(store: Store, variants: list[str] | tuple[str, ...]) -> None:
167
+ release_store_cache(store, variants)
168
+ gc.collect()
169
+
170
+
171
+ def _evenly_spaced_layers(layers: list[int], max_count: int) -> list[int]:
172
+ if max_count >= len(layers):
173
+ return layers
174
+ if max_count <= 1:
175
+ return [layers[0]]
176
+
177
+ last = len(layers) - 1
178
+ indices = [round(i * last / (max_count - 1)) for i in range(max_count)]
179
+ return [layers[index] for index in dict.fromkeys(indices)]
180
+
181
+
182
+ def _render_layer_frame_controls(
183
+ store: Store,
184
+ scope: str,
185
+ layers: list[int],
186
+ ) -> list[int]:
187
+ if len(layers) <= _DEFAULT_LAYER_FRAMES:
188
+ st.caption(f"Using all {len(layers)} available layer(s).")
189
+ return layers
190
+
191
+ frame_count = st.slider(
192
+ "Layer frames",
193
+ min_value=2,
194
+ max_value=len(layers),
195
+ value=_DEFAULT_LAYER_FRAMES,
196
+ key=widget_key("load", "layer_frames", scope, store_id(store)),
197
+ help="Limit animated Plotly frames to keep browser and RAM usage bounded.",
198
+ )
199
+ selected = _evenly_spaced_layers(layers, frame_count)
200
+ st.caption(f"Using {len(selected)} of {len(layers)} layers.")
201
+ return selected
202
+
203
+
204
  def _load_persona_options(
205
  store: Store,
206
  variants: list[str],
 
256
  options: PersonaOptions,
257
  *,
258
  default_all: bool,
259
+ default_count_limit: int | None = None,
260
  ) -> tuple[int, bool]:
261
  remembered_count_key = f"{remember_key}:count"
262
  remembered_assistant_key = f"{remember_key}:include_assistant"
 
271
  options.assistant_id in legacy_ids,
272
  )
273
 
274
+ if default_count_limit is not None:
275
+ default_count = min(default_count_limit, len(options.regular_ids))
276
+ elif default_all:
277
+ default_count = len(options.regular_ids)
278
+ else:
279
+ default_count = min(1, len(options.regular_ids))
280
  remembered_count = int(st.session_state.get(remembered_count_key, default_count))
281
  persona_count = min(max(remembered_count, 0), len(options.regular_ids))
282
  include_assistant = bool(st.session_state.get(remembered_assistant_key, False))
 
340
  widget_scope: str,
341
  remember_key: str,
342
  default_all: bool = False,
343
+ default_count_limit: int | None = None,
344
  ) -> list[str]:
345
  empty_message = (
346
  "No personas have vectors for all selected variants. "
 
361
  remember_key,
362
  options,
363
  default_all=default_all,
364
+ default_count_limit=default_count_limit,
365
  )
366
  persona_count, include_assistant = _render_persona_count_controls(
367
  store,
 
482
 
483
  def _build_cosine_figures(
484
  store: Store,
485
+ mask_strategy: MaskStrategy,
486
  selection: CosineSelection,
487
  ) -> tuple[object, object | None, int, int] | None:
488
  variant_sample_cache: dict[str, object] = {}
489
 
490
  def _load_variant(variant: str):
491
  if variant not in variant_sample_cache:
492
+ samples = _load_variant_vectors(
493
  store,
494
  [variant],
495
+ mask_strategy,
496
  persona_ids=selection.persona_ids,
497
  )
498
  variant_sample_cache[variant] = samples[variant]
 
587
  mask_strategy.value,
588
  "_".join(selection.variants),
589
  )
590
+ _clear_old_figure_states(cosine_fig_key)
591
 
592
  if st.button(
593
  "Compare vectors",
 
606
  progress = st.progress(0, text="Loading activation vectors…")
607
  try:
608
  progress.progress(15, text="Loading activation vectors…")
609
+ figures = _build_cosine_figures(store, mask_strategy, selection)
610
  if figures is None:
611
  st.session_state.pop(cosine_fig_key, None)
612
  return
613
  progress.progress(90, text="Storing figure state…")
614
+ _store_figure_state(cosine_fig_key, figures)
615
  progress.progress(100, text="Done.")
616
  finally:
617
+ _release_vector_memory(store, selection.variants)
618
  progress.empty()
619
 
620
  if cosine_fig_key in st.session_state:
 
637
  store: Store,
638
  mask_strategy: MaskStrategy,
639
  scope: str,
640
+ *,
641
+ remember_key: str,
642
+ default_count_limit: int,
643
  ) -> tuple[str, list[str], str, list[int]] | None:
644
  variants = available_variants(store, mask_strategy)
645
  if not variants:
 
657
  [variant],
658
  mask_strategy,
659
  widget_scope=f"{scope}:{store_id(store)}",
660
+ remember_key=remember_key,
661
+ default_count_limit=default_count_limit,
662
  )
663
  if not persona_ids:
664
  return None
 
669
  st.info("No shared layers are available for the selected personas.")
670
  return None
671
 
672
+ selected_layers = _render_layer_frame_controls(store, scope, layer_options)
673
+ return variant, persona_ids, persona_key, selected_layers
674
 
675
 
676
  def _render_layered_figure_analysis(
 
683
  title_fn: Callable[[str], str],
684
  include_pair_trajectories: bool = False,
685
  n_components: int = 2,
686
+ remember_key: str = _LAST_PROJECTION_PERSONAS_KEY,
687
+ default_count_limit: int = 500,
688
  ) -> None:
689
  """Render a single-variant layered analysis: select → button → figure(s).
690
 
691
  Used for similarity matrix, PCA, and UMAP. Set ``include_pair_trajectories``
692
  to add the pair-similarity-trajectory figure (similarity matrix only).
693
  """
694
+ selected = _select_single_variant_samples(
695
+ store,
696
+ mask_strategy,
697
+ scope,
698
+ remember_key=remember_key,
699
+ default_count_limit=default_count_limit,
700
+ )
701
  if selected is None:
702
  return
703
  variant, persona_ids, persona_key, selected_layers = selected
704
 
705
+ pair_trajectories = False
706
+ if include_pair_trajectories:
707
+ pair_count = len(persona_ids) * (len(persona_ids) - 1) // 2
708
+ if pair_count > _MAX_PAIR_TRAJECTORY_TRACES:
709
+ st.caption(
710
+ "Pair trajectories hidden because this selection would create "
711
+ f"{pair_count:,} Plotly traces."
712
+ )
713
+ else:
714
+ pair_trajectories = st.checkbox(
715
+ "Pair trajectories",
716
+ value=False,
717
+ key=widget_key("load", "pair_trajectories", scope, store_id(store)),
718
+ help="Adds one line per persona pair. Keep this off for larger selections.",
719
+ )
720
+
721
+ if figure_kind == "similarity":
722
+ similarity_cells = len(persona_ids) * len(persona_ids) * len(selected_layers)
723
+ if similarity_cells > _MAX_SIMILARITY_CELLS:
724
+ st.error(
725
+ "Reduce personas or layer frames before generating the similarity "
726
+ f"matrix ({similarity_cells:,} cells selected)."
727
+ )
728
+ return
729
+
730
  n_clusters = None
731
  if figure_kind in {"pca", "umap"}:
732
  use_kmeans = st.toggle(
 
756
  variant,
757
  "persona_vector",
758
  persona_key,
759
+ "_".join(map(str, selected_layers)),
760
+ str(pair_trajectories),
761
  )
762
  filename = scope
763
+ _clear_old_figure_states(fig_key)
764
 
765
  if st.button(button_label, type="primary"):
766
  build_label = {
 
771
  progress = st.progress(0, text="Loading activation vectors…")
772
  try:
773
  progress.progress(15, text="Loading activation vectors…")
774
+ samples = _load_persona_vectors(
775
  store,
776
  variant,
777
+ mask_strategy,
778
+ persona_ids,
779
  )
780
  progress.progress(55, text=build_label)
781
  build_kwargs = {}
 
783
  build_kwargs["n_components"] = n_components
784
  if n_clusters is not None:
785
  build_kwargs["n_clusters"] = n_clusters
786
+ if figure_kind == "similarity" and pair_trajectories:
787
  main_fig, extra_fig = build_similarity_figures(
788
  samples,
789
  layers=selected_layers,
 
812
  f"{prompt_variant_label(variant)} - persona vectors"
813
  ),
814
  )
815
+ if pair_trajectories
816
  else None
817
  )
818
  progress.progress(90, text="Storing figure state…")
819
+ n_samples = samples.vectors.shape[0]
820
+ del samples
821
+ _store_figure_state(fig_key, (main_fig, extra_fig, n_samples))
822
  progress.progress(100, text="Done.")
823
  except Exception as exc:
824
  st.error(f"Could not build figure: {exc}")
825
  st.session_state.pop(fig_key, None)
826
  finally:
827
+ _release_vector_memory(store, [variant])
828
  progress.empty()
829
 
830
  if fig_key in st.session_state:
 
886
  mask_strategy,
887
  widget_scope=f"dendro:{store_id(store)}",
888
  remember_key=_LAST_DENDRO_PERSONAS_KEY,
889
+ default_count_limit=_DEFAULT_PERSONA_LIMITS["dendro"],
890
  )
891
  if not persona_ids:
892
  return
 
907
  key=widget_key("load", "dendro_linkage", store_id(store)),
908
  )
909
 
910
+ selected_layers: list[int] | None = None
911
+ if layered_mode:
912
+ source, location, model_name = store_cache_parts(store)
913
+ layer_options = store_layers_cached(
914
+ source,
915
+ location,
916
+ model_name,
917
+ mask_strategy.value,
918
+ tuple(shared_variants),
919
+ tuple(persona_ids),
920
+ )
921
+ if not layer_options:
922
+ st.info("No shared layers are available for the selected personas.")
923
+ return
924
+ selected_layers = _render_layer_frame_controls(store, "dendro", layer_options)
925
+
926
  persona_key = personas_fingerprint(persona_ids)
927
  fig_key = widget_key(
928
  "load",
 
935
  persona_key,
936
  str(layered_mode),
937
  linkage,
938
+ "_".join(map(str, selected_layers or [])),
939
  )
940
+ _clear_old_figure_states(fig_key)
941
 
942
  if st.button(
943
  "Generate dendrograms",
 
949
  progress = st.progress(0, text="Loading first variant vectors…")
950
  try:
951
  progress.progress(15, text="Loading first variant vectors…")
952
+ samples_a = _load_persona_vectors(
953
  store,
954
  variant_a,
955
+ mask_strategy,
956
+ persona_ids,
957
  )
958
  progress.progress(40, text="Building first dendrogram…")
959
  fig_a = plot_persona_dendrogram(
960
  samples_a,
961
  layered=layered_mode,
962
+ layers=selected_layers,
963
  linkage=linkage,
964
  title=f"Dendrogram — {prompt_variant_label(variant_a)}",
965
  )
966
  fig_a.update_layout(height=750)
967
+ del samples_a
968
  fig_b = None
969
  if variant_a != variant_b:
970
  progress.progress(60, text="Loading second variant vectors…")
971
+ samples_b = _load_persona_vectors(
972
  store,
973
  variant_b,
974
+ mask_strategy,
975
+ persona_ids,
976
  )
977
  progress.progress(75, text="Building second dendrogram…")
978
  fig_b = plot_persona_dendrogram(
979
  samples_b,
980
  layered=layered_mode,
981
+ layers=selected_layers,
982
  linkage=linkage,
983
  title=f"Dendrogram — {prompt_variant_label(variant_b)}",
984
  )
985
  fig_b.update_layout(height=750)
986
+ del samples_b
987
  progress.progress(90, text="Storing figure state…")
988
+ _store_figure_state(
989
+ fig_key,
990
+ (fig_a, fig_b, len(persona_ids), variant_a, variant_b),
 
 
 
991
  )
992
  progress.progress(100, text="Done.")
993
  except Exception as exc:
994
  st.error(f"Could not build dendrogram: {exc}")
995
  st.session_state.pop(fig_key, None)
996
  finally:
997
+ _release_vector_memory(store, shared_variants)
998
  progress.empty()
999
 
1000
  if fig_key in st.session_state:
 
1205
  f"Centered similarity - {prompt_variant_label(v)} - persona vectors"
1206
  ),
1207
  include_pair_trajectories=True,
1208
+ remember_key=_LAST_SIMILARITY_PERSONAS_KEY,
1209
+ default_count_limit=_DEFAULT_PERSONA_LIMITS["similarity"],
1210
  )
1211
  return
1212
 
 
1233
  f"{analysis_mode}{dim_suffix} - {prompt_variant_label(v)} - persona vectors"
1234
  ),
1235
  n_components=n_components,
1236
+ default_count_limit=_DEFAULT_PERSONA_LIMITS[analysis_mode.lower()],
1237
  )
utils/compare_sources.py CHANGED
@@ -1,9 +1,12 @@
1
  import os
2
 
3
  import streamlit as st
 
 
4
  from persona_vectors.artifacts import (
5
  ActivationStore,
6
  HFActivationStore,
 
7
  discover_activation_models,
8
  model_dir_name,
9
  )
@@ -22,6 +25,28 @@ SOURCE_LOCAL = "Local activations"
22
  SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  @st.cache_resource(show_spinner=False, max_entries=1)
26
  def activation_store_cached(
27
  source: str,
@@ -54,6 +79,26 @@ def personas_cached(
54
  mask_strategy_value: str,
55
  variants: tuple[str, ...],
56
  ) -> list[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  store = activation_store_cached(source, location, model_name, mask_strategy_value)
58
  return store.list_personas(
59
  list(variants),
@@ -70,6 +115,25 @@ def persona_names_cached(
70
  variants: tuple[str, ...],
71
  persona_ids: tuple[str, ...],
72
  ) -> dict[str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  store = activation_store_cached(source, location, model_name, mask_strategy_value)
74
  return store.persona_names(
75
  list(persona_ids),
@@ -126,6 +190,26 @@ def store_layers_cached(
126
  variants: tuple[str, ...],
127
  persona_ids: tuple[str, ...],
128
  ) -> list[int]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  store = activation_store_cached(source, location, model_name, mask_strategy_value)
130
  return store.list_layers(
131
  list(variants),
@@ -136,3 +220,100 @@ def store_layers_cached(
136
 
137
  def local_model_matches(left: str, right: str) -> bool:
138
  return model_dir_name(left) == model_dir_name(right)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
 
3
  import streamlit as st
4
+ import torch
5
+ from persona_vectors.analysis import LayeredSamples
6
  from persona_vectors.artifacts import (
7
  ActivationStore,
8
  HFActivationStore,
9
+ activation_config_name,
10
  discover_activation_models,
11
  model_dir_name,
12
  )
 
25
  SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
26
 
27
 
28
+ def _hub_split(repo_id: str, model_name: str, mask_strategy_value: str, variant: str):
29
+ from datasets import load_dataset
30
+
31
+ return load_dataset(
32
+ repo_id,
33
+ name=activation_config_name(model_name, mask_strategy_value),
34
+ split=variant,
35
+ keep_in_memory=False,
36
+ )
37
+
38
+
39
+ def _hub_split_columns(
40
+ repo_id: str,
41
+ model_name: str,
42
+ mask_strategy_value: str,
43
+ variant: str,
44
+ columns: list[str],
45
+ ):
46
+ dataset = _hub_split(repo_id, model_name, mask_strategy_value, variant)
47
+ return dataset.select_columns(columns)
48
+
49
+
50
  @st.cache_resource(show_spinner=False, max_entries=1)
51
  def activation_store_cached(
52
  source: str,
 
79
  mask_strategy_value: str,
80
  variants: tuple[str, ...],
81
  ) -> list[str]:
82
+ if source == SOURCE_HUB:
83
+ variant_ids = [
84
+ list(
85
+ _hub_split_columns(
86
+ location,
87
+ model_name,
88
+ mask_strategy_value,
89
+ variant,
90
+ ["persona_id"],
91
+ )["persona_id"]
92
+ )
93
+ for variant in variants
94
+ ]
95
+ if not variant_ids:
96
+ return []
97
+ shared = set(variant_ids[0])
98
+ for ids in variant_ids[1:]:
99
+ shared &= set(ids)
100
+ return [persona_id for persona_id in variant_ids[0] if persona_id in shared]
101
+
102
  store = activation_store_cached(source, location, model_name, mask_strategy_value)
103
  return store.list_personas(
104
  list(variants),
 
115
  variants: tuple[str, ...],
116
  persona_ids: tuple[str, ...],
117
  ) -> dict[str, str]:
118
+ if source == SOURCE_HUB:
119
+ requested = set(persona_ids)
120
+ names: dict[str, str] = {}
121
+ for variant in variants:
122
+ metadata = _hub_split_columns(
123
+ location,
124
+ model_name,
125
+ mask_strategy_value,
126
+ variant,
127
+ ["persona_id", "name"],
128
+ )
129
+ for row in metadata:
130
+ persona_id = row["persona_id"]
131
+ if persona_id in requested and persona_id not in names:
132
+ names[persona_id] = row.get("name") or persona_id
133
+ if len(names) == len(requested):
134
+ return {pid: names.get(pid, pid) for pid in persona_ids}
135
+ return {pid: names.get(pid, pid) for pid in persona_ids}
136
+
137
  store = activation_store_cached(source, location, model_name, mask_strategy_value)
138
  return store.persona_names(
139
  list(persona_ids),
 
190
  variants: tuple[str, ...],
191
  persona_ids: tuple[str, ...],
192
  ) -> list[int]:
193
+ if source == SOURCE_HUB:
194
+ shared_layers: set[int] | None = None
195
+ requested = list(persona_ids)
196
+ for variant in variants:
197
+ dataset = _hub_split(location, model_name, mask_strategy_value, variant)
198
+ ids = list(dataset.select_columns(["persona_id"])["persona_id"])
199
+ sample_id = requested[0] if requested else (ids[0] if ids else None)
200
+ if sample_id is None:
201
+ return []
202
+ if requested and any(persona_id not in ids for persona_id in requested):
203
+ return []
204
+ vector = torch.as_tensor(dataset[ids.index(sample_id)]["vector"])
205
+ if vector.ndim != 2:
206
+ raise ValueError(
207
+ f"tensor for {sample_id!r} must have shape (num_layers, hidden_size)"
208
+ )
209
+ layers = set(range(int(vector.shape[0])))
210
+ shared_layers = layers if shared_layers is None else shared_layers & layers
211
+ return sorted(shared_layers or set())
212
+
213
  store = activation_store_cached(source, location, model_name, mask_strategy_value)
214
  return store.list_layers(
215
  list(variants),
 
220
 
221
  def local_model_matches(left: str, right: str) -> bool:
222
  return model_dir_name(left) == model_dir_name(right)
223
+
224
+
225
+ def load_persona_vectors_lean(
226
+ source: str,
227
+ location: str,
228
+ model_name: str,
229
+ mask_strategy_value: str,
230
+ variant: str,
231
+ persona_ids: tuple[str, ...],
232
+ ) -> LayeredSamples:
233
+ if source != SOURCE_HUB:
234
+ from persona_vectors.analysis import load_persona_vectors
235
+
236
+ store = activation_store_cached(
237
+ source,
238
+ location,
239
+ model_name,
240
+ mask_strategy_value,
241
+ )
242
+ return load_persona_vectors(
243
+ store,
244
+ variant,
245
+ mask_strategy=MaskStrategy(mask_strategy_value),
246
+ persona_ids=list(persona_ids),
247
+ )
248
+
249
+ dataset = _hub_split(location, model_name, mask_strategy_value, variant)
250
+ metadata = dataset.select_columns(["persona_id", "name"])
251
+ index_by_id: dict[str, int] = {}
252
+ name_by_id: dict[str, str] = {}
253
+ requested = set(persona_ids)
254
+ for index, row in enumerate(metadata):
255
+ persona_id = row["persona_id"]
256
+ if persona_id in requested:
257
+ index_by_id[persona_id] = index
258
+ name_by_id[persona_id] = row.get("name") or persona_id
259
+ if len(index_by_id) == len(requested):
260
+ break
261
+
262
+ missing = [
263
+ persona_id for persona_id in persona_ids if persona_id not in index_by_id
264
+ ]
265
+ if missing:
266
+ raise FileNotFoundError(
267
+ f"Missing {len(missing)} persona vector(s) in {variant!r}: {missing[:3]}"
268
+ )
269
+
270
+ vectors, labels, hover_text = [], [], []
271
+ for persona_id in persona_ids:
272
+ name = name_by_id.get(persona_id, persona_id)
273
+ vector = torch.as_tensor(
274
+ dataset[index_by_id[persona_id]]["vector"],
275
+ dtype=torch.float32,
276
+ )
277
+ if vector.ndim != 2:
278
+ raise ValueError(
279
+ f"tensor for {persona_id!r} must have shape (num_layers, hidden_size)"
280
+ )
281
+ vectors.append(vector)
282
+ labels.append(name)
283
+ hover_text.append(f"Persona: {name}<br>ID: {persona_id}")
284
+ return LayeredSamples(torch.stack(vectors), labels, hover_text)
285
+
286
+
287
+ def load_variant_vectors_lean(
288
+ source: str,
289
+ location: str,
290
+ model_name: str,
291
+ mask_strategy_value: str,
292
+ variants: tuple[str, ...],
293
+ persona_ids: tuple[str, ...],
294
+ ) -> dict[str, LayeredSamples]:
295
+ return {
296
+ variant: load_persona_vectors_lean(
297
+ source,
298
+ location,
299
+ model_name,
300
+ mask_strategy_value,
301
+ variant,
302
+ persona_ids,
303
+ )
304
+ for variant in variants
305
+ }
306
+
307
+
308
+ def release_store_cache(
309
+ store: Store,
310
+ variants: list[str] | tuple[str, ...] | None = None,
311
+ ) -> None:
312
+ cache = getattr(store, "_cache", None)
313
+ if not isinstance(cache, dict):
314
+ return
315
+ if variants is None:
316
+ cache.clear()
317
+ return
318
+ for variant in variants:
319
+ cache.pop(variant, None)