Jac-Zac commited on
Commit ·
d39b2dd
1
Parent(s): ecd19ae
Imrpoved for performances and better support
Browse files- pyproject.toml +1 -1
- tabs/compare.py +118 -50
- tabs/extract.py +1 -1
- utils/compare_sources.py +5 -5
- utils/helpers.py +14 -0
- uv.lock +4 -4
pyproject.toml
CHANGED
|
@@ -5,7 +5,7 @@ description = "Streamlit UI for persona-vectors"
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
-
"persona-vectors>=0.7.
|
| 9 |
"persona-data>=0.4.2",
|
| 10 |
"datasets>=4.8.5",
|
| 11 |
"huggingface-hub>=1.14.0",
|
|
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
+
"persona-vectors>=0.7.2",
|
| 9 |
"persona-data>=0.4.2",
|
| 10 |
"datasets>=4.8.5",
|
| 11 |
"huggingface-hub>=1.14.0",
|
tabs/compare.py
CHANGED
|
@@ -11,6 +11,7 @@ from persona_vectors.extraction import MaskStrategy
|
|
| 11 |
from persona_vectors.plots import (
|
| 12 |
build_layered_figure,
|
| 13 |
build_pair_similarity_figure,
|
|
|
|
| 14 |
plot_layer_similarity,
|
| 15 |
plot_persona_dendrogram,
|
| 16 |
save_plot_html,
|
|
@@ -38,6 +39,7 @@ from utils.controls import render_mask_strategy_select
|
|
| 38 |
from utils.helpers import (
|
| 39 |
ANALYSIS_HELP_TEXT,
|
| 40 |
ANALYSIS_MODES,
|
|
|
|
| 41 |
prompt_variant_label,
|
| 42 |
slugify,
|
| 43 |
widget_key,
|
|
@@ -171,9 +173,7 @@ def _seed_persona_memory(
|
|
| 171 |
)
|
| 172 |
remembered_count = int(st.session_state.get(remembered_count_key, default_count))
|
| 173 |
persona_count = min(max(remembered_count, 0), len(options.regular_ids))
|
| 174 |
-
include_assistant = bool(
|
| 175 |
-
st.session_state.get(remembered_assistant_key, options.assistant_id is not None)
|
| 176 |
-
)
|
| 177 |
return persona_count, include_assistant
|
| 178 |
|
| 179 |
|
|
@@ -355,7 +355,7 @@ def _render_cosine_selection(
|
|
| 355 |
variant_a=variant_a,
|
| 356 |
variant_b=variant_b,
|
| 357 |
persona_ids=persona_ids,
|
| 358 |
-
persona_key=
|
| 359 |
)
|
| 360 |
|
| 361 |
|
|
@@ -363,30 +363,31 @@ def _build_cosine_figures(
|
|
| 363 |
store: Store,
|
| 364 |
selection: CosineSelection,
|
| 365 |
) -> tuple[object, object | None, int, int] | None:
|
| 366 |
-
variant_sample_cache = {}
|
| 367 |
|
| 368 |
-
def
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
variant_sample_cache[key] = load_variant_vectors(
|
| 372 |
store,
|
| 373 |
-
[
|
| 374 |
persona_ids=selection.persona_ids,
|
| 375 |
)
|
| 376 |
-
|
|
|
|
| 377 |
|
| 378 |
try:
|
| 379 |
-
|
|
|
|
| 380 |
except Exception as exc:
|
| 381 |
st.error(f"Could not load vectors: {exc}")
|
| 382 |
return None
|
| 383 |
|
| 384 |
-
labels =
|
| 385 |
display_traces = [
|
| 386 |
(
|
| 387 |
label,
|
| 388 |
-
|
| 389 |
-
|
| 390 |
)
|
| 391 |
for index, label in enumerate(labels)
|
| 392 |
]
|
|
@@ -403,12 +404,13 @@ def _build_cosine_figures(
|
|
| 403 |
pair_errors = []
|
| 404 |
for left, right in combinations(selection.variants, 2):
|
| 405 |
try:
|
| 406 |
-
|
|
|
|
| 407 |
pair_traces.append(
|
| 408 |
(
|
| 409 |
f"{prompt_variant_label(left)} vs {prompt_variant_label(right)}",
|
| 410 |
-
|
| 411 |
-
|
| 412 |
)
|
| 413 |
)
|
| 414 |
except Exception as exc:
|
|
@@ -477,11 +479,18 @@ def _render_cosine_similarity(
|
|
| 477 |
selection.persona_key,
|
| 478 |
),
|
| 479 |
):
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
if cosine_fig_key in st.session_state:
|
| 487 |
fig, pair_fig, n_traces, n_pair_traces = st.session_state[cosine_fig_key]
|
|
@@ -526,7 +535,7 @@ def _select_single_variant_samples(
|
|
| 526 |
if not persona_ids:
|
| 527 |
return None
|
| 528 |
|
| 529 |
-
persona_key =
|
| 530 |
layer_options = _layers_for_variant(store, variant, persona_ids, mask_strategy)
|
| 531 |
if not layer_options:
|
| 532 |
st.info("No shared layers are available for the selected personas.")
|
|
@@ -590,43 +599,66 @@ def _render_layered_figure_analysis(
|
|
| 590 |
filename = scope
|
| 591 |
|
| 592 |
if st.button(button_label, type="primary"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 593 |
try:
|
|
|
|
| 594 |
samples = load_persona_vectors(
|
| 595 |
store,
|
| 596 |
variant,
|
| 597 |
mask_strategy=mask_strategy,
|
| 598 |
persona_ids=persona_ids,
|
| 599 |
)
|
|
|
|
| 600 |
build_kwargs = {}
|
| 601 |
if figure_kind in {"umap", "pca"}:
|
| 602 |
build_kwargs["n_components"] = n_components
|
| 603 |
if n_clusters is not None:
|
| 604 |
build_kwargs["n_clusters"] = n_clusters
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
figure_kind,
|
| 608 |
-
layers=selected_layers,
|
| 609 |
-
title=title_fn(variant),
|
| 610 |
-
**build_kwargs,
|
| 611 |
-
)
|
| 612 |
-
if figure_kind in {"umap", "pca"}:
|
| 613 |
-
main_fig.update_layout(height=700)
|
| 614 |
-
extra_fig = (
|
| 615 |
-
build_pair_similarity_figure(
|
| 616 |
samples,
|
| 617 |
layers=selected_layers,
|
| 618 |
-
title=(
|
|
|
|
| 619 |
"Pair similarity trajectories - "
|
| 620 |
f"{prompt_variant_label(variant)} - persona vectors"
|
| 621 |
),
|
| 622 |
)
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 626 |
st.session_state[fig_key] = (main_fig, extra_fig, samples.vectors.shape[0])
|
|
|
|
| 627 |
except Exception as exc:
|
| 628 |
st.error(f"Could not build figure: {exc}")
|
| 629 |
st.session_state.pop(fig_key, None)
|
|
|
|
|
|
|
| 630 |
|
| 631 |
if fig_key in st.session_state:
|
| 632 |
main_fig, extra_fig, n_samples = st.session_state[fig_key]
|
|
@@ -658,7 +690,11 @@ def _render_dendrogram_analysis(
|
|
| 658 |
with st.expander("Variant selection", expanded=True):
|
| 659 |
col1, col2 = st.columns(2)
|
| 660 |
default_a = "biography" if "biography" in variants else variants[0]
|
| 661 |
-
default_b_idx =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 662 |
with col1:
|
| 663 |
variant_a = st.selectbox(
|
| 664 |
"Variant A",
|
|
@@ -704,26 +740,37 @@ def _render_dendrogram_analysis(
|
|
| 704 |
key=widget_key("load", "dendro_linkage", store_id(store)),
|
| 705 |
)
|
| 706 |
|
| 707 |
-
persona_key =
|
| 708 |
fig_key = widget_key(
|
| 709 |
-
"load",
|
|
|
|
| 710 |
store_id(store),
|
| 711 |
store.model_name,
|
| 712 |
mask_strategy.value,
|
| 713 |
-
variant_a,
|
|
|
|
| 714 |
persona_key,
|
| 715 |
-
str(layered_mode),
|
|
|
|
| 716 |
)
|
| 717 |
|
| 718 |
if st.button(
|
| 719 |
"Generate dendrograms",
|
| 720 |
type="primary",
|
| 721 |
-
key=widget_key(
|
|
|
|
|
|
|
| 722 |
):
|
|
|
|
| 723 |
try:
|
|
|
|
| 724 |
samples_a = load_persona_vectors(
|
| 725 |
-
store,
|
|
|
|
|
|
|
|
|
|
| 726 |
)
|
|
|
|
| 727 |
fig_a = plot_persona_dendrogram(
|
| 728 |
samples_a,
|
| 729 |
layered=layered_mode,
|
|
@@ -733,9 +780,14 @@ def _render_dendrogram_analysis(
|
|
| 733 |
fig_a.update_layout(height=750)
|
| 734 |
fig_b = None
|
| 735 |
if variant_a != variant_b:
|
|
|
|
| 736 |
samples_b = load_persona_vectors(
|
| 737 |
-
store,
|
|
|
|
|
|
|
|
|
|
| 738 |
)
|
|
|
|
| 739 |
fig_b = plot_persona_dendrogram(
|
| 740 |
samples_b,
|
| 741 |
layered=layered_mode,
|
|
@@ -743,10 +795,20 @@ def _render_dendrogram_analysis(
|
|
| 743 |
title=f"Dendrogram — {prompt_variant_label(variant_b)}",
|
| 744 |
)
|
| 745 |
fig_b.update_layout(height=750)
|
| 746 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 747 |
except Exception as exc:
|
| 748 |
st.error(f"Could not build dendrogram: {exc}")
|
| 749 |
st.session_state.pop(fig_key, None)
|
|
|
|
|
|
|
| 750 |
|
| 751 |
if fig_key in st.session_state:
|
| 752 |
fig_a, fig_b, n_personas, va, vb = st.session_state[fig_key]
|
|
@@ -764,7 +826,11 @@ def _render_dendrogram_analysis(
|
|
| 764 |
figs = [fig_a] + ([fig_b] if fig_b else [])
|
| 765 |
filenames = [
|
| 766 |
_filename("dendro", store.model_name, mask_strategy.value, va),
|
| 767 |
-
*(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 768 |
]
|
| 769 |
_render_save_buttons(figs, filenames, "dendro")
|
| 770 |
st.success(f"Generated dendrogram(s) for {n_personas} persona(s).")
|
|
@@ -917,7 +983,9 @@ def render_compare_tab() -> None:
|
|
| 917 |
"""Render the analysis tab."""
|
| 918 |
|
| 919 |
st.title("Analysis")
|
| 920 |
-
st.caption(
|
|
|
|
|
|
|
| 921 |
|
| 922 |
source = _render_source_select()
|
| 923 |
|
|
|
|
| 11 |
from persona_vectors.plots import (
|
| 12 |
build_layered_figure,
|
| 13 |
build_pair_similarity_figure,
|
| 14 |
+
build_similarity_figures,
|
| 15 |
plot_layer_similarity,
|
| 16 |
plot_persona_dendrogram,
|
| 17 |
save_plot_html,
|
|
|
|
| 39 |
from utils.helpers import (
|
| 40 |
ANALYSIS_HELP_TEXT,
|
| 41 |
ANALYSIS_MODES,
|
| 42 |
+
personas_fingerprint,
|
| 43 |
prompt_variant_label,
|
| 44 |
slugify,
|
| 45 |
widget_key,
|
|
|
|
| 173 |
)
|
| 174 |
remembered_count = int(st.session_state.get(remembered_count_key, default_count))
|
| 175 |
persona_count = min(max(remembered_count, 0), len(options.regular_ids))
|
| 176 |
+
include_assistant = bool(st.session_state.get(remembered_assistant_key, False))
|
|
|
|
|
|
|
| 177 |
return persona_count, include_assistant
|
| 178 |
|
| 179 |
|
|
|
|
| 355 |
variant_a=variant_a,
|
| 356 |
variant_b=variant_b,
|
| 357 |
persona_ids=persona_ids,
|
| 358 |
+
persona_key=personas_fingerprint(persona_ids),
|
| 359 |
)
|
| 360 |
|
| 361 |
|
|
|
|
| 363 |
store: Store,
|
| 364 |
selection: CosineSelection,
|
| 365 |
) -> tuple[object, object | None, int, int] | None:
|
| 366 |
+
variant_sample_cache: dict[str, object] = {}
|
| 367 |
|
| 368 |
+
def _load_variant(variant: str):
|
| 369 |
+
if variant not in variant_sample_cache:
|
| 370 |
+
samples = load_variant_vectors(
|
|
|
|
| 371 |
store,
|
| 372 |
+
[variant],
|
| 373 |
persona_ids=selection.persona_ids,
|
| 374 |
)
|
| 375 |
+
variant_sample_cache[variant] = samples[variant]
|
| 376 |
+
return variant_sample_cache[variant]
|
| 377 |
|
| 378 |
try:
|
| 379 |
+
samples_a = _load_variant(selection.variant_a)
|
| 380 |
+
samples_b = _load_variant(selection.variant_b)
|
| 381 |
except Exception as exc:
|
| 382 |
st.error(f"Could not load vectors: {exc}")
|
| 383 |
return None
|
| 384 |
|
| 385 |
+
labels = samples_a.labels
|
| 386 |
display_traces = [
|
| 387 |
(
|
| 388 |
label,
|
| 389 |
+
samples_a.vectors[index],
|
| 390 |
+
samples_b.vectors[index],
|
| 391 |
)
|
| 392 |
for index, label in enumerate(labels)
|
| 393 |
]
|
|
|
|
| 404 |
pair_errors = []
|
| 405 |
for left, right in combinations(selection.variants, 2):
|
| 406 |
try:
|
| 407 |
+
left_samples = _load_variant(left)
|
| 408 |
+
right_samples = _load_variant(right)
|
| 409 |
pair_traces.append(
|
| 410 |
(
|
| 411 |
f"{prompt_variant_label(left)} vs {prompt_variant_label(right)}",
|
| 412 |
+
left_samples.vectors.mean(dim=0),
|
| 413 |
+
right_samples.vectors.mean(dim=0),
|
| 414 |
)
|
| 415 |
)
|
| 416 |
except Exception as exc:
|
|
|
|
| 479 |
selection.persona_key,
|
| 480 |
),
|
| 481 |
):
|
| 482 |
+
progress = st.progress(0, text="Loading activation vectors…")
|
| 483 |
+
try:
|
| 484 |
+
progress.progress(15, text="Loading activation vectors…")
|
| 485 |
+
figures = _build_cosine_figures(store, selection)
|
| 486 |
+
if figures is None:
|
| 487 |
+
st.session_state.pop(cosine_fig_key, None)
|
| 488 |
+
return
|
| 489 |
+
progress.progress(90, text="Storing figure state…")
|
| 490 |
+
st.session_state[cosine_fig_key] = figures
|
| 491 |
+
progress.progress(100, text="Done.")
|
| 492 |
+
finally:
|
| 493 |
+
progress.empty()
|
| 494 |
|
| 495 |
if cosine_fig_key in st.session_state:
|
| 496 |
fig, pair_fig, n_traces, n_pair_traces = st.session_state[cosine_fig_key]
|
|
|
|
| 535 |
if not persona_ids:
|
| 536 |
return None
|
| 537 |
|
| 538 |
+
persona_key = personas_fingerprint(persona_ids)
|
| 539 |
layer_options = _layers_for_variant(store, variant, persona_ids, mask_strategy)
|
| 540 |
if not layer_options:
|
| 541 |
st.info("No shared layers are available for the selected personas.")
|
|
|
|
| 599 |
filename = scope
|
| 600 |
|
| 601 |
if st.button(button_label, type="primary"):
|
| 602 |
+
build_label = {
|
| 603 |
+
"umap": "Computing UMAP projections…",
|
| 604 |
+
"pca": "Computing PCA projections…",
|
| 605 |
+
"similarity": "Computing similarity matrices…",
|
| 606 |
+
}.get(figure_kind, "Building figure…")
|
| 607 |
+
progress = st.progress(0, text="Loading activation vectors…")
|
| 608 |
try:
|
| 609 |
+
progress.progress(15, text="Loading activation vectors…")
|
| 610 |
samples = load_persona_vectors(
|
| 611 |
store,
|
| 612 |
variant,
|
| 613 |
mask_strategy=mask_strategy,
|
| 614 |
persona_ids=persona_ids,
|
| 615 |
)
|
| 616 |
+
progress.progress(55, text=build_label)
|
| 617 |
build_kwargs = {}
|
| 618 |
if figure_kind in {"umap", "pca"}:
|
| 619 |
build_kwargs["n_components"] = n_components
|
| 620 |
if n_clusters is not None:
|
| 621 |
build_kwargs["n_clusters"] = n_clusters
|
| 622 |
+
if figure_kind == "similarity" and include_pair_trajectories:
|
| 623 |
+
main_fig, extra_fig = build_similarity_figures(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 624 |
samples,
|
| 625 |
layers=selected_layers,
|
| 626 |
+
title=title_fn(variant),
|
| 627 |
+
pair_title=(
|
| 628 |
"Pair similarity trajectories - "
|
| 629 |
f"{prompt_variant_label(variant)} - persona vectors"
|
| 630 |
),
|
| 631 |
)
|
| 632 |
+
else:
|
| 633 |
+
main_fig = build_layered_figure(
|
| 634 |
+
samples,
|
| 635 |
+
figure_kind,
|
| 636 |
+
layers=selected_layers,
|
| 637 |
+
title=title_fn(variant),
|
| 638 |
+
**build_kwargs,
|
| 639 |
+
)
|
| 640 |
+
if figure_kind in {"umap", "pca"}:
|
| 641 |
+
main_fig.update_layout(height=700)
|
| 642 |
+
extra_fig = (
|
| 643 |
+
build_pair_similarity_figure(
|
| 644 |
+
samples,
|
| 645 |
+
layers=selected_layers,
|
| 646 |
+
title=(
|
| 647 |
+
"Pair similarity trajectories - "
|
| 648 |
+
f"{prompt_variant_label(variant)} - persona vectors"
|
| 649 |
+
),
|
| 650 |
+
)
|
| 651 |
+
if include_pair_trajectories
|
| 652 |
+
else None
|
| 653 |
+
)
|
| 654 |
+
progress.progress(90, text="Storing figure state…")
|
| 655 |
st.session_state[fig_key] = (main_fig, extra_fig, samples.vectors.shape[0])
|
| 656 |
+
progress.progress(100, text="Done.")
|
| 657 |
except Exception as exc:
|
| 658 |
st.error(f"Could not build figure: {exc}")
|
| 659 |
st.session_state.pop(fig_key, None)
|
| 660 |
+
finally:
|
| 661 |
+
progress.empty()
|
| 662 |
|
| 663 |
if fig_key in st.session_state:
|
| 664 |
main_fig, extra_fig, n_samples = st.session_state[fig_key]
|
|
|
|
| 690 |
with st.expander("Variant selection", expanded=True):
|
| 691 |
col1, col2 = st.columns(2)
|
| 692 |
default_a = "biography" if "biography" in variants else variants[0]
|
| 693 |
+
default_b_idx = (
|
| 694 |
+
variants.index("templated")
|
| 695 |
+
if "templated" in variants
|
| 696 |
+
else min(1, len(variants) - 1)
|
| 697 |
+
)
|
| 698 |
with col1:
|
| 699 |
variant_a = st.selectbox(
|
| 700 |
"Variant A",
|
|
|
|
| 740 |
key=widget_key("load", "dendro_linkage", store_id(store)),
|
| 741 |
)
|
| 742 |
|
| 743 |
+
persona_key = personas_fingerprint(persona_ids)
|
| 744 |
fig_key = widget_key(
|
| 745 |
+
"load",
|
| 746 |
+
"dendro_fig_state",
|
| 747 |
store_id(store),
|
| 748 |
store.model_name,
|
| 749 |
mask_strategy.value,
|
| 750 |
+
variant_a,
|
| 751 |
+
variant_b,
|
| 752 |
persona_key,
|
| 753 |
+
str(layered_mode),
|
| 754 |
+
linkage,
|
| 755 |
)
|
| 756 |
|
| 757 |
if st.button(
|
| 758 |
"Generate dendrograms",
|
| 759 |
type="primary",
|
| 760 |
+
key=widget_key(
|
| 761 |
+
"load", "dendro_btn", store_id(store), variant_a, variant_b, persona_key
|
| 762 |
+
),
|
| 763 |
):
|
| 764 |
+
progress = st.progress(0, text="Loading first variant vectors…")
|
| 765 |
try:
|
| 766 |
+
progress.progress(15, text="Loading first variant vectors…")
|
| 767 |
samples_a = load_persona_vectors(
|
| 768 |
+
store,
|
| 769 |
+
variant_a,
|
| 770 |
+
mask_strategy=mask_strategy,
|
| 771 |
+
persona_ids=persona_ids,
|
| 772 |
)
|
| 773 |
+
progress.progress(40, text="Building first dendrogram…")
|
| 774 |
fig_a = plot_persona_dendrogram(
|
| 775 |
samples_a,
|
| 776 |
layered=layered_mode,
|
|
|
|
| 780 |
fig_a.update_layout(height=750)
|
| 781 |
fig_b = None
|
| 782 |
if variant_a != variant_b:
|
| 783 |
+
progress.progress(60, text="Loading second variant vectors…")
|
| 784 |
samples_b = load_persona_vectors(
|
| 785 |
+
store,
|
| 786 |
+
variant_b,
|
| 787 |
+
mask_strategy=mask_strategy,
|
| 788 |
+
persona_ids=persona_ids,
|
| 789 |
)
|
| 790 |
+
progress.progress(75, text="Building second dendrogram…")
|
| 791 |
fig_b = plot_persona_dendrogram(
|
| 792 |
samples_b,
|
| 793 |
layered=layered_mode,
|
|
|
|
| 795 |
title=f"Dendrogram — {prompt_variant_label(variant_b)}",
|
| 796 |
)
|
| 797 |
fig_b.update_layout(height=750)
|
| 798 |
+
progress.progress(90, text="Storing figure state…")
|
| 799 |
+
st.session_state[fig_key] = (
|
| 800 |
+
fig_a,
|
| 801 |
+
fig_b,
|
| 802 |
+
len(persona_ids),
|
| 803 |
+
variant_a,
|
| 804 |
+
variant_b,
|
| 805 |
+
)
|
| 806 |
+
progress.progress(100, text="Done.")
|
| 807 |
except Exception as exc:
|
| 808 |
st.error(f"Could not build dendrogram: {exc}")
|
| 809 |
st.session_state.pop(fig_key, None)
|
| 810 |
+
finally:
|
| 811 |
+
progress.empty()
|
| 812 |
|
| 813 |
if fig_key in st.session_state:
|
| 814 |
fig_a, fig_b, n_personas, va, vb = st.session_state[fig_key]
|
|
|
|
| 826 |
figs = [fig_a] + ([fig_b] if fig_b else [])
|
| 827 |
filenames = [
|
| 828 |
_filename("dendro", store.model_name, mask_strategy.value, va),
|
| 829 |
+
*(
|
| 830 |
+
[_filename("dendro", store.model_name, mask_strategy.value, vb)]
|
| 831 |
+
if fig_b
|
| 832 |
+
else []
|
| 833 |
+
),
|
| 834 |
]
|
| 835 |
_render_save_buttons(figs, filenames, "dendro")
|
| 836 |
st.success(f"Generated dendrogram(s) for {n_personas} persona(s).")
|
|
|
|
| 983 |
"""Render the analysis tab."""
|
| 984 |
|
| 985 |
st.title("Analysis")
|
| 986 |
+
st.caption(
|
| 987 |
+
"Analyse persona vectors by cosine similarity, PCA, UMAP, or hierarchical clustering."
|
| 988 |
+
)
|
| 989 |
|
| 990 |
source = _render_source_select()
|
| 991 |
|
tabs/extract.py
CHANGED
|
@@ -92,7 +92,7 @@ def _render_variant_controls(
|
|
| 92 |
)
|
| 93 |
include_baseline = st.checkbox(
|
| 94 |
"Extract Assistant baseline",
|
| 95 |
-
value=st.session_state.get(_LAST_BASELINE_KEY,
|
| 96 |
key=_extract_widget_key(model_name, remote, dataset_source, "baseline"),
|
| 97 |
help="Also extract the Assistant baseline persona using the first persona's QA set.",
|
| 98 |
)
|
|
|
|
| 92 |
)
|
| 93 |
include_baseline = st.checkbox(
|
| 94 |
"Extract Assistant baseline",
|
| 95 |
+
value=st.session_state.get(_LAST_BASELINE_KEY, False),
|
| 96 |
key=_extract_widget_key(model_name, remote, dataset_source, "baseline"),
|
| 97 |
help="Also extract the Assistant baseline persona using the first persona's QA set.",
|
| 98 |
)
|
utils/compare_sources.py
CHANGED
|
@@ -22,7 +22,7 @@ SOURCE_LOCAL = "Local activations"
|
|
| 22 |
SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
|
| 23 |
|
| 24 |
|
| 25 |
-
@st.cache_resource(show_spinner=False)
|
| 26 |
def activation_store_cached(
|
| 27 |
source: str,
|
| 28 |
location: str,
|
|
@@ -35,7 +35,7 @@ def activation_store_cached(
|
|
| 35 |
return ActivationStore(model_name, location, mask_strategy=mask_strategy)
|
| 36 |
|
| 37 |
|
| 38 |
-
@st.cache_data(show_spinner=False
|
| 39 |
def available_variants_cached(
|
| 40 |
source: str,
|
| 41 |
location: str,
|
|
@@ -46,7 +46,7 @@ def available_variants_cached(
|
|
| 46 |
return store.available_variants()
|
| 47 |
|
| 48 |
|
| 49 |
-
@st.cache_data(show_spinner=False
|
| 50 |
def personas_cached(
|
| 51 |
source: str,
|
| 52 |
location: str,
|
|
@@ -61,7 +61,7 @@ def personas_cached(
|
|
| 61 |
)
|
| 62 |
|
| 63 |
|
| 64 |
-
@st.cache_data(show_spinner=False
|
| 65 |
def persona_names_cached(
|
| 66 |
source: str,
|
| 67 |
location: str,
|
|
@@ -78,7 +78,7 @@ def persona_names_cached(
|
|
| 78 |
)
|
| 79 |
|
| 80 |
|
| 81 |
-
@st.cache_data(show_spinner=False
|
| 82 |
def local_model_options_cached(
|
| 83 |
artifacts_root: str, mask_strategy_value: str
|
| 84 |
) -> list[str]:
|
|
|
|
| 22 |
SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
|
| 23 |
|
| 24 |
|
| 25 |
+
@st.cache_resource(show_spinner=False, max_entries=1)
|
| 26 |
def activation_store_cached(
|
| 27 |
source: str,
|
| 28 |
location: str,
|
|
|
|
| 35 |
return ActivationStore(model_name, location, mask_strategy=mask_strategy)
|
| 36 |
|
| 37 |
|
| 38 |
+
@st.cache_data(show_spinner=False)
|
| 39 |
def available_variants_cached(
|
| 40 |
source: str,
|
| 41 |
location: str,
|
|
|
|
| 46 |
return store.available_variants()
|
| 47 |
|
| 48 |
|
| 49 |
+
@st.cache_data(show_spinner=False)
|
| 50 |
def personas_cached(
|
| 51 |
source: str,
|
| 52 |
location: str,
|
|
|
|
| 61 |
)
|
| 62 |
|
| 63 |
|
| 64 |
+
@st.cache_data(show_spinner=False)
|
| 65 |
def persona_names_cached(
|
| 66 |
source: str,
|
| 67 |
location: str,
|
|
|
|
| 78 |
)
|
| 79 |
|
| 80 |
|
| 81 |
+
@st.cache_data(show_spinner=False)
|
| 82 |
def local_model_options_cached(
|
| 83 |
artifacts_root: str, mask_strategy_value: str
|
| 84 |
) -> list[str]:
|
utils/helpers.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
|
|
| 1 |
import re
|
|
|
|
| 2 |
|
| 3 |
from persona_data.synth_persona import PersonaData
|
| 4 |
|
|
@@ -54,6 +56,18 @@ def widget_key(*parts: str) -> str:
|
|
| 54 |
return "::".join(parts)
|
| 55 |
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
def prompt_variant_label(variant: str) -> str:
|
| 58 |
"""Return a human-friendly prompt-variant label."""
|
| 59 |
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
import re
|
| 3 |
+
from collections.abc import Iterable
|
| 4 |
|
| 5 |
from persona_data.synth_persona import PersonaData
|
| 6 |
|
|
|
|
| 56 |
return "::".join(parts)
|
| 57 |
|
| 58 |
|
| 59 |
+
def personas_fingerprint(persona_ids: Iterable[str]) -> str:
|
| 60 |
+
"""Stable short fingerprint for a set of persona ids.
|
| 61 |
+
|
| 62 |
+
Used as a discriminator in widget keys and session-state keys. At ~1k
|
| 63 |
+
personas, joining ids would produce ~20 KB strings; the sha1 prefix is
|
| 64 |
+
fixed-length and keeps tracebacks readable.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
joined = "|".join(sorted(persona_ids))
|
| 68 |
+
return hashlib.sha1(joined.encode()).hexdigest()[:16]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
def prompt_variant_label(variant: str) -> str:
|
| 72 |
"""Return a human-friendly prompt-variant label."""
|
| 73 |
|
uv.lock
CHANGED
|
@@ -1594,7 +1594,7 @@ requires-dist = [
|
|
| 1594 |
{ name = "datasets", specifier = ">=4.8.5" },
|
| 1595 |
{ name = "huggingface-hub", specifier = ">=1.14.0" },
|
| 1596 |
{ name = "persona-data", specifier = ">=0.4.2" },
|
| 1597 |
-
{ name = "persona-vectors", specifier = ">=0.7.
|
| 1598 |
{ name = "plotly", specifier = ">=6.6.0" },
|
| 1599 |
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
| 1600 |
{ name = "streamlit", specifier = ">=1.44.0" },
|
|
@@ -1602,7 +1602,7 @@ requires-dist = [
|
|
| 1602 |
|
| 1603 |
[[package]]
|
| 1604 |
name = "persona-vectors"
|
| 1605 |
-
version = "0.7.
|
| 1606 |
source = { registry = "https://pypi.org/simple" }
|
| 1607 |
dependencies = [
|
| 1608 |
{ name = "datasets" },
|
|
@@ -1621,9 +1621,9 @@ dependencies = [
|
|
| 1621 |
{ name = "transformers" },
|
| 1622 |
{ name = "umap-learn" },
|
| 1623 |
]
|
| 1624 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 1625 |
wheels = [
|
| 1626 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 1627 |
]
|
| 1628 |
|
| 1629 |
[[package]]
|
|
|
|
| 1594 |
{ name = "datasets", specifier = ">=4.8.5" },
|
| 1595 |
{ name = "huggingface-hub", specifier = ">=1.14.0" },
|
| 1596 |
{ name = "persona-data", specifier = ">=0.4.2" },
|
| 1597 |
+
{ name = "persona-vectors", specifier = ">=0.7.2" },
|
| 1598 |
{ name = "plotly", specifier = ">=6.6.0" },
|
| 1599 |
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
| 1600 |
{ name = "streamlit", specifier = ">=1.44.0" },
|
|
|
|
| 1602 |
|
| 1603 |
[[package]]
|
| 1604 |
name = "persona-vectors"
|
| 1605 |
+
version = "0.7.2"
|
| 1606 |
source = { registry = "https://pypi.org/simple" }
|
| 1607 |
dependencies = [
|
| 1608 |
{ name = "datasets" },
|
|
|
|
| 1621 |
{ name = "transformers" },
|
| 1622 |
{ name = "umap-learn" },
|
| 1623 |
]
|
| 1624 |
+
sdist = { url = "https://files.pythonhosted.org/packages/03/ee/3550a015546ad9e7842e4d94346e2268c1aa65a7278f39ab574dd2b702bd/persona_vectors-0.7.2.tar.gz", hash = "sha256:22f3e4349bd020fcc287ed5e00ff9d0b752dfe95b3f145b8778efc485cf7f67f", size = 27777, upload-time = "2026-05-10T10:29:07.67Z" }
|
| 1625 |
wheels = [
|
| 1626 |
+
{ url = "https://files.pythonhosted.org/packages/c7/c2/334c66944f1669c7cc19e4e3c2f49ab65b0a186e6cd6cb709afd253b4268/persona_vectors-0.7.2-py3-none-any.whl", hash = "sha256:cbd5b6de2afb17fccef9414eba1b0096ffa3407a1496b253bb2d30ed548d123e", size = 32387, upload-time = "2026-05-10T10:29:06.808Z" },
|
| 1627 |
]
|
| 1628 |
|
| 1629 |
[[package]]
|