Jac-Zac commited on
Commit ·
9edffb7
1
Parent(s): fee1567
Updated to latest probing options
Browse files- Cleaned up repo
- Improved performance drastically updating to the latest versions of
the librarires + less reloading smarter caching and prefetchign
- .env.example +2 -0
- README.md +6 -0
- pyproject.toml +1 -1
- tabs/analysis/_shared.py +14 -17
- tabs/analysis/_state.py +5 -5
- tabs/analysis/cosine.py +11 -18
- tabs/analysis/dendrogram.py +8 -12
- tabs/analysis/layered.py +52 -34
- tabs/probe.py +150 -79
- tests/test_probes.py +32 -0
- utils/analysis_sources.py +94 -24
- utils/probes.py +14 -9
- uv.lock +4 -4
.env.example
CHANGED
|
@@ -23,3 +23,5 @@ ARTIFACTS_DIR=artifacts
|
|
| 23 |
# Keep model cache at 1 unless you have enough RAM for multiple loaded models.
|
| 24 |
# PERSONA_UI_MODEL_CACHE_ENTRIES=1
|
| 25 |
# PERSONA_UI_STORE_CACHE_ENTRIES=4
|
|
|
|
|
|
|
|
|
| 23 |
# Keep model cache at 1 unless you have enough RAM for multiple loaded models.
|
| 24 |
# PERSONA_UI_MODEL_CACHE_ENTRIES=1
|
| 25 |
# PERSONA_UI_STORE_CACHE_ENTRIES=4
|
| 26 |
+
# PERSONA_UI_VECTOR_CACHE_ENTRIES=4
|
| 27 |
+
# PERSONA_UI_PREPARED_CACHE_ENTRIES=8
|
README.md
CHANGED
|
@@ -116,6 +116,8 @@ NDIF_API_KEY=... # Required for remote (NDIF) model execution
|
|
| 116 |
HF_HOME=... # Optional: HuggingFace cache directory
|
| 117 |
ARTIFACTS_DIR=... # Optional: where persona vectors are read from (default: ./artifacts)
|
| 118 |
PERSONA_VECTORS_HUB_REPO=... # Optional: default Analysis/Probing Hub dataset repo
|
|
|
|
|
|
|
| 119 |
```
|
| 120 |
|
| 121 |
The app picks up this file automatically via `load_dotenv()` on startup.
|
|
@@ -148,3 +150,7 @@ the Analysis/Probing tab's Local source path) at the tree you want to load.
|
|
| 148 |
|
| 149 |
The store classes are `PersonaVectorStore` (local) and `HFPersonaVectorStore`
|
| 150 |
(Hub) — same API, both imported by `utils/analysis_sources.py`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
HF_HOME=... # Optional: HuggingFace cache directory
|
| 117 |
ARTIFACTS_DIR=... # Optional: where persona vectors are read from (default: ./artifacts)
|
| 118 |
PERSONA_VECTORS_HUB_REPO=... # Optional: default Analysis/Probing Hub dataset repo
|
| 119 |
+
PERSONA_UI_VECTOR_CACHE_ENTRIES=4 # Optional: loaded analysis datasets kept warm
|
| 120 |
+
PERSONA_UI_PREPARED_CACHE_ENTRIES=8 # Optional: prepared projections / k-means groups kept warm
|
| 121 |
```
|
| 122 |
|
| 123 |
The app picks up this file automatically via `load_dotenv()` on startup.
|
|
|
|
| 150 |
|
| 151 |
The store classes are `PersonaVectorStore` (local) and `HFPersonaVectorStore`
|
| 152 |
(Hub) — same API, both imported by `utils/analysis_sources.py`.
|
| 153 |
+
|
| 154 |
+
## Analysis responsiveness
|
| 155 |
+
|
| 156 |
+
The Analysis tab keeps a small bounded cache of loaded vector datasets and prepared projection data. Once a projection has been computed, recoloring it by persona, attribute, or k-means group reuses the same coordinates; nearby Hub interactions also keep metadata warm instead of re-scanning after every figure. Tune `PERSONA_UI_VECTOR_CACHE_ENTRIES` if RAM is tight or you regularly switch among many selections, and `PERSONA_UI_PREPARED_CACHE_ENTRIES` if you revisit several projection configurations in one session.
|
pyproject.toml
CHANGED
|
@@ -5,7 +5,7 @@ description = "Streamlit UI for persona-vectors"
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
-
"persona-vectors>=0.8.
|
| 9 |
"datasets>=4.8.5",
|
| 10 |
"huggingface-hub>=1.14.0",
|
| 11 |
"streamlit>=1.44.0",
|
|
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
+
"persona-vectors>=0.8.3",
|
| 9 |
"datasets>=4.8.5",
|
| 10 |
"huggingface-hub>=1.14.0",
|
| 11 |
"streamlit>=1.44.0",
|
tabs/analysis/_shared.py
CHANGED
|
@@ -6,6 +6,19 @@ from persona_data.synth_persona import BASELINE_PERSONA_ID
|
|
| 6 |
from persona_vectors.extraction import MaskStrategy
|
| 7 |
from persona_vectors.plots import save_plot_html
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from utils.analysis_sources import (
|
| 10 |
Store,
|
| 11 |
available_variants,
|
|
@@ -13,7 +26,6 @@ from utils.analysis_sources import (
|
|
| 13 |
load_variant_vectors_cached,
|
| 14 |
persona_names_cached,
|
| 15 |
personas_cached,
|
| 16 |
-
release_hf_store_cache,
|
| 17 |
store_cache_parts,
|
| 18 |
store_id,
|
| 19 |
store_layers_cached,
|
|
@@ -22,20 +34,6 @@ from utils.controls import render_mask_strategy_select
|
|
| 22 |
from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
|
| 23 |
from utils.theme import active_base, style_plotly_layer_controls
|
| 24 |
|
| 25 |
-
from tabs.analysis._state import (
|
| 26 |
-
_DEFAULT_LAYER_FRAMES,
|
| 27 |
-
_HIGHLIGHT_OTHER_COLOR,
|
| 28 |
-
_HIGHLIGHT_OTHER_LABEL,
|
| 29 |
-
_LAST_LAYER_FRAMES_KEY,
|
| 30 |
-
_LAST_MASK_STRATEGY_KEY,
|
| 31 |
-
PersonaOptions,
|
| 32 |
-
_is_assistant_persona,
|
| 33 |
-
_persona_names_state_key,
|
| 34 |
-
_personas_empty_message,
|
| 35 |
-
_remembered_selectbox,
|
| 36 |
-
_sequence_to_list,
|
| 37 |
-
)
|
| 38 |
-
|
| 39 |
|
| 40 |
def _gray_out_unselected_personas(fig: go.Figure) -> None:
|
| 41 |
def _gray_trace(trace: object) -> None:
|
|
@@ -118,8 +116,7 @@ def _load_variant_vectors(
|
|
| 118 |
)
|
| 119 |
|
| 120 |
|
| 121 |
-
def _release_vector_memory(
|
| 122 |
-
release_hf_store_cache(store, variants)
|
| 123 |
gc.collect()
|
| 124 |
|
| 125 |
|
|
|
|
| 6 |
from persona_vectors.extraction import MaskStrategy
|
| 7 |
from persona_vectors.plots import save_plot_html
|
| 8 |
|
| 9 |
+
from tabs.analysis._state import (
|
| 10 |
+
_DEFAULT_LAYER_FRAMES,
|
| 11 |
+
_HIGHLIGHT_OTHER_COLOR,
|
| 12 |
+
_HIGHLIGHT_OTHER_LABEL,
|
| 13 |
+
_LAST_LAYER_FRAMES_KEY,
|
| 14 |
+
_LAST_MASK_STRATEGY_KEY,
|
| 15 |
+
PersonaOptions,
|
| 16 |
+
_is_assistant_persona,
|
| 17 |
+
_persona_names_state_key,
|
| 18 |
+
_personas_empty_message,
|
| 19 |
+
_remembered_selectbox,
|
| 20 |
+
_sequence_to_list,
|
| 21 |
+
)
|
| 22 |
from utils.analysis_sources import (
|
| 23 |
Store,
|
| 24 |
available_variants,
|
|
|
|
| 26 |
load_variant_vectors_cached,
|
| 27 |
persona_names_cached,
|
| 28 |
personas_cached,
|
|
|
|
| 29 |
store_cache_parts,
|
| 30 |
store_id,
|
| 31 |
store_layers_cached,
|
|
|
|
| 34 |
from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
|
| 35 |
from utils.theme import active_base, style_plotly_layer_controls
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
def _gray_out_unselected_personas(fig: go.Figure) -> None:
|
| 39 |
def _gray_trace(trace: object) -> None:
|
|
|
|
| 116 |
)
|
| 117 |
|
| 118 |
|
| 119 |
+
def _release_vector_memory() -> None:
|
|
|
|
| 120 |
gc.collect()
|
| 121 |
|
| 122 |
|
tabs/analysis/_state.py
CHANGED
|
@@ -45,7 +45,7 @@ _CLUSTER_MODES = {
|
|
| 45 |
"First selected layer": "first_layer",
|
| 46 |
"Per layer": "per_layer",
|
| 47 |
}
|
| 48 |
-
_PROJECTION_COLOR_MODES = ["Persona", "K-means clusters"
|
| 49 |
_MAX_ATTRIBUTE_CATEGORIES = DEFAULT_MAX_ATTRIBUTE_CATEGORIES
|
| 50 |
|
| 51 |
|
|
@@ -87,7 +87,7 @@ class ProjectionColorConfig:
|
|
| 87 |
@dataclass(frozen=True)
|
| 88 |
class LayeredFigureStateKeys:
|
| 89 |
figure: str
|
| 90 |
-
|
| 91 |
|
| 92 |
|
| 93 |
_HIGHLIGHT_OTHER_LABEL = "Other"
|
|
@@ -139,7 +139,7 @@ _TRACKED_STATE_KEYS_KEY = "analysis:_tracked_state_keys"
|
|
| 139 |
|
| 140 |
|
| 141 |
def _clear_old_load_states(current_key: str, suffix: str) -> None:
|
| 142 |
-
# Only one heavy figure
|
| 143 |
# the keys we create per suffix so eviction is O(1) instead of scanning
|
| 144 |
# all of session_state on every rerun. Every such key is passed through
|
| 145 |
# this function before it is set, so the registry stays authoritative.
|
|
@@ -156,8 +156,8 @@ def _clear_old_figure_states(current_key: str) -> None:
|
|
| 156 |
_clear_old_load_states(current_key, "_fig_state")
|
| 157 |
|
| 158 |
|
| 159 |
-
def
|
| 160 |
-
_clear_old_load_states(current_key, "
|
| 161 |
|
| 162 |
|
| 163 |
def _store_figure_state(key: str, value: object) -> None:
|
|
|
|
| 45 |
"First selected layer": "first_layer",
|
| 46 |
"Per layer": "per_layer",
|
| 47 |
}
|
| 48 |
+
_PROJECTION_COLOR_MODES = ["Persona attribute", "Persona", "K-means clusters"]
|
| 49 |
_MAX_ATTRIBUTE_CATEGORIES = DEFAULT_MAX_ATTRIBUTE_CATEGORIES
|
| 50 |
|
| 51 |
|
|
|
|
| 87 |
@dataclass(frozen=True)
|
| 88 |
class LayeredFigureStateKeys:
|
| 89 |
figure: str
|
| 90 |
+
prepared: str | None = None
|
| 91 |
|
| 92 |
|
| 93 |
_HIGHLIGHT_OTHER_LABEL = "Other"
|
|
|
|
| 139 |
|
| 140 |
|
| 141 |
def _clear_old_load_states(current_key: str, suffix: str) -> None:
|
| 142 |
+
# Only one heavy figure state should live at a time. We track
|
| 143 |
# the keys we create per suffix so eviction is O(1) instead of scanning
|
| 144 |
# all of session_state on every rerun. Every such key is passed through
|
| 145 |
# this function before it is set, so the registry stays authoritative.
|
|
|
|
| 156 |
_clear_old_load_states(current_key, "_fig_state")
|
| 157 |
|
| 158 |
|
| 159 |
+
def _clear_old_prepared_states(current_key: str) -> None:
|
| 160 |
+
_clear_old_load_states(current_key, "_projection_ready")
|
| 161 |
|
| 162 |
|
| 163 |
def _store_figure_state(key: str, value: object) -> None:
|
tabs/analysis/cosine.py
CHANGED
|
@@ -78,22 +78,15 @@ def _build_cosine_figures(
|
|
| 78 |
mask_strategy: MaskStrategy,
|
| 79 |
selection: CosineSelection,
|
| 80 |
) -> tuple[object, object | None, int, int] | None:
|
| 81 |
-
variant_sample_cache: dict[str, object] = {}
|
| 82 |
-
|
| 83 |
-
def _load_variant(variant: str):
|
| 84 |
-
if variant not in variant_sample_cache:
|
| 85 |
-
samples = _load_variant_vectors(
|
| 86 |
-
store,
|
| 87 |
-
[variant],
|
| 88 |
-
mask_strategy,
|
| 89 |
-
persona_ids=selection.persona_ids,
|
| 90 |
-
)
|
| 91 |
-
variant_sample_cache[variant] = samples[variant]
|
| 92 |
-
return variant_sample_cache[variant]
|
| 93 |
-
|
| 94 |
try:
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
except Exception as exc:
|
| 98 |
st.error(f"Could not load vectors: {exc}")
|
| 99 |
return None
|
|
@@ -120,8 +113,8 @@ def _build_cosine_figures(
|
|
| 120 |
pair_errors = []
|
| 121 |
for left, right in combinations(selection.variants, 2):
|
| 122 |
try:
|
| 123 |
-
left_samples =
|
| 124 |
-
right_samples =
|
| 125 |
pair_traces.append(
|
| 126 |
(
|
| 127 |
f"{prompt_variant_label(left)} vs {prompt_variant_label(right)}",
|
|
@@ -207,7 +200,7 @@ def _render_cosine_similarity(
|
|
| 207 |
_store_figure_state(cosine_fig_key, figures)
|
| 208 |
progress.progress(100, text="Done.")
|
| 209 |
finally:
|
| 210 |
-
_release_vector_memory(
|
| 211 |
progress.empty()
|
| 212 |
|
| 213 |
if cosine_fig_key in st.session_state:
|
|
|
|
| 78 |
mask_strategy: MaskStrategy,
|
| 79 |
selection: CosineSelection,
|
| 80 |
) -> tuple[object, object | None, int, int] | None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
try:
|
| 82 |
+
by_variant = _load_variant_vectors(
|
| 83 |
+
store,
|
| 84 |
+
selection.variants,
|
| 85 |
+
mask_strategy,
|
| 86 |
+
persona_ids=selection.persona_ids,
|
| 87 |
+
)
|
| 88 |
+
samples_a = by_variant[selection.variant_a]
|
| 89 |
+
samples_b = by_variant[selection.variant_b]
|
| 90 |
except Exception as exc:
|
| 91 |
st.error(f"Could not load vectors: {exc}")
|
| 92 |
return None
|
|
|
|
| 113 |
pair_errors = []
|
| 114 |
for left, right in combinations(selection.variants, 2):
|
| 115 |
try:
|
| 116 |
+
left_samples = by_variant[left]
|
| 117 |
+
right_samples = by_variant[right]
|
| 118 |
pair_traces.append(
|
| 119 |
(
|
| 120 |
f"{prompt_variant_label(left)} vs {prompt_variant_label(right)}",
|
|
|
|
| 200 |
_store_figure_state(cosine_fig_key, figures)
|
| 201 |
progress.progress(100, text="Done.")
|
| 202 |
finally:
|
| 203 |
+
_release_vector_memory()
|
| 204 |
progress.empty()
|
| 205 |
|
| 206 |
if cosine_fig_key in st.session_state:
|
tabs/analysis/dendrogram.py
CHANGED
|
@@ -13,7 +13,7 @@ from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
|
|
| 13 |
|
| 14 |
from tabs.analysis._shared import (
|
| 15 |
_load_persona_options,
|
| 16 |
-
|
| 17 |
_plotly_chart,
|
| 18 |
_release_vector_memory,
|
| 19 |
_render_layer_frame_controls,
|
|
@@ -204,13 +204,14 @@ def _render_dendrogram_analysis(
|
|
| 204 |
):
|
| 205 |
progress = st.progress(0, text="Loading first variant vectors…")
|
| 206 |
try:
|
| 207 |
-
progress.progress(15, text="Loading
|
| 208 |
-
|
| 209 |
store,
|
| 210 |
-
|
| 211 |
mask_strategy,
|
| 212 |
persona_ids,
|
| 213 |
)
|
|
|
|
| 214 |
progress.progress(40, text="Building first dendrogram…")
|
| 215 |
fig_a = plot_persona_dendrogram(
|
| 216 |
samples_a,
|
|
@@ -223,13 +224,8 @@ def _render_dendrogram_analysis(
|
|
| 223 |
del samples_a
|
| 224 |
fig_b = None
|
| 225 |
if variant_a != variant_b:
|
| 226 |
-
progress.progress(60, text="
|
| 227 |
-
samples_b =
|
| 228 |
-
store,
|
| 229 |
-
variant_b,
|
| 230 |
-
mask_strategy,
|
| 231 |
-
persona_ids,
|
| 232 |
-
)
|
| 233 |
progress.progress(75, text="Building second dendrogram…")
|
| 234 |
fig_b = plot_persona_dendrogram(
|
| 235 |
samples_b,
|
|
@@ -250,7 +246,7 @@ def _render_dendrogram_analysis(
|
|
| 250 |
st.error(f"Could not build dendrogram: {exc}")
|
| 251 |
st.session_state.pop(fig_key, None)
|
| 252 |
finally:
|
| 253 |
-
_release_vector_memory(
|
| 254 |
progress.empty()
|
| 255 |
|
| 256 |
if fig_key in st.session_state:
|
|
|
|
| 13 |
|
| 14 |
from tabs.analysis._shared import (
|
| 15 |
_load_persona_options,
|
| 16 |
+
_load_variant_vectors,
|
| 17 |
_plotly_chart,
|
| 18 |
_release_vector_memory,
|
| 19 |
_render_layer_frame_controls,
|
|
|
|
| 204 |
):
|
| 205 |
progress = st.progress(0, text="Loading first variant vectors…")
|
| 206 |
try:
|
| 207 |
+
progress.progress(15, text="Loading variant vectors…")
|
| 208 |
+
by_variant = _load_variant_vectors(
|
| 209 |
store,
|
| 210 |
+
shared_variants,
|
| 211 |
mask_strategy,
|
| 212 |
persona_ids,
|
| 213 |
)
|
| 214 |
+
samples_a = by_variant[variant_a]
|
| 215 |
progress.progress(40, text="Building first dendrogram…")
|
| 216 |
fig_a = plot_persona_dendrogram(
|
| 217 |
samples_a,
|
|
|
|
| 224 |
del samples_a
|
| 225 |
fig_b = None
|
| 226 |
if variant_a != variant_b:
|
| 227 |
+
progress.progress(60, text="Building second dendrogram…")
|
| 228 |
+
samples_b = by_variant[variant_b]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
progress.progress(75, text="Building second dendrogram…")
|
| 230 |
fig_b = plot_persona_dendrogram(
|
| 231 |
samples_b,
|
|
|
|
| 246 |
st.error(f"Could not build dendrogram: {exc}")
|
| 247 |
st.session_state.pop(fig_key, None)
|
| 248 |
finally:
|
| 249 |
+
_release_vector_memory()
|
| 250 |
progress.empty()
|
| 251 |
|
| 252 |
if fig_key in st.session_state:
|
tabs/analysis/layered.py
CHANGED
|
@@ -11,14 +11,19 @@ from persona_vectors.plots import (
|
|
| 11 |
build_layered_figure,
|
| 12 |
build_pair_similarity_figure,
|
| 13 |
build_similarity_figures,
|
| 14 |
-
prepare_layered_projection_data,
|
| 15 |
)
|
| 16 |
|
| 17 |
from utils.analysis_metadata import (
|
| 18 |
synth_persona_attribute_names,
|
| 19 |
synth_persona_dataset_cached,
|
| 20 |
)
|
| 21 |
-
from utils.analysis_sources import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
|
| 23 |
|
| 24 |
from tabs.analysis._shared import (
|
|
@@ -48,7 +53,7 @@ from tabs.analysis._state import (
|
|
| 48 |
LayeredFigureStateKeys,
|
| 49 |
ProjectionColorConfig,
|
| 50 |
_clear_old_figure_states,
|
| 51 |
-
|
| 52 |
_highlight_persona_groups,
|
| 53 |
_persona_display_label,
|
| 54 |
_persona_names_state_key,
|
|
@@ -116,7 +121,7 @@ def _render_projection_color_config(
|
|
| 116 |
key=color_mode_key,
|
| 117 |
remember_key=_LAST_PROJECTION_COLOR_MODE_KEY,
|
| 118 |
options=_PROJECTION_COLOR_MODES,
|
| 119 |
-
default="Persona",
|
| 120 |
)
|
| 121 |
if color_mode == "K-means clusters":
|
| 122 |
max_clusters = min(10, len(persona_ids))
|
|
@@ -265,36 +270,34 @@ def _layered_figure_state_keys(
|
|
| 265 |
)
|
| 266 |
if figure_kind not in _PROJECTION_KINDS:
|
| 267 |
return LayeredFigureStateKeys(figure=figure_key)
|
| 268 |
-
|
| 269 |
-
graph_overlay = figure_kind == "isomap"
|
| 270 |
-
projection_key = widget_key(
|
| 271 |
"load",
|
| 272 |
-
f"{scope}
|
| 273 |
store_id(store),
|
| 274 |
store.model_name,
|
| 275 |
mask_strategy.value,
|
| 276 |
figure_kind,
|
| 277 |
str(n_components),
|
| 278 |
-
str(
|
| 279 |
str(_DEFAULT_GRAPH_NEIGHBORS),
|
| 280 |
variant,
|
| 281 |
-
"persona_vector",
|
| 282 |
persona_key,
|
| 283 |
layer_key,
|
| 284 |
)
|
| 285 |
-
return LayeredFigureStateKeys(figure=figure_key,
|
| 286 |
|
| 287 |
|
| 288 |
def _projection_build_kwargs(
|
| 289 |
-
samples,
|
| 290 |
*,
|
|
|
|
|
|
|
|
|
|
| 291 |
figure_kind: str,
|
| 292 |
selected_layers: list[int],
|
| 293 |
n_components: int,
|
| 294 |
color_config: ProjectionColorConfig,
|
| 295 |
persona_ids: list[str],
|
| 296 |
persona_names: dict[str, str],
|
| 297 |
-
projection_key: str | None,
|
| 298 |
) -> dict:
|
| 299 |
if figure_kind not in _PROJECTION_KINDS:
|
| 300 |
return {}
|
|
@@ -305,22 +308,29 @@ def _projection_build_kwargs(
|
|
| 305 |
"graph_overlay": graph_overlay,
|
| 306 |
"graph_n_neighbors": _DEFAULT_GRAPH_NEIGHBORS,
|
| 307 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
if color_config.n_clusters is not None:
|
| 309 |
-
build_kwargs["
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
projection_data = prepare_layered_projection_data(
|
| 315 |
-
samples,
|
| 316 |
-
figure_kind,
|
| 317 |
-
layers=selected_layers,
|
| 318 |
-
n_components=n_components,
|
| 319 |
-
graph_overlay=graph_overlay,
|
| 320 |
-
graph_n_neighbors=_DEFAULT_GRAPH_NEIGHBORS,
|
| 321 |
-
)
|
| 322 |
-
st.session_state[projection_key] = projection_data
|
| 323 |
-
build_kwargs["projection_data"] = projection_data
|
| 324 |
if color_config.attribute_name is not None:
|
| 325 |
build_kwargs.update(
|
| 326 |
attribute_color_kwargs(
|
|
@@ -487,8 +497,6 @@ def _render_layered_figure_analysis(
|
|
| 487 |
selected_layers=selected_layers,
|
| 488 |
pair_trajectories=pair_trajectories,
|
| 489 |
)
|
| 490 |
-
if state_keys.projection is not None:
|
| 491 |
-
_clear_old_projection_states(state_keys.projection)
|
| 492 |
filename = scope
|
| 493 |
_clear_old_figure_states(state_keys.figure)
|
| 494 |
persona_names = st.session_state.get(
|
|
@@ -496,7 +504,13 @@ def _render_layered_figure_analysis(
|
|
| 496 |
{},
|
| 497 |
)
|
| 498 |
|
| 499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
build_label = {
|
| 501 |
"umap": "Computing UMAP projections…",
|
| 502 |
"pca": "Computing PCA projections…",
|
|
@@ -514,14 +528,15 @@ def _render_layered_figure_analysis(
|
|
| 514 |
)
|
| 515 |
progress.progress(55, text=build_label)
|
| 516 |
build_kwargs = _projection_build_kwargs(
|
| 517 |
-
|
|
|
|
|
|
|
| 518 |
figure_kind=figure_kind,
|
| 519 |
selected_layers=selected_layers,
|
| 520 |
n_components=n_components,
|
| 521 |
color_config=color_config,
|
| 522 |
persona_ids=persona_ids,
|
| 523 |
persona_names=persona_names,
|
| 524 |
-
projection_key=state_keys.projection,
|
| 525 |
)
|
| 526 |
main_fig, extra_fig = _build_layered_analysis_figures(
|
| 527 |
samples,
|
|
@@ -541,12 +556,15 @@ def _render_layered_figure_analysis(
|
|
| 541 |
n_samples = samples.vectors.shape[0]
|
| 542 |
del samples
|
| 543 |
_store_figure_state(state_keys.figure, (main_fig, extra_fig, n_samples))
|
|
|
|
|
|
|
|
|
|
| 544 |
progress.progress(100, text="Done.")
|
| 545 |
except Exception as exc:
|
| 546 |
st.error(f"Could not build figure: {exc}")
|
| 547 |
st.session_state.pop(state_keys.figure, None)
|
| 548 |
finally:
|
| 549 |
-
_release_vector_memory(
|
| 550 |
progress.empty()
|
| 551 |
|
| 552 |
if state_keys.figure in st.session_state:
|
|
|
|
| 11 |
build_layered_figure,
|
| 12 |
build_pair_similarity_figure,
|
| 13 |
build_similarity_figures,
|
|
|
|
| 14 |
)
|
| 15 |
|
| 16 |
from utils.analysis_metadata import (
|
| 17 |
synth_persona_attribute_names,
|
| 18 |
synth_persona_dataset_cached,
|
| 19 |
)
|
| 20 |
+
from utils.analysis_sources import (
|
| 21 |
+
Store,
|
| 22 |
+
kmeans_groups_cached,
|
| 23 |
+
projection_data_cached,
|
| 24 |
+
store_cache_parts,
|
| 25 |
+
store_id,
|
| 26 |
+
)
|
| 27 |
from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
|
| 28 |
|
| 29 |
from tabs.analysis._shared import (
|
|
|
|
| 53 |
LayeredFigureStateKeys,
|
| 54 |
ProjectionColorConfig,
|
| 55 |
_clear_old_figure_states,
|
| 56 |
+
_clear_old_prepared_states,
|
| 57 |
_highlight_persona_groups,
|
| 58 |
_persona_display_label,
|
| 59 |
_persona_names_state_key,
|
|
|
|
| 121 |
key=color_mode_key,
|
| 122 |
remember_key=_LAST_PROJECTION_COLOR_MODE_KEY,
|
| 123 |
options=_PROJECTION_COLOR_MODES,
|
| 124 |
+
default="Persona attribute",
|
| 125 |
)
|
| 126 |
if color_mode == "K-means clusters":
|
| 127 |
max_clusters = min(10, len(persona_ids))
|
|
|
|
| 270 |
)
|
| 271 |
if figure_kind not in _PROJECTION_KINDS:
|
| 272 |
return LayeredFigureStateKeys(figure=figure_key)
|
| 273 |
+
prepared_key = widget_key(
|
|
|
|
|
|
|
| 274 |
"load",
|
| 275 |
+
f"{scope}_projection_ready",
|
| 276 |
store_id(store),
|
| 277 |
store.model_name,
|
| 278 |
mask_strategy.value,
|
| 279 |
figure_kind,
|
| 280 |
str(n_components),
|
| 281 |
+
str(figure_kind == "isomap"),
|
| 282 |
str(_DEFAULT_GRAPH_NEIGHBORS),
|
| 283 |
variant,
|
|
|
|
| 284 |
persona_key,
|
| 285 |
layer_key,
|
| 286 |
)
|
| 287 |
+
return LayeredFigureStateKeys(figure=figure_key, prepared=prepared_key)
|
| 288 |
|
| 289 |
|
| 290 |
def _projection_build_kwargs(
|
|
|
|
| 291 |
*,
|
| 292 |
+
store: Store,
|
| 293 |
+
mask_strategy: MaskStrategy,
|
| 294 |
+
variant: str,
|
| 295 |
figure_kind: str,
|
| 296 |
selected_layers: list[int],
|
| 297 |
n_components: int,
|
| 298 |
color_config: ProjectionColorConfig,
|
| 299 |
persona_ids: list[str],
|
| 300 |
persona_names: dict[str, str],
|
|
|
|
| 301 |
) -> dict:
|
| 302 |
if figure_kind not in _PROJECTION_KINDS:
|
| 303 |
return {}
|
|
|
|
| 308 |
"graph_overlay": graph_overlay,
|
| 309 |
"graph_n_neighbors": _DEFAULT_GRAPH_NEIGHBORS,
|
| 310 |
}
|
| 311 |
+
source, location, model_name = store_cache_parts(store)
|
| 312 |
+
cache_args = (
|
| 313 |
+
source,
|
| 314 |
+
location,
|
| 315 |
+
model_name,
|
| 316 |
+
mask_strategy.value,
|
| 317 |
+
variant,
|
| 318 |
+
tuple(persona_ids),
|
| 319 |
+
tuple(selected_layers),
|
| 320 |
+
)
|
| 321 |
+
build_kwargs["projection_data"] = projection_data_cached(
|
| 322 |
+
*cache_args,
|
| 323 |
+
figure_kind,
|
| 324 |
+
n_components,
|
| 325 |
+
graph_overlay,
|
| 326 |
+
_DEFAULT_GRAPH_NEIGHBORS,
|
| 327 |
+
)
|
| 328 |
if color_config.n_clusters is not None:
|
| 329 |
+
build_kwargs["groups"] = kmeans_groups_cached(
|
| 330 |
+
*cache_args,
|
| 331 |
+
color_config.n_clusters,
|
| 332 |
+
color_config.cluster_mode or "mean_across_layers",
|
| 333 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
if color_config.attribute_name is not None:
|
| 335 |
build_kwargs.update(
|
| 336 |
attribute_color_kwargs(
|
|
|
|
| 497 |
selected_layers=selected_layers,
|
| 498 |
pair_trajectories=pair_trajectories,
|
| 499 |
)
|
|
|
|
|
|
|
| 500 |
filename = scope
|
| 501 |
_clear_old_figure_states(state_keys.figure)
|
| 502 |
persona_names = st.session_state.get(
|
|
|
|
| 504 |
{},
|
| 505 |
)
|
| 506 |
|
| 507 |
+
build_clicked = st.button(button_label, type="primary")
|
| 508 |
+
recolor_from_warm_projection = (
|
| 509 |
+
state_keys.prepared is not None
|
| 510 |
+
and bool(st.session_state.get(state_keys.prepared))
|
| 511 |
+
and state_keys.figure not in st.session_state
|
| 512 |
+
)
|
| 513 |
+
if build_clicked or recolor_from_warm_projection:
|
| 514 |
build_label = {
|
| 515 |
"umap": "Computing UMAP projections…",
|
| 516 |
"pca": "Computing PCA projections…",
|
|
|
|
| 528 |
)
|
| 529 |
progress.progress(55, text=build_label)
|
| 530 |
build_kwargs = _projection_build_kwargs(
|
| 531 |
+
store=store,
|
| 532 |
+
mask_strategy=mask_strategy,
|
| 533 |
+
variant=variant,
|
| 534 |
figure_kind=figure_kind,
|
| 535 |
selected_layers=selected_layers,
|
| 536 |
n_components=n_components,
|
| 537 |
color_config=color_config,
|
| 538 |
persona_ids=persona_ids,
|
| 539 |
persona_names=persona_names,
|
|
|
|
| 540 |
)
|
| 541 |
main_fig, extra_fig = _build_layered_analysis_figures(
|
| 542 |
samples,
|
|
|
|
| 556 |
n_samples = samples.vectors.shape[0]
|
| 557 |
del samples
|
| 558 |
_store_figure_state(state_keys.figure, (main_fig, extra_fig, n_samples))
|
| 559 |
+
if state_keys.prepared is not None:
|
| 560 |
+
_clear_old_prepared_states(state_keys.prepared)
|
| 561 |
+
st.session_state[state_keys.prepared] = True
|
| 562 |
progress.progress(100, text="Done.")
|
| 563 |
except Exception as exc:
|
| 564 |
st.error(f"Could not build figure: {exc}")
|
| 565 |
st.session_state.pop(state_keys.figure, None)
|
| 566 |
finally:
|
| 567 |
+
_release_vector_memory()
|
| 568 |
progress.empty()
|
| 569 |
|
| 570 |
if state_keys.figure in st.session_state:
|
tabs/probe.py
CHANGED
|
@@ -23,6 +23,7 @@ from persona_vectors.plots import plot_metric_comparison, plot_metric_over_layer
|
|
| 23 |
from persona_vectors.probes import (
|
| 24 |
AttributeLabels,
|
| 25 |
attribute_probe_labels,
|
|
|
|
| 26 |
filter_attribute_samples_min_count,
|
| 27 |
infer_probe_task,
|
| 28 |
layer_matrix,
|
|
@@ -85,8 +86,9 @@ class _SweepInputs:
|
|
| 85 |
mask_value: str
|
| 86 |
variant: str
|
| 87 |
persona_ids: tuple[str, ...]
|
| 88 |
-
|
| 89 |
task: str
|
|
|
|
| 90 |
n_pca_components: int | None
|
| 91 |
layers: tuple[int, ...]
|
| 92 |
min_class_count: int
|
|
@@ -234,22 +236,62 @@ def _select_personas(
|
|
| 234 |
# ---------------------------------------------------------------------------
|
| 235 |
|
| 236 |
|
| 237 |
-
|
|
|
|
| 238 |
dataset = synth_persona_dataset_cached()
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
else:
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
|
|
|
| 246 |
options=options,
|
| 247 |
-
index=default_index,
|
| 248 |
format_func=lambda name: attribute_display_label(dataset, name),
|
| 249 |
-
key=
|
|
|
|
|
|
|
| 250 |
)
|
| 251 |
|
| 252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
def _select_pca_components() -> int | None:
|
| 254 |
use_pca = st.toggle(
|
| 255 |
"Add PCA-compressed comparison",
|
|
@@ -298,61 +340,78 @@ def _select_layers(num_layers: int) -> list[int]:
|
|
| 298 |
@st.cache_resource(show_spinner=False)
|
| 299 |
def _cached_sweep(
|
| 300 |
inputs: _SweepInputs,
|
| 301 |
-
) -> tuple[
|
|
|
|
|
|
|
|
|
|
| 302 |
samples = load_persona_vectors_cached(
|
| 303 |
inputs.source, inputs.location, inputs.model_name,
|
| 304 |
inputs.mask_value, inputs.variant, inputs.persona_ids,
|
| 305 |
)
|
| 306 |
dataset = synth_persona_dataset_cached()
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
-
def _sweep(n_pca: int | None) -> list[dict[str, object]]:
|
|
|
|
| 315 |
return sweep_attribute(
|
| 316 |
probe_samples, labels,
|
| 317 |
layers=list(inputs.layers),
|
|
|
|
| 318 |
n_pca_components=n_pca,
|
| 319 |
seed=inputs.seed,
|
| 320 |
)
|
| 321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
if inputs.n_pca_components is not None:
|
| 323 |
# Always overlay the compressed sweep against full activations.
|
| 324 |
rows_by_label = {
|
| 325 |
-
"full":
|
| 326 |
-
f"pca{inputs.n_pca_components}":
|
| 327 |
}
|
| 328 |
else:
|
| 329 |
-
rows_by_label = {"full":
|
| 330 |
-
return rows_by_label,
|
| 331 |
|
| 332 |
|
| 333 |
def _show_sweep(
|
| 334 |
rows_by_label: dict[str, list[dict[str, object]]],
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
attribute: str,
|
| 338 |
task: str,
|
| 339 |
inputs: _SweepInputs,
|
| 340 |
) -> None:
|
| 341 |
primary = _PRIMARY_METRIC[task]
|
| 342 |
secondary = _SECONDARY_METRIC.get(task)
|
| 343 |
|
| 344 |
-
# Tolerate stale session state from a previous code version (bare rows).
|
| 345 |
-
if isinstance(rows_by_label, list):
|
| 346 |
-
rows_by_label = {"full": rows_by_label}
|
| 347 |
primary_label = (
|
| 348 |
f"pca{inputs.n_pca_components}" if inputs.n_pca_components else "full"
|
| 349 |
)
|
| 350 |
rows = rows_by_label.get(primary_label) or next(iter(rows_by_label.values()))
|
| 351 |
|
| 352 |
def _plot(metric: str):
|
| 353 |
-
if len(rows_by_label) > 1:
|
| 354 |
-
return plot_metric_comparison(
|
| 355 |
-
|
|
|
|
|
|
|
| 356 |
|
| 357 |
st.plotly_chart(_plot(primary), width="stretch")
|
| 358 |
if secondary is not None:
|
|
@@ -377,21 +436,31 @@ def _show_sweep(
|
|
| 377 |
if best is None:
|
| 378 |
return
|
| 379 |
|
| 380 |
-
|
|
|
|
| 381 |
summary_rows = []
|
| 382 |
for label, label_rows in rows_by_label.items():
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
if summary_rows:
|
| 396 |
st.dataframe(summary_rows, width="stretch", hide_index=True)
|
| 397 |
|
|
@@ -399,18 +468,26 @@ def _show_sweep(
|
|
| 399 |
f" · pca{inputs.n_pca_components}" if inputs.n_pca_components else ""
|
| 400 |
)
|
| 401 |
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
|
| 412 |
_render_selectivity_control(best, labels, samples, task, inputs)
|
| 413 |
-
_render_save_artifact(best, labels, samples,
|
| 414 |
|
| 415 |
|
| 416 |
def _render_selectivity_control(
|
|
@@ -461,7 +538,6 @@ def _render_save_artifact(
|
|
| 461 |
best: dict[str, object],
|
| 462 |
labels: AttributeLabels,
|
| 463 |
samples: LayeredSamples,
|
| 464 |
-
attribute: str,
|
| 465 |
task: str,
|
| 466 |
inputs: _SweepInputs,
|
| 467 |
) -> None:
|
|
@@ -540,12 +616,15 @@ def render_probing_tab() -> None:
|
|
| 540 |
if not persona_ids:
|
| 541 |
return
|
| 542 |
|
| 543 |
-
dataset = synth_persona_dataset_cached()
|
| 544 |
with st.expander("Probe configuration", expanded=True):
|
| 545 |
-
|
| 546 |
-
|
|
|
|
|
|
|
|
|
|
| 547 |
st.caption(f"Inferred task: **{task}**")
|
| 548 |
|
|
|
|
| 549 |
n_pca_components = _select_pca_components()
|
| 550 |
|
| 551 |
source, location, model_name = store_cache_parts(store)
|
|
@@ -563,17 +642,13 @@ def render_probing_tab() -> None:
|
|
| 563 |
num_layers = max(available_layers) + 1
|
| 564 |
layers = _select_layers(num_layers)
|
| 565 |
min_class_count = _MIN_CLASS_COUNT
|
| 566 |
-
seed =
|
| 567 |
-
"Seed", min_value=0, max_value=10_000, value=0, step=1,
|
| 568 |
-
key="probe:seed",
|
| 569 |
-
help="Seeds the probe/PCA fit. The 80/20 split itself is fixed "
|
| 570 |
-
"(random_state=0).",
|
| 571 |
-
)
|
| 572 |
|
| 573 |
inputs = _SweepInputs(
|
| 574 |
source=source, location=location, model_name=model_name,
|
| 575 |
mask_value=mask_strategy.value, variant=variant,
|
| 576 |
-
persona_ids=tuple(persona_ids),
|
|
|
|
| 577 |
n_pca_components=n_pca_components,
|
| 578 |
layers=tuple(layers), min_class_count=min_class_count,
|
| 579 |
seed=int(seed),
|
|
@@ -584,25 +659,21 @@ def render_probing_tab() -> None:
|
|
| 584 |
if run:
|
| 585 |
with st.spinner("Evaluating probes across layers..."):
|
| 586 |
try:
|
| 587 |
-
sweep,
|
| 588 |
except Exception as exc:
|
| 589 |
st.error(f"Sweep failed: {exc}")
|
| 590 |
st.session_state.pop(state_key, None)
|
| 591 |
return
|
| 592 |
-
st.session_state[state_key] = (
|
| 593 |
-
sweep,
|
| 594 |
-
labels,
|
| 595 |
-
probe_samples,
|
| 596 |
-
attribute,
|
| 597 |
-
task,
|
| 598 |
-
inputs,
|
| 599 |
-
)
|
| 600 |
|
| 601 |
if state_key in st.session_state:
|
| 602 |
saved_result = st.session_state[state_key]
|
| 603 |
-
if len(saved_result) =
|
| 604 |
-
|
| 605 |
-
|
| 606 |
else:
|
| 607 |
-
sweep,
|
| 608 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
from persona_vectors.probes import (
|
| 24 |
AttributeLabels,
|
| 25 |
attribute_probe_labels,
|
| 26 |
+
default_probe_kinds,
|
| 27 |
filter_attribute_samples_min_count,
|
| 28 |
infer_probe_task,
|
| 29 |
layer_matrix,
|
|
|
|
| 86 |
mask_value: str
|
| 87 |
variant: str
|
| 88 |
persona_ids: tuple[str, ...]
|
| 89 |
+
attributes: tuple[str, ...]
|
| 90 |
task: str
|
| 91 |
+
probe_kinds: tuple[str, ...]
|
| 92 |
n_pca_components: int | None
|
| 93 |
layers: tuple[int, ...]
|
| 94 |
min_class_count: int
|
|
|
|
| 236 |
# ---------------------------------------------------------------------------
|
| 237 |
|
| 238 |
|
| 239 |
+
@st.cache_data(show_spinner=False)
|
| 240 |
+
def _attribute_tasks() -> dict[str, str]:
|
| 241 |
dataset = synth_persona_dataset_cached()
|
| 242 |
+
return {
|
| 243 |
+
name: infer_probe_task(dataset, name)
|
| 244 |
+
for name in synth_persona_attribute_names()
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _select_attributes() -> list[str]:
|
| 249 |
+
"""Multi-select locked to one task type.
|
| 250 |
+
|
| 251 |
+
Picking the first attribute fixes the task; only same-task attributes stay
|
| 252 |
+
selectable. Clearing the selection reopens every attribute again.
|
| 253 |
+
"""
|
| 254 |
+
dataset = synth_persona_dataset_cached()
|
| 255 |
+
tasks = _attribute_tasks()
|
| 256 |
+
all_names = list(synth_persona_attribute_names())
|
| 257 |
+
|
| 258 |
+
key = "probe:attributes"
|
| 259 |
+
if key not in st.session_state:
|
| 260 |
+
st.session_state[key] = ["sex"] if "sex" in all_names else all_names[:1]
|
| 261 |
+
|
| 262 |
+
selected = st.session_state[key]
|
| 263 |
+
if selected:
|
| 264 |
+
locked = tasks[selected[0]]
|
| 265 |
+
options = [name for name in all_names if tasks[name] == locked]
|
| 266 |
else:
|
| 267 |
+
options = all_names
|
| 268 |
+
|
| 269 |
+
return st.multiselect(
|
| 270 |
+
"Attributes to probe",
|
| 271 |
options=options,
|
|
|
|
| 272 |
format_func=lambda name: attribute_display_label(dataset, name),
|
| 273 |
+
key=key,
|
| 274 |
+
help="Pick one or more attributes of the same task type. They are "
|
| 275 |
+
"overlaid in one figure. Remove all to switch to a different task type.",
|
| 276 |
)
|
| 277 |
|
| 278 |
|
| 279 |
+
def _select_probe_kinds(task: str) -> list[str]:
|
| 280 |
+
"""Pick which probe families to fit. Only shown when the task has >1."""
|
| 281 |
+
available = list(default_probe_kinds(task)) # type: ignore[arg-type]
|
| 282 |
+
if len(available) < 2:
|
| 283 |
+
return available
|
| 284 |
+
selected = st.multiselect(
|
| 285 |
+
"Probe kinds to fit",
|
| 286 |
+
options=available,
|
| 287 |
+
default=available,
|
| 288 |
+
key=f"probe:kinds:{task}",
|
| 289 |
+
help="Which probe families to fit at each layer. Defaults to all "
|
| 290 |
+
"available for this task.",
|
| 291 |
+
)
|
| 292 |
+
return selected or available
|
| 293 |
+
|
| 294 |
+
|
| 295 |
def _select_pca_components() -> int | None:
|
| 296 |
use_pca = st.toggle(
|
| 297 |
"Add PCA-compressed comparison",
|
|
|
|
| 340 |
@st.cache_resource(show_spinner=False)
|
| 341 |
def _cached_sweep(
|
| 342 |
inputs: _SweepInputs,
|
| 343 |
+
) -> tuple[
|
| 344 |
+
dict[str, list[dict[str, object]]],
|
| 345 |
+
dict[str, tuple[AttributeLabels, LayeredSamples]],
|
| 346 |
+
]:
|
| 347 |
samples = load_persona_vectors_cached(
|
| 348 |
inputs.source, inputs.location, inputs.model_name,
|
| 349 |
inputs.mask_value, inputs.variant, inputs.persona_ids,
|
| 350 |
)
|
| 351 |
dataset = synth_persona_dataset_cached()
|
| 352 |
+
# The min-count filter drops personas per attribute, so each attribute keeps
|
| 353 |
+
# its own (labels, samples) pair for the downstream selectivity/save tools.
|
| 354 |
+
per_attr: dict[str, tuple[AttributeLabels, LayeredSamples]] = {}
|
| 355 |
+
|
| 356 |
+
def _labels_and_samples(attribute: str) -> tuple[AttributeLabels, LayeredSamples]:
|
| 357 |
+
if attribute not in per_attr:
|
| 358 |
+
labels = attribute_probe_labels(
|
| 359 |
+
dataset, attribute, list(inputs.persona_ids), task=inputs.task, # type: ignore[arg-type]
|
| 360 |
+
)
|
| 361 |
+
probe_samples, labels = filter_attribute_samples_min_count(
|
| 362 |
+
samples, labels, min_count=inputs.min_class_count
|
| 363 |
+
)
|
| 364 |
+
per_attr[attribute] = (labels, probe_samples)
|
| 365 |
+
return per_attr[attribute]
|
| 366 |
|
| 367 |
+
def _sweep(attribute: str, n_pca: int | None) -> list[dict[str, object]]:
|
| 368 |
+
labels, probe_samples = _labels_and_samples(attribute)
|
| 369 |
return sweep_attribute(
|
| 370 |
probe_samples, labels,
|
| 371 |
layers=list(inputs.layers),
|
| 372 |
+
probe_kinds=list(inputs.probe_kinds), # type: ignore[arg-type]
|
| 373 |
n_pca_components=n_pca,
|
| 374 |
seed=inputs.seed,
|
| 375 |
)
|
| 376 |
|
| 377 |
+
def _sweep_all(n_pca: int | None) -> list[dict[str, object]]:
|
| 378 |
+
rows: list[dict[str, object]] = []
|
| 379 |
+
for attribute in inputs.attributes:
|
| 380 |
+
rows.extend(_sweep(attribute, n_pca))
|
| 381 |
+
return rows
|
| 382 |
+
|
| 383 |
if inputs.n_pca_components is not None:
|
| 384 |
# Always overlay the compressed sweep against full activations.
|
| 385 |
rows_by_label = {
|
| 386 |
+
"full": _sweep_all(None),
|
| 387 |
+
f"pca{inputs.n_pca_components}": _sweep_all(inputs.n_pca_components),
|
| 388 |
}
|
| 389 |
else:
|
| 390 |
+
rows_by_label = {"full": _sweep_all(None)}
|
| 391 |
+
return rows_by_label, per_attr
|
| 392 |
|
| 393 |
|
| 394 |
def _show_sweep(
|
| 395 |
rows_by_label: dict[str, list[dict[str, object]]],
|
| 396 |
+
per_attr: dict[str, tuple[AttributeLabels, LayeredSamples]],
|
| 397 |
+
attributes: tuple[str, ...],
|
|
|
|
| 398 |
task: str,
|
| 399 |
inputs: _SweepInputs,
|
| 400 |
) -> None:
|
| 401 |
primary = _PRIMARY_METRIC[task]
|
| 402 |
secondary = _SECONDARY_METRIC.get(task)
|
| 403 |
|
|
|
|
|
|
|
|
|
|
| 404 |
primary_label = (
|
| 405 |
f"pca{inputs.n_pca_components}" if inputs.n_pca_components else "full"
|
| 406 |
)
|
| 407 |
rows = rows_by_label.get(primary_label) or next(iter(rows_by_label.values()))
|
| 408 |
|
| 409 |
def _plot(metric: str):
|
| 410 |
+
if len(rows_by_label) > 1 or len(attributes) > 1:
|
| 411 |
+
return plot_metric_comparison(
|
| 412 |
+
rows_by_label, list(attributes), metric=metric
|
| 413 |
+
)
|
| 414 |
+
return plot_metric_over_layers(rows, attributes[0], metric=metric)
|
| 415 |
|
| 416 |
st.plotly_chart(_plot(primary), width="stretch")
|
| 417 |
if secondary is not None:
|
|
|
|
| 436 |
if best is None:
|
| 437 |
return
|
| 438 |
|
| 439 |
+
multi_attr = len(attributes) > 1
|
| 440 |
+
if len(rows_by_label) > 1 or multi_attr:
|
| 441 |
summary_rows = []
|
| 442 |
for label, label_rows in rows_by_label.items():
|
| 443 |
+
for attribute in attributes:
|
| 444 |
+
attr_rows = [
|
| 445 |
+
row for row in label_rows
|
| 446 |
+
if row.get("attribute") == attribute
|
| 447 |
+
]
|
| 448 |
+
label_best = _best_row(attr_rows)
|
| 449 |
+
if label_best is None:
|
| 450 |
+
continue
|
| 451 |
+
summary_row: dict[str, object] = {}
|
| 452 |
+
if multi_attr:
|
| 453 |
+
summary_row["attribute"] = attribute
|
| 454 |
+
summary_row.update({
|
| 455 |
+
"features": label,
|
| 456 |
+
"best_layer": label_best["layer"],
|
| 457 |
+
"probe": label_best["probe_kind"],
|
| 458 |
+
primary: round(float(label_best[primary]), 3),
|
| 459 |
+
f"baseline_{primary}": round(
|
| 460 |
+
float(label_best.get(f"baseline_{primary}", float("nan"))), 3
|
| 461 |
+
),
|
| 462 |
+
})
|
| 463 |
+
summary_rows.append(summary_row)
|
| 464 |
if summary_rows:
|
| 465 |
st.dataframe(summary_rows, width="stretch", hide_index=True)
|
| 466 |
|
|
|
|
| 468 |
f" · pca{inputs.n_pca_components}" if inputs.n_pca_components else ""
|
| 469 |
)
|
| 470 |
|
| 471 |
+
best_attr = str(best["attribute"])
|
| 472 |
+
labels, samples = per_attr[best_attr]
|
| 473 |
+
if multi_attr:
|
| 474 |
+
# The per-attribute summary table above already covers every result;
|
| 475 |
+
# a single "best" card would only show one attribute, so skip it and
|
| 476 |
+
# just say which one the controls below operate on.
|
| 477 |
+
st.caption(f"Controls below use the best result: **{best_attr}**.")
|
| 478 |
+
else:
|
| 479 |
+
cols = st.columns([1, 1.2, 1.8])
|
| 480 |
+
cols[0].metric("Best layer", best["layer"])
|
| 481 |
+
cols[1].metric(
|
| 482 |
+
f"Best {primary}",
|
| 483 |
+
f"{best[primary]:.3f}",
|
| 484 |
+
delta=f"baseline {best.get(f'baseline_{primary}', float('nan')):.3f}",
|
| 485 |
+
delta_color="off",
|
| 486 |
+
)
|
| 487 |
+
cols[2].metric("Probe", f"{best['probe_kind']}{feature_desc}")
|
| 488 |
|
| 489 |
_render_selectivity_control(best, labels, samples, task, inputs)
|
| 490 |
+
_render_save_artifact(best, labels, samples, task, inputs)
|
| 491 |
|
| 492 |
|
| 493 |
def _render_selectivity_control(
|
|
|
|
| 538 |
best: dict[str, object],
|
| 539 |
labels: AttributeLabels,
|
| 540 |
samples: LayeredSamples,
|
|
|
|
| 541 |
task: str,
|
| 542 |
inputs: _SweepInputs,
|
| 543 |
) -> None:
|
|
|
|
| 616 |
if not persona_ids:
|
| 617 |
return
|
| 618 |
|
|
|
|
| 619 |
with st.expander("Probe configuration", expanded=True):
|
| 620 |
+
attributes = _select_attributes()
|
| 621 |
+
if not attributes:
|
| 622 |
+
st.info("Select at least one attribute to probe.")
|
| 623 |
+
return
|
| 624 |
+
task = _attribute_tasks()[attributes[0]]
|
| 625 |
st.caption(f"Inferred task: **{task}**")
|
| 626 |
|
| 627 |
+
probe_kinds = _select_probe_kinds(task)
|
| 628 |
n_pca_components = _select_pca_components()
|
| 629 |
|
| 630 |
source, location, model_name = store_cache_parts(store)
|
|
|
|
| 642 |
num_layers = max(available_layers) + 1
|
| 643 |
layers = _select_layers(num_layers)
|
| 644 |
min_class_count = _MIN_CLASS_COUNT
|
| 645 |
+
seed = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
|
| 647 |
inputs = _SweepInputs(
|
| 648 |
source=source, location=location, model_name=model_name,
|
| 649 |
mask_value=mask_strategy.value, variant=variant,
|
| 650 |
+
persona_ids=tuple(persona_ids), attributes=tuple(attributes), task=task,
|
| 651 |
+
probe_kinds=tuple(probe_kinds),
|
| 652 |
n_pca_components=n_pca_components,
|
| 653 |
layers=tuple(layers), min_class_count=min_class_count,
|
| 654 |
seed=int(seed),
|
|
|
|
| 659 |
if run:
|
| 660 |
with st.spinner("Evaluating probes across layers..."):
|
| 661 |
try:
|
| 662 |
+
sweep, per_attr = _cached_sweep(inputs)
|
| 663 |
except Exception as exc:
|
| 664 |
st.error(f"Sweep failed: {exc}")
|
| 665 |
st.session_state.pop(state_key, None)
|
| 666 |
return
|
| 667 |
+
st.session_state[state_key] = (sweep, per_attr, inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
|
| 669 |
if state_key in st.session_state:
|
| 670 |
saved_result = st.session_state[state_key]
|
| 671 |
+
if len(saved_result) != 3:
|
| 672 |
+
# Stale shape from a previous code version — drop it.
|
| 673 |
+
st.session_state.pop(state_key, None)
|
| 674 |
else:
|
| 675 |
+
sweep, per_attr, result_inputs = saved_result
|
| 676 |
+
_show_sweep(
|
| 677 |
+
sweep, per_attr, result_inputs.attributes,
|
| 678 |
+
result_inputs.task, result_inputs,
|
| 679 |
+
)
|
tests/test_probes.py
CHANGED
|
@@ -12,9 +12,11 @@ two correctness fixes:
|
|
| 12 |
import pytest
|
| 13 |
import torch
|
| 14 |
|
|
|
|
| 15 |
from utils.probes import (
|
| 16 |
LoadedProbe,
|
| 17 |
_LinearProbe,
|
|
|
|
| 18 |
_normalize_labels,
|
| 19 |
parse_probe_filename,
|
| 20 |
)
|
|
@@ -196,3 +198,33 @@ def test_run_single_output_predicts_negative_when_score_low():
|
|
| 196 |
result = probe.run(torch.tensor([1.0, 1.0]))
|
| 197 |
assert result.predicted_index == 0
|
| 198 |
assert result.predicted_label == "neg"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
import pytest
|
| 13 |
import torch
|
| 14 |
|
| 15 |
+
from persona_vectors.probes import ProbeArtifact
|
| 16 |
from utils.probes import (
|
| 17 |
LoadedProbe,
|
| 18 |
_LinearProbe,
|
| 19 |
+
_loaded_probe_from_artifact,
|
| 20 |
_normalize_labels,
|
| 21 |
parse_probe_filename,
|
| 22 |
)
|
|
|
|
| 198 |
result = probe.run(torch.tensor([1.0, 1.0]))
|
| 199 |
assert result.predicted_index == 0
|
| 200 |
assert result.predicted_label == "neg"
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# --------------------------------------------------------------------------- #
|
| 204 |
+
# canonical persona-vectors artifacts
|
| 205 |
+
# --------------------------------------------------------------------------- #
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def test_loaded_probe_from_canonical_artifact():
|
| 209 |
+
artifact = ProbeArtifact(
|
| 210 |
+
metadata={
|
| 211 |
+
"schema_version": 2,
|
| 212 |
+
"input_dim": 2,
|
| 213 |
+
"artifact_feature_dim": 2,
|
| 214 |
+
"class_names": ["neg", "pos"],
|
| 215 |
+
"task": "binary",
|
| 216 |
+
"probe_kind": "logistic_regression",
|
| 217 |
+
"layer": 3,
|
| 218 |
+
},
|
| 219 |
+
tensors={
|
| 220 |
+
"weight": torch.tensor([[-1.0, 0.0], [1.0, 0.0]]),
|
| 221 |
+
"bias": torch.zeros(2),
|
| 222 |
+
},
|
| 223 |
+
)
|
| 224 |
+
probe = _loaded_probe_from_artifact(
|
| 225 |
+
filename="m/answer_mean/templated/sex/logistic_regression_layer3/probe.json",
|
| 226 |
+
artifact=artifact,
|
| 227 |
+
)
|
| 228 |
+
assert probe.labels == ["neg", "pos"]
|
| 229 |
+
assert probe.layer == 3
|
| 230 |
+
assert probe.run(torch.tensor([1.0, 0.0])).predicted_label == "pos"
|
utils/analysis_sources.py
CHANGED
|
@@ -1,7 +1,11 @@
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
-
from persona_vectors.analysis import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from persona_vectors.artifacts import (
|
| 6 |
PersonaVectorStore,
|
| 7 |
HFPersonaVectorStore,
|
|
@@ -10,6 +14,11 @@ from persona_vectors.artifacts import (
|
|
| 10 |
)
|
| 11 |
from persona_vectors.extraction import MaskStrategy
|
| 12 |
from persona_vectors.hub import list_hub_vector_models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
from utils.helpers import env_int
|
| 15 |
|
|
@@ -26,7 +35,8 @@ SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
|
|
| 26 |
|
| 27 |
|
| 28 |
_STORE_CACHE_ENTRIES = env_int("PERSONA_UI_STORE_CACHE_ENTRIES", 4)
|
| 29 |
-
_VECTOR_CACHE_ENTRIES = env_int("PERSONA_UI_VECTOR_CACHE_ENTRIES",
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
@st.cache_resource(show_spinner=False, max_entries=_STORE_CACHE_ENTRIES)
|
|
@@ -137,23 +147,41 @@ def local_model_matches(left: str, right: str) -> bool:
|
|
| 137 |
|
| 138 |
|
| 139 |
@st.cache_resource(show_spinner=False, max_entries=_VECTOR_CACHE_ENTRIES)
|
| 140 |
-
def
|
| 141 |
source: str,
|
| 142 |
location: str,
|
| 143 |
model_name: str,
|
| 144 |
mask_strategy_value: str,
|
| 145 |
-
|
| 146 |
persona_ids: tuple[str, ...],
|
| 147 |
-
) ->
|
| 148 |
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 149 |
-
return
|
| 150 |
store,
|
| 151 |
-
|
| 152 |
mask_strategy=MaskStrategy(mask_strategy_value),
|
| 153 |
-
persona_ids=
|
| 154 |
)
|
| 155 |
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
def load_variant_vectors_cached(
|
| 158 |
source: str,
|
| 159 |
location: str,
|
|
@@ -162,12 +190,64 @@ def load_variant_vectors_cached(
|
|
| 162 |
variants: tuple[str, ...],
|
| 163 |
persona_ids: tuple[str, ...],
|
| 164 |
) -> dict[str, LayeredSamples]:
|
| 165 |
-
return
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
|
| 173 |
def prefetch_hub_metadata(
|
|
@@ -194,13 +274,3 @@ def prefetch_hub_metadata(
|
|
| 194 |
mask_strategy_value,
|
| 195 |
(variant,),
|
| 196 |
)
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
def release_hf_store_cache(
|
| 200 |
-
store: Store,
|
| 201 |
-
variants: list[str] | tuple[str, ...] | None = None,
|
| 202 |
-
) -> None:
|
| 203 |
-
"""Drop cached HF data for ``variants`` (or all) on Hub stores."""
|
| 204 |
-
release_cache = getattr(store, "release_cache", None)
|
| 205 |
-
if isinstance(store, HFPersonaVectorStore) and callable(release_cache):
|
| 206 |
-
release_cache(variants)
|
|
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
+
from persona_vectors.analysis import (
|
| 5 |
+
AnalysisDataset,
|
| 6 |
+
LayeredSamples,
|
| 7 |
+
load_analysis_dataset,
|
| 8 |
+
)
|
| 9 |
from persona_vectors.artifacts import (
|
| 10 |
PersonaVectorStore,
|
| 11 |
HFPersonaVectorStore,
|
|
|
|
| 14 |
)
|
| 15 |
from persona_vectors.extraction import MaskStrategy
|
| 16 |
from persona_vectors.hub import list_hub_vector_models
|
| 17 |
+
from persona_vectors.plots import (
|
| 18 |
+
LayeredProjectionData,
|
| 19 |
+
prepare_kmeans_groups,
|
| 20 |
+
prepare_layered_projection_data,
|
| 21 |
+
)
|
| 22 |
|
| 23 |
from utils.helpers import env_int
|
| 24 |
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
_STORE_CACHE_ENTRIES = env_int("PERSONA_UI_STORE_CACHE_ENTRIES", 4)
|
| 38 |
+
_VECTOR_CACHE_ENTRIES = env_int("PERSONA_UI_VECTOR_CACHE_ENTRIES", 4)
|
| 39 |
+
_PREPARED_CACHE_ENTRIES = env_int("PERSONA_UI_PREPARED_CACHE_ENTRIES", 8)
|
| 40 |
|
| 41 |
|
| 42 |
@st.cache_resource(show_spinner=False, max_entries=_STORE_CACHE_ENTRIES)
|
|
|
|
| 147 |
|
| 148 |
|
| 149 |
@st.cache_resource(show_spinner=False, max_entries=_VECTOR_CACHE_ENTRIES)
|
| 150 |
+
def load_analysis_dataset_cached(
|
| 151 |
source: str,
|
| 152 |
location: str,
|
| 153 |
model_name: str,
|
| 154 |
mask_strategy_value: str,
|
| 155 |
+
variants: tuple[str, ...],
|
| 156 |
persona_ids: tuple[str, ...],
|
| 157 |
+
) -> AnalysisDataset:
|
| 158 |
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 159 |
+
return load_analysis_dataset(
|
| 160 |
store,
|
| 161 |
+
variants,
|
| 162 |
mask_strategy=MaskStrategy(mask_strategy_value),
|
| 163 |
+
persona_ids=persona_ids,
|
| 164 |
)
|
| 165 |
|
| 166 |
|
| 167 |
+
def load_persona_vectors_cached(
|
| 168 |
+
source: str,
|
| 169 |
+
location: str,
|
| 170 |
+
model_name: str,
|
| 171 |
+
mask_strategy_value: str,
|
| 172 |
+
variant: str,
|
| 173 |
+
persona_ids: tuple[str, ...],
|
| 174 |
+
) -> LayeredSamples:
|
| 175 |
+
return load_analysis_dataset_cached(
|
| 176 |
+
source,
|
| 177 |
+
location,
|
| 178 |
+
model_name,
|
| 179 |
+
mask_strategy_value,
|
| 180 |
+
(variant,),
|
| 181 |
+
persona_ids,
|
| 182 |
+
).samples(variant)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
def load_variant_vectors_cached(
|
| 186 |
source: str,
|
| 187 |
location: str,
|
|
|
|
| 190 |
variants: tuple[str, ...],
|
| 191 |
persona_ids: tuple[str, ...],
|
| 192 |
) -> dict[str, LayeredSamples]:
|
| 193 |
+
return load_analysis_dataset_cached(
|
| 194 |
+
source,
|
| 195 |
+
location,
|
| 196 |
+
model_name,
|
| 197 |
+
mask_strategy_value,
|
| 198 |
+
variants,
|
| 199 |
+
persona_ids,
|
| 200 |
+
).samples_by_variant
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@st.cache_resource(show_spinner=False, max_entries=_PREPARED_CACHE_ENTRIES)
|
| 204 |
+
def projection_data_cached(
|
| 205 |
+
source: str,
|
| 206 |
+
location: str,
|
| 207 |
+
model_name: str,
|
| 208 |
+
mask_strategy_value: str,
|
| 209 |
+
variant: str,
|
| 210 |
+
persona_ids: tuple[str, ...],
|
| 211 |
+
layers: tuple[int, ...],
|
| 212 |
+
kind: str,
|
| 213 |
+
n_components: int,
|
| 214 |
+
graph_overlay: bool,
|
| 215 |
+
graph_n_neighbors: int,
|
| 216 |
+
) -> LayeredProjectionData:
|
| 217 |
+
samples = load_persona_vectors_cached(
|
| 218 |
+
source, location, model_name, mask_strategy_value, variant, persona_ids
|
| 219 |
+
)
|
| 220 |
+
return prepare_layered_projection_data(
|
| 221 |
+
samples,
|
| 222 |
+
kind,
|
| 223 |
+
layers=list(layers),
|
| 224 |
+
n_components=n_components,
|
| 225 |
+
graph_overlay=graph_overlay,
|
| 226 |
+
graph_n_neighbors=graph_n_neighbors,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@st.cache_resource(show_spinner=False, max_entries=_PREPARED_CACHE_ENTRIES)
|
| 231 |
+
def kmeans_groups_cached(
|
| 232 |
+
source: str,
|
| 233 |
+
location: str,
|
| 234 |
+
model_name: str,
|
| 235 |
+
mask_strategy_value: str,
|
| 236 |
+
variant: str,
|
| 237 |
+
persona_ids: tuple[str, ...],
|
| 238 |
+
layers: tuple[int, ...],
|
| 239 |
+
n_clusters: int,
|
| 240 |
+
cluster_mode: str,
|
| 241 |
+
) -> list[str] | dict[int, list[str]]:
|
| 242 |
+
samples = load_persona_vectors_cached(
|
| 243 |
+
source, location, model_name, mask_strategy_value, variant, persona_ids
|
| 244 |
+
)
|
| 245 |
+
return prepare_kmeans_groups(
|
| 246 |
+
samples,
|
| 247 |
+
layers=list(layers),
|
| 248 |
+
n_clusters=n_clusters,
|
| 249 |
+
cluster_mode=cluster_mode,
|
| 250 |
+
)
|
| 251 |
|
| 252 |
|
| 253 |
def prefetch_hub_metadata(
|
|
|
|
| 274 |
mask_strategy_value,
|
| 275 |
(variant,),
|
| 276 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/probes.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import io
|
| 4 |
-
import json
|
| 5 |
import os
|
| 6 |
import re
|
| 7 |
from dataclasses import dataclass
|
|
@@ -12,6 +11,7 @@ import streamlit as st
|
|
| 12 |
import torch
|
| 13 |
import torch.nn as nn
|
| 14 |
import torch.nn.functional as F
|
|
|
|
| 15 |
|
| 16 |
PROBE_FILENAME_RE = re.compile(
|
| 17 |
r"^cognitive_map_probe_layer(?P<layer>\d+)_(?P<model_type>[a-z0-9]+)_"
|
|
@@ -457,14 +457,19 @@ def _load_persona_probe_artifact(
|
|
| 457 |
metadata_path: Path,
|
| 458 |
weights_path: Path,
|
| 459 |
) -> LoadedProbe:
|
| 460 |
-
if
|
| 461 |
-
raise
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
payload = {
|
| 469 |
**metadata,
|
| 470 |
"model_type": "linear",
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import io
|
|
|
|
| 4 |
import os
|
| 5 |
import re
|
| 6 |
from dataclasses import dataclass
|
|
|
|
| 11 |
import torch
|
| 12 |
import torch.nn as nn
|
| 13 |
import torch.nn.functional as F
|
| 14 |
+
from persona_vectors.probes import ProbeArtifact, load_probe_artifact
|
| 15 |
|
| 16 |
PROBE_FILENAME_RE = re.compile(
|
| 17 |
r"^cognitive_map_probe_layer(?P<layer>\d+)_(?P<model_type>[a-z0-9]+)_"
|
|
|
|
| 457 |
metadata_path: Path,
|
| 458 |
weights_path: Path,
|
| 459 |
) -> LoadedProbe:
|
| 460 |
+
if metadata_path.parent != weights_path.parent:
|
| 461 |
+
raise ValueError("Canonical probe files must share one artifact directory.")
|
| 462 |
+
artifact = load_probe_artifact(metadata_path)
|
| 463 |
+
return _loaded_probe_from_artifact(filename=filename, artifact=artifact)
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def _loaded_probe_from_artifact(
|
| 467 |
+
*,
|
| 468 |
+
filename: str,
|
| 469 |
+
artifact: ProbeArtifact,
|
| 470 |
+
) -> LoadedProbe:
|
| 471 |
+
metadata = artifact.metadata
|
| 472 |
+
tensors = artifact.tensors
|
| 473 |
payload = {
|
| 474 |
**metadata,
|
| 475 |
"model_type": "linear",
|
uv.lock
CHANGED
|
@@ -1608,7 +1608,7 @@ requires-dist = [
|
|
| 1608 |
{ name = "catppuccin", specifier = ">=2.5.0" },
|
| 1609 |
{ name = "datasets", specifier = ">=4.8.5" },
|
| 1610 |
{ name = "huggingface-hub", specifier = ">=1.14.0" },
|
| 1611 |
-
{ name = "persona-vectors", specifier = ">=0.8.
|
| 1612 |
{ name = "plotly", specifier = ">=6.6.0" },
|
| 1613 |
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
| 1614 |
{ name = "safetensors", specifier = ">=0.7.0" },
|
|
@@ -1620,7 +1620,7 @@ dev = [{ name = "pytest", specifier = ">=9.0.3" }]
|
|
| 1620 |
|
| 1621 |
[[package]]
|
| 1622 |
name = "persona-vectors"
|
| 1623 |
-
version = "0.8.
|
| 1624 |
source = { registry = "https://pypi.org/simple" }
|
| 1625 |
dependencies = [
|
| 1626 |
{ name = "datasets" },
|
|
@@ -1639,9 +1639,9 @@ dependencies = [
|
|
| 1639 |
{ name = "transformers" },
|
| 1640 |
{ name = "umap-learn" },
|
| 1641 |
]
|
| 1642 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 1643 |
wheels = [
|
| 1644 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 1645 |
]
|
| 1646 |
|
| 1647 |
[[package]]
|
|
|
|
| 1608 |
{ name = "catppuccin", specifier = ">=2.5.0" },
|
| 1609 |
{ name = "datasets", specifier = ">=4.8.5" },
|
| 1610 |
{ name = "huggingface-hub", specifier = ">=1.14.0" },
|
| 1611 |
+
{ name = "persona-vectors", specifier = ">=0.8.3" },
|
| 1612 |
{ name = "plotly", specifier = ">=6.6.0" },
|
| 1613 |
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
| 1614 |
{ name = "safetensors", specifier = ">=0.7.0" },
|
|
|
|
| 1620 |
|
| 1621 |
[[package]]
|
| 1622 |
name = "persona-vectors"
|
| 1623 |
+
version = "0.8.3"
|
| 1624 |
source = { registry = "https://pypi.org/simple" }
|
| 1625 |
dependencies = [
|
| 1626 |
{ name = "datasets" },
|
|
|
|
| 1639 |
{ name = "transformers" },
|
| 1640 |
{ name = "umap-learn" },
|
| 1641 |
]
|
| 1642 |
+
sdist = { url = "https://files.pythonhosted.org/packages/c0/1d/472284f43e2a276a035e9e3de08a92654945193699598def6d6a2aa74c96/persona_vectors-0.8.3.tar.gz", hash = "sha256:f0519846b3712865bd2562cd239df05ddd006ac3d1e73e5ec5a6c860aaed5b2e", size = 43146, upload-time = "2026-05-17T12:43:13.601Z" }
|
| 1643 |
wheels = [
|
| 1644 |
+
{ url = "https://files.pythonhosted.org/packages/60/d1/a38dc354718310122cd5d3de63e3aa9060490c8db4c2eadb1d4985684796/persona_vectors-0.8.3-py3-none-any.whl", hash = "sha256:2feeaf45b071ed417d88add48a1012455c8027e4f839e99658a9808c26786b8a", size = 53129, upload-time = "2026-05-17T12:43:12.693Z" },
|
| 1645 |
]
|
| 1646 |
|
| 1647 |
[[package]]
|