Jac-Zac commited on
Commit
2bf3d21
·
1 Parent(s): 12cdb17

App new feel and look revamp

Browse files
Files changed (7) hide show
  1. app.py +35 -8
  2. pyproject.toml +3 -1
  3. tabs/compare.py +339 -137
  4. tabs/extract.py +3 -19
  5. utils/compare_sources.py +186 -0
  6. utils/controls.py +29 -0
  7. uv.lock +22 -18
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
 
3
  import streamlit as st
4
  from dotenv import load_dotenv
@@ -16,6 +17,14 @@ _TABS = ["Chat", "Compare", "Extract"]
16
  _TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
17
 
18
 
 
 
 
 
 
 
 
 
19
  def _remote_model_input(remote_models: list[str]) -> str:
20
  """Return the active remote model id, picking from running NDIF deployments or a custom value."""
21
 
@@ -74,7 +83,7 @@ def _remote_model_input(remote_models: list[str]) -> str:
74
  return model_name
75
 
76
 
77
- def _sidebar_controls() -> tuple[bool, str, str, str]:
78
  from utils.runtime import list_remote_models
79
 
80
  with st.sidebar:
@@ -96,6 +105,19 @@ def _sidebar_controls() -> tuple[bool, str, str, str]:
96
  st.session_state["sidebar__active_tab"] = tab_name
97
  st.rerun()
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  st.divider()
100
  st.caption("Runtime")
101
  remote = st.toggle("Remote (NDIF)", value=False, key="sidebar__remote")
@@ -119,7 +141,12 @@ def _sidebar_controls() -> tuple[bool, str, str, str]:
119
  help="Dataset for Chat and Extract.",
120
  )
121
 
122
- return remote, model_name, dataset_source, active_tab
 
 
 
 
 
123
 
124
 
125
  def main() -> None:
@@ -136,20 +163,20 @@ def main() -> None:
136
 
137
  torch.set_grad_enabled(False)
138
 
139
- remote, model_name, dataset_source, active_tab = _sidebar_controls()
140
 
141
- if active_tab == "Extract":
142
  from tabs.extract import render_extract_tab
143
 
144
- render_extract_tab(remote, model_name, dataset_source)
145
- elif active_tab == "Compare":
146
  from tabs.compare import render_compare_tab
147
 
148
- render_compare_tab(model_name)
149
  else:
150
  from tabs.chat import render_chat_tab
151
 
152
- render_chat_tab(remote, model_name, dataset_source)
153
 
154
 
155
  if __name__ == "__main__":
 
1
  import os
2
+ from dataclasses import dataclass
3
 
4
  import streamlit as st
5
  from dotenv import load_dotenv
 
17
  _TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
18
 
19
 
20
+ @dataclass(frozen=True)
21
+ class SidebarState:
22
+ remote: bool
23
+ model_name: str
24
+ dataset_source: str
25
+ active_tab: str
26
+
27
+
28
  def _remote_model_input(remote_models: list[str]) -> str:
29
  """Return the active remote model id, picking from running NDIF deployments or a custom value."""
30
 
 
83
  return model_name
84
 
85
 
86
+ def _sidebar_controls() -> SidebarState:
87
  from utils.runtime import list_remote_models
88
 
89
  with st.sidebar:
 
105
  st.session_state["sidebar__active_tab"] = tab_name
106
  st.rerun()
107
 
108
+ if active_tab == "Compare":
109
+ model_name = st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL)
110
+ dataset_source = st.session_state.get(
111
+ "sidebar__dataset_source",
112
+ DATASET_SOURCES[0],
113
+ )
114
+ return SidebarState(
115
+ remote=False,
116
+ model_name=model_name,
117
+ dataset_source=dataset_source,
118
+ active_tab=active_tab,
119
+ )
120
+
121
  st.divider()
122
  st.caption("Runtime")
123
  remote = st.toggle("Remote (NDIF)", value=False, key="sidebar__remote")
 
141
  help="Dataset for Chat and Extract.",
142
  )
143
 
144
+ return SidebarState(
145
+ remote=remote,
146
+ model_name=model_name,
147
+ dataset_source=dataset_source,
148
+ active_tab=active_tab,
149
+ )
150
 
151
 
152
  def main() -> None:
 
163
 
164
  torch.set_grad_enabled(False)
165
 
166
+ sidebar = _sidebar_controls()
167
 
168
+ if sidebar.active_tab == "Extract":
169
  from tabs.extract import render_extract_tab
170
 
171
+ render_extract_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)
172
+ elif sidebar.active_tab == "Compare":
173
  from tabs.compare import render_compare_tab
174
 
175
+ render_compare_tab()
176
  else:
177
  from tabs.chat import render_chat_tab
178
 
179
+ render_chat_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)
180
 
181
 
182
  if __name__ == "__main__":
pyproject.toml CHANGED
@@ -1,12 +1,14 @@
1
  [project]
2
  name = "persona-ui"
3
- version = "0.3.0"
4
  description = "Streamlit UI for persona-vectors"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
  "persona-vectors>=0.6.4",
9
  "persona-data>=0.4.2",
 
 
10
  "streamlit>=1.44.0",
11
  "plotly>=6.6.0",
12
  "python-dotenv>=1.2.2",
 
1
  [project]
2
  name = "persona-ui"
3
+ version = "0.4.0"
4
  description = "Streamlit UI for persona-vectors"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
  "persona-vectors>=0.6.4",
9
  "persona-data>=0.4.2",
10
+ "datasets>=4.8.5",
11
+ "huggingface-hub>=1.14.0",
12
  "streamlit>=1.44.0",
13
  "plotly>=6.6.0",
14
  "python-dotenv>=1.2.2",
tabs/compare.py CHANGED
@@ -1,13 +1,13 @@
1
- import os
2
  from collections.abc import Callable
3
  from dataclasses import dataclass
4
  from itertools import combinations
 
5
 
6
  import streamlit as st
7
  from persona_data.environment import get_artifacts_dir
 
8
  from persona_vectors.analysis import load_persona_vectors, load_variant_vectors
9
- from persona_vectors.artifacts import ActivationStore, HFActivationStore
10
- from persona_vectors.artifacts import list_layers as list_local_layers
11
  from persona_vectors.extraction import MaskStrategy
