Jac-Zac commited on
Commit
330d092
·
1 Parent(s): 7c332a2

Updated to latest persona-vector and loading from HF

Browse files
Files changed (7) hide show
  1. README.md +14 -12
  2. pyproject.toml +2 -2
  3. tabs/chat_ui.py +1 -2
  4. tabs/compare.py +118 -57
  5. tabs/extract.py +2 -7
  6. utils/helpers.py +3 -3
  7. uv.lock +5 -5
README.md CHANGED
@@ -20,7 +20,7 @@ Streamlit interface for persona vector extraction, analysis, and chat.
20
  A web app built on top of [persona-vectors](../persona-vectors) that provides three tabs:
21
 
22
  - **Chat** — interactive conversations with a model using persona-based system prompts (templated or biography)
23
- - **Compare** — load saved activations and explore layer-wise cosine similarity, persona-mean PCA, UMAP, and similarity projections
24
  - **Extract** — run activation extraction from HuggingFace persona datasets or a local JSONL dataset directly from the browser
25
 
26
  ## Repository Layout
@@ -29,19 +29,20 @@ A web app built on top of [persona-vectors](../persona-vectors) that provides th
29
  persona-ui/
30
  ├── app.py # Main entry point (Streamlit)
31
  ├── state.py # Session state management (chat history, KV cache)
32
- ├── scripts/
33
- │ └── oracle_probe.py # Notebook-style activation oracle script
34
  ├── tabs/
35
  │ ├── chat.py # Chat tab
36
  │ ├── compare.py # Activation comparison tab
37
  │ ├── compare_chat.py # Side-by-side chat comparison mode
38
- ── extract.py # Extraction tab
 
39
  └── utils/
40
  ├── chat.py # Chat generation logic
41
  ├── chat_export.py # Export chat logs to JSON
42
  ├── contrast.py # Contrastive token log-prob coloring
43
  ├── datasets.py # Dataset loader wrapper
44
  ├── helpers.py # UI labels and slug helpers
 
 
45
  └── runtime.py # Model caching and NDIF queries
46
  ```
47
 
@@ -58,13 +59,11 @@ cp .env.example .env
58
 
59
  ## Local Development
60
 
61
- The committed dependency graph uses git sources so `persona-ui` can install cleanly in a Hugging Face Space or any isolated environment.
 
 
62
 
63
- For local sibling checkouts, uncomment the `path` sources in `persona-ui/pyproject.toml` and `persona-vectors/pyproject.toml`, then comment out the git sources.
64
-
65
- ## Local Setup Note
66
-
67
- For local development, `persona-data` and `persona-vectors` can still be checked out in the parent directory of `persona-ui`.
68
 
69
  Example:
70
 
@@ -112,13 +111,16 @@ Copy `.env.example` to `.env` and fill in:
112
  NDIF_API_KEY=... # Required for remote (NDIF) model execution
113
  HF_HOME=... # Optional: HuggingFace cache directory
114
  ARTIFACTS_DIR=... # Optional: where activations are read from (default: ./artifacts)
 
115
  ```
116
 
117
  The app picks up this file automatically via `load_dotenv()` on startup.
118
 
119
- ## Saved Artifacts
120
 
121
- The Compare and Extract tabs read from / write to:
 
 
122
 
123
  ```
124
  artifacts/
 
20
  A web app built on top of [persona-vectors](../persona-vectors) that provides three tabs:
21
 
22
  - **Chat** — interactive conversations with a model using persona-based system prompts (templated or biography)
23
+ - **Compare** — load local or Hub persona vectors and explore cosine similarity, PCA, UMAP, and similarity views
24
  - **Extract** — run activation extraction from HuggingFace persona datasets or a local JSONL dataset directly from the browser
25
 
26
  ## Repository Layout
 
29
  persona-ui/
30
  ├── app.py # Main entry point (Streamlit)
31
  ├── state.py # Session state management (chat history, KV cache)
 
 
32
  ├── tabs/
33
  │ ├── chat.py # Chat tab
34
  │ ├── compare.py # Activation comparison tab
35
  │ ├── compare_chat.py # Side-by-side chat comparison mode
36
+ ── extract.py # Extraction tab
37
+ │ └── probe_ui.py # Probe upload and tracing controls
38
  └── utils/
39
  ├── chat.py # Chat generation logic
40
  ├── chat_export.py # Export chat logs to JSON
41
  ├── contrast.py # Contrastive token log-prob coloring
42
  ├── datasets.py # Dataset loader wrapper
43
  ├── helpers.py # UI labels and slug helpers
44
+ ├── probe_trace.py # Chat-token activation tracing
45
+ ├── probes.py # Probe loading and scoring
46
  └── runtime.py # Model caching and NDIF queries
47
  ```
48
 
 
59
 
60
  ## Local Development
61
 
