Jac-Zac commited on
Commit
ecd19ae
·
1 Parent(s): f40862f

Adding plotting ooptions

Browse files
Files changed (3) hide show
  1. app.py +3 -3
  2. tabs/compare.py +156 -3
  3. utils/helpers.py +2 -1
app.py CHANGED
@@ -13,7 +13,7 @@ _LAST_LOCAL_MODEL_KEY = "sidebar:last_local_model"
13
  _LAST_REMOTE_MODEL_KEY = "sidebar:last_remote_model"
14
 
15
 
16
- _TABS = ["Chat", "Compare", "Extract"]
17
  _TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
18
 
19
 
@@ -105,7 +105,7 @@ def _sidebar_controls() -> SidebarState:
105
  st.session_state["sidebar__active_tab"] = tab_name
106
  st.rerun()
107
 
108
- if active_tab == "Compare":
109
  model_name = st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL)
110
  dataset_source = st.session_state.get(
111
  "sidebar__dataset_source",
@@ -169,7 +169,7 @@ def main() -> None:
169
  from tabs.extract import render_extract_tab
170
 
171
  render_extract_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)
172
- elif sidebar.active_tab == "Compare":
173
  from tabs.compare import render_compare_tab
174
 
175
  render_compare_tab()
 
13
  _LAST_REMOTE_MODEL_KEY = "sidebar:last_remote_model"
14
 
15
 
16
+ _TABS = ["Chat", "Analysis", "Extract"]
17
  _TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
18
 
19
 
 
105
  st.session_state["sidebar__active_tab"] = tab_name
106
  st.rerun()
107
 
108
+ if active_tab == "Analysis":
109
  model_name = st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL)
110
  dataset_source = st.session_state.get(
111
  "sidebar__dataset_source",
 
169
  from tabs.extract import render_extract_tab
170
 
171
  render_extract_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)
172
+ elif sidebar.active_tab == "Analysis":
173
  from tabs.compare import render_compare_tab
174
 
175
  render_compare_tab()
tabs/compare.py CHANGED
@@ -12,6 +12,7 @@ from persona_vectors.plots import (
12
  build_layered_figure,
13
  build_pair_similarity_figure,
14
  plot_layer_similarity,
 
15
  save_plot_html,
16
  )
17
 
