Jac-Zac commited on
Commit ·
4df7d97
1
Parent(s): 1b16c40
Performance speedup
Browse files- tabs/compare.py +206 -31
- utils/compare_sources.py +181 -0
tabs/compare.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from collections.abc import Callable
|
| 2 |
from dataclasses import dataclass
|
| 3 |
from itertools import combinations
|
|
@@ -7,7 +8,6 @@ import plotly.graph_objects as go
|
|
| 7 |
import streamlit as st
|
| 8 |
from persona_data.environment import get_artifacts_dir
|
| 9 |
from persona_data.synth_persona import BASELINE_PERSONA_ID
|
| 10 |
-
from persona_vectors.analysis import load_persona_vectors, load_variant_vectors
|
| 11 |
from persona_vectors.extraction import MaskStrategy
|
| 12 |
from persona_vectors.plots import (
|
| 13 |
build_layered_figure,
|
|
@@ -28,10 +28,13 @@ from utils.compare_sources import (
|
|
| 28 |
activation_store_cached,
|
| 29 |
available_variants,
|
| 30 |
hub_models_by_mask_strategy,
|
|
|
|
|
|
|
| 31 |
local_model_matches,
|
| 32 |
local_model_options_cached,
|
| 33 |
persona_names_cached,
|
| 34 |
personas_cached,
|
|
|
|
| 35 |
store_cache_parts,
|
| 36 |
store_id,
|
| 37 |
store_layers_cached,
|
|
@@ -56,9 +59,20 @@ def _filename(*parts: str) -> str:
|
|
| 56 |
# overwrite cosine similarity defaults.
|
| 57 |
_LAST_COSINE_PERSONAS_KEY = "compare:last_personas:cosine"
|
| 58 |
_LAST_PROJECTION_PERSONAS_KEY = "compare:last_personas:projection"
|
|
|
|
| 59 |
_LAST_MASK_STRATEGY_KEY = "compare:last_mask_strategy"
|
| 60 |
_LAST_SOURCE_KEY = "compare:last_source"
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
def _is_assistant_persona(persona_id: str, persona_name: str | None = None) -> bool:
|
| 64 |
persona_id_normalized = persona_id.strip().lower()
|
|
@@ -101,6 +115,92 @@ def _layers_for_variant(
|
|
| 101 |
)
|
| 102 |
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
def _load_persona_options(
|
| 105 |
store: Store,
|
| 106 |
variants: list[str],
|
|
@@ -156,6 +256,7 @@ def _seed_persona_memory(
|
|
| 156 |
options: PersonaOptions,
|
| 157 |
*,
|
| 158 |
default_all: bool,
|
|
|
|
| 159 |
) -> tuple[int, bool]:
|
| 160 |
remembered_count_key = f"{remember_key}:count"
|
| 161 |
remembered_assistant_key = f"{remember_key}:include_assistant"
|
|
@@ -170,9 +271,12 @@ def _seed_persona_memory(
|
|
| 170 |
options.assistant_id in legacy_ids,
|
| 171 |
)
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
| 176 |
remembered_count = int(st.session_state.get(remembered_count_key, default_count))
|
| 177 |
persona_count = min(max(remembered_count, 0), len(options.regular_ids))
|
| 178 |
include_assistant = bool(st.session_state.get(remembered_assistant_key, False))
|
|
@@ -236,6 +340,7 @@ def _select_artifact_personas(
|
|
| 236 |
widget_scope: str,
|
| 237 |
remember_key: str,
|
| 238 |
default_all: bool = False,
|
|
|
|
| 239 |
) -> list[str]:
|
| 240 |
empty_message = (
|
| 241 |
"No personas have vectors for all selected variants. "
|
|
@@ -256,6 +361,7 @@ def _select_artifact_personas(
|
|
| 256 |
remember_key,
|
| 257 |
options,
|
| 258 |
default_all=default_all,
|
|
|
|
| 259 |
)
|
| 260 |
persona_count, include_assistant = _render_persona_count_controls(
|
| 261 |
store,
|
|
@@ -376,15 +482,17 @@ def _render_cosine_selection(
|
|
| 376 |
|
| 377 |
def _build_cosine_figures(
|
| 378 |
store: Store,
|
|
|
|
| 379 |
selection: CosineSelection,
|
| 380 |
) -> tuple[object, object | None, int, int] | None:
|
| 381 |
variant_sample_cache: dict[str, object] = {}
|
| 382 |
|
| 383 |
def _load_variant(variant: str):
|
| 384 |
if variant not in variant_sample_cache:
|
| 385 |
-
samples =
|
| 386 |
store,
|
| 387 |
[variant],
|
|
|
|
| 388 |
persona_ids=selection.persona_ids,
|
| 389 |
)
|
| 390 |
variant_sample_cache[variant] = samples[variant]
|
|
@@ -479,6 +587,7 @@ def _render_cosine_similarity(
|
|
| 479 |
mask_strategy.value,
|
| 480 |
"_".join(selection.variants),
|
| 481 |
)
|
|
|
|
| 482 |
|
| 483 |
if st.button(
|
| 484 |
"Compare vectors",
|
|
@@ -497,14 +606,15 @@ def _render_cosine_similarity(
|
|
| 497 |
progress = st.progress(0, text="Loading activation vectors…")
|
| 498 |
try:
|
| 499 |
progress.progress(15, text="Loading activation vectors…")
|
| 500 |
-
figures = _build_cosine_figures(store, selection)
|
| 501 |
if figures is None:
|
| 502 |
st.session_state.pop(cosine_fig_key, None)
|
| 503 |
return
|
| 504 |
progress.progress(90, text="Storing figure state…")
|
| 505 |
-
|
| 506 |
progress.progress(100, text="Done.")
|
| 507 |
finally:
|
|
|
|
| 508 |
progress.empty()
|
| 509 |
|
| 510 |
if cosine_fig_key in st.session_state:
|
|
@@ -527,6 +637,9 @@ def _select_single_variant_samples(
|
|
| 527 |
store: Store,
|
| 528 |
mask_strategy: MaskStrategy,
|
| 529 |
scope: str,
|
|
|
|
|
|
|
|
|
|
| 530 |
) -> tuple[str, list[str], str, list[int]] | None:
|
| 531 |
variants = available_variants(store, mask_strategy)
|
| 532 |
if not variants:
|
|
@@ -544,8 +657,8 @@ def _select_single_variant_samples(
|
|
| 544 |
[variant],
|
| 545 |
mask_strategy,
|
| 546 |
widget_scope=f"{scope}:{store_id(store)}",
|
| 547 |
-
remember_key=
|
| 548 |
-
|
| 549 |
)
|
| 550 |
if not persona_ids:
|
| 551 |
return None
|
|
@@ -556,8 +669,8 @@ def _select_single_variant_samples(
|
|
| 556 |
st.info("No shared layers are available for the selected personas.")
|
| 557 |
return None
|
| 558 |
|
| 559 |
-
|
| 560 |
-
return variant, persona_ids, persona_key,
|
| 561 |
|
| 562 |
|
| 563 |
def _render_layered_figure_analysis(
|
|
@@ -570,17 +683,50 @@ def _render_layered_figure_analysis(
|
|
| 570 |
title_fn: Callable[[str], str],
|
| 571 |
include_pair_trajectories: bool = False,
|
| 572 |
n_components: int = 2,
|
|
|
|
|
|
|
| 573 |
) -> None:
|
| 574 |
"""Render a single-variant layered analysis: select → button → figure(s).
|
| 575 |
|
| 576 |
Used for similarity matrix, PCA, and UMAP. Set ``include_pair_trajectories``
|
| 577 |
to add the pair-similarity-trajectory figure (similarity matrix only).
|
| 578 |
"""
|
| 579 |
-
selected = _select_single_variant_samples(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
if selected is None:
|
| 581 |
return
|
| 582 |
variant, persona_ids, persona_key, selected_layers = selected
|
| 583 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
n_clusters = None
|
| 585 |
if figure_kind in {"pca", "umap"}:
|
| 586 |
use_kmeans = st.toggle(
|
|
@@ -610,8 +756,11 @@ def _render_layered_figure_analysis(
|
|
| 610 |
variant,
|
| 611 |
"persona_vector",
|
| 612 |
persona_key,
|
|
|
|
|
|
|
| 613 |
)
|
| 614 |
filename = scope
|
|
|
|
| 615 |
|
| 616 |
if st.button(button_label, type="primary"):
|
| 617 |
build_label = {
|
|
@@ -622,11 +771,11 @@ def _render_layered_figure_analysis(
|
|
| 622 |
progress = st.progress(0, text="Loading activation vectors…")
|
| 623 |
try:
|
| 624 |
progress.progress(15, text="Loading activation vectors…")
|
| 625 |
-
samples =
|
| 626 |
store,
|
| 627 |
variant,
|
| 628 |
-
mask_strategy
|
| 629 |
-
persona_ids
|
| 630 |
)
|
| 631 |
progress.progress(55, text=build_label)
|
| 632 |
build_kwargs = {}
|
|
@@ -634,7 +783,7 @@ def _render_layered_figure_analysis(
|
|
| 634 |
build_kwargs["n_components"] = n_components
|
| 635 |
if n_clusters is not None:
|
| 636 |
build_kwargs["n_clusters"] = n_clusters
|
| 637 |
-
if figure_kind == "similarity" and
|
| 638 |
main_fig, extra_fig = build_similarity_figures(
|
| 639 |
samples,
|
| 640 |
layers=selected_layers,
|
|
@@ -663,16 +812,19 @@ def _render_layered_figure_analysis(
|
|
| 663 |
f"{prompt_variant_label(variant)} - persona vectors"
|
| 664 |
),
|
| 665 |
)
|
| 666 |
-
if
|
| 667 |
else None
|
| 668 |
)
|
| 669 |
progress.progress(90, text="Storing figure state…")
|
| 670 |
-
|
|
|
|
|
|
|
| 671 |
progress.progress(100, text="Done.")
|
| 672 |
except Exception as exc:
|
| 673 |
st.error(f"Could not build figure: {exc}")
|
| 674 |
st.session_state.pop(fig_key, None)
|
| 675 |
finally:
|
|
|
|
| 676 |
progress.empty()
|
| 677 |
|
| 678 |
if fig_key in st.session_state:
|
|
@@ -734,7 +886,7 @@ def _render_dendrogram_analysis(
|
|
| 734 |
mask_strategy,
|
| 735 |
widget_scope=f"dendro:{store_id(store)}",
|
| 736 |
remember_key=_LAST_DENDRO_PERSONAS_KEY,
|
| 737 |
-
|
| 738 |
)
|
| 739 |
if not persona_ids:
|
| 740 |
return
|
|
@@ -755,6 +907,22 @@ def _render_dendrogram_analysis(
|
|
| 755 |
key=widget_key("load", "dendro_linkage", store_id(store)),
|
| 756 |
)
|
| 757 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 758 |
persona_key = personas_fingerprint(persona_ids)
|
| 759 |
fig_key = widget_key(
|
| 760 |
"load",
|
|
@@ -767,7 +935,9 @@ def _render_dendrogram_analysis(
|
|
| 767 |
persona_key,
|
| 768 |
str(layered_mode),
|
| 769 |
linkage,
|
|
|
|
| 770 |
)
|
|
|
|
| 771 |
|
| 772 |
if st.button(
|
| 773 |
"Generate dendrograms",
|
|
@@ -779,50 +949,52 @@ def _render_dendrogram_analysis(
|
|
| 779 |
progress = st.progress(0, text="Loading first variant vectors…")
|
| 780 |
try:
|
| 781 |
progress.progress(15, text="Loading first variant vectors…")
|
| 782 |
-
samples_a =
|
| 783 |
store,
|
| 784 |
variant_a,
|
| 785 |
-
mask_strategy
|
| 786 |
-
persona_ids
|
| 787 |
)
|
| 788 |
progress.progress(40, text="Building first dendrogram…")
|
| 789 |
fig_a = plot_persona_dendrogram(
|
| 790 |
samples_a,
|
| 791 |
layered=layered_mode,
|
|
|
|
| 792 |
linkage=linkage,
|
| 793 |
title=f"Dendrogram — {prompt_variant_label(variant_a)}",
|
| 794 |
)
|
| 795 |
fig_a.update_layout(height=750)
|
|
|
|
| 796 |
fig_b = None
|
| 797 |
if variant_a != variant_b:
|
| 798 |
progress.progress(60, text="Loading second variant vectors…")
|
| 799 |
-
samples_b =
|
| 800 |
store,
|
| 801 |
variant_b,
|
| 802 |
-
mask_strategy
|
| 803 |
-
persona_ids
|
| 804 |
)
|
| 805 |
progress.progress(75, text="Building second dendrogram…")
|
| 806 |
fig_b = plot_persona_dendrogram(
|
| 807 |
samples_b,
|
| 808 |
layered=layered_mode,
|
|
|
|
| 809 |
linkage=linkage,
|
| 810 |
title=f"Dendrogram — {prompt_variant_label(variant_b)}",
|
| 811 |
)
|
| 812 |
fig_b.update_layout(height=750)
|
|
|
|
| 813 |
progress.progress(90, text="Storing figure state…")
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
fig_b,
|
| 817 |
-
len(persona_ids),
|
| 818 |
-
variant_a,
|
| 819 |
-
variant_b,
|
| 820 |
)
|
| 821 |
progress.progress(100, text="Done.")
|
| 822 |
except Exception as exc:
|
| 823 |
st.error(f"Could not build dendrogram: {exc}")
|
| 824 |
st.session_state.pop(fig_key, None)
|
| 825 |
finally:
|
|
|
|
| 826 |
progress.empty()
|
| 827 |
|
| 828 |
if fig_key in st.session_state:
|
|
@@ -1033,6 +1205,8 @@ def render_compare_tab() -> None:
|
|
| 1033 |
f"Centered similarity - {prompt_variant_label(v)} - persona vectors"
|
| 1034 |
),
|
| 1035 |
include_pair_trajectories=True,
|
|
|
|
|
|
|
| 1036 |
)
|
| 1037 |
return
|
| 1038 |
|
|
@@ -1059,4 +1233,5 @@ def render_compare_tab() -> None:
|
|
| 1059 |
f"{analysis_mode}{dim_suffix} - {prompt_variant_label(v)} - persona vectors"
|
| 1060 |
),
|
| 1061 |
n_components=n_components,
|
|
|
|
| 1062 |
)
|
|
|
|
| 1 |
+
import gc
|
| 2 |
from collections.abc import Callable
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from itertools import combinations
|
|
|
|
| 8 |
import streamlit as st
|
| 9 |
from persona_data.environment import get_artifacts_dir
|
| 10 |
from persona_data.synth_persona import BASELINE_PERSONA_ID
|
|
|
|
| 11 |
from persona_vectors.extraction import MaskStrategy
|
| 12 |
from persona_vectors.plots import (
|
| 13 |
build_layered_figure,
|
|
|
|
| 28 |
activation_store_cached,
|
| 29 |
available_variants,
|
| 30 |
hub_models_by_mask_strategy,
|
| 31 |
+
load_persona_vectors_lean,
|
| 32 |
+
load_variant_vectors_lean,
|
| 33 |
local_model_matches,
|
| 34 |
local_model_options_cached,
|
| 35 |
persona_names_cached,
|
| 36 |
personas_cached,
|
| 37 |
+
release_store_cache,
|
| 38 |
store_cache_parts,
|
| 39 |
store_id,
|
| 40 |
store_layers_cached,
|
|
|
|
| 59 |
# overwrite cosine similarity defaults.
|
| 60 |
_LAST_COSINE_PERSONAS_KEY = "compare:last_personas:cosine"
|
| 61 |
_LAST_PROJECTION_PERSONAS_KEY = "compare:last_personas:projection"
|
| 62 |
+
_LAST_SIMILARITY_PERSONAS_KEY = "compare:last_personas:similarity"
|
| 63 |
_LAST_MASK_STRATEGY_KEY = "compare:last_mask_strategy"
|
| 64 |
_LAST_SOURCE_KEY = "compare:last_source"
|
| 65 |
|
| 66 |
+
_DEFAULT_LAYER_FRAMES = 16
|
| 67 |
+
_DEFAULT_PERSONA_LIMITS = {
|
| 68 |
+
"similarity": 120,
|
| 69 |
+
"pca": 500,
|
| 70 |
+
"umap": 500,
|
| 71 |
+
"dendro": 160,
|
| 72 |
+
}
|
| 73 |
+
_MAX_SIMILARITY_CELLS = 4_000_000
|
| 74 |
+
_MAX_PAIR_TRAJECTORY_TRACES = 500
|
| 75 |
+
|
| 76 |
|
| 77 |
def _is_assistant_persona(persona_id: str, persona_name: str | None = None) -> bool:
|
| 78 |
persona_id_normalized = persona_id.strip().lower()
|
|
|
|
| 115 |
)
|
| 116 |
|
| 117 |
|
| 118 |
+
def _load_persona_vectors(
|
| 119 |
+
store: Store,
|
| 120 |
+
variant: str,
|
| 121 |
+
mask_strategy: MaskStrategy,
|
| 122 |
+
persona_ids: list[str],
|
| 123 |
+
):
|
| 124 |
+
source, location, model_name = store_cache_parts(store)
|
| 125 |
+
return load_persona_vectors_lean(
|
| 126 |
+
source,
|
| 127 |
+
location,
|
| 128 |
+
model_name,
|
| 129 |
+
mask_strategy.value,
|
| 130 |
+
variant,
|
| 131 |
+
tuple(persona_ids),
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _load_variant_vectors(
|
| 136 |
+
store: Store,
|
| 137 |
+
variants: list[str] | tuple[str, ...],
|
| 138 |
+
mask_strategy: MaskStrategy,
|
| 139 |
+
persona_ids: list[str],
|
| 140 |
+
):
|
| 141 |
+
source, location, model_name = store_cache_parts(store)
|
| 142 |
+
return load_variant_vectors_lean(
|
| 143 |
+
source,
|
| 144 |
+
location,
|
| 145 |
+
model_name,
|
| 146 |
+
mask_strategy.value,
|
| 147 |
+
tuple(variants),
|
| 148 |
+
tuple(persona_ids),
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _clear_old_figure_states(current_key: str) -> None:
|
| 153 |
+
for key in list(st.session_state):
|
| 154 |
+
if key == current_key or not isinstance(key, str):
|
| 155 |
+
continue
|
| 156 |
+
parts = key.split("::", 2)
|
| 157 |
+
if len(parts) >= 2 and parts[0] == "load" and parts[1].endswith("_fig_state"):
|
| 158 |
+
st.session_state.pop(key, None)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _store_figure_state(key: str, value: object) -> None:
|
| 162 |
+
_clear_old_figure_states(key)
|
| 163 |
+
st.session_state[key] = value
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def _release_vector_memory(store: Store, variants: list[str] | tuple[str, ...]) -> None:
|
| 167 |
+
release_store_cache(store, variants)
|
| 168 |
+
gc.collect()
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _evenly_spaced_layers(layers: list[int], max_count: int) -> list[int]:
|
| 172 |
+
if max_count >= len(layers):
|
| 173 |
+
return layers
|
| 174 |
+
if max_count <= 1:
|
| 175 |
+
return [layers[0]]
|
| 176 |
+
|
| 177 |
+
last = len(layers) - 1
|
| 178 |
+
indices = [round(i * last / (max_count - 1)) for i in range(max_count)]
|
| 179 |
+
return [layers[index] for index in dict.fromkeys(indices)]
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _render_layer_frame_controls(
|
| 183 |
+
store: Store,
|
| 184 |
+
scope: str,
|
| 185 |
+
layers: list[int],
|
| 186 |
+
) -> list[int]:
|
| 187 |
+
if len(layers) <= _DEFAULT_LAYER_FRAMES:
|
| 188 |
+
st.caption(f"Using all {len(layers)} available layer(s).")
|
| 189 |
+
return layers
|
| 190 |
+
|
| 191 |
+
frame_count = st.slider(
|
| 192 |
+
"Layer frames",
|
| 193 |
+
min_value=2,
|
| 194 |
+
max_value=len(layers),
|
| 195 |
+
value=_DEFAULT_LAYER_FRAMES,
|
| 196 |
+
key=widget_key("load", "layer_frames", scope, store_id(store)),
|
| 197 |
+
help="Limit animated Plotly frames to keep browser and RAM usage bounded.",
|
| 198 |
+
)
|
| 199 |
+
selected = _evenly_spaced_layers(layers, frame_count)
|
| 200 |
+
st.caption(f"Using {len(selected)} of {len(layers)} layers.")
|
| 201 |
+
return selected
|
| 202 |
+
|
| 203 |
+
|
| 204 |
def _load_persona_options(
|
| 205 |
store: Store,
|
| 206 |
variants: list[str],
|
|
|
|
| 256 |
options: PersonaOptions,
|
| 257 |
*,
|
| 258 |
default_all: bool,
|
| 259 |
+
default_count_limit: int | None = None,
|
| 260 |
) -> tuple[int, bool]:
|
| 261 |
remembered_count_key = f"{remember_key}:count"
|
| 262 |
remembered_assistant_key = f"{remember_key}:include_assistant"
|
|
|
|
| 271 |
options.assistant_id in legacy_ids,
|
| 272 |
)
|
| 273 |
|
| 274 |
+
if default_count_limit is not None:
|
| 275 |
+
default_count = min(default_count_limit, len(options.regular_ids))
|
| 276 |
+
elif default_all:
|
| 277 |
+
default_count = len(options.regular_ids)
|
| 278 |
+
else:
|
| 279 |
+
default_count = min(1, len(options.regular_ids))
|
| 280 |
remembered_count = int(st.session_state.get(remembered_count_key, default_count))
|
| 281 |
persona_count = min(max(remembered_count, 0), len(options.regular_ids))
|
| 282 |
include_assistant = bool(st.session_state.get(remembered_assistant_key, False))
|
|
|
|
| 340 |
widget_scope: str,
|
| 341 |
remember_key: str,
|
| 342 |
default_all: bool = False,
|
| 343 |
+
default_count_limit: int | None = None,
|
| 344 |
) -> list[str]:
|
| 345 |
empty_message = (
|
| 346 |
"No personas have vectors for all selected variants. "
|
|
|
|
| 361 |
remember_key,
|
| 362 |
options,
|
| 363 |
default_all=default_all,
|
| 364 |
+
default_count_limit=default_count_limit,
|
| 365 |
)
|
| 366 |
persona_count, include_assistant = _render_persona_count_controls(
|
| 367 |
store,
|
|
|
|
| 482 |
|
| 483 |
def _build_cosine_figures(
|
| 484 |
store: Store,
|
| 485 |
+
mask_strategy: MaskStrategy,
|
| 486 |
selection: CosineSelection,
|
| 487 |
) -> tuple[object, object | None, int, int] | None:
|
| 488 |
variant_sample_cache: dict[str, object] = {}
|
| 489 |
|
| 490 |
def _load_variant(variant: str):
|
| 491 |
if variant not in variant_sample_cache:
|
| 492 |
+
samples = _load_variant_vectors(
|
| 493 |
store,
|
| 494 |
[variant],
|
| 495 |
+
mask_strategy,
|
| 496 |
persona_ids=selection.persona_ids,
|
| 497 |
)
|
| 498 |
variant_sample_cache[variant] = samples[variant]
|
|
|
|
| 587 |
mask_strategy.value,
|
| 588 |
"_".join(selection.variants),
|
| 589 |
)
|
| 590 |
+
_clear_old_figure_states(cosine_fig_key)
|
| 591 |
|
| 592 |
if st.button(
|
| 593 |
"Compare vectors",
|
|
|
|
| 606 |
progress = st.progress(0, text="Loading activation vectors…")
|
| 607 |
try:
|
| 608 |
progress.progress(15, text="Loading activation vectors…")
|
| 609 |
+
figures = _build_cosine_figures(store, mask_strategy, selection)
|
| 610 |
if figures is None:
|
| 611 |
st.session_state.pop(cosine_fig_key, None)
|
| 612 |
return
|
| 613 |
progress.progress(90, text="Storing figure state…")
|
| 614 |
+
_store_figure_state(cosine_fig_key, figures)
|
| 615 |
progress.progress(100, text="Done.")
|
| 616 |
finally:
|
| 617 |
+
_release_vector_memory(store, selection.variants)
|
| 618 |
progress.empty()
|
| 619 |
|
| 620 |
if cosine_fig_key in st.session_state:
|
|
|
|
| 637 |
store: Store,
|
| 638 |
mask_strategy: MaskStrategy,
|
| 639 |
scope: str,
|
| 640 |
+
*,
|
| 641 |
+
remember_key: str,
|
| 642 |
+
default_count_limit: int,
|
| 643 |
) -> tuple[str, list[str], str, list[int]] | None:
|
| 644 |
variants = available_variants(store, mask_strategy)
|
| 645 |
if not variants:
|
|
|
|
| 657 |
[variant],
|
| 658 |
mask_strategy,
|
| 659 |
widget_scope=f"{scope}:{store_id(store)}",
|
| 660 |
+
remember_key=remember_key,
|
| 661 |
+
default_count_limit=default_count_limit,
|
| 662 |
)
|
| 663 |
if not persona_ids:
|
| 664 |
return None
|
|
|
|
| 669 |
st.info("No shared layers are available for the selected personas.")
|
| 670 |
return None
|
| 671 |
|
| 672 |
+
selected_layers = _render_layer_frame_controls(store, scope, layer_options)
|
| 673 |
+
return variant, persona_ids, persona_key, selected_layers
|
| 674 |
|
| 675 |
|
| 676 |
def _render_layered_figure_analysis(
|
|
|
|
| 683 |
title_fn: Callable[[str], str],
|
| 684 |
include_pair_trajectories: bool = False,
|
| 685 |
n_components: int = 2,
|
| 686 |
+
remember_key: str = _LAST_PROJECTION_PERSONAS_KEY,
|
| 687 |
+
default_count_limit: int = 500,
|
| 688 |
) -> None:
|
| 689 |
"""Render a single-variant layered analysis: select → button → figure(s).
|
| 690 |
|
| 691 |
Used for similarity matrix, PCA, and UMAP. Set ``include_pair_trajectories``
|
| 692 |
to add the pair-similarity-trajectory figure (similarity matrix only).
|
| 693 |
"""
|
| 694 |
+
selected = _select_single_variant_samples(
|
| 695 |
+
store,
|
| 696 |
+
mask_strategy,
|
| 697 |
+
scope,
|
| 698 |
+
remember_key=remember_key,
|
| 699 |
+
default_count_limit=default_count_limit,
|
| 700 |
+
)
|
| 701 |
if selected is None:
|
| 702 |
return
|
| 703 |
variant, persona_ids, persona_key, selected_layers = selected
|
| 704 |
|
| 705 |
+
pair_trajectories = False
|
| 706 |
+
if include_pair_trajectories:
|
| 707 |
+
pair_count = len(persona_ids) * (len(persona_ids) - 1) // 2
|
| 708 |
+
if pair_count > _MAX_PAIR_TRAJECTORY_TRACES:
|
| 709 |
+
st.caption(
|
| 710 |
+
"Pair trajectories hidden because this selection would create "
|
| 711 |
+
f"{pair_count:,} Plotly traces."
|
| 712 |
+
)
|
| 713 |
+
else:
|
| 714 |
+
pair_trajectories = st.checkbox(
|
| 715 |
+
"Pair trajectories",
|
| 716 |
+
value=False,
|
| 717 |
+
key=widget_key("load", "pair_trajectories", scope, store_id(store)),
|
| 718 |
+
help="Adds one line per persona pair. Keep this off for larger selections.",
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
if figure_kind == "similarity":
|
| 722 |
+
similarity_cells = len(persona_ids) * len(persona_ids) * len(selected_layers)
|
| 723 |
+
if similarity_cells > _MAX_SIMILARITY_CELLS:
|
| 724 |
+
st.error(
|
| 725 |
+
"Reduce personas or layer frames before generating the similarity "
|
| 726 |
+
f"matrix ({similarity_cells:,} cells selected)."
|
| 727 |
+
)
|
| 728 |
+
return
|
| 729 |
+
|
| 730 |
n_clusters = None
|
| 731 |
if figure_kind in {"pca", "umap"}:
|
| 732 |
use_kmeans = st.toggle(
|
|
|
|
| 756 |
variant,
|
| 757 |
"persona_vector",
|
| 758 |
persona_key,
|
| 759 |
+
"_".join(map(str, selected_layers)),
|
| 760 |
+
str(pair_trajectories),
|
| 761 |
)
|
| 762 |
filename = scope
|
| 763 |
+
_clear_old_figure_states(fig_key)
|
| 764 |
|
| 765 |
if st.button(button_label, type="primary"):
|
| 766 |
build_label = {
|
|
|
|
| 771 |
progress = st.progress(0, text="Loading activation vectors…")
|
| 772 |
try:
|
| 773 |
progress.progress(15, text="Loading activation vectors…")
|
| 774 |
+
samples = _load_persona_vectors(
|
| 775 |
store,
|
| 776 |
variant,
|
| 777 |
+
mask_strategy,
|
| 778 |
+
persona_ids,
|
| 779 |
)
|
| 780 |
progress.progress(55, text=build_label)
|
| 781 |
build_kwargs = {}
|
|
|
|
| 783 |
build_kwargs["n_components"] = n_components
|
| 784 |
if n_clusters is not None:
|
| 785 |
build_kwargs["n_clusters"] = n_clusters
|
| 786 |
+
if figure_kind == "similarity" and pair_trajectories:
|
| 787 |
main_fig, extra_fig = build_similarity_figures(
|
| 788 |
samples,
|
| 789 |
layers=selected_layers,
|
|
|
|
| 812 |
f"{prompt_variant_label(variant)} - persona vectors"
|
| 813 |
),
|
| 814 |
)
|
| 815 |
+
if pair_trajectories
|
| 816 |
else None
|
| 817 |
)
|
| 818 |
progress.progress(90, text="Storing figure state…")
|
| 819 |
+
n_samples = samples.vectors.shape[0]
|
| 820 |
+
del samples
|
| 821 |
+
_store_figure_state(fig_key, (main_fig, extra_fig, n_samples))
|
| 822 |
progress.progress(100, text="Done.")
|
| 823 |
except Exception as exc:
|
| 824 |
st.error(f"Could not build figure: {exc}")
|
| 825 |
st.session_state.pop(fig_key, None)
|
| 826 |
finally:
|
| 827 |
+
_release_vector_memory(store, [variant])
|
| 828 |
progress.empty()
|
| 829 |
|
| 830 |
if fig_key in st.session_state:
|
|
|
|
| 886 |
mask_strategy,
|
| 887 |
widget_scope=f"dendro:{store_id(store)}",
|
| 888 |
remember_key=_LAST_DENDRO_PERSONAS_KEY,
|
| 889 |
+
default_count_limit=_DEFAULT_PERSONA_LIMITS["dendro"],
|
| 890 |
)
|
| 891 |
if not persona_ids:
|
| 892 |
return
|
|
|
|
| 907 |
key=widget_key("load", "dendro_linkage", store_id(store)),
|
| 908 |
)
|
| 909 |
|
| 910 |
+
selected_layers: list[int] | None = None
|
| 911 |
+
if layered_mode:
|
| 912 |
+
source, location, model_name = store_cache_parts(store)
|
| 913 |
+
layer_options = store_layers_cached(
|
| 914 |
+
source,
|
| 915 |
+
location,
|
| 916 |
+
model_name,
|
| 917 |
+
mask_strategy.value,
|
| 918 |
+
tuple(shared_variants),
|
| 919 |
+
tuple(persona_ids),
|
| 920 |
+
)
|
| 921 |
+
if not layer_options:
|
| 922 |
+
st.info("No shared layers are available for the selected personas.")
|
| 923 |
+
return
|
| 924 |
+
selected_layers = _render_layer_frame_controls(store, "dendro", layer_options)
|
| 925 |
+
|
| 926 |
persona_key = personas_fingerprint(persona_ids)
|
| 927 |
fig_key = widget_key(
|
| 928 |
"load",
|
|
|
|
| 935 |
persona_key,
|
| 936 |
str(layered_mode),
|
| 937 |
linkage,
|
| 938 |
+
"_".join(map(str, selected_layers or [])),
|
| 939 |
)
|
| 940 |
+
_clear_old_figure_states(fig_key)
|
| 941 |
|
| 942 |
if st.button(
|
| 943 |
"Generate dendrograms",
|
|
|
|
| 949 |
progress = st.progress(0, text="Loading first variant vectors…")
|
| 950 |
try:
|
| 951 |
progress.progress(15, text="Loading first variant vectors…")
|
| 952 |
+
samples_a = _load_persona_vectors(
|
| 953 |
store,
|
| 954 |
variant_a,
|
| 955 |
+
mask_strategy,
|
| 956 |
+
persona_ids,
|
| 957 |
)
|
| 958 |
progress.progress(40, text="Building first dendrogram…")
|
| 959 |
fig_a = plot_persona_dendrogram(
|
| 960 |
samples_a,
|
| 961 |
layered=layered_mode,
|
| 962 |
+
layers=selected_layers,
|
| 963 |
linkage=linkage,
|
| 964 |
title=f"Dendrogram — {prompt_variant_label(variant_a)}",
|
| 965 |
)
|
| 966 |
fig_a.update_layout(height=750)
|
| 967 |
+
del samples_a
|
| 968 |
fig_b = None
|
| 969 |
if variant_a != variant_b:
|
| 970 |
progress.progress(60, text="Loading second variant vectors…")
|
| 971 |
+
samples_b = _load_persona_vectors(
|
| 972 |
store,
|
| 973 |
variant_b,
|
| 974 |
+
mask_strategy,
|
| 975 |
+
persona_ids,
|
| 976 |
)
|
| 977 |
progress.progress(75, text="Building second dendrogram…")
|
| 978 |
fig_b = plot_persona_dendrogram(
|
| 979 |
samples_b,
|
| 980 |
layered=layered_mode,
|
| 981 |
+
layers=selected_layers,
|
| 982 |
linkage=linkage,
|
| 983 |
title=f"Dendrogram — {prompt_variant_label(variant_b)}",
|
| 984 |
)
|
| 985 |
fig_b.update_layout(height=750)
|
| 986 |
+
del samples_b
|
| 987 |
progress.progress(90, text="Storing figure state…")
|
| 988 |
+
_store_figure_state(
|
| 989 |
+
fig_key,
|
| 990 |
+
(fig_a, fig_b, len(persona_ids), variant_a, variant_b),
|
|
|
|
|
|
|
|
|
|
| 991 |
)
|
| 992 |
progress.progress(100, text="Done.")
|
| 993 |
except Exception as exc:
|
| 994 |
st.error(f"Could not build dendrogram: {exc}")
|
| 995 |
st.session_state.pop(fig_key, None)
|
| 996 |
finally:
|
| 997 |
+
_release_vector_memory(store, shared_variants)
|
| 998 |
progress.empty()
|
| 999 |
|
| 1000 |
if fig_key in st.session_state:
|
|
|
|
| 1205 |
f"Centered similarity - {prompt_variant_label(v)} - persona vectors"
|
| 1206 |
),
|
| 1207 |
include_pair_trajectories=True,
|
| 1208 |
+
remember_key=_LAST_SIMILARITY_PERSONAS_KEY,
|
| 1209 |
+
default_count_limit=_DEFAULT_PERSONA_LIMITS["similarity"],
|
| 1210 |
)
|
| 1211 |
return
|
| 1212 |
|
|
|
|
| 1233 |
f"{analysis_mode}{dim_suffix} - {prompt_variant_label(v)} - persona vectors"
|
| 1234 |
),
|
| 1235 |
n_components=n_components,
|
| 1236 |
+
default_count_limit=_DEFAULT_PERSONA_LIMITS[analysis_mode.lower()],
|
| 1237 |
)
|
utils/compare_sources.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
import streamlit as st
|
|
|
|
|
|
|
| 4 |
from persona_vectors.artifacts import (
|
| 5 |
ActivationStore,
|
| 6 |
HFActivationStore,
|
|
|
|
| 7 |
discover_activation_models,
|
| 8 |
model_dir_name,
|
| 9 |
)
|
|
@@ -22,6 +25,28 @@ SOURCE_LOCAL = "Local activations"
|
|
| 22 |
SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
@st.cache_resource(show_spinner=False, max_entries=1)
|
| 26 |
def activation_store_cached(
|
| 27 |
source: str,
|
|
@@ -54,6 +79,26 @@ def personas_cached(
|
|
| 54 |
mask_strategy_value: str,
|
| 55 |
variants: tuple[str, ...],
|
| 56 |
) -> list[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 58 |
return store.list_personas(
|
| 59 |
list(variants),
|
|
@@ -70,6 +115,25 @@ def persona_names_cached(
|
|
| 70 |
variants: tuple[str, ...],
|
| 71 |
persona_ids: tuple[str, ...],
|
| 72 |
) -> dict[str, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 74 |
return store.persona_names(
|
| 75 |
list(persona_ids),
|
|
@@ -126,6 +190,26 @@ def store_layers_cached(
|
|
| 126 |
variants: tuple[str, ...],
|
| 127 |
persona_ids: tuple[str, ...],
|
| 128 |
) -> list[int]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 130 |
return store.list_layers(
|
| 131 |
list(variants),
|
|
@@ -136,3 +220,100 @@ def store_layers_cached(
|
|
| 136 |
|
| 137 |
def local_model_matches(left: str, right: str) -> bool:
|
| 138 |
return model_dir_name(left) == model_dir_name(right)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
+
import torch
|
| 5 |
+
from persona_vectors.analysis import LayeredSamples
|
| 6 |
from persona_vectors.artifacts import (
|
| 7 |
ActivationStore,
|
| 8 |
HFActivationStore,
|
| 9 |
+
activation_config_name,
|
| 10 |
discover_activation_models,
|
| 11 |
model_dir_name,
|
| 12 |
)
|
|
|
|
| 25 |
SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
|
| 26 |
|
| 27 |
|
| 28 |
+
def _hub_split(repo_id: str, model_name: str, mask_strategy_value: str, variant: str):
|
| 29 |
+
from datasets import load_dataset
|
| 30 |
+
|
| 31 |
+
return load_dataset(
|
| 32 |
+
repo_id,
|
| 33 |
+
name=activation_config_name(model_name, mask_strategy_value),
|
| 34 |
+
split=variant,
|
| 35 |
+
keep_in_memory=False,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _hub_split_columns(
|
| 40 |
+
repo_id: str,
|
| 41 |
+
model_name: str,
|
| 42 |
+
mask_strategy_value: str,
|
| 43 |
+
variant: str,
|
| 44 |
+
columns: list[str],
|
| 45 |
+
):
|
| 46 |
+
dataset = _hub_split(repo_id, model_name, mask_strategy_value, variant)
|
| 47 |
+
return dataset.select_columns(columns)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
@st.cache_resource(show_spinner=False, max_entries=1)
|
| 51 |
def activation_store_cached(
|
| 52 |
source: str,
|
|
|
|
| 79 |
mask_strategy_value: str,
|
| 80 |
variants: tuple[str, ...],
|
| 81 |
) -> list[str]:
|
| 82 |
+
if source == SOURCE_HUB:
|
| 83 |
+
variant_ids = [
|
| 84 |
+
list(
|
| 85 |
+
_hub_split_columns(
|
| 86 |
+
location,
|
| 87 |
+
model_name,
|
| 88 |
+
mask_strategy_value,
|
| 89 |
+
variant,
|
| 90 |
+
["persona_id"],
|
| 91 |
+
)["persona_id"]
|
| 92 |
+
)
|
| 93 |
+
for variant in variants
|
| 94 |
+
]
|
| 95 |
+
if not variant_ids:
|
| 96 |
+
return []
|
| 97 |
+
shared = set(variant_ids[0])
|
| 98 |
+
for ids in variant_ids[1:]:
|
| 99 |
+
shared &= set(ids)
|
| 100 |
+
return [persona_id for persona_id in variant_ids[0] if persona_id in shared]
|
| 101 |
+
|
| 102 |
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 103 |
return store.list_personas(
|
| 104 |
list(variants),
|
|
|
|
| 115 |
variants: tuple[str, ...],
|
| 116 |
persona_ids: tuple[str, ...],
|
| 117 |
) -> dict[str, str]:
|
| 118 |
+
if source == SOURCE_HUB:
|
| 119 |
+
requested = set(persona_ids)
|
| 120 |
+
names: dict[str, str] = {}
|
| 121 |
+
for variant in variants:
|
| 122 |
+
metadata = _hub_split_columns(
|
| 123 |
+
location,
|
| 124 |
+
model_name,
|
| 125 |
+
mask_strategy_value,
|
| 126 |
+
variant,
|
| 127 |
+
["persona_id", "name"],
|
| 128 |
+
)
|
| 129 |
+
for row in metadata:
|
| 130 |
+
persona_id = row["persona_id"]
|
| 131 |
+
if persona_id in requested and persona_id not in names:
|
| 132 |
+
names[persona_id] = row.get("name") or persona_id
|
| 133 |
+
if len(names) == len(requested):
|
| 134 |
+
return {pid: names.get(pid, pid) for pid in persona_ids}
|
| 135 |
+
return {pid: names.get(pid, pid) for pid in persona_ids}
|
| 136 |
+
|
| 137 |
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 138 |
return store.persona_names(
|
| 139 |
list(persona_ids),
|
|
|
|
| 190 |
variants: tuple[str, ...],
|
| 191 |
persona_ids: tuple[str, ...],
|
| 192 |
) -> list[int]:
|
| 193 |
+
if source == SOURCE_HUB:
|
| 194 |
+
shared_layers: set[int] | None = None
|
| 195 |
+
requested = list(persona_ids)
|
| 196 |
+
for variant in variants:
|
| 197 |
+
dataset = _hub_split(location, model_name, mask_strategy_value, variant)
|
| 198 |
+
ids = list(dataset.select_columns(["persona_id"])["persona_id"])
|
| 199 |
+
sample_id = requested[0] if requested else (ids[0] if ids else None)
|
| 200 |
+
if sample_id is None:
|
| 201 |
+
return []
|
| 202 |
+
if requested and any(persona_id not in ids for persona_id in requested):
|
| 203 |
+
return []
|
| 204 |
+
vector = torch.as_tensor(dataset[ids.index(sample_id)]["vector"])
|
| 205 |
+
if vector.ndim != 2:
|
| 206 |
+
raise ValueError(
|
| 207 |
+
f"tensor for {sample_id!r} must have shape (num_layers, hidden_size)"
|
| 208 |
+
)
|
| 209 |
+
layers = set(range(int(vector.shape[0])))
|
| 210 |
+
shared_layers = layers if shared_layers is None else shared_layers & layers
|
| 211 |
+
return sorted(shared_layers or set())
|
| 212 |
+
|
| 213 |
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 214 |
return store.list_layers(
|
| 215 |
list(variants),
|
|
|
|
| 220 |
|
| 221 |
def local_model_matches(left: str, right: str) -> bool:
|
| 222 |
return model_dir_name(left) == model_dir_name(right)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def load_persona_vectors_lean(
|
| 226 |
+
source: str,
|
| 227 |
+
location: str,
|
| 228 |
+
model_name: str,
|
| 229 |
+
mask_strategy_value: str,
|
| 230 |
+
variant: str,
|
| 231 |
+
persona_ids: tuple[str, ...],
|
| 232 |
+
) -> LayeredSamples:
|
| 233 |
+
if source != SOURCE_HUB:
|
| 234 |
+
from persona_vectors.analysis import load_persona_vectors
|
| 235 |
+
|
| 236 |
+
store = activation_store_cached(
|
| 237 |
+
source,
|
| 238 |
+
location,
|
| 239 |
+
model_name,
|
| 240 |
+
mask_strategy_value,
|
| 241 |
+
)
|
| 242 |
+
return load_persona_vectors(
|
| 243 |
+
store,
|
| 244 |
+
variant,
|
| 245 |
+
mask_strategy=MaskStrategy(mask_strategy_value),
|
| 246 |
+
persona_ids=list(persona_ids),
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
dataset = _hub_split(location, model_name, mask_strategy_value, variant)
|
| 250 |
+
metadata = dataset.select_columns(["persona_id", "name"])
|
| 251 |
+
index_by_id: dict[str, int] = {}
|
| 252 |
+
name_by_id: dict[str, str] = {}
|
| 253 |
+
requested = set(persona_ids)
|
| 254 |
+
for index, row in enumerate(metadata):
|
| 255 |
+
persona_id = row["persona_id"]
|
| 256 |
+
if persona_id in requested:
|
| 257 |
+
index_by_id[persona_id] = index
|
| 258 |
+
name_by_id[persona_id] = row.get("name") or persona_id
|
| 259 |
+
if len(index_by_id) == len(requested):
|
| 260 |
+
break
|
| 261 |
+
|
| 262 |
+
missing = [
|
| 263 |
+
persona_id for persona_id in persona_ids if persona_id not in index_by_id
|
| 264 |
+
]
|
| 265 |
+
if missing:
|
| 266 |
+
raise FileNotFoundError(
|
| 267 |
+
f"Missing {len(missing)} persona vector(s) in {variant!r}: {missing[:3]}"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
vectors, labels, hover_text = [], [], []
|
| 271 |
+
for persona_id in persona_ids:
|
| 272 |
+
name = name_by_id.get(persona_id, persona_id)
|
| 273 |
+
vector = torch.as_tensor(
|
| 274 |
+
dataset[index_by_id[persona_id]]["vector"],
|
| 275 |
+
dtype=torch.float32,
|
| 276 |
+
)
|
| 277 |
+
if vector.ndim != 2:
|
| 278 |
+
raise ValueError(
|
| 279 |
+
f"tensor for {persona_id!r} must have shape (num_layers, hidden_size)"
|
| 280 |
+
)
|
| 281 |
+
vectors.append(vector)
|
| 282 |
+
labels.append(name)
|
| 283 |
+
hover_text.append(f"Persona: {name}<br>ID: {persona_id}")
|
| 284 |
+
return LayeredSamples(torch.stack(vectors), labels, hover_text)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def load_variant_vectors_lean(
|
| 288 |
+
source: str,
|
| 289 |
+
location: str,
|
| 290 |
+
model_name: str,
|
| 291 |
+
mask_strategy_value: str,
|
| 292 |
+
variants: tuple[str, ...],
|
| 293 |
+
persona_ids: tuple[str, ...],
|
| 294 |
+
) -> dict[str, LayeredSamples]:
|
| 295 |
+
return {
|
| 296 |
+
variant: load_persona_vectors_lean(
|
| 297 |
+
source,
|
| 298 |
+
location,
|
| 299 |
+
model_name,
|
| 300 |
+
mask_strategy_value,
|
| 301 |
+
variant,
|
| 302 |
+
persona_ids,
|
| 303 |
+
)
|
| 304 |
+
for variant in variants
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def release_store_cache(
|
| 309 |
+
store: Store,
|
| 310 |
+
variants: list[str] | tuple[str, ...] | None = None,
|
| 311 |
+
) -> None:
|
| 312 |
+
cache = getattr(store, "_cache", None)
|
| 313 |
+
if not isinstance(cache, dict):
|
| 314 |
+
return
|
| 315 |
+
if variants is None:
|
| 316 |
+
cache.clear()
|
| 317 |
+
return
|
| 318 |
+
for variant in variants:
|
| 319 |
+
cache.pop(variant, None)
|