62
+ This checkout is configured to use the sibling `../persona-vectors` package as
63
+ an editable dependency. For deployment, switch `persona-vectors` back to the
64
+ published package or another installable source.
65
 
66
+ `persona-data` can also be checked out next to this repo for local package work.
 
 
 
 
67
 
68
  Example:
69
 
 
111
  NDIF_API_KEY=... # Required for remote (NDIF) model execution
112
  HF_HOME=... # Optional: HuggingFace cache directory
113
  ARTIFACTS_DIR=... # Optional: where activations are read from (default: ./artifacts)
114
+ PERSONA_VECTORS_HUB_REPO=... # Optional: default Compare-tab Hub dataset repo
115
  ```
116
 
117
  The app picks up this file automatically via `load_dotenv()` on startup.
118
 
119
+ ## Persona Vectors
120
 
121
+ The Compare tab reads persona vectors from either a Hugging Face dataset created
122
+ by `persona-vectors/scripts/push_to_hf.py` or from local artifacts. The Extract
123
+ tab writes local artifacts to:
124
 
125
  ```
126
  artifacts/
pyproject.toml CHANGED
@@ -1,11 +1,11 @@
1
  [project]
2
  name = "persona-ui"
3
- version = "0.2.1"
4
  description = "Streamlit UI for persona-vectors"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
- "persona-vectors>=0.5.3",
9
  "persona-data>=0.4.1",
10
  "streamlit>=1.44.0",
11
  "plotly>=6.6.0",
 
1
  [project]
2
  name = "persona-ui"
3
+ version = "0.3.0"
4
  description = "Streamlit UI for persona-vectors"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
+ "persona-vectors>=0.6.1",
9
  "persona-data>=0.4.1",
10
  "streamlit>=1.44.0",
11
  "plotly>=6.6.0",
tabs/chat_ui.py CHANGED
@@ -6,14 +6,13 @@ from persona_data.synth_persona import PersonaData
6
 
7
  from utils.contrast import TokenContrast, render_contrast_html
8
  from utils.helpers import (
9
- CHAT_PROMPT_MODE_LABELS,
10
  CHAT_PROMPT_MODE_LABEL_TO_KEY,
 
11
  VARIANT_LABELS,
12
  persona_label,
13
  widget_key,
14
  )
15
 
