Jac-Zac commited on
Commit
610d258
·
1 Parent(s): 39e6e65

Updated versions

Browse files
Files changed (3) hide show
  1. pyproject.toml +2 -2
  2. tabs/compare.py +22 -41
  3. uv.lock +5 -5
pyproject.toml CHANGED
@@ -1,11 +1,11 @@
1
  [project]
2
  name = "persona-ui"
3
- version = "0.2.0"
4
  description = "Streamlit UI for persona-vectors"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
- "persona-vectors>=0.4.1",
9
  "persona-data>=0.2.5",
10
  "streamlit>=1.44.0",
11
  "plotly>=6.6.0",
 
1
  [project]
2
  name = "persona-ui"
3
+ version = "0.2.1"
4
  description = "Streamlit UI for persona-vectors"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
+ "persona-vectors>=0.4.2",
9
  "persona-data>=0.2.5",
10
  "streamlit>=1.44.0",
11
  "plotly>=6.6.0",
tabs/compare.py CHANGED
@@ -7,10 +7,8 @@ from persona_vectors.analysis import (
7
  load_persona_mean_samples,
8
  load_variant_mean_samples,
9
  )
10
- from persona_vectors.artifacts import PERSONA_VARIANTS, ActivationStore
11
  from persona_vectors.artifacts import list_layers as list_available_layers
12
- from persona_vectors.artifacts import list_personas as list_available_personas
13
- from persona_vectors.artifacts import load_persona_names
14
  from persona_vectors.extraction import MaskStrategy