12
  from persona_vectors.plots import (
13
  build_layered_figure,
@@ -16,50 +16,39 @@ from persona_vectors.plots import (
16
  save_plot_html,
17
  )
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  from utils.helpers import (
20
  ANALYSIS_HELP_TEXT,
21
  ANALYSIS_MODES,
22
- persona_display_label,
23
  prompt_variant_label,
24
  slugify,
25
  widget_key,
26
  )
27
 
28
- Store = ActivationStore | HFActivationStore
29
-
30
- DEFAULT_HUB_REPO = os.environ.get(
31
- "PERSONA_VECTORS_HUB_REPO",
32
- "implicit-personalization/synth-persona-vectors",
33
- )
34
- SOURCE_HUB = "Hugging Face Hub"
35
- SOURCE_LOCAL = "Local activations"
36
- SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
37
-
38
 
39
  def _filename(*parts: str) -> str:
40
  return "__".join(slugify(part) for part in parts if part)
41
 
42
 
43
- _list_layers_cached = st.cache_data(show_spinner=False)(list_local_layers)
44
-
45
-
46
- @st.cache_data(show_spinner=False)
47
- def _hub_layers_cached(
48
- repo_id: str,
49
- model_name: str,
50
- mask_strategy_value: str,
51
- variant: str,
52
- persona_id: str,
53
- ) -> list[int]:
54
- store = HFActivationStore(
55
- repo_id,
56
- model_name,
57
- mask_strategy=MaskStrategy(mask_strategy_value),
58
- )
59
- sample = store.load(variant, persona_id)
60
- return list(range(int(sample.shape[0])))
61
-
62
-
63
  # Keep compare-tab selection state separate so projection defaults do not
64
  # overwrite cosine similarity defaults.
65
  _LAST_COSINE_PERSONAS_KEY = "compare:last_personas:cosine"
@@ -68,6 +57,15 @@ _LAST_MASK_STRATEGY_KEY = "compare:last_mask_strategy"
68
  _LAST_SOURCE_KEY = "compare:last_source"
69
 
70
 
 
 
 
 
 
 
 
 
 
71
  @dataclass(frozen=True)
72
  class CosineSelection:
73
  variants: list[str]
@@ -77,11 +75,10 @@ class CosineSelection:
77
  persona_key: str
78
 
79
 
80
- def _store_id(store: Store) -> str:
81
- """Stable identifier for cache/widget keys that distinguishes Hub vs local."""
82
- if isinstance(store, HFActivationStore):
83
- return f"hub:{store.repo_id}"
84
- return f"local:{store.root_dir}"
85
 
86
 
87
  def _layers_for_variant(
@@ -93,14 +90,14 @@ def _layers_for_variant(
93
  if isinstance(store, HFActivationStore):
94
  if not persona_ids:
95
  return []
96
- return _hub_layers_cached(
97
  store.repo_id,
98
  store.model_name,
99
  mask_strategy.value,
100
  variant,
101
  persona_ids[0],
102
  )
103
- return _list_layers_cached(
104
  str(store.root_dir),
105
  store.model_name,
106
  [variant],
@@ -109,59 +106,188 @@ def _layers_for_variant(
109
  )
110
 
111
 
112
- def _select_artifact_personas(
113
  store: Store,
114
  variants: list[str],
115
  mask_strategy: MaskStrategy,
116
  *,
117
- widget_scope: str,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  remember_key: str,
119
- default_all: bool = False,
120
- ) -> tuple[list[str], dict[str, str]]:
121
- persona_options = store.list_personas(variants)
122
- persona_names = store.persona_names(persona_options, variants=variants)
123
- if not persona_options:
124
- if len(variants) > 1:
125
- st.info(
126
- "No personas have vectors for all selected variants. "
127
- "Pick a single variant or change the source."
128
- )
129
- else:
130
- st.info("No personas found for this model and variant.")
131
- return [], persona_names
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- last_personas: list[str] = st.session_state.get(remember_key, [])
134
- default_personas = [p for p in last_personas if p in persona_options]
135
- if not default_personas:
136
- default_personas = persona_options if default_all else persona_options[:1]
137
 
138
- persona_key = widget_key(
 
 
 
 
 
 
 
 
 
 
139
  "load",
140
- "personas",
 
 
 
 
 
 
 
 
141
  widget_scope,
142
  store.model_name,
143
  mask_strategy.value,
144
  *variants,
145
  )
146
 
147
- def _remember_personas() -> None:
148
- st.session_state[remember_key] = [
149
- persona_id
150
- for persona_id in st.session_state.get(persona_key, [])
151
- if persona_id in persona_options
152
- ]
153
-
154
- persona_ids = st.multiselect(
155
- "Personas",
156
- options=persona_options,
157
- default=default_personas,
158
- format_func=lambda persona_id: persona_display_label(
159
- persona_id, persona_names.get(persona_id)
160
- ),
161
- key=persona_key,
162
- on_change=_remember_personas,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  )
164
- return persona_ids, persona_names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
 
167
  def _render_save_buttons(
@@ -179,35 +305,18 @@ def _render_save_buttons(
179
 
180
 
181
  def _render_mask_strategy_select(scope: str) -> MaskStrategy:
182
- last_strategy = st.session_state.get(
183
- _LAST_MASK_STRATEGY_KEY,
184
- MaskStrategy.ANSWER_MEAN.value,
185
- )
186
- strategies = list(MaskStrategy)
187
- selected = st.selectbox(
188
- "Mask strategy",
189
- options=strategies,
190
- index=next(
191
- (
192
- idx
193
- for idx, strategy in enumerate(strategies)
194
- if strategy.value == last_strategy
195
- ),
196
- 0,
197
- ),
198
- format_func=lambda strategy: strategy.value.replace("_", " ").title(),
199
  key=widget_key("load", "mask_strategy", scope),
 
200
  help="Which extracted activation set to load.",
201
  )
202
- st.session_state[_LAST_MASK_STRATEGY_KEY] = selected.value
203
- return selected
204
 
205
 
206
  def _render_cosine_selection(
207
  store: Store,
208
  mask_strategy: MaskStrategy,
209
  ) -> CosineSelection | None:
210
- variants = store.available_variants()
211
  if len(variants) < 2:
212
  st.info("Need at least two variants with saved vectors for cosine comparison.")
213
  return None
@@ -220,7 +329,7 @@ def _render_cosine_selection(
220
  options=variants,
221
  index=0,
222
  format_func=prompt_variant_label,
223
- key=widget_key("load", "variant_a", _store_id(store)),
224
  )
225
  with col2:
226
  variant_b = st.selectbox(
@@ -228,18 +337,18 @@ def _render_cosine_selection(
228
  options=variants,
229
  index=min(1, len(variants) - 1),
230
  format_func=prompt_variant_label,
231
- key=widget_key("load", "variant_b", _store_id(store)),
232
  )
233
 
234
  if variant_a == variant_b:
235
  st.warning("Choose two different variants to compare.")
236
  return None
237
 
238
- persona_ids, _ = _select_artifact_personas(
239
  store,
240
  [variant_a, variant_b],
241
  mask_strategy,
242
- widget_scope=f"cosine:{_store_id(store)}",
243
  remember_key=_LAST_COSINE_PERSONAS_KEY,
244
  )
245
  if not persona_ids:
@@ -334,7 +443,7 @@ def _render_cosine_similarity(
334
  cosine_fig_key = widget_key(
335
  "load",
336
  "cosine_fig_state",
337
- _store_id(store),
338
  store.model_name,
339
  mask_strategy.value,
340
  selection.variant_a,
@@ -363,7 +472,7 @@ def _render_cosine_similarity(
363
  key=widget_key(
364
  "load",
365
  "compare_vectors",
366
- _store_id(store),
367
  store.model_name,
368
  mask_strategy.value,
369
  selection.variant_a,
@@ -398,7 +507,7 @@ def _select_single_variant_samples(
398
  mask_strategy: MaskStrategy,
399
  scope: str,
400
  ) -> tuple[str, list[str], str, list[int]] | None:
401
- variants = store.available_variants()
402
  if not variants:
403
  st.info("No variants with saved vectors for this model.")
404
  return None
@@ -407,13 +516,13 @@ def _select_single_variant_samples(
407
  options=variants,
408
  index=variants.index("biography") if "biography" in variants else 0,
409
  format_func=prompt_variant_label,
410
- key=widget_key("load", "variant", scope, _store_id(store)),
411
  )
412
- persona_ids, _ = _select_artifact_personas(
413
  store,
414
  [variant],
415
  mask_strategy,
416
- widget_scope=f"{scope}:{_store_id(store)}",
417
  remember_key=_LAST_PROJECTION_PERSONAS_KEY,
418
  default_all=True,
419
  )
@@ -426,26 +535,8 @@ def _select_single_variant_samples(
426
  st.info("No shared layers are available for the selected personas.")
427
  return None
428
 
429
- selected_layers = st.multiselect(
430
- "Layers",
431
- options=layer_options,
432
- default=layer_options,
433
- key=widget_key(
434
- "load",
435
- "layers",
436
- scope,
437
- _store_id(store),
438
- store.model_name,
439
- mask_strategy.value,
440
- variant,
441
- persona_key,
442
- ),
443
- )
444
- if not selected_layers:
445
- st.info("Select at least one layer.")
446
- return None
447
-
448
- return variant, persona_ids, persona_key, selected_layers
449
 
450
 
451
  def _render_layered_figure_analysis(
@@ -472,7 +563,7 @@ def _render_layered_figure_analysis(
472
  fig_key = widget_key(
473
  "load",
474
  f"{scope}_fig_state",
475
- _store_id(store),
476
  store.model_name,
477
  mask_strategy.value,
478
  figure_kind,
@@ -481,7 +572,7 @@ def _render_layered_figure_analysis(
481
  "persona_vector",
482
  persona_key,
483
  )
484
- filename = scope if n_components == 2 else f"{scope}_3d"
485
 
486
  if st.button(button_label, type="primary"):
487
  try:
@@ -549,7 +640,105 @@ def _render_source_select() -> str:
549
  return source
550
 
551
 
552
- def _build_store(source: str, model_name: str, mask_strategy: MaskStrategy) -> Store:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  if source == SOURCE_HUB:
554
  repo = st.text_input(
555
  "Hub repo",
@@ -557,16 +746,29 @@ def _build_store(source: str, model_name: str, mask_strategy: MaskStrategy) -> S
557
  key="compare:hub_repo",
558
  help="Hugging Face dataset published by `scripts/push_to_hf.py`.",
559
  )
560
- return HFActivationStore(repo, model_name, mask_strategy=mask_strategy)
 
 
 
 
 
 
561
  artifacts_root = st.text_input(
562
  "Artifacts root",
563
  value=str(get_artifacts_dir() / "activations"),
564
  key="compare:artifacts_root",
565
  )
566
- return ActivationStore(model_name, artifacts_root, mask_strategy=mask_strategy)
 
 
 
 
 
 
 
567
 
568
 
569
- def render_compare_tab(model_name: str) -> None:
570
  """Render the compare tab."""
571
 
572
  st.title("Compare")
@@ -585,9 +787,9 @@ def render_compare_tab(model_name: str) -> None:
585
  analysis_mode = ANALYSIS_MODES[0]
586
  st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
587
 
588
- with st.expander("Source settings", expanded=False):
589
  mask_strategy = _render_mask_strategy_select(analysis_mode)
590
- store = _build_store(source, model_name, mask_strategy)
591
 
592
  if analysis_mode == "Cosine similarity":
593
  _render_cosine_similarity(store, mask_strategy)
 
 
1
  from collections.abc import Callable
2
  from dataclasses import dataclass
3
  from itertools import combinations
4
+ from pathlib import Path
5
 
6
  import streamlit as st
7
  from persona_data.environment import get_artifacts_dir
8
+ from persona_data.synth_persona import BASELINE_PERSONA_ID
9
  from persona_vectors.analysis import load_persona_vectors, load_variant_vectors
10
+ from persona_vectors.artifacts import HFActivationStore
 
11
  from persona_vectors.extraction import MaskStrategy
12
  from persona_vectors.plots import (
13
  build_layered_figure,
 
16
  save_plot_html,
17
  )
18
 
19
+ from utils.compare_sources import (
20
+ DEFAULT_COMPARE_MODEL,
21
+ DEFAULT_HUB_REPO,
22
+ SOURCE_HUB,
23
+ SOURCE_LOCAL,
24
+ SOURCES,
25
+ Store,
26
+ activation_store_cached,
27
+ available_variants,
28
+ hub_layers_cached,
29
+ hub_models_by_mask_strategy,
30
+ list_layers_cached,
31
+ local_model_matches,
32
+ local_model_options_cached,
33
+ persona_names_cached,
34
+ personas_cached,
35
+ store_cache_parts,
36
+ store_id,
37
+ )
38
+ from utils.controls import render_mask_strategy_select
39
  from utils.helpers import (
40
  ANALYSIS_HELP_TEXT,
41
  ANALYSIS_MODES,
 
42
  prompt_variant_label,
43
  slugify,
44
  widget_key,
45
  )
46
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  def _filename(*parts: str) -> str:
49
  return "__".join(slugify(part) for part in parts if part)
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  # Keep compare-tab selection state separate so projection defaults do not
53
  # overwrite cosine similarity defaults.
54
  _LAST_COSINE_PERSONAS_KEY = "compare:last_personas:cosine"
 
57
  _LAST_SOURCE_KEY = "compare:last_source"
58
 
59
 
60
+ def _is_assistant_persona(persona_id: str, persona_name: str | None = None) -> bool:
61
+ persona_id_normalized = persona_id.strip().lower()
62
+ persona_name_normalized = (persona_name or "").strip().lower()
63
+ return (
64
+ persona_id_normalized in {"assistant", BASELINE_PERSONA_ID.lower()}
65
+ or persona_name_normalized == "assistant"
66
+ )
67
+
68
+
69
  @dataclass(frozen=True)
70
  class CosineSelection:
71
  variants: list[str]
 
75
  persona_key: str
76
 
77
 
78
+ @dataclass(frozen=True)
79
+ class PersonaOptions:
80
+ regular_ids: list[str]
81
+ assistant_id: str | None
 
82
 
83
 
84
  def _layers_for_variant(
 
90
  if isinstance(store, HFActivationStore):
91
  if not persona_ids:
92
  return []
93
+ return hub_layers_cached(
94
  store.repo_id,
95
  store.model_name,
96
  mask_strategy.value,
97
  variant,
98
  persona_ids[0],
99
  )
100
+ return list_layers_cached(
101
  str(store.root_dir),
102
  store.model_name,
103
  [variant],
 
106
  )
107
 
108
 
109
+ def _load_persona_options(
110
  store: Store,
111
  variants: list[str],
112
  mask_strategy: MaskStrategy,
113
  *,
114
+ empty_message: str,
115
+ ) -> PersonaOptions | None:
116
+ source, location, model_name = store_cache_parts(store)
117
+ variant_key = tuple(variants)
118
+ persona_ids = personas_cached(
119
+ source,
120
+ location,
121
+ model_name,
122
+ mask_strategy.value,
123
+ variant_key,
124
+ )
125
+ if not persona_ids:
126
+ st.info(empty_message)
127
+ return None
128
+
129
+ persona_names = persona_names_cached(
130
+ source,
131
+ location,
132
+ model_name,
133
+ mask_strategy.value,
134
+ variant_key,
135
+ tuple(persona_ids),
136
+ )
137
+ assistant_ids = [
138
+ persona_id
139
+ for persona_id in persona_ids
140
+ if _is_assistant_persona(persona_id, persona_names.get(persona_id))
141
+ ]
142
+ assistant_id = next(
143
+ (
144
+ persona_id
145
+ for persona_id in assistant_ids
146
+ if persona_id == BASELINE_PERSONA_ID
147
+ ),
148
+ assistant_ids[0] if assistant_ids else None,
149
+ )
150
+ regular_ids = [persona_id for persona_id in persona_ids if persona_id not in assistant_ids]
151
+ if not regular_ids and assistant_id is None:
152
+ st.info("No personas found for this model and variant.")
153
+ return None
154
+ return PersonaOptions(regular_ids=regular_ids, assistant_id=assistant_id)
155
+
156
+
157
+ def _seed_persona_memory(
158
  remember_key: str,
159
+ options: PersonaOptions,
160
+ *,
161
+ default_all: bool,
162
+ ) -> tuple[int, bool]:
163
+ remembered_count_key = f"{remember_key}:count"
164
+ remembered_assistant_key = f"{remember_key}:include_assistant"
165
+ legacy_ids = st.session_state.get(remember_key, [])
166
+ if isinstance(legacy_ids, list) and legacy_ids:
167
+ st.session_state.setdefault(
168
+ remembered_count_key,
169
+ sum(persona_id in options.regular_ids for persona_id in legacy_ids),
170
+ )
171
+ st.session_state.setdefault(
172
+ remembered_assistant_key,
173
+ options.assistant_id in legacy_ids,
174
+ )
175
+
176
+ default_count = len(options.regular_ids) if default_all else min(1, len(options.regular_ids))
177
+ remembered_count = int(st.session_state.get(remembered_count_key, default_count))
178
+ persona_count = min(max(remembered_count, 0), len(options.regular_ids))
179
+ include_assistant = bool(
180
+ st.session_state.get(remembered_assistant_key, options.assistant_id is not None)
181
+ )
182
+ return persona_count, include_assistant
183
 
 
 
 
 
184
 
185
+ def _render_persona_count_controls(
186
+ store: Store,
187
+ variants: list[str],
188
+ mask_strategy: MaskStrategy,
189
+ widget_scope: str,
190
+ options: PersonaOptions,
191
+ *,
192
+ default_count: int,
193
+ include_assistant_default: bool,
194
+ ) -> tuple[int, bool]:
195
+ count_key = widget_key(
196
  "load",
197
+ "persona_count",
198
+ widget_scope,
199
+ store.model_name,
200
+ mask_strategy.value,
201
+ *variants,
202
+ )
203
+ assistant_key = widget_key(
204
+ "load",
205
+ "include_assistant",
206
  widget_scope,
207
  store.model_name,
208
  mask_strategy.value,
209
  *variants,
210
  )
211
 
212
+ if options.regular_ids:
213
+ persona_count = st.slider(
214
+ "Personas",
215
+ min_value=0 if options.assistant_id is not None else 1,
216
+ max_value=len(options.regular_ids),
217
+ value=default_count,
218
+ key=count_key,
219
+ help="Use the first N available non-assistant personas.",
220
+ )
221
+ else:
222
+ persona_count = 0
223
+ st.caption("No non-assistant personas are available for this selection.")
224
+ include_assistant = False
225
+ if options.assistant_id is not None:
226
+ include_assistant = st.checkbox(
227
+ "Include Assistant persona",
228
+ value=include_assistant_default,
229
+ key=assistant_key,
230
+ )
231
+ return persona_count, include_assistant
232
+
233
+
234
+ def _select_artifact_personas(
235
+ store: Store,
236
+ variants: list[str],
237
+ mask_strategy: MaskStrategy,
238
+ *,
239
+ widget_scope: str,
240
+ remember_key: str,
241
+ default_all: bool = False,
242
+ ) -> list[str]:
243
+ empty_message = (
244
+ "No personas have vectors for all selected variants. "
245
+ "Pick a single variant or change the source."
246
+ if len(variants) > 1
247
+ else "No personas found for this model and variant."
248
  )
249
+ options = _load_persona_options(
250
+ store,
251
+ variants,
252
+ mask_strategy,
253
+ empty_message=empty_message,
254
+ )
255
+ if options is None:
256
+ return []
257
+
258
+ default_count, include_assistant_default = _seed_persona_memory(
259
+ remember_key,
260
+ options,
261
+ default_all=default_all,
262
+ )
263
+ persona_count, include_assistant = _render_persona_count_controls(
264
+ store,
265
+ variants,
266
+ mask_strategy,
267
+ widget_scope,
268
+ options,
269
+ default_count=default_count,
270
+ include_assistant_default=include_assistant_default,
271
+ )
272
+
273
+ persona_ids = options.regular_ids[:persona_count]
274
+ if include_assistant and options.assistant_id is not None:
275
+ persona_ids.append(options.assistant_id)
276
+
277
+ remembered_count_key = f"{remember_key}:count"
278
+ remembered_assistant_key = f"{remember_key}:include_assistant"
279
+ st.session_state[remembered_count_key] = persona_count
280
+ st.session_state[remembered_assistant_key] = include_assistant
281
+ st.session_state[remember_key] = persona_ids
282
+
283
+ if not persona_ids:
284
+ st.info("Select at least one persona or include the Assistant persona.")
285
+ return []
286
+
287
+ regular_label = f"{persona_count} persona{'s' if persona_count != 1 else ''}"
288
+ assistant_label = " plus Assistant" if include_assistant and options.assistant_id else ""
289
+ st.caption(f"Using {regular_label}{assistant_label}.")
290
+ return persona_ids
291
 
292
 
293
  def _render_save_buttons(
 
305
 
306
 
307
  def _render_mask_strategy_select(scope: str) -> MaskStrategy:
308
+ return render_mask_strategy_select(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  key=widget_key("load", "mask_strategy", scope),
310
+ last_key=_LAST_MASK_STRATEGY_KEY,
311
  help="Which extracted activation set to load.",
312
  )
 
 
313
 
314
 
315
  def _render_cosine_selection(
316
  store: Store,
317
  mask_strategy: MaskStrategy,
318
  ) -> CosineSelection | None:
319
+ variants = available_variants(store, mask_strategy)
320
  if len(variants) < 2:
321
  st.info("Need at least two variants with saved vectors for cosine comparison.")
322
  return None
 
329
  options=variants,
330
  index=0,
331
  format_func=prompt_variant_label,
332
+ key=widget_key("load", "variant_a", store_id(store)),
333
  )
334
  with col2:
335
  variant_b = st.selectbox(
 
337
  options=variants,
338
  index=min(1, len(variants) - 1),
339
  format_func=prompt_variant_label,
340
+ key=widget_key("load", "variant_b", store_id(store)),
341
  )
342
 
343
  if variant_a == variant_b:
344
  st.warning("Choose two different variants to compare.")
345
  return None
346
 
347
+ persona_ids = _select_artifact_personas(
348
  store,
349
  [variant_a, variant_b],
350
  mask_strategy,
351
+ widget_scope=f"cosine:{store_id(store)}",
352
  remember_key=_LAST_COSINE_PERSONAS_KEY,
353
  )
354
  if not persona_ids:
 
443
  cosine_fig_key = widget_key(
444
  "load",
445
  "cosine_fig_state",
446
+ store_id(store),
447
  store.model_name,
448
  mask_strategy.value,
449
  selection.variant_a,
 
472
  key=widget_key(
473
  "load",
474
  "compare_vectors",
475
+ store_id(store),
476
  store.model_name,
477
  mask_strategy.value,
478
  selection.variant_a,
 
507
  mask_strategy: MaskStrategy,
508
  scope: str,
509
  ) -> tuple[str, list[str], str, list[int]] | None:
510
+ variants = available_variants(store, mask_strategy)
511
  if not variants:
512
  st.info("No variants with saved vectors for this model.")
513
  return None
 
516
  options=variants,
517
  index=variants.index("biography") if "biography" in variants else 0,
518
  format_func=prompt_variant_label,
519
+ key=widget_key("load", "variant", scope, store_id(store)),
520
  )
521
+ persona_ids = _select_artifact_personas(
522
  store,
523
  [variant],
524
  mask_strategy,
525
+ widget_scope=f"{scope}:{store_id(store)}",
526
  remember_key=_LAST_PROJECTION_PERSONAS_KEY,
527
  default_all=True,
528
  )
 
535
  st.info("No shared layers are available for the selected personas.")
536
  return None
537
 
538
+ st.caption(f"Using all {len(layer_options)} available layer(s).")
539
+ return variant, persona_ids, persona_key, layer_options
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
 
541
 
542
  def _render_layered_figure_analysis(
 
563
  fig_key = widget_key(
564
  "load",
565
  f"{scope}_fig_state",
566
+ store_id(store),
567
  store.model_name,
568
  mask_strategy.value,
569
  figure_kind,
 
572
  "persona_vector",
573
  persona_key,
574
  )
575
+ filename = scope
576
 
577
  if st.button(button_label, type="primary"):
578
  try:
 
640
  return source
641
 
642
 
643
+ def _render_hub_model_select(
644
+ repo_id: str,
645
+ mask_strategy: MaskStrategy,
646
+ ) -> str:
647
+ fallback_model = st.session_state.get(
648
+ "compare:hub_model_fallback",
649
+ DEFAULT_COMPARE_MODEL,
650
+ )
651
+ try:
652
+ models_by_strategy = hub_models_by_mask_strategy(repo_id)
653
+ except Exception as exc:
654
+ st.warning(f"Could not load Hub configs for `{repo_id}`: {exc}")
655
+ return st.text_input(
656
+ "Hub model",
657
+ value=fallback_model,
658
+ key="compare:hub_model_fallback",
659
+ help="Compare-only model id to use if Hub config discovery is unavailable.",
660
+ )
661
+
662
+ model_options = models_by_strategy.get(mask_strategy, [])
663
+ if not model_options:
664
+ st.warning(
665
+ f"No Hub vector configs found for `{mask_strategy.value}` in `{repo_id}`."
666
+ )
667
+ return st.text_input(
668
+ "Hub model",
669
+ value=fallback_model,
670
+ key="compare:hub_model_fallback",
671
+ help="Compare-only model id to use for this Hub repo.",
672
+ )
673
+
674
+ previous_model = st.session_state.get(
675
+ widget_key("load", "hub_model", repo_id, mask_strategy.value),
676
+ fallback_model,
677
+ )
678
+ default_model = (
679
+ previous_model if previous_model in model_options else model_options[0]
680
+ )
681
+
682
+ return st.selectbox(
683
+ "Hub model",
684
+ options=model_options,
685
+ index=model_options.index(default_model),
686
+ key=widget_key("load", "hub_model", repo_id, mask_strategy.value),
687
+ help="Models with vectors in the selected Hub repo and mask strategy.",
688
+ )
689
+
690
+
691
+ def _render_local_model_select(
692
+ artifacts_root: str,
693
+ mask_strategy: MaskStrategy,
694
+ ) -> str:
695
+ fallback_model = st.session_state.get("compare:local_model", DEFAULT_COMPARE_MODEL)
696
+ model_options = local_model_options_cached(artifacts_root, mask_strategy.value)
697
+ if not model_options:
698
+ return st.text_input(
699
+ "Local model",
700
+ value=fallback_model,
701
+ key="compare:local_model",
702
+ help="Compare-only local model id or path.",
703
+ )
704
+
705
+ custom = st.toggle(
706
+ "Custom local model",
707
+ value=False,
708
+ key="compare:local_model_custom_enabled",
709
+ help="Enter a model id/path manually instead of choosing from activation directories.",
710
+ )
711
+ if custom:
712
+ return st.text_input(
713
+ "Local model",
714
+ value=fallback_model,
715
+ key="compare:local_model",
716
+ help="Compare-only local model id or path.",
717
+ )
718
+
719
+ previous_model = st.session_state.get("compare:local_model_select", fallback_model)
720
+ if not any(local_model_matches(previous_model, option) for option in model_options):
721
+ previous_model = fallback_model
722
+ default_model = next(
723
+ (
724
+ option
725
+ for option in model_options
726
+ if local_model_matches(option, previous_model)
727
+ ),
728
+ model_options[0],
729
+ )
730
+ selected = st.selectbox(
731
+ "Local model",
732
+ options=model_options,
733
+ index=model_options.index(default_model),
734
+ key="compare:local_model_select",
735
+ help="Models discovered under the selected artifacts root.",
736
+ )
737
+ st.session_state["compare:local_model"] = selected
738
+ return selected
739
+
740
+
741
+ def _build_store(source: str, mask_strategy: MaskStrategy) -> Store:
742
  if source == SOURCE_HUB:
743
  repo = st.text_input(
744
  "Hub repo",
 
746
  key="compare:hub_repo",
747
  help="Hugging Face dataset published by `scripts/push_to_hf.py`.",
748
  )
749
+ hub_model_name = _render_hub_model_select(repo, mask_strategy)
750
+ return activation_store_cached(
751
+ SOURCE_HUB,
752
+ repo,
753
+ hub_model_name,
754
+ mask_strategy.value,
755
+ )
756
  artifacts_root = st.text_input(
757
  "Artifacts root",
758
  value=str(get_artifacts_dir() / "activations"),
759
  key="compare:artifacts_root",
760
  )
761
+ artifacts_root = str(Path(artifacts_root).expanduser())
762
+ local_model_name = _render_local_model_select(artifacts_root, mask_strategy)
763
+ return activation_store_cached(
764
+ SOURCE_LOCAL,
765
+ artifacts_root,
766
+ local_model_name,
767
+ mask_strategy.value,
768
+ )
769
 
770
 
771
+ def render_compare_tab() -> None:
772
  """Render the compare tab."""
773
 
774
  st.title("Compare")
 
787
  analysis_mode = ANALYSIS_MODES[0]
788
  st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
789
 
790
+ with st.expander("Source settings", expanded=True):
791
  mask_strategy = _render_mask_strategy_select(analysis_mode)
792
+ store = _build_store(source, mask_strategy)
793
 
794
  if analysis_mode == "Cosine similarity":
795
  _render_cosine_similarity(store, mask_strategy)
tabs/extract.py CHANGED
@@ -13,6 +13,7 @@ from persona_vectors.extraction import (
13
  from persona_vectors.preview import TokenSegment, preview_token_segments
14
 
15
  from utils.datasets import load_dataset, load_persona_list
 
16
  from utils.helpers import (
17
  NDIF_STATUS_ICONS,
18
  persona_label,
@@ -211,28 +212,11 @@ def _render_mask_strategy_select(
211
  remote: bool,
212
  dataset_source: str,
213
  ) -> MaskStrategy:
214
- last_strategy = st.session_state.get(
215
- _LAST_MASK_STRATEGY_KEY,
216
- MaskStrategy.ANSWER_MEAN.value,
217
- )
218
- strategy_options = list(MaskStrategy)
219
- mask_strategy = st.selectbox(
220
- "Mask strategy",
221
- options=strategy_options,
222
- index=next(
223
- (
224
- idx
225
- for idx, strategy in enumerate(strategy_options)
226
- if strategy.value == last_strategy
227
- ),
228
- 0,
229
- ),
230
- format_func=lambda s: s.value.replace("_", " ").title(),
231
  key=_extract_widget_key(model_name, remote, dataset_source, "mask_strategy"),
 
232
  help="Which tokens contribute to the averaged hidden state.",
233
  )
234
- st.session_state[_LAST_MASK_STRATEGY_KEY] = mask_strategy.value
235
- return mask_strategy
236
 
237
 
238
  def _collect_runs(
 
13
  from persona_vectors.preview import TokenSegment, preview_token_segments
14
 
15
  from utils.datasets import load_dataset, load_persona_list
16
+ from utils.controls import render_mask_strategy_select
17
  from utils.helpers import (
18
  NDIF_STATUS_ICONS,
19
  persona_label,
 
212
  remote: bool,
213
  dataset_source: str,
214
  ) -> MaskStrategy:
215
+ return render_mask_strategy_select(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  key=_extract_widget_key(model_name, remote, dataset_source, "mask_strategy"),
217
+ last_key=_LAST_MASK_STRATEGY_KEY,
218
  help="Which tokens contribute to the averaged hidden state.",
219
  )
 
 
220
 
221
 
222
  def _collect_runs(
utils/compare_sources.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import streamlit as st
5
+ from persona_vectors.artifacts import ActivationStore, HFActivationStore
6
+ from persona_vectors.artifacts import list_layers as list_local_layers
7
+ from persona_vectors.artifacts import model_dir_name
8
+ from persona_vectors.extraction import MaskStrategy
9
+
10
+ Store = ActivationStore | HFActivationStore
11
+
12
+ DEFAULT_HUB_REPO = os.environ.get(
13
+ "PERSONA_VECTORS_HUB_REPO",
14
+ "implicit-personalization/synth-persona-vectors",
15
+ )
16
+ DEFAULT_COMPARE_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
17
+ SOURCE_HUB = "Hugging Face Hub"
18
+ SOURCE_LOCAL = "Local activations"
19
+ SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
20
+
21
+ list_layers_cached = st.cache_data(show_spinner=False)(list_local_layers)
22
+
23
+
24
+ @st.cache_resource(show_spinner=False)
25
+ def activation_store_cached(
26
+ source: str,
27
+ location: str,
28
+ model_name: str,
29
+ mask_strategy_value: str,
30
+ ) -> Store:
31
+ mask_strategy = MaskStrategy(mask_strategy_value)
32
+ if source == SOURCE_HUB:
33
+ return HFActivationStore(location, model_name, mask_strategy=mask_strategy)
34
+ return ActivationStore(model_name, location, mask_strategy=mask_strategy)
35
+
36
+
37
+ @st.cache_data(show_spinner=False, ttl=10)
38
+ def available_variants_cached(
39
+ source: str,
40
+ location: str,
41
+ model_name: str,
42
+ mask_strategy_value: str,
43
+ ) -> list[str]:
44
+ store = activation_store_cached(source, location, model_name, mask_strategy_value)
45
+ return store.available_variants()
46
+
47
+
48
+ @st.cache_data(show_spinner=False, ttl=10)
49
+ def personas_cached(
50
+ source: str,
51
+ location: str,
52
+ model_name: str,
53
+ mask_strategy_value: str,
54
+ variants: tuple[str, ...],
55
+ ) -> list[str]:
56
+ store = activation_store_cached(source, location, model_name, mask_strategy_value)
57
+ return store.list_personas(
58
+ list(variants),
59
+ mask_strategy=MaskStrategy(mask_strategy_value),
60
+ )
61
+
62
+
63
+ @st.cache_data(show_spinner=False, ttl=10)
64
+ def persona_names_cached(
65
+ source: str,
66
+ location: str,
67
+ model_name: str,
68
+ mask_strategy_value: str,
69
+ variants: tuple[str, ...],
70
+ persona_ids: tuple[str, ...],
71
+ ) -> dict[str, str]:
72
+ store = activation_store_cached(source, location, model_name, mask_strategy_value)
73
+ return store.persona_names(
74
+ list(persona_ids),
75
+ variants=list(variants),
76
+ mask_strategy=MaskStrategy(mask_strategy_value),
77
+ )
78
+
79
+
80
+ @st.cache_data(show_spinner=False, ttl=10)
81
+ def local_model_options_cached(
82
+ artifacts_root: str, mask_strategy_value: str
83
+ ) -> list[str]:
84
+ root = Path(artifacts_root).expanduser()
85
+ if not root.exists() or not root.is_dir():
86
+ return []
87
+
88
+ options = []
89
+ try:
90
+ model_roots = sorted(path for path in root.iterdir() if path.is_dir())
91
+ except OSError:
92
+ return []
93
+
94
+ for model_root in model_roots:
95
+ strategy_root = model_root / mask_strategy_value
96
+ if not strategy_root.is_dir():
97
+ continue
98
+ variant_roots = (
99
+ variant_root
100
+ for variant_root in strategy_root.iterdir()
101
+ if variant_root.is_dir()
102
+ )
103
+ if any(
104
+ (variant_root / "manifest.json").exists() for variant_root in variant_roots
105
+ ):
106
+ options.append(model_root.name.replace("__", "/"))
107
+ return options
108
+
109
+
110
+ @st.cache_data(show_spinner=False)
111
+ def hub_config_names_cached(repo_id: str) -> list[str]:
112
+ try:
113
+ from huggingface_hub import get_dataset_config_names
114
+ except ImportError:
115
+ from datasets import get_dataset_config_names
116
+
117
+ return sorted(get_dataset_config_names(repo_id))
118
+
119
+
120
+ @st.cache_data(show_spinner=False)
121
+ def hub_layers_cached(
122
+ repo_id: str,
123
+ model_name: str,
124
+ mask_strategy_value: str,
125
+ variant: str,
126
+ persona_id: str,
127
+ ) -> list[int]:
128
+ store = HFActivationStore(
129
+ repo_id,
130
+ model_name,
131
+ mask_strategy=MaskStrategy(mask_strategy_value),
132
+ )
133
+ sample = store.load(variant, persona_id)
134
+ return list(range(int(sample.shape[0])))
135
+
136
+
137
+ def parse_hub_config_name(config_name: str) -> tuple[str, MaskStrategy] | None:
138
+ for strategy in MaskStrategy:
139
+ suffix = f"__{strategy.value}"
140
+ if config_name.endswith(suffix):
141
+ model_key = config_name[: -len(suffix)]
142
+ return model_key.replace("__", "/"), strategy
143
+ return None
144
+
145
+
146
+ def hub_models_by_mask_strategy(repo_id: str) -> dict[MaskStrategy, list[str]]:
147
+ models_by_strategy: dict[MaskStrategy, set[str]] = {
148
+ strategy: set() for strategy in MaskStrategy
149
+ }
150
+ for config_name in hub_config_names_cached(repo_id):
151
+ parsed = parse_hub_config_name(config_name)
152
+ if parsed is None:
153
+ continue
154
+ model_name, strategy = parsed
155
+ models_by_strategy[strategy].add(model_name)
156
+ return {
157
+ strategy: sorted(models)
158
+ for strategy, models in models_by_strategy.items()
159
+ if models
160
+ }
161
+
162
+
163
+ def store_cache_parts(store: Store) -> tuple[str, str, str]:
164
+ if isinstance(store, HFActivationStore):
165
+ return SOURCE_HUB, store.repo_id, store.model_name
166
+ return SOURCE_LOCAL, str(store.root_dir), store.model_name
167
+
168
+
169
+ def store_id(store: Store) -> str:
170
+ if isinstance(store, HFActivationStore):
171
+ return f"hub:{store.repo_id}"
172
+ return f"local:{store.root_dir}"
173
+
174
+
175
+ def available_variants(store: Store, mask_strategy: MaskStrategy) -> list[str]:
176
+ source, location, model_name = store_cache_parts(store)
177
+ return available_variants_cached(
178
+ source,
179
+ location,
180
+ model_name,
181
+ mask_strategy.value,
182
+ )
183
+
184
+
185
+ def local_model_matches(left: str, right: str) -> bool:
186
+ return model_dir_name(left) == model_dir_name(right)
utils/controls.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from persona_vectors.extraction import MaskStrategy
3
+
4
+
5
+ def render_mask_strategy_select(
6
+ *,
7
+ key: str,
8
+ last_key: str,
9
+ help_text: str,
10
+ ) -> MaskStrategy:
11
+ last_strategy = st.session_state.get(last_key, MaskStrategy.ANSWER_MEAN.value)
12
+ strategies = list(MaskStrategy)
13
+ selected = st.selectbox(
14
+ "Mask strategy",
15
+ options=strategies,
16
+ index=next(
17
+ (
18
+ idx
19
+ for idx, strategy in enumerate(strategies)
20
+ if strategy.value == last_strategy
21
+ ),
22
+ 0,
23
+ ),
24
+ format_func=lambda strategy: strategy.value.replace("_", " ").title(),
25
+ key=key,
26
+ help=help_text,
27
+ )
28
+ st.session_state[last_key] = selected.value
29
+ return selected
uv.lock CHANGED
@@ -376,7 +376,7 @@ name = "cuda-bindings"
376
  version = "13.2.0"
377
  source = { registry = "https://pypi.org/simple" }
378
  dependencies = [
379
- { name = "cuda-pathfinder" },
380
  ]
381
  wheels = [
382
  { url = "https://files.pythonhosted.org/packages/52/c8/b2589d68acf7e3d63e2be330b84bc25712e97ed799affbca7edd7eae25d6/cuda_bindings-13.2.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e865447abfb83d6a98ad5130ed3c70b1fc295ae3eeee39fd07b4ddb0671b6788", size = 5722404, upload-time = "2026-03-11T00:12:44.041Z" },
@@ -407,37 +407,37 @@ wheels = [
407
 
408
  [package.optional-dependencies]
409
  cublas = [
410
- { name = "nvidia-cublas", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
411
  ]
412
  cudart = [
413
- { name = "nvidia-cuda-runtime", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
414
  ]
415
  cufft = [
416
- { name = "nvidia-cufft", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
417
  ]
418
  cufile = [
419
  { name = "nvidia-cufile", marker = "sys_platform == 'linux'" },
420
  ]
421
  cupti = [
422
- { name = "nvidia-cuda-cupti", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
423
  ]
424
  curand = [
425
- { name = "nvidia-curand", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
426
  ]
427
  cusolver = [
428
- { name = "nvidia-cusolver", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
429
  ]
430
  cusparse = [
431
- { name = "nvidia-cusparse", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
432
  ]
433
  nvjitlink = [
434
- { name = "nvidia-nvjitlink", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
435
  ]
436
  nvrtc = [
437
- { name = "nvidia-cuda-nvrtc", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
438
  ]
439
  nvtx = [
440
- { name = "nvidia-nvtx", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
441
  ]
442
 
443
  [[package]]
@@ -1326,7 +1326,7 @@ name = "nvidia-cudnn-cu13"
1326
  version = "9.19.0.56"
1327
  source = { registry = "https://pypi.org/simple" }
1328
  dependencies = [
1329
- { name = "nvidia-cublas" },
1330
  ]
1331
  wheels = [
1332
  { url = "https://files.pythonhosted.org/packages/f1/84/26025437c1e6b61a707442184fa0c03d083b661adf3a3eecfd6d21677740/nvidia_cudnn_cu13-9.19.0.56-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:6ed29ffaee1176c612daf442e4dd6cfeb6a0caa43ddcbeb59da94953030b1be4", size = 433781201, upload-time = "2026-02-03T20:40:53.805Z" },
@@ -1338,7 +1338,7 @@ name = "nvidia-cufft"
1338
  version = "12.0.0.61"
1339
  source = { registry = "https://pypi.org/simple" }
1340
  dependencies = [
1341
- { name = "nvidia-nvjitlink" },
1342
  ]
1343
  wheels = [
1344
  { url = "https://files.pythonhosted.org/packages/8b/ae/f417a75c0259e85c1d2f83ca4e960289a5f814ed0cea74d18c353d3e989d/nvidia_cufft-12.0.0.61-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5", size = 214053554, upload-time = "2025-09-04T08:31:38.196Z" },
@@ -1368,9 +1368,9 @@ name = "nvidia-cusolver"
1368
  version = "12.0.4.66"
1369
  source = { registry = "https://pypi.org/simple" }
1370
  dependencies = [
1371
- { name = "nvidia-cublas" },
1372
- { name = "nvidia-cusparse" },
1373
- { name = "nvidia-nvjitlink" },
1374
  ]
1375
  wheels = [
1376
  { url = "https://files.pythonhosted.org/packages/c8/c3/b30c9e935fc01e3da443ec0116ed1b2a009bb867f5324d3f2d7e533e776b/nvidia_cusolver-12.0.4.66-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2", size = 223467760, upload-time = "2025-09-04T08:33:04.222Z" },
@@ -1382,7 +1382,7 @@ name = "nvidia-cusparse"
1382
  version = "12.6.3.3"
1383
  source = { registry = "https://pypi.org/simple" }
1384
  dependencies = [
1385
- { name = "nvidia-nvjitlink" },
1386
  ]
1387
  wheels = [
1388
  { url = "https://files.pythonhosted.org/packages/f8/94/5c26f33738ae35276672f12615a64bd008ed5be6d1ebcb23579285d960a9/nvidia_cusparse-12.6.3.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c", size = 162155568, upload-time = "2025-09-04T08:33:42.864Z" },
@@ -1575,10 +1575,12 @@ wheels = [
1575
 
1576
  [[package]]
1577
  name = "persona-ui"
1578
- version = "0.3.0"
1579
  source = { virtual = "." }
1580
  dependencies = [
1581
  { name = "catppuccin" },
 
 
1582
  { name = "persona-data" },
1583
  { name = "persona-vectors" },
1584
  { name = "plotly" },
@@ -1589,6 +1591,8 @@ dependencies = [
1589
  [package.metadata]
1590
  requires-dist = [
1591
  { name = "catppuccin", specifier = ">=2.5.0" },
 
 
1592
  { name = "persona-data", specifier = ">=0.4.2" },
1593
  { name = "persona-vectors", specifier = ">=0.6.4" },
1594
  { name = "plotly", specifier = ">=6.6.0" },
 
376
  version = "13.2.0"
377
  source = { registry = "https://pypi.org/simple" }
378
  dependencies = [
379
+ { name = "cuda-pathfinder", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" },
380
  ]
381
  wheels = [
382
  { url = "https://files.pythonhosted.org/packages/52/c8/b2589d68acf7e3d63e2be330b84bc25712e97ed799affbca7edd7eae25d6/cuda_bindings-13.2.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e865447abfb83d6a98ad5130ed3c70b1fc295ae3eeee39fd07b4ddb0671b6788", size = 5722404, upload-time = "2026-03-11T00:12:44.041Z" },
 
407
 
408
  [package.optional-dependencies]
409
  cublas = [
410
+ { name = "nvidia-cublas", marker = "sys_platform == 'linux'" },
411
  ]
412
  cudart = [
413
+ { name = "nvidia-cuda-runtime", marker = "sys_platform == 'linux'" },
414
  ]
415
  cufft = [
416
+ { name = "nvidia-cufft", marker = "sys_platform == 'linux'" },
417
  ]
418
  cufile = [
419
  { name = "nvidia-cufile", marker = "sys_platform == 'linux'" },
420
  ]
421
  cupti = [
422
+ { name = "nvidia-cuda-cupti", marker = "sys_platform == 'linux'" },
423
  ]
424
  curand = [
425
+ { name = "nvidia-curand", marker = "sys_platform == 'linux'" },
426
  ]
427
  cusolver = [
428
+ { name = "nvidia-cusolver", marker = "sys_platform == 'linux'" },
429
  ]
430
  cusparse = [
431
+ { name = "nvidia-cusparse", marker = "sys_platform == 'linux'" },
432
  ]
433
  nvjitlink = [
434
+ { name = "nvidia-nvjitlink", marker = "sys_platform == 'linux'" },
435
  ]
436
  nvrtc = [
437
+ { name = "nvidia-cuda-nvrtc", marker = "sys_platform == 'linux'" },
438
  ]
439
  nvtx = [
440
+ { name = "nvidia-nvtx", marker = "sys_platform == 'linux'" },
441
  ]
442
 
443
  [[package]]
 
1326
  version = "9.19.0.56"
1327
  source = { registry = "https://pypi.org/simple" }
1328
  dependencies = [
1329
+ { name = "nvidia-cublas", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" },
1330
  ]
1331
  wheels = [
1332
  { url = "https://files.pythonhosted.org/packages/f1/84/26025437c1e6b61a707442184fa0c03d083b661adf3a3eecfd6d21677740/nvidia_cudnn_cu13-9.19.0.56-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:6ed29ffaee1176c612daf442e4dd6cfeb6a0caa43ddcbeb59da94953030b1be4", size = 433781201, upload-time = "2026-02-03T20:40:53.805Z" },
 
1338
  version = "12.0.0.61"
1339
  source = { registry = "https://pypi.org/simple" }
1340
  dependencies = [
1341
+ { name = "nvidia-nvjitlink", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" },
1342
  ]
1343
  wheels = [
1344
  { url = "https://files.pythonhosted.org/packages/8b/ae/f417a75c0259e85c1d2f83ca4e960289a5f814ed0cea74d18c353d3e989d/nvidia_cufft-12.0.0.61-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5", size = 214053554, upload-time = "2025-09-04T08:31:38.196Z" },
 
1368
  version = "12.0.4.66"
1369
  source = { registry = "https://pypi.org/simple" }
1370
  dependencies = [
1371
+ { name = "nvidia-cublas", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" },
1372
+ { name = "nvidia-cusparse", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" },
1373
+ { name = "nvidia-nvjitlink", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" },
1374
  ]
1375
  wheels = [
1376
  { url = "https://files.pythonhosted.org/packages/c8/c3/b30c9e935fc01e3da443ec0116ed1b2a009bb867f5324d3f2d7e533e776b/nvidia_cusolver-12.0.4.66-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2", size = 223467760, upload-time = "2025-09-04T08:33:04.222Z" },
 
1382
  version = "12.6.3.3"
1383
  source = { registry = "https://pypi.org/simple" }
1384
  dependencies = [
1385
+ { name = "nvidia-nvjitlink", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" },
1386
  ]
1387
  wheels = [
1388
  { url = "https://files.pythonhosted.org/packages/f8/94/5c26f33738ae35276672f12615a64bd008ed5be6d1ebcb23579285d960a9/nvidia_cusparse-12.6.3.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c", size = 162155568, upload-time = "2025-09-04T08:33:42.864Z" },
 
1575
 
1576
  [[package]]
1577
  name = "persona-ui"
1578
+ version = "0.4.0"
1579
  source = { virtual = "." }
1580
  dependencies = [
1581
  { name = "catppuccin" },
1582
+ { name = "datasets" },
1583
+ { name = "huggingface-hub" },
1584
  { name = "persona-data" },
1585
  { name = "persona-vectors" },
1586
  { name = "plotly" },
 
1591
  [package.metadata]
1592
  requires-dist = [
1593
  { name = "catppuccin", specifier = ">=2.5.0" },
1594
+ { name = "datasets", specifier = ">=4.8.5" },
1595
+ { name = "huggingface-hub", specifier = ">=1.14.0" },
1596
  { name = "persona-data", specifier = ">=0.4.2" },
1597
  { name = "persona-vectors", specifier = ">=0.6.4" },
1598
  { name = "plotly", specifier = ">=6.6.0" },