16
-
17
  GENERATION_DEFAULTS = {
18
  "max_new_tokens": 256,
19
  "temperature": 1.0,
 
6
 
7
  from utils.contrast import TokenContrast, render_contrast_html
8
  from utils.helpers import (
 
9
  CHAT_PROMPT_MODE_LABEL_TO_KEY,
10
+ CHAT_PROMPT_MODE_LABELS,
11
  VARIANT_LABELS,
12
  persona_label,
13
  widget_key,
14
  )
15
 
 
16
  GENERATION_DEFAULTS = {
17
  "max_new_tokens": 256,
18
  "temperature": 1.0,
tabs/compare.py CHANGED
@@ -1,15 +1,16 @@
 
1
  from collections.abc import Callable
2
- from itertools import combinations
3
  from dataclasses import dataclass
 
4
 
5
  import streamlit as st
6
  from persona_data.environment import get_artifacts_dir
7
  from persona_vectors.analysis import (
8
- load_persona_mean_samples,
9
- load_variant_mean_samples,
10
  )
11
- from persona_vectors.artifacts import ActivationStore
12
- from persona_vectors.artifacts import list_layers as list_available_layers
13
  from persona_vectors.extraction import MaskStrategy
14
  from persona_vectors.plots import (
15
  build_layered_figure,
@@ -28,18 +29,29 @@ from utils.helpers import (
28
  widget_key,
29
  )
30
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def _filename(*parts: str) -> str:
33
  return "__".join(slugify(part) for part in parts if part)
34
 
35
 
36
- _list_layers_cached = st.cache_data(show_spinner=False)(list_available_layers)
37
 
38
  # Keep compare-tab selection state separate so projection defaults do not
39
  # overwrite cosine similarity defaults.
40
  _LAST_COSINE_PERSONAS_KEY = "compare:last_personas:cosine"
41
  _LAST_PROJECTION_PERSONAS_KEY = "compare:last_personas:projection"
42
  _LAST_MASK_STRATEGY_KEY = "compare:last_mask_strategy"
 
43
 
44
 
45
  @dataclass(frozen=True)
@@ -51,8 +63,35 @@ class CosineSelection:
51
  persona_key: str
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def _select_artifact_personas(
55
- store: ActivationStore,
56
  variants: list[str],
57
  mask_strategy: MaskStrategy,
58
  *,
@@ -61,17 +100,15 @@ def _select_artifact_personas(
61
  default_all: bool = False,
62
  ) -> tuple[list[str], dict[str, str]]:
63
  persona_options = store.list_personas(variants)
64
- persona_names = store.persona_names(
65
- persona_options,
66
- variants=variants,
67
- )
68
  if not persona_options:
69
  if len(variants) > 1:
70
  st.info(
71
- "No personas have saved activations for all selected variants. Run extraction for both variants first."
 
72
  )
73
  else:
74
- st.info("No personas found for this model yet. Run extraction first.")
75
  return [], persona_names
76
 
77
  last_personas: list[str] = st.session_state.get(remember_key, [])
@@ -147,19 +184,19 @@ def _render_mask_strategy_select(scope: str) -> MaskStrategy:
147
  ),
148
  format_func=lambda strategy: strategy.value.replace("_", " ").title(),
149
  key=widget_key("load", "mask_strategy", scope),
150
- help="Which extracted activation artifact set to load.",
151
  )
152
  st.session_state[_LAST_MASK_STRATEGY_KEY] = selected.value
153
  return selected
154
 
155
 
156
  def _render_cosine_selection(
157
- store: ActivationStore,
158
  mask_strategy: MaskStrategy,
159
  ) -> CosineSelection | None:
160
- variants = list(store.variants)
161
  if len(variants) < 2:
162
- st.info("Need at least two non-baseline variants for cosine comparison.")
163
  return None
164
 
165
  with st.expander("Vector selection", expanded=True):
@@ -170,7 +207,7 @@ def _render_cosine_selection(
170
  options=variants,
171
  index=0,
172
  format_func=prompt_variant_label,
173
- key=widget_key("load", "variant_a"),
174
  )
175
  with col2:
176
  variant_b = st.selectbox(
@@ -178,7 +215,7 @@ def _render_cosine_selection(
178
  options=variants,
179
  index=min(1, len(variants) - 1),
180
  format_func=prompt_variant_label,
181
- key=widget_key("load", "variant_b"),
182
  )
183
 
184
  if variant_a == variant_b:
@@ -189,7 +226,7 @@ def _render_cosine_selection(
189
  store,
190
  [variant_a, variant_b],
191
  mask_strategy,
192
- widget_scope="cosine",
193
  remember_key=_LAST_COSINE_PERSONAS_KEY,
194
  )
195
  if not persona_ids:
@@ -204,11 +241,11 @@ def _render_cosine_selection(
204
 
205
 
206
  def _build_cosine_figures(
207
- store: ActivationStore,
208
  selection: CosineSelection,
209
  ) -> tuple[object, object | None, int, int] | None:
210
  try:
211
- variant_samples = load_variant_mean_samples(
212
  store,
213
  [selection.variant_a, selection.variant_b],
214
  persona_ids=selection.persona_ids,
@@ -242,7 +279,7 @@ def _build_cosine_figures(
242
  pair_samples = (
243
  variant_samples
244
  if {left, right} == {selection.variant_a, selection.variant_b}
245
- else load_variant_mean_samples(
246
  store,
247
  [left, right],
248
  persona_ids=selection.persona_ids,
@@ -274,7 +311,7 @@ def _build_cosine_figures(
274
 
275
 
276
  def _render_cosine_similarity(
277
- store: ActivationStore,
278
  mask_strategy: MaskStrategy,
279
  ) -> None:
280
  selection = _render_cosine_selection(store, mask_strategy)
@@ -284,6 +321,7 @@ def _render_cosine_similarity(
284
  cosine_fig_key = widget_key(
285
  "load",
286
  "cosine_fig_state",
 
287
  store.model_name,
288
  mask_strategy.value,
289
  selection.variant_a,
@@ -312,6 +350,7 @@ def _render_cosine_similarity(
312
  key=widget_key(
313
  "load",
314
  "compare_vectors",
 
315
  store.model_name,
316
  mask_strategy.value,
317
  selection.variant_a,
@@ -342,27 +381,26 @@ def _render_cosine_similarity(
342
 
343
 
344
  def _select_single_variant_samples(
345
- store: ActivationStore,
346
  mask_strategy: MaskStrategy,
347
  scope: str,
348
  ) -> tuple[str, list[str], str, list[int]] | None:
349
- variants = list(store.variants)
 
 
 
350
  variant = st.selectbox(
351
  "Variant",
352
  options=variants,
353
- index=(
354
- variants.index("biography")
355
- if "biography" in variants
356
- else 0
357
- ),
358
  format_func=prompt_variant_label,
359
- key=widget_key("load", "variant", scope),
360
  )
361
  persona_ids, _ = _select_artifact_personas(
362
  store,
363
  [variant],
364
  mask_strategy,
365
- widget_scope=scope,
366
  remember_key=_LAST_PROJECTION_PERSONAS_KEY,
367
  default_all=True,
368
  )
@@ -370,13 +408,7 @@ def _select_single_variant_samples(
370
  return None
371
 
372
  persona_key = "_".join(sorted(persona_ids))
373
- layer_options = _list_layers_cached(
374
- str(store.root_dir),
375
- store.model_name,
376
- [variant],
377
- persona_ids,
378
- mask_strategy=mask_strategy,
379
- )
380
  if not layer_options:
381
  st.info("No shared layers are available for the selected personas.")
382
  return None
@@ -389,6 +421,7 @@ def _select_single_variant_samples(
389
  "load",
390
  "layers",
391
  scope,
 
392
  store.model_name,
393
  mask_strategy.value,
394
  variant,
@@ -403,7 +436,7 @@ def _select_single_variant_samples(
403
 
404
 
405
  def _render_layered_figure_analysis(
406
- store: ActivationStore,
407
  mask_strategy: MaskStrategy,
408
  *,
409
  scope: str,
@@ -425,11 +458,12 @@ def _render_layered_figure_analysis(
425
  fig_key = widget_key(
426
  "load",
427
  f"{scope}_fig_state",
 
428
  store.model_name,
429
  mask_strategy.value,
430
  figure_kind,
431
  variant,
432
- "persona_mean",
433
  persona_key,
434
  )
435
  filename = _filename(
@@ -438,13 +472,13 @@ def _render_layered_figure_analysis(
438
  store.model_name,
439
  mask_strategy.value,
440
  variant,
441
- "persona_mean",
442
  persona_key,
443
  )
444
 
445
  if st.button(button_label, type="primary"):
446
  try:
447
- samples = load_persona_mean_samples(
448
  store,
449
  variant,
450
  mask_strategy=mask_strategy,
@@ -462,8 +496,7 @@ def _render_layered_figure_analysis(
462
  layers=selected_layers,
463
  title=(
464
  "Pair similarity trajectories - "
465
- f"{prompt_variant_label(variant)} - "
466
- "persona mean activations"
467
  ),
468
  )
469
  if include_pair_trajectories
@@ -488,17 +521,45 @@ def _render_layered_figure_analysis(
488
  st.success(f"Loaded {n_samples} samples.")
489
 
490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  def render_compare_tab(model_name: str) -> None:
492
  """Render the compare tab."""
493
 
494
  st.title("Compare")
495
- st.caption("Compare saved activations by cosine similarity, PCA, or UMAP.")
496
 
497
- with st.expander("Artifact settings", expanded=False):
498
- artifacts_root = st.text_input(
499
- "Artifacts root",
500
- value=str(get_artifacts_dir() / "activations"),
501
- )
502
 
503
  analysis_mode = st.segmented_control(
504
  "Analysis mode",
@@ -510,9 +571,10 @@ def render_compare_tab(model_name: str) -> None:
510
  if analysis_mode is None:
511
  analysis_mode = ANALYSIS_MODES[0]
512
  st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
513
- with st.expander("Activation settings", expanded=False):
 
514
  mask_strategy = _render_mask_strategy_select(analysis_mode)
515
- store = ActivationStore(model_name, artifacts_root, mask_strategy=mask_strategy)
516
 
517
  if analysis_mode == "Cosine similarity":
518
  _render_cosine_similarity(store, mask_strategy)
@@ -525,8 +587,7 @@ def render_compare_tab(model_name: str) -> None:
525
  figure_kind="similarity",
526
  button_label="Generate similarity matrix",
527
  title_fn=lambda v: (
528
- "Centered similarity - "
529
- f"{prompt_variant_label(v)} - persona mean activations"
530
  ),
531
  include_pair_trajectories=True,
532
  )
@@ -539,6 +600,6 @@ def render_compare_tab(model_name: str) -> None:
539
  figure_kind=analysis_mode.lower(),
540
  button_label=f"Generate {analysis_mode} projection",
541
  title_fn=lambda v: (
542
- f"{analysis_mode} - {prompt_variant_label(v)} - Persona means"
543
  ),
544
  )
 
1
+ import os
2
  from collections.abc import Callable
 
3
  from dataclasses import dataclass
4
+ from itertools import combinations
5
 
6
  import streamlit as st
7
  from persona_data.environment import get_artifacts_dir
8
  from persona_vectors.analysis import (
9
+ load_persona_vectors,
10
+ load_variant_vectors,
11
  )
12
+ from persona_vectors.artifacts import ActivationStore, HFActivationStore
13
+ from persona_vectors.artifacts import list_layers as list_local_layers
14
  from persona_vectors.extraction import MaskStrategy
15
  from persona_vectors.plots import (
16
  build_layered_figure,
 
29
  widget_key,
30
  )
31
 
32
+ Store = ActivationStore | HFActivationStore
33
+
34
+ DEFAULT_HUB_REPO = os.environ.get(
35
+ "PERSONA_VECTORS_HUB_REPO",
36
+ "implicit-personalization/synth-persona-vectors",
37
+ )
38
+ SOURCE_HUB = "Hugging Face Hub"
39
+ SOURCE_LOCAL = "Local activations"
40
+ SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
41
+
42
 
43
  def _filename(*parts: str) -> str:
44
  return "__".join(slugify(part) for part in parts if part)
45
 
46
 
47
+ _list_layers_cached = st.cache_data(show_spinner=False)(list_local_layers)
48
 
49
  # Keep compare-tab selection state separate so projection defaults do not
50
  # overwrite cosine similarity defaults.
51
  _LAST_COSINE_PERSONAS_KEY = "compare:last_personas:cosine"
52
  _LAST_PROJECTION_PERSONAS_KEY = "compare:last_personas:projection"
53
  _LAST_MASK_STRATEGY_KEY = "compare:last_mask_strategy"
54
+ _LAST_SOURCE_KEY = "compare:last_source"
55
 
56
 
57
  @dataclass(frozen=True)
 
63
  persona_key: str
64
 
65
 
66
+ def _store_id(store: Store) -> str:
67
+ """Stable identifier for cache/widget keys that distinguishes Hub vs local."""
68
+ if isinstance(store, HFActivationStore):
69
+ return f"hub:{store.repo_id}"
70
+ return f"local:{store.root_dir}"
71
+
72
+
73
+ def _layers_for_variant(
74
+ store: Store,
75
+ variant: str,
76
+ persona_ids: list[str],
77
+ mask_strategy: MaskStrategy,
78
+ ) -> list[int]:
79
+ if isinstance(store, HFActivationStore):
80
+ if not persona_ids:
81
+ return []
82
+ sample = store.load(variant, persona_ids[0])
83
+ return list(range(int(sample.shape[0])))
84
+ return _list_layers_cached(
85
+ str(store.root_dir),
86
+ store.model_name,
87
+ [variant],
88
+ persona_ids,
89
+ mask_strategy=mask_strategy,
90
+ )
91
+
92
+
93
  def _select_artifact_personas(
94
+ store: Store,
95
  variants: list[str],
96
  mask_strategy: MaskStrategy,
97
  *,
 
100
  default_all: bool = False,
101
  ) -> tuple[list[str], dict[str, str]]:
102
  persona_options = store.list_personas(variants)
103
+ persona_names = store.persona_names(persona_options, variants=variants)
 
 
 
104
  if not persona_options:
105
  if len(variants) > 1:
106
  st.info(
107
+ "No personas have vectors for all selected variants. "
108
+ "Pick a single variant or change the source."
109
  )
110
  else:
111
+ st.info("No personas found for this model and variant.")
112
  return [], persona_names
113
 
114
  last_personas: list[str] = st.session_state.get(remember_key, [])
 
184
  ),
185
  format_func=lambda strategy: strategy.value.replace("_", " ").title(),
186
  key=widget_key("load", "mask_strategy", scope),
187
+ help="Which extracted activation set to load.",
188
  )
189
  st.session_state[_LAST_MASK_STRATEGY_KEY] = selected.value
190
  return selected
191
 
192
 
193
  def _render_cosine_selection(
194
+ store: Store,
195
  mask_strategy: MaskStrategy,
196
  ) -> CosineSelection | None:
197
+ variants = store.available_variants()
198
  if len(variants) < 2:
199
+ st.info("Need at least two variants with saved vectors for cosine comparison.")
200
  return None
201
 
202
  with st.expander("Vector selection", expanded=True):
 
207
  options=variants,
208
  index=0,
209
  format_func=prompt_variant_label,
210
+ key=widget_key("load", "variant_a", _store_id(store)),
211
  )
212
  with col2:
213
  variant_b = st.selectbox(
 
215
  options=variants,
216
  index=min(1, len(variants) - 1),
217
  format_func=prompt_variant_label,
218
+ key=widget_key("load", "variant_b", _store_id(store)),
219
  )
220
 
221
  if variant_a == variant_b:
 
226
  store,
227
  [variant_a, variant_b],
228
  mask_strategy,
229
+ widget_scope=f"cosine:{_store_id(store)}",
230
  remember_key=_LAST_COSINE_PERSONAS_KEY,
231
  )
232
  if not persona_ids:
 
241
 
242
 
243
  def _build_cosine_figures(
244
+ store: Store,
245
  selection: CosineSelection,
246
  ) -> tuple[object, object | None, int, int] | None:
247
  try:
248
+ variant_samples = load_variant_vectors(
249
  store,
250
  [selection.variant_a, selection.variant_b],
251
  persona_ids=selection.persona_ids,
 
279
  pair_samples = (
280
  variant_samples
281
  if {left, right} == {selection.variant_a, selection.variant_b}
282
+ else load_variant_vectors(
283
  store,
284
  [left, right],
285
  persona_ids=selection.persona_ids,
 
311
 
312
 
313
  def _render_cosine_similarity(
314
+ store: Store,
315
  mask_strategy: MaskStrategy,
316
  ) -> None:
317
  selection = _render_cosine_selection(store, mask_strategy)
 
321
  cosine_fig_key = widget_key(
322
  "load",
323
  "cosine_fig_state",
324
+ _store_id(store),
325
  store.model_name,
326
  mask_strategy.value,
327
  selection.variant_a,
 
350
  key=widget_key(
351
  "load",
352
  "compare_vectors",
353
+ _store_id(store),
354
  store.model_name,
355
  mask_strategy.value,
356
  selection.variant_a,
 
381
 
382
 
383
  def _select_single_variant_samples(
384
+ store: Store,
385
  mask_strategy: MaskStrategy,
386
  scope: str,
387
  ) -> tuple[str, list[str], str, list[int]] | None:
388
+ variants = store.available_variants()
389
+ if not variants:
390
+ st.info("No variants with saved vectors for this model.")
391
+ return None
392
  variant = st.selectbox(
393
  "Variant",
394
  options=variants,
395
+ index=variants.index("biography") if "biography" in variants else 0,
 
 
 
 
396
  format_func=prompt_variant_label,
397
+ key=widget_key("load", "variant", scope, _store_id(store)),
398
  )
399
  persona_ids, _ = _select_artifact_personas(
400
  store,
401
  [variant],
402
  mask_strategy,
403
+ widget_scope=f"{scope}:{_store_id(store)}",
404
  remember_key=_LAST_PROJECTION_PERSONAS_KEY,
405
  default_all=True,
406
  )
 
408
  return None
409
 
410
  persona_key = "_".join(sorted(persona_ids))
411
+ layer_options = _layers_for_variant(store, variant, persona_ids, mask_strategy)
 
 
 
 
 
 
412
  if not layer_options:
413
  st.info("No shared layers are available for the selected personas.")
414
  return None
 
421
  "load",
422
  "layers",
423
  scope,
424
+ _store_id(store),
425
  store.model_name,
426
  mask_strategy.value,
427
  variant,
 
436
 
437
 
438
  def _render_layered_figure_analysis(
439
+ store: Store,
440
  mask_strategy: MaskStrategy,
441
  *,
442
  scope: str,
 
458
  fig_key = widget_key(
459
  "load",
460
  f"{scope}_fig_state",
461
+ _store_id(store),
462
  store.model_name,
463
  mask_strategy.value,
464
  figure_kind,
465
  variant,
466
+ "persona_vector",
467
  persona_key,
468
  )
469
  filename = _filename(
 
472
  store.model_name,
473
  mask_strategy.value,
474
  variant,
475
+ "persona_vector",
476
  persona_key,
477
  )
478
 
479
  if st.button(button_label, type="primary"):
480
  try:
481
+ samples = load_persona_vectors(
482
  store,
483
  variant,
484
  mask_strategy=mask_strategy,
 
496
  layers=selected_layers,
497
  title=(
498
  "Pair similarity trajectories - "
499
+ f"{prompt_variant_label(variant)} - persona vectors"
 
500
  ),
501
  )
502
  if include_pair_trajectories
 
521
  st.success(f"Loaded {n_samples} samples.")
522
 
523
 
524
+ def _render_source_select() -> str:
525
+ last_source = st.session_state.get(_LAST_SOURCE_KEY, SOURCE_HUB)
526
+ source = st.segmented_control(
527
+ "Source",
528
+ options=SOURCES,
529
+ default=last_source if last_source in SOURCES else SOURCE_HUB,
530
+ key=widget_key("load", "source"),
531
+ label_visibility="collapsed",
532
+ )
533
+ if source is None:
534
+ source = SOURCE_HUB
535
+ st.session_state[_LAST_SOURCE_KEY] = source
536
+ return source
537
+
538
+
539
+ def _build_store(source: str, model_name: str, mask_strategy: MaskStrategy) -> Store:
540
+ if source == SOURCE_HUB:
541
+ repo = st.text_input(
542
+ "Hub repo",
543
+ value=st.session_state.get("compare:hub_repo", DEFAULT_HUB_REPO),
544
+ key="compare:hub_repo",
545
+ help="Hugging Face dataset published by `scripts/push_to_hf.py`.",
546
+ )
547
+ return HFActivationStore(repo, model_name, mask_strategy=mask_strategy)
548
+ artifacts_root = st.text_input(
549
+ "Artifacts root",
550
+ value=str(get_artifacts_dir() / "activations"),
551
+ key="compare:artifacts_root",
552
+ )
553
+ return ActivationStore(model_name, artifacts_root, mask_strategy=mask_strategy)
554
+
555
+
556
  def render_compare_tab(model_name: str) -> None:
557
  """Render the compare tab."""
558
 
559
  st.title("Compare")
560
+ st.caption("Compare persona vectors by cosine similarity, PCA, or UMAP.")
561
 
562
+ source = _render_source_select()
 
 
 
 
563
 
564
  analysis_mode = st.segmented_control(
565
  "Analysis mode",
 
571
  if analysis_mode is None:
572
  analysis_mode = ANALYSIS_MODES[0]
573
  st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
574
+
575
+ with st.expander("Source settings", expanded=False):
576
  mask_strategy = _render_mask_strategy_select(analysis_mode)
577
+ store = _build_store(source, model_name, mask_strategy)
578
 
579
  if analysis_mode == "Cosine similarity":
580
  _render_cosine_similarity(store, mask_strategy)
 
587
  figure_kind="similarity",
588
  button_label="Generate similarity matrix",
589
  title_fn=lambda v: (
590
+ f"Centered similarity - {prompt_variant_label(v)} - persona vectors"
 
591
  ),
592
  include_pair_trajectories=True,
593
  )
 
600
  figure_kind=analysis_mode.lower(),
601
  button_label=f"Generate {analysis_mode} projection",
602
  title_fn=lambda v: (
603
+ f"{analysis_mode} - {prompt_variant_label(v)} - persona vectors"
604
  ),
605
  )
tabs/extract.py CHANGED
@@ -7,11 +7,10 @@ from persona_data.synth_persona import BASELINE_PERSONA_ID, PersonaData, QAPair
7
  from persona_vectors.artifacts import PERSONA_VARIANTS
8
  from persona_vectors.extraction import (
9
  MaskStrategy,
10
- TokenSegment,
11
  prepare_inputs_for_strategy,
12
- preview_token_segments,
13
  run_extraction,
14
  )
 
15
 
16
  from utils.datasets import load_dataset
17
  from utils.helpers import (
@@ -33,7 +32,6 @@ _DEFAULT_MAX_QUESTIONS = 50
33
 
34
  @dataclass(frozen=True)
35
  class ExtractSettings:
36
- runs: list[tuple[PersonaData, list[QAPair]]]
37
  mask_strategy: MaskStrategy
38
  max_questions: int
39
 
@@ -307,7 +305,6 @@ def _render_extract_actions() -> tuple[bool, bool]:
307
 
308
  def _render_token_preview(
309
  *,
310
- remote: bool,
311
  model_name: str,
312
  run_plan: list[tuple[PersonaData, list[QAPair], str]],
313
  settings: ExtractSettings,
@@ -387,7 +384,7 @@ def _run_extraction_plan(
387
  progress.empty()
388
  ndif_status_box.empty()
389
 
390
- status_box.success("Extraction complete")
391
  st.success(f"Saved {len(results)} artifact set(s)")
392
 
393
  for result in results:
@@ -448,7 +445,6 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
448
  dataset_source=dataset_source,
449
  )
450
  settings = ExtractSettings(
451
- runs=runs,
452
  mask_strategy=mask_strategy,
453
  max_questions=max_questions,
454
  )
@@ -458,7 +454,6 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
458
 
459
  if preview_clicked:
460
  _render_token_preview(
461
- remote=remote,
462
  model_name=model_name,
463
  run_plan=run_plan,
464
  settings=settings,
 
7
  from persona_vectors.artifacts import PERSONA_VARIANTS
8
  from persona_vectors.extraction import (
9
  MaskStrategy,
 
10
  prepare_inputs_for_strategy,
 
11
  run_extraction,
12
  )
13
+ from persona_vectors.preview import TokenSegment, preview_token_segments
14
 
15
  from utils.datasets import load_dataset
16
  from utils.helpers import (
 
32
 
33
  @dataclass(frozen=True)
34
  class ExtractSettings:
 
35
  mask_strategy: MaskStrategy
36
  max_questions: int
37
 
 
305
 
306
  def _render_token_preview(
307
  *,
 
308
  model_name: str,
309
  run_plan: list[tuple[PersonaData, list[QAPair], str]],
310
  settings: ExtractSettings,
 
384
  progress.empty()
385
  ndif_status_box.empty()
386
 
387
+ status_box.empty()
388
  st.success(f"Saved {len(results)} artifact set(s)")
389
 
390
  for result in results:
 
445
  dataset_source=dataset_source,
446
  )
447
  settings = ExtractSettings(
 
448
  mask_strategy=mask_strategy,
449
  max_questions=max_questions,
450
  )
 
454
 
455
  if preview_clicked:
456
  _render_token_preview(
 
457
  model_name=model_name,
458
  run_plan=run_plan,
459
  settings=settings,
utils/helpers.py CHANGED
@@ -28,9 +28,9 @@ ANALYSIS_MODES = ["Cosine similarity", "Similarity matrix", "PCA", "UMAP"]
28
 
29
  ANALYSIS_HELP_TEXT = {
30
  "Cosine similarity": "Compare layer-wise alignment between variants.",
31
- "Similarity matrix": "Compare centered pairwise similarity between persona means by layer, with pair trajectories across layers.",
32
- "PCA": "Project per-persona mean activations into a 2D global view.",
33
- "UMAP": "Project per-persona mean activations into a 2D local-neighborhood view.",
34
  }
35
 
36
  NDIF_STATUS_ICONS = {
 
28
 
29
  ANALYSIS_HELP_TEXT = {
30
  "Cosine similarity": "Compare layer-wise alignment between variants.",
31
+ "Similarity matrix": "Compare centered pairwise similarity between persona vectors by layer, with pair trajectories across layers.",
32
+ "PCA": "Project per-persona vectors into a 2D global view.",
33
+ "UMAP": "Project per-persona vectors into a 2D local-neighborhood view.",
34
  }
35
 
36
  NDIF_STATUS_ICONS = {
uv.lock CHANGED
@@ -1566,7 +1566,7 @@ wheels = [
1566
 
1567
  [[package]]
1568
  name = "persona-ui"
1569
- version = "0.2.1"
1570
  source = { virtual = "." }
1571
  dependencies = [
1572
  { name = "persona-data" },
@@ -1579,7 +1579,7 @@ dependencies = [
1579
  [package.metadata]
1580
  requires-dist = [
1581
  { name = "persona-data", specifier = ">=0.4.1" },
1582
- { name = "persona-vectors", specifier = ">=0.5.3" },
1583
  { name = "plotly", specifier = ">=6.6.0" },
1584
  { name = "python-dotenv", specifier = ">=1.2.2" },
1585
  { name = "streamlit", specifier = ">=1.44.0" },
@@ -1587,7 +1587,7 @@ requires-dist = [
1587
 
1588
  [[package]]
1589
  name = "persona-vectors"
1590
- version = "0.5.3"
1591
  source = { registry = "https://pypi.org/simple" }
1592
  dependencies = [
1593
  { name = "datasets" },
@@ -1606,9 +1606,9 @@ dependencies = [
1606
  { name = "transformers" },
1607
  { name = "umap-learn" },
1608
  ]
1609
- sdist = { url = "https://files.pythonhosted.org/packages/c7/53/20c77e298eb864ab917d58312b679007b936af64ede5fbb72d409268d62e/persona_vectors-0.5.3.tar.gz", hash = "sha256:d8bcb088a1814702401d22e21c39662ab840a4fb4b4f57dfd79999c9debfc1b8", size = 22791, upload-time = "2026-05-07T12:16:49.969Z" }
1610
  wheels = [
1611
- { url = "https://files.pythonhosted.org/packages/32/26/a0197928e5202403094883331ff87799b54803bd3fd749b7d9c11f7332b3/persona_vectors-0.5.3-py3-none-any.whl", hash = "sha256:e44af7a6846d6d9249da12707dfae57f89d2a50a4ced05cdb4a844d39a9f03e8", size = 26805, upload-time = "2026-05-07T12:16:51.035Z" },
1612
  ]
1613
 
1614
  [[package]]
 
1566
 
1567
  [[package]]
1568
  name = "persona-ui"
1569
+ version = "0.3.0"
1570
  source = { virtual = "." }
1571
  dependencies = [
1572
  { name = "persona-data" },
 
1579
  [package.metadata]
1580
  requires-dist = [
1581
  { name = "persona-data", specifier = ">=0.4.1" },
1582
+ { name = "persona-vectors", specifier = ">=0.6.1" },
1583
  { name = "plotly", specifier = ">=6.6.0" },
1584
  { name = "python-dotenv", specifier = ">=1.2.2" },
1585
  { name = "streamlit", specifier = ">=1.44.0" },
 
1587
 
1588
  [[package]]
1589
  name = "persona-vectors"
1590
+ version = "0.6.1"
1591
  source = { registry = "https://pypi.org/simple" }
1592
  dependencies = [
1593
  { name = "datasets" },
 
1606
  { name = "transformers" },
1607
  { name = "umap-learn" },
1608
  ]
1609
+ sdist = { url = "https://files.pythonhosted.org/packages/69/f3/6da35af90c8ea5333db1763ece04a3230353ac5a76c0dc8fea705a6e86cf/persona_vectors-0.6.1.tar.gz", hash = "sha256:552ac9a0d739a453c5d9eb612cb0d0d2820a1b53ce84f490295a84105a71f7cc", size = 24311, upload-time = "2026-05-07T15:07:29.951Z" }
1610
  wheels = [
1611
+ { url = "https://files.pythonhosted.org/packages/86/66/91df378258e2c0cbc7860652b07b5e65ee1949ba14be2efdb6c646a933f1/persona_vectors-0.6.1-py3-none-any.whl", hash = "sha256:593977ad19c9f23df7d86e302fe4bcf49159425da67d83281a11858026c5e85e", size = 28683, upload-time = "2026-05-07T15:07:30.791Z" },
1612
  ]
1613
 
1614
  [[package]]