15
  from persona_vectors.plots import (
16
  build_layered_figure,
@@ -52,18 +50,10 @@ def _select_artifact_personas(
52
  remember_key: str,
53
  default_all: bool = False,
54
  ) -> tuple[list[str], dict[str, str]]:
55
- persona_options = list_available_personas(
56
- store.root_dir,
57
- store.model_name,
58
- variants,
59
- mask_strategy=mask_strategy,
60
- )
61
- persona_names = load_persona_names(
62
- store.root_dir,
63
- store.model_name,
64
- variants,
65
  persona_options,
66
- mask_strategy=mask_strategy,
67
  )
68
  if not persona_options:
69
  if len(variants) > 1:
@@ -158,7 +148,8 @@ def _render_cosine_similarity(
158
  store: ActivationStore,
159
  mask_strategy: MaskStrategy,
160
  ) -> None:
161
- if len(PERSONA_VARIANTS) < 2:
 
162
  st.info("Need at least two non-baseline variants for cosine comparison.")
163
  return
164
 
@@ -166,7 +157,7 @@ def _render_cosine_similarity(
166
  with col1:
167
  variant_a = st.selectbox(
168
  "Variant A",
169
- options=PERSONA_VARIANTS,
170
  index=0,
171
  format_func=prompt_variant_label,
172
  key=widget_key("load", "variant_a"),
@@ -174,8 +165,8 @@ def _render_cosine_similarity(
174
  with col2:
175
  variant_b = st.selectbox(
176
  "Variant B",
177
- options=PERSONA_VARIANTS,
178
- index=min(1, len(PERSONA_VARIANTS) - 1),
179
  format_func=prompt_variant_label,
180
  key=widget_key("load", "variant_b"),
181
  )
@@ -215,16 +206,14 @@ def _render_cosine_similarity(
215
  "cosine_pairs",
216
  store.model_name,
217
  mask_strategy.value,
218
- "_".join(PERSONA_VARIANTS),
219
  )
220
 
221
  if st.button("Compare vectors", type="primary"):
222
  try:
223
  variant_samples = load_variant_mean_samples(
224
- store.root_dir,
225
- store.model_name,
226
  [variant_a, variant_b],
227
- mask_strategy=mask_strategy,
228
  persona_ids=persona_ids,
229
  )
230
  except Exception as exc:
@@ -249,16 +238,14 @@ def _render_cosine_similarity(
249
 
250
  pair_traces = []
251
  pair_errors = []
252
- for left, right in combinations(PERSONA_VARIANTS, 2):
253
  try:
254
  pair_samples = (
255
  variant_samples
256
  if {left, right} == {variant_a, variant_b}
257
  else load_variant_mean_samples(
258
- store.root_dir,
259
- store.model_name,
260
  [left, right],
261
- mask_strategy=mask_strategy,
262
  persona_ids=persona_ids,
263
  )
264
  )
@@ -313,12 +300,13 @@ def _select_single_variant_samples(
313
  mask_strategy: MaskStrategy,
314
  scope: str,
315
  ) -> tuple[str, list[str], str, list[int]] | None:
 
316
  variant = st.selectbox(
317
  "Variant",
318
- options=PERSONA_VARIANTS,
319
  index=(
320
- PERSONA_VARIANTS.index("biography")
321
- if "biography" in PERSONA_VARIANTS
322
  else 0
323
  ),
324
  format_func=prompt_variant_label,
@@ -370,13 +358,9 @@ def _select_single_variant_samples(
370
 
371
  def _baseline_available(
372
  store: ActivationStore,
373
- mask_strategy: MaskStrategy,
374
  ) -> bool:
375
- return BASELINE_PERSONA_ID in list_available_personas(
376
- store.root_dir,
377
- store.model_name,
378
  [BASELINE_PERSONA_ID],
379
- mask_strategy=mask_strategy,
380
  warn_missing=False,
381
  )
382
 
@@ -386,7 +370,7 @@ def _render_baseline_reference_toggle(
386
  mask_strategy: MaskStrategy,
387
  scope: str,
388
  ) -> bool:
389
- available = _baseline_available(store, mask_strategy)
390
  return st.checkbox(
391
  "Include Assistant baseline reference",
392
  value=available,
@@ -442,8 +426,7 @@ def _render_similarity_matrix(
442
  if st.button("Generate similarity matrix", type="primary"):
443
  try:
444
  samples = load_persona_mean_samples(
445
- store.root_dir,
446
- store.model_name,
447
  variant,
448
  mask_strategy=mask_strategy,
449
  persona_ids=persona_ids,
@@ -534,8 +517,7 @@ def _render_embedding_analysis(
534
  if st.button(f"Generate {analysis_mode} projection", type="primary"):
535
  try:
536
  samples = load_persona_mean_samples(
537
- store.root_dir,
538
- store.model_name,
539
  variant,
540
  mask_strategy=mask_strategy,
541
  persona_ids=persona_ids,
@@ -575,8 +557,6 @@ def render_compare_tab(model_name: str) -> None:
575
  value=str(get_artifacts_dir() / "activations"),
576
  )
577
 
578
- store = ActivationStore(model_name, artifacts_root)
579
-
580
  analysis_mode = st.segmented_control(
581
  "Analysis mode",
582
  options=ANALYSIS_MODES,
@@ -588,6 +568,7 @@ def render_compare_tab(model_name: str) -> None:
588
  analysis_mode = ANALYSIS_MODES[0]
589
  st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
590
  mask_strategy = _render_mask_strategy_select(analysis_mode)
 
591
 
592
  if analysis_mode == "Cosine similarity":
593
  _render_cosine_similarity(store, mask_strategy)
 
7
  load_persona_mean_samples,
8
  load_variant_mean_samples,
9
  )
10
+ from persona_vectors.artifacts import ActivationStore
11
  from persona_vectors.artifacts import list_layers as list_available_layers
 
 
12
  from persona_vectors.extraction import MaskStrategy
13
  from persona_vectors.plots import (
14
  build_layered_figure,
 
50
  remember_key: str,
51
  default_all: bool = False,
52
  ) -> tuple[list[str], dict[str, str]]:
53
+ persona_options = store.list_personas(variants)
54
+ persona_names = store.persona_names(
 
 
 
 
 
 
 
 
55
  persona_options,
56
+ variants=variants,
57
  )
58
  if not persona_options:
59
  if len(variants) > 1:
 
148
  store: ActivationStore,
149
  mask_strategy: MaskStrategy,
150
  ) -> None:
151
+ variants = list(store.variants)
152
+ if len(variants) < 2:
153
  st.info("Need at least two non-baseline variants for cosine comparison.")
154
  return
155
 
 
157
  with col1:
158
  variant_a = st.selectbox(
159
  "Variant A",
160
+ options=variants,
161
  index=0,
162
  format_func=prompt_variant_label,
163
  key=widget_key("load", "variant_a"),
 
165
  with col2:
166
  variant_b = st.selectbox(
167
  "Variant B",
168
+ options=variants,
169
+ index=min(1, len(variants) - 1),
170
  format_func=prompt_variant_label,
171
  key=widget_key("load", "variant_b"),
172
  )
 
206
  "cosine_pairs",
207
  store.model_name,
208
  mask_strategy.value,
209
+ "_".join(variants),
210
  )
211
 
212
  if st.button("Compare vectors", type="primary"):
213
  try:
214
  variant_samples = load_variant_mean_samples(
215
+ store,
 
216
  [variant_a, variant_b],
 
217
  persona_ids=persona_ids,
218
  )
219
  except Exception as exc:
 
238
 
239
  pair_traces = []
240
  pair_errors = []
241
+ for left, right in combinations(variants, 2):
242
  try:
243
  pair_samples = (
244
  variant_samples
245
  if {left, right} == {variant_a, variant_b}
246
  else load_variant_mean_samples(
247
+ store,
 
248
  [left, right],
 
249
  persona_ids=persona_ids,
250
  )
251
  )
 
300
  mask_strategy: MaskStrategy,
301
  scope: str,
302
  ) -> tuple[str, list[str], str, list[int]] | None:
303
+ variants = list(store.variants)
304
  variant = st.selectbox(
305
  "Variant",
306
+ options=variants,
307
  index=(
308
+ variants.index("biography")
309
+ if "biography" in variants
310
  else 0
311
  ),
312
  format_func=prompt_variant_label,
 
358
 
359
  def _baseline_available(
360
  store: ActivationStore,
 
361
  ) -> bool:
362
+ return BASELINE_PERSONA_ID in store.list_personas(
 
 
363
  [BASELINE_PERSONA_ID],
 
364
  warn_missing=False,
365
  )
366
 
 
370
  mask_strategy: MaskStrategy,
371
  scope: str,
372
  ) -> bool:
373
+ available = _baseline_available(store)
374
  return st.checkbox(
375
  "Include Assistant baseline reference",
376
  value=available,
 
426
  if st.button("Generate similarity matrix", type="primary"):
427
  try:
428
  samples = load_persona_mean_samples(
429
+ store,
 
430
  variant,
431
  mask_strategy=mask_strategy,
432
  persona_ids=persona_ids,
 
517
  if st.button(f"Generate {analysis_mode} projection", type="primary"):
518
  try:
519
  samples = load_persona_mean_samples(
520
+ store,
 
521
  variant,
522
  mask_strategy=mask_strategy,
523
  persona_ids=persona_ids,
 
557
  value=str(get_artifacts_dir() / "activations"),
558
  )
559
 
 
 
560
  analysis_mode = st.segmented_control(
561
  "Analysis mode",
562
  options=ANALYSIS_MODES,
 
568
  analysis_mode = ANALYSIS_MODES[0]
569
  st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
570
  mask_strategy = _render_mask_strategy_select(analysis_mode)
571
+ store = ActivationStore(model_name, artifacts_root, mask_strategy=mask_strategy)
572
 
573
  if analysis_mode == "Cosine similarity":
574
  _render_cosine_similarity(store, mask_strategy)
uv.lock CHANGED
@@ -1225,7 +1225,7 @@ wheels = [
1225
 
1226
  [[package]]
1227
  name = "persona-ui"
1228
- version = "0.2.0"
1229
  source = { virtual = "." }
1230
  dependencies = [
1231
  { name = "persona-data" },
@@ -1239,7 +1239,7 @@ dependencies = [
1239
  [package.metadata]
1240
  requires-dist = [
1241
  { name = "persona-data", specifier = ">=0.2.5" },
1242
- { name = "persona-vectors", specifier = ">=0.4.1" },
1243
  { name = "plotly", specifier = ">=6.6.0" },
1244
  { name = "python-dotenv", specifier = ">=1.2.2" },
1245
  { name = "streamlit", specifier = ">=1.44.0" },
@@ -1248,7 +1248,7 @@ requires-dist = [
1248
 
1249
  [[package]]
1250
  name = "persona-vectors"
1251
- version = "0.4.1"
1252
  source = { registry = "https://pypi.org/simple" }
1253
  dependencies = [
1254
  { name = "kaleido" },
@@ -1265,9 +1265,9 @@ dependencies = [
1265
  { name = "transformers" },
1266
  { name = "umap-learn" },
1267
  ]
1268
- sdist = { url = "https://files.pythonhosted.org/packages/0b/55/4aa2f8dd2b4411e87cadf4dad421e0ffd2c68461e1b4087339c8bd17a857/persona_vectors-0.4.1.tar.gz", hash = "sha256:60a7ff64fd90938da7ee767fd88dac419897a26a24f28b29b5f5b34bccef143a", size = 21116, upload-time = "2026-04-29T10:45:51.396Z" }
1269
  wheels = [
1270
- { url = "https://files.pythonhosted.org/packages/02/e6/4c3824d92a366b42a1c70474c7df03f5f81a985924eb6998e93bd0124937/persona_vectors-0.4.1-py3-none-any.whl", hash = "sha256:6b889f5cf790ad45afed24dad1f2a2d48b3fb18bb23959d31200e9ffafe8f1fa", size = 25009, upload-time = "2026-04-29T10:45:52.445Z" },
1271
  ]
1272
 
1273
  [[package]]
 
1225
 
1226
  [[package]]
1227
  name = "persona-ui"
1228
+ version = "0.2.1"
1229
  source = { virtual = "." }
1230
  dependencies = [
1231
  { name = "persona-data" },
 
1239
  [package.metadata]
1240
  requires-dist = [
1241
  { name = "persona-data", specifier = ">=0.2.5" },
1242
+ { name = "persona-vectors", specifier = ">=0.4.2" },
1243
  { name = "plotly", specifier = ">=6.6.0" },
1244
  { name = "python-dotenv", specifier = ">=1.2.2" },
1245
  { name = "streamlit", specifier = ">=1.44.0" },
 
1248
 
1249
  [[package]]
1250
  name = "persona-vectors"
1251
+ version = "0.4.2"
1252
  source = { registry = "https://pypi.org/simple" }
1253
  dependencies = [
1254
  { name = "kaleido" },
 
1265
  { name = "transformers" },
1266
  { name = "umap-learn" },
1267
  ]
1268
+ sdist = { url = "https://files.pythonhosted.org/packages/01/48/e9ea34cd42213868b0e569bc25045f4098ab0fe483421f4eed3f4204f91d/persona_vectors-0.4.2.tar.gz", hash = "sha256:5c421ce6904cf7c92b8fb9ecf27a95d48a381ea307dae3dd67db54c11e9c0892", size = 21509, upload-time = "2026-04-29T14:24:00.304Z" }
1269
  wheels = [
1270
+ { url = "https://files.pythonhosted.org/packages/6b/2f/be21db62788f9880bba713412ee8ad04d9388d749346437c445d83f5bbfb/persona_vectors-0.4.2-py3-none-any.whl", hash = "sha256:c672b91e7b34d02e8ec967e698bcd3b307305cdce02322e9f8e047f9984e264e", size = 25441, upload-time = "2026-04-29T14:23:59.129Z" },
1271
  ]
1272
 
1273
  [[package]]