@@ -556,6 +557,23 @@ def _render_layered_figure_analysis(
556
  return
557
  variant, persona_ids, persona_key, selected_layers = selected
558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
  fig_key = widget_key(
560
  "load",
561
  f"{scope}_fig_state",
@@ -564,6 +582,7 @@ def _render_layered_figure_analysis(
564
  mask_strategy.value,
565
  figure_kind,
566
  str(n_components),
 
567
  variant,
568
  "persona_vector",
569
  persona_key,
@@ -581,6 +600,8 @@ def _render_layered_figure_analysis(
581
  build_kwargs = {}
582
  if figure_kind in {"umap", "pca"}:
583
  build_kwargs["n_components"] = n_components
 
 
584
  main_fig = build_layered_figure(
585
  samples,
586
  figure_kind,
@@ -621,6 +642,134 @@ def _render_layered_figure_analysis(
621
  st.success(f"Loaded {n_samples} samples.")
622
 
623
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624
  def _render_source_select() -> str:
625
  last_source = st.session_state.get(_LAST_SOURCE_KEY, SOURCE_HUB)
626
  source = st.segmented_control(
@@ -765,10 +914,10 @@ def _build_store(source: str, mask_strategy: MaskStrategy) -> Store:
765
 
766
 
767
  def render_compare_tab() -> None:
768
- """Render the compare tab."""
769
 
770
- st.title("Compare")
771
- st.caption("Compare persona vectors by cosine similarity, PCA, or UMAP.")
772
 
773
  source = _render_source_select()
774
 
@@ -804,6 +953,10 @@ def render_compare_tab() -> None:
804
  )
805
  return
806
 
 
 
 
 
807
  dimension_choice = st.segmented_control(
808
  "Projection dimensions",
809
  options=["2D", "3D"],
 
12
  build_layered_figure,
13
  build_pair_similarity_figure,
14
  plot_layer_similarity,
15
+ plot_persona_dendrogram,
16
  save_plot_html,
17
  )
18
 
 
557
  return
558
  variant, persona_ids, persona_key, selected_layers = selected
559
 
560
+ n_clusters = None
561
+ if figure_kind in {"pca", "umap"}:
562
+ use_kmeans = st.toggle(
563
+ "Color by K-means clusters",
564
+ value=False,
565
+ key=widget_key("load", "kmeans_enabled", scope, store_id(store)),
566
+ help="Run K-means on persona vectors and color each persona by cluster.",
567
+ )
568
+ if use_kmeans:
569
+ n_clusters = st.slider(
570
+ "K (clusters)",
571
+ min_value=2,
572
+ max_value=min(10, len(persona_ids)),
573
+ value=min(3, len(persona_ids)),
574
+ key=widget_key("load", "kmeans_k", scope, store_id(store)),
575
+ )
576
+
577
  fig_key = widget_key(
578
  "load",
579
  f"{scope}_fig_state",
 
582
  mask_strategy.value,
583
  figure_kind,
584
  str(n_components),
585
+ str(n_clusters),
586
  variant,
587
  "persona_vector",
588
  persona_key,
 
600
  build_kwargs = {}
601
  if figure_kind in {"umap", "pca"}:
602
  build_kwargs["n_components"] = n_components
603
+ if n_clusters is not None:
604
+ build_kwargs["n_clusters"] = n_clusters
605
  main_fig = build_layered_figure(
606
  samples,
607
  figure_kind,
 
642
  st.success(f"Loaded {n_samples} samples.")
643
 
644
 
645
+ _LAST_DENDRO_PERSONAS_KEY = "compare:last_personas:dendro"
646
+ _DENDRO_LINKAGE_OPTIONS = ["ward", "complete", "average", "single"]
647
+
648
+
649
+ def _render_dendrogram_analysis(
650
+ store: Store,
651
+ mask_strategy: MaskStrategy,
652
+ ) -> None:
653
+ variants = available_variants(store, mask_strategy)
654
+ if not variants:
655
+ st.info("No variants with saved vectors for this model.")
656
+ return
657
+
658
+ with st.expander("Variant selection", expanded=True):
659
+ col1, col2 = st.columns(2)
660
+ default_a = "biography" if "biography" in variants else variants[0]
661
+ default_b_idx = variants.index("templated") if "templated" in variants else min(1, len(variants) - 1)
662
+ with col1:
663
+ variant_a = st.selectbox(
664
+ "Variant A",
665
+ options=variants,
666
+ index=variants.index(default_a),
667
+ format_func=prompt_variant_label,
668
+ key=widget_key("load", "dendro_variant_a", store_id(store)),
669
+ )
670
+ with col2:
671
+ variant_b = st.selectbox(
672
+ "Variant B",
673
+ options=variants,
674
+ index=default_b_idx,
675
+ format_func=prompt_variant_label,
676
+ key=widget_key("load", "dendro_variant_b", store_id(store)),
677
+ )
678
+
679
+ shared_variants = list(dict.fromkeys([variant_a, variant_b]))
680
+ persona_ids = _select_artifact_personas(
681
+ store,
682
+ shared_variants,
683
+ mask_strategy,
684
+ widget_scope=f"dendro:{store_id(store)}",
685
+ remember_key=_LAST_DENDRO_PERSONAS_KEY,
686
+ default_all=True,
687
+ )
688
+ if not persona_ids:
689
+ return
690
+
691
+ col_opts1, col_opts2 = st.columns(2)
692
+ with col_opts1:
693
+ layered_mode = st.toggle(
694
+ "Per-layer animated",
695
+ value=False,
696
+ key=widget_key("load", "dendro_layered", store_id(store)),
697
+ help="Animated dendrogram with one frame per layer instead of averaging all layers.",
698
+ )
699
+ with col_opts2:
700
+ linkage = st.selectbox(
701
+ "Linkage",
702
+ options=_DENDRO_LINKAGE_OPTIONS,
703
+ index=0,
704
+ key=widget_key("load", "dendro_linkage", store_id(store)),
705
+ )
706
+
707
+ persona_key = "_".join(sorted(persona_ids))
708
+ fig_key = widget_key(
709
+ "load", "dendro_fig_state",
710
+ store_id(store),
711
+ store.model_name,
712
+ mask_strategy.value,
713
+ variant_a, variant_b,
714
+ persona_key,
715
+ str(layered_mode), linkage,
716
+ )
717
+
718
+ if st.button(
719
+ "Generate dendrograms",
720
+ type="primary",
721
+ key=widget_key("load", "dendro_btn", store_id(store), variant_a, variant_b, persona_key),
722
+ ):
723
+ try:
724
+ samples_a = load_persona_vectors(
725
+ store, variant_a, mask_strategy=mask_strategy, persona_ids=persona_ids,
726
+ )
727
+ fig_a = plot_persona_dendrogram(
728
+ samples_a,
729
+ layered=layered_mode,
730
+ linkage=linkage,
731
+ title=f"Dendrogram — {prompt_variant_label(variant_a)}",
732
+ )
733
+ fig_a.update_layout(height=750)
734
+ fig_b = None
735
+ if variant_a != variant_b:
736
+ samples_b = load_persona_vectors(
737
+ store, variant_b, mask_strategy=mask_strategy, persona_ids=persona_ids,
738
+ )
739
+ fig_b = plot_persona_dendrogram(
740
+ samples_b,
741
+ layered=layered_mode,
742
+ linkage=linkage,
743
+ title=f"Dendrogram — {prompt_variant_label(variant_b)}",
744
+ )
745
+ fig_b.update_layout(height=750)
746
+ st.session_state[fig_key] = (fig_a, fig_b, len(persona_ids), variant_a, variant_b)
747
+ except Exception as exc:
748
+ st.error(f"Could not build dendrogram: {exc}")
749
+ st.session_state.pop(fig_key, None)
750
+
751
+ if fig_key in st.session_state:
752
+ fig_a, fig_b, n_personas, va, vb = st.session_state[fig_key]
753
+ if fig_b is not None:
754
+ col_a, col_b = st.columns(2)
755
+ with col_a:
756
+ st.subheader(prompt_variant_label(va))
757
+ st.plotly_chart(fig_a, width="stretch")
758
+ with col_b:
759
+ st.subheader(prompt_variant_label(vb))
760
+ st.plotly_chart(fig_b, width="stretch")
761
+ else:
762
+ st.plotly_chart(fig_a, width="stretch")
763
+
764
+ figs = [fig_a] + ([fig_b] if fig_b else [])
765
+ filenames = [
766
+ _filename("dendro", store.model_name, mask_strategy.value, va),
767
+ *([_filename("dendro", store.model_name, mask_strategy.value, vb)] if fig_b else []),
768
+ ]
769
+ _render_save_buttons(figs, filenames, "dendro")
770
+ st.success(f"Generated dendrogram(s) for {n_personas} persona(s).")
771
+
772
+
773
  def _render_source_select() -> str:
774
  last_source = st.session_state.get(_LAST_SOURCE_KEY, SOURCE_HUB)
775
  source = st.segmented_control(
 
914
 
915
 
916
  def render_compare_tab() -> None:
917
+ """Render the analysis tab."""
918
 
919
+ st.title("Analysis")
920
+ st.caption("Analyse persona vectors by cosine similarity, PCA, UMAP, or hierarchical clustering.")
921
 
922
  source = _render_source_select()
923
 
 
953
  )
954
  return
955
 
956
+ if analysis_mode == "Dendrogram":
957
+ _render_dendrogram_analysis(store, mask_strategy)
958
+ return
959
+
960
  dimension_choice = st.segmented_control(
961
  "Projection dimensions",
962
  options=["2D", "3D"],
utils/helpers.py CHANGED
@@ -22,13 +22,14 @@ DATASET_SOURCES = [
22
  "HuggingFace: nemotron-usa",
23
  "Local JSONL upload",
24
  ]
25
- ANALYSIS_MODES = ["Cosine similarity", "Similarity matrix", "PCA", "UMAP"]
26
 
27
  ANALYSIS_HELP_TEXT = {
28
  "Cosine similarity": "Compare layer-wise alignment between variants.",
29
  "Similarity matrix": "Compare centered pairwise similarity between persona vectors by layer, with pair trajectories across layers.",
30
  "PCA": "Project per-persona vectors into a 2D or 3D global view.",
31
  "UMAP": "Project per-persona vectors into a 2D or 3D local-neighborhood view.",
 
32
  }
33
 
34
  NDIF_STATUS_ICONS = {
 
22
  "HuggingFace: nemotron-usa",
23
  "Local JSONL upload",
24
  ]
25
+ ANALYSIS_MODES = ["Cosine similarity", "Similarity matrix", "PCA", "UMAP", "Dendrogram"]
26
 
27
  ANALYSIS_HELP_TEXT = {
28
  "Cosine similarity": "Compare layer-wise alignment between variants.",
29
  "Similarity matrix": "Compare centered pairwise similarity between persona vectors by layer, with pair trajectories across layers.",
30
  "PCA": "Project per-persona vectors into a 2D or 3D global view.",
31
  "UMAP": "Project per-persona vectors into a 2D or 3D local-neighborhood view.",
32
+ "Dendrogram": "Hierarchical clustering of persona vectors — shows biography and templated side by side for direct comparison.",
33
  }
34
 
35
  NDIF_STATUS_ICONS = {