Jac-Zac commited on
Commit ·
ecd19ae
1
Parent(s): f40862f
Adding plotting ooptions
Browse files- app.py +3 -3
- tabs/compare.py +156 -3
- utils/helpers.py +2 -1
app.py
CHANGED
|
@@ -13,7 +13,7 @@ _LAST_LOCAL_MODEL_KEY = "sidebar:last_local_model"
|
|
| 13 |
_LAST_REMOTE_MODEL_KEY = "sidebar:last_remote_model"
|
| 14 |
|
| 15 |
|
| 16 |
-
_TABS = ["Chat", "
|
| 17 |
_TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
|
| 18 |
|
| 19 |
|
|
@@ -105,7 +105,7 @@ def _sidebar_controls() -> SidebarState:
|
|
| 105 |
st.session_state["sidebar__active_tab"] = tab_name
|
| 106 |
st.rerun()
|
| 107 |
|
| 108 |
-
if active_tab == "
|
| 109 |
model_name = st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL)
|
| 110 |
dataset_source = st.session_state.get(
|
| 111 |
"sidebar__dataset_source",
|
|
@@ -169,7 +169,7 @@ def main() -> None:
|
|
| 169 |
from tabs.extract import render_extract_tab
|
| 170 |
|
| 171 |
render_extract_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)
|
| 172 |
-
elif sidebar.active_tab == "
|
| 173 |
from tabs.compare import render_compare_tab
|
| 174 |
|
| 175 |
render_compare_tab()
|
|
|
|
| 13 |
_LAST_REMOTE_MODEL_KEY = "sidebar:last_remote_model"
|
| 14 |
|
| 15 |
|
| 16 |
+
_TABS = ["Chat", "Analysis", "Extract"]
|
| 17 |
_TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
|
| 18 |
|
| 19 |
|
|
|
|
| 105 |
st.session_state["sidebar__active_tab"] = tab_name
|
| 106 |
st.rerun()
|
| 107 |
|
| 108 |
+
if active_tab == "Analysis":
|
| 109 |
model_name = st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL)
|
| 110 |
dataset_source = st.session_state.get(
|
| 111 |
"sidebar__dataset_source",
|
|
|
|
| 169 |
from tabs.extract import render_extract_tab
|
| 170 |
|
| 171 |
render_extract_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)
|
| 172 |
+
elif sidebar.active_tab == "Analysis":
|
| 173 |
from tabs.compare import render_compare_tab
|
| 174 |
|
| 175 |
render_compare_tab()
|
tabs/compare.py
CHANGED
|
@@ -12,6 +12,7 @@ from persona_vectors.plots import (
|
|
| 12 |
build_layered_figure,
|
| 13 |
build_pair_similarity_figure,
|
| 14 |
plot_layer_similarity,
|
|
|
|
| 15 |
save_plot_html,
|
| 16 |
)
|
| 17 |
|
|
@@ -556,6 +557,23 @@ def _render_layered_figure_analysis(
|
|
| 556 |
return
|
| 557 |
variant, persona_ids, persona_key, selected_layers = selected
|
| 558 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 559 |
fig_key = widget_key(
|
| 560 |
"load",
|
| 561 |
f"{scope}_fig_state",
|
|
@@ -564,6 +582,7 @@ def _render_layered_figure_analysis(
|
|
| 564 |
mask_strategy.value,
|
| 565 |
figure_kind,
|
| 566 |
str(n_components),
|
|
|
|
| 567 |
variant,
|
| 568 |
"persona_vector",
|
| 569 |
persona_key,
|
|
@@ -581,6 +600,8 @@ def _render_layered_figure_analysis(
|
|
| 581 |
build_kwargs = {}
|
| 582 |
if figure_kind in {"umap", "pca"}:
|
| 583 |
build_kwargs["n_components"] = n_components
|
|
|
|
|
|
|
| 584 |
main_fig = build_layered_figure(
|
| 585 |
samples,
|
| 586 |
figure_kind,
|
|
@@ -621,6 +642,134 @@ def _render_layered_figure_analysis(
|
|
| 621 |
st.success(f"Loaded {n_samples} samples.")
|
| 622 |
|
| 623 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 624 |
def _render_source_select() -> str:
|
| 625 |
last_source = st.session_state.get(_LAST_SOURCE_KEY, SOURCE_HUB)
|
| 626 |
source = st.segmented_control(
|
|
@@ -765,10 +914,10 @@ def _build_store(source: str, mask_strategy: MaskStrategy) -> Store:
|
|
| 765 |
|
| 766 |
|
| 767 |
def render_compare_tab() -> None:
|
| 768 |
-
"""Render the
|
| 769 |
|
| 770 |
-
st.title("
|
| 771 |
-
st.caption("
|
| 772 |
|
| 773 |
source = _render_source_select()
|
| 774 |
|
|
@@ -804,6 +953,10 @@ def render_compare_tab() -> None:
|
|
| 804 |
)
|
| 805 |
return
|
| 806 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
dimension_choice = st.segmented_control(
|
| 808 |
"Projection dimensions",
|
| 809 |
options=["2D", "3D"],
|
|
|
|
| 12 |
build_layered_figure,
|
| 13 |
build_pair_similarity_figure,
|
| 14 |
plot_layer_similarity,
|
| 15 |
+
plot_persona_dendrogram,
|
| 16 |
save_plot_html,
|
| 17 |
)
|
| 18 |
|
|
|
|
| 557 |
return
|
| 558 |
variant, persona_ids, persona_key, selected_layers = selected
|
| 559 |
|
| 560 |
+
n_clusters = None
|
| 561 |
+
if figure_kind in {"pca", "umap"}:
|
| 562 |
+
use_kmeans = st.toggle(
|
| 563 |
+
"Color by K-means clusters",
|
| 564 |
+
value=False,
|
| 565 |
+
key=widget_key("load", "kmeans_enabled", scope, store_id(store)),
|
| 566 |
+
help="Run K-means on persona vectors and color each persona by cluster.",
|
| 567 |
+
)
|
| 568 |
+
if use_kmeans:
|
| 569 |
+
n_clusters = st.slider(
|
| 570 |
+
"K (clusters)",
|
| 571 |
+
min_value=2,
|
| 572 |
+
max_value=min(10, len(persona_ids)),
|
| 573 |
+
value=min(3, len(persona_ids)),
|
| 574 |
+
key=widget_key("load", "kmeans_k", scope, store_id(store)),
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
fig_key = widget_key(
|
| 578 |
"load",
|
| 579 |
f"{scope}_fig_state",
|
|
|
|
| 582 |
mask_strategy.value,
|
| 583 |
figure_kind,
|
| 584 |
str(n_components),
|
| 585 |
+
str(n_clusters),
|
| 586 |
variant,
|
| 587 |
"persona_vector",
|
| 588 |
persona_key,
|
|
|
|
| 600 |
build_kwargs = {}
|
| 601 |
if figure_kind in {"umap", "pca"}:
|
| 602 |
build_kwargs["n_components"] = n_components
|
| 603 |
+
if n_clusters is not None:
|
| 604 |
+
build_kwargs["n_clusters"] = n_clusters
|
| 605 |
main_fig = build_layered_figure(
|
| 606 |
samples,
|
| 607 |
figure_kind,
|
|
|
|
| 642 |
st.success(f"Loaded {n_samples} samples.")
|
| 643 |
|
| 644 |
|
| 645 |
+
_LAST_DENDRO_PERSONAS_KEY = "compare:last_personas:dendro"
|
| 646 |
+
_DENDRO_LINKAGE_OPTIONS = ["ward", "complete", "average", "single"]
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def _render_dendrogram_analysis(
|
| 650 |
+
store: Store,
|
| 651 |
+
mask_strategy: MaskStrategy,
|
| 652 |
+
) -> None:
|
| 653 |
+
variants = available_variants(store, mask_strategy)
|
| 654 |
+
if not variants:
|
| 655 |
+
st.info("No variants with saved vectors for this model.")
|
| 656 |
+
return
|
| 657 |
+
|
| 658 |
+
with st.expander("Variant selection", expanded=True):
|
| 659 |
+
col1, col2 = st.columns(2)
|
| 660 |
+
default_a = "biography" if "biography" in variants else variants[0]
|
| 661 |
+
default_b_idx = variants.index("templated") if "templated" in variants else min(1, len(variants) - 1)
|
| 662 |
+
with col1:
|
| 663 |
+
variant_a = st.selectbox(
|
| 664 |
+
"Variant A",
|
| 665 |
+
options=variants,
|
| 666 |
+
index=variants.index(default_a),
|
| 667 |
+
format_func=prompt_variant_label,
|
| 668 |
+
key=widget_key("load", "dendro_variant_a", store_id(store)),
|
| 669 |
+
)
|
| 670 |
+
with col2:
|
| 671 |
+
variant_b = st.selectbox(
|
| 672 |
+
"Variant B",
|
| 673 |
+
options=variants,
|
| 674 |
+
index=default_b_idx,
|
| 675 |
+
format_func=prompt_variant_label,
|
| 676 |
+
key=widget_key("load", "dendro_variant_b", store_id(store)),
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
shared_variants = list(dict.fromkeys([variant_a, variant_b]))
|
| 680 |
+
persona_ids = _select_artifact_personas(
|
| 681 |
+
store,
|
| 682 |
+
shared_variants,
|
| 683 |
+
mask_strategy,
|
| 684 |
+
widget_scope=f"dendro:{store_id(store)}",
|
| 685 |
+
remember_key=_LAST_DENDRO_PERSONAS_KEY,
|
| 686 |
+
default_all=True,
|
| 687 |
+
)
|
| 688 |
+
if not persona_ids:
|
| 689 |
+
return
|
| 690 |
+
|
| 691 |
+
col_opts1, col_opts2 = st.columns(2)
|
| 692 |
+
with col_opts1:
|
| 693 |
+
layered_mode = st.toggle(
|
| 694 |
+
"Per-layer animated",
|
| 695 |
+
value=False,
|
| 696 |
+
key=widget_key("load", "dendro_layered", store_id(store)),
|
| 697 |
+
help="Animated dendrogram with one frame per layer instead of averaging all layers.",
|
| 698 |
+
)
|
| 699 |
+
with col_opts2:
|
| 700 |
+
linkage = st.selectbox(
|
| 701 |
+
"Linkage",
|
| 702 |
+
options=_DENDRO_LINKAGE_OPTIONS,
|
| 703 |
+
index=0,
|
| 704 |
+
key=widget_key("load", "dendro_linkage", store_id(store)),
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
persona_key = "_".join(sorted(persona_ids))
|
| 708 |
+
fig_key = widget_key(
|
| 709 |
+
"load", "dendro_fig_state",
|
| 710 |
+
store_id(store),
|
| 711 |
+
store.model_name,
|
| 712 |
+
mask_strategy.value,
|
| 713 |
+
variant_a, variant_b,
|
| 714 |
+
persona_key,
|
| 715 |
+
str(layered_mode), linkage,
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
if st.button(
|
| 719 |
+
"Generate dendrograms",
|
| 720 |
+
type="primary",
|
| 721 |
+
key=widget_key("load", "dendro_btn", store_id(store), variant_a, variant_b, persona_key),
|
| 722 |
+
):
|
| 723 |
+
try:
|
| 724 |
+
samples_a = load_persona_vectors(
|
| 725 |
+
store, variant_a, mask_strategy=mask_strategy, persona_ids=persona_ids,
|
| 726 |
+
)
|
| 727 |
+
fig_a = plot_persona_dendrogram(
|
| 728 |
+
samples_a,
|
| 729 |
+
layered=layered_mode,
|
| 730 |
+
linkage=linkage,
|
| 731 |
+
title=f"Dendrogram — {prompt_variant_label(variant_a)}",
|
| 732 |
+
)
|
| 733 |
+
fig_a.update_layout(height=750)
|
| 734 |
+
fig_b = None
|
| 735 |
+
if variant_a != variant_b:
|
| 736 |
+
samples_b = load_persona_vectors(
|
| 737 |
+
store, variant_b, mask_strategy=mask_strategy, persona_ids=persona_ids,
|
| 738 |
+
)
|
| 739 |
+
fig_b = plot_persona_dendrogram(
|
| 740 |
+
samples_b,
|
| 741 |
+
layered=layered_mode,
|
| 742 |
+
linkage=linkage,
|
| 743 |
+
title=f"Dendrogram — {prompt_variant_label(variant_b)}",
|
| 744 |
+
)
|
| 745 |
+
fig_b.update_layout(height=750)
|
| 746 |
+
st.session_state[fig_key] = (fig_a, fig_b, len(persona_ids), variant_a, variant_b)
|
| 747 |
+
except Exception as exc:
|
| 748 |
+
st.error(f"Could not build dendrogram: {exc}")
|
| 749 |
+
st.session_state.pop(fig_key, None)
|
| 750 |
+
|
| 751 |
+
if fig_key in st.session_state:
|
| 752 |
+
fig_a, fig_b, n_personas, va, vb = st.session_state[fig_key]
|
| 753 |
+
if fig_b is not None:
|
| 754 |
+
col_a, col_b = st.columns(2)
|
| 755 |
+
with col_a:
|
| 756 |
+
st.subheader(prompt_variant_label(va))
|
| 757 |
+
st.plotly_chart(fig_a, width="stretch")
|
| 758 |
+
with col_b:
|
| 759 |
+
st.subheader(prompt_variant_label(vb))
|
| 760 |
+
st.plotly_chart(fig_b, width="stretch")
|
| 761 |
+
else:
|
| 762 |
+
st.plotly_chart(fig_a, width="stretch")
|
| 763 |
+
|
| 764 |
+
figs = [fig_a] + ([fig_b] if fig_b else [])
|
| 765 |
+
filenames = [
|
| 766 |
+
_filename("dendro", store.model_name, mask_strategy.value, va),
|
| 767 |
+
*([_filename("dendro", store.model_name, mask_strategy.value, vb)] if fig_b else []),
|
| 768 |
+
]
|
| 769 |
+
_render_save_buttons(figs, filenames, "dendro")
|
| 770 |
+
st.success(f"Generated dendrogram(s) for {n_personas} persona(s).")
|
| 771 |
+
|
| 772 |
+
|
| 773 |
def _render_source_select() -> str:
|
| 774 |
last_source = st.session_state.get(_LAST_SOURCE_KEY, SOURCE_HUB)
|
| 775 |
source = st.segmented_control(
|
|
|
|
| 914 |
|
| 915 |
|
| 916 |
def render_compare_tab() -> None:
|
| 917 |
+
"""Render the analysis tab."""
|
| 918 |
|
| 919 |
+
st.title("Analysis")
|
| 920 |
+
st.caption("Analyse persona vectors by cosine similarity, PCA, UMAP, or hierarchical clustering.")
|
| 921 |
|
| 922 |
source = _render_source_select()
|
| 923 |
|
|
|
|
| 953 |
)
|
| 954 |
return
|
| 955 |
|
| 956 |
+
if analysis_mode == "Dendrogram":
|
| 957 |
+
_render_dendrogram_analysis(store, mask_strategy)
|
| 958 |
+
return
|
| 959 |
+
|
| 960 |
dimension_choice = st.segmented_control(
|
| 961 |
"Projection dimensions",
|
| 962 |
options=["2D", "3D"],
|
utils/helpers.py
CHANGED
|
@@ -22,13 +22,14 @@ DATASET_SOURCES = [
|
|
| 22 |
"HuggingFace: nemotron-usa",
|
| 23 |
"Local JSONL upload",
|
| 24 |
]
|
| 25 |
-
ANALYSIS_MODES = ["Cosine similarity", "Similarity matrix", "PCA", "UMAP"]
|
| 26 |
|
| 27 |
ANALYSIS_HELP_TEXT = {
|
| 28 |
"Cosine similarity": "Compare layer-wise alignment between variants.",
|
| 29 |
"Similarity matrix": "Compare centered pairwise similarity between persona vectors by layer, with pair trajectories across layers.",
|
| 30 |
"PCA": "Project per-persona vectors into a 2D or 3D global view.",
|
| 31 |
"UMAP": "Project per-persona vectors into a 2D or 3D local-neighborhood view.",
|
|
|
|
| 32 |
}
|
| 33 |
|
| 34 |
NDIF_STATUS_ICONS = {
|
|
|
|
| 22 |
"HuggingFace: nemotron-usa",
|
| 23 |
"Local JSONL upload",
|
| 24 |
]
|
| 25 |
+
ANALYSIS_MODES = ["Cosine similarity", "Similarity matrix", "PCA", "UMAP", "Dendrogram"]
|
| 26 |
|
| 27 |
ANALYSIS_HELP_TEXT = {
|
| 28 |
"Cosine similarity": "Compare layer-wise alignment between variants.",
|
| 29 |
"Similarity matrix": "Compare centered pairwise similarity between persona vectors by layer, with pair trajectories across layers.",
|
| 30 |
"PCA": "Project per-persona vectors into a 2D or 3D global view.",
|
| 31 |
"UMAP": "Project per-persona vectors into a 2D or 3D local-neighborhood view.",
|
| 32 |
+
"Dendrogram": "Hierarchical clustering of persona vectors — shows biography and templated side by side for direct comparison.",
|
| 33 |
}
|
| 34 |
|
| 35 |
NDIF_STATUS_ICONS = {
|