persona-ui / tabs /analysis /dendrogram.py
Jac-Zac
Updated to latest persona-vector
e8b71ab
import gc
from copy import deepcopy
import plotly.graph_objects as go
import streamlit as st
from persona_vectors.extraction import MaskStrategy
from persona_vectors.plots import plot_persona_dendrogram
from plotly.subplots import make_subplots
from tabs.analysis._shared import (
_load_persona_options,
_load_variant_vectors,
_plotly_chart,
_render_layer_frame_controls,
_render_persona_select_controls,
_render_save_buttons,
_select_artifact_personas,
)
from tabs.analysis._state import (
_DEFAULT_PERSONA_LIMITS,
_MAX_PERSONA_COUNTS,
_clear_old_figure_states,
_filename,
_persona_names_state_key,
_personas_empty_message,
_store_figure_state,
)
from utils.analysis_sources import (
Store,
available_variants,
store_cache_parts,
store_id,
store_layers_cached,
)
from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
_LAST_DENDRO_PERSONAS_KEY = "analysis:last_personas:dendro"
_DENDRO_LINKAGE_OPTIONS = ["ward", "complete", "average", "single"]
def _comparison_dendrogram_figure(
fig_a: go.Figure,
fig_b: go.Figure,
*,
title_a: str,
title_b: str,
) -> go.Figure:
"""Merge two layered dendrograms so one slider drives both panels."""
combined = make_subplots(
rows=1,
cols=2,
subplot_titles=(title_a, title_b),
shared_yaxes=True,
horizontal_spacing=0.05,
)
for trace in fig_a.data:
combined.add_trace(deepcopy(trace), row=1, col=1)
for trace in fig_b.data:
combined.add_trace(deepcopy(trace), row=1, col=2)
frames: list[go.Frame] = []
for frame_a, frame_b in zip(fig_a.frames, fig_b.frames, strict=True):
right_data = []
for trace in frame_b.data:
copied = deepcopy(trace)
copied.update(xaxis="x2", yaxis="y2")
right_data.append(copied)
frame_xaxis = frame_a.layout.xaxis.to_plotly_json()
frame_xaxis2 = frame_b.layout.xaxis.to_plotly_json()
frame_xaxis2["matches"] = None
frame_xaxis2["anchor"] = "y2"
frame_yaxis = frame_a.layout.yaxis.to_plotly_json()
frame_yaxis2 = frame_b.layout.yaxis.to_plotly_json()
frame_yaxis2["matches"] = "y"
frame_yaxis2["anchor"] = "x2"
frames.append(
go.Frame(
name=frame_a.name,
data=[*deepcopy(frame_a.data), *right_data],
layout={
"title": {"text": f"Dendrogram comparison - Layer {frame_a.name}"},
"xaxis": frame_xaxis,
"xaxis2": frame_xaxis2,
"yaxis": frame_yaxis,
"yaxis2": frame_yaxis2,
},
)
)
y_ranges = [
fig_a.layout.yaxis.range,
fig_b.layout.yaxis.range,
]
max_y = max(float(axis_range[1]) for axis_range in y_ranges if axis_range)
first_layer = fig_a.frames[0].name if fig_a.frames else ""
combined.frames = frames
combined.update_layout(
title={
"text": f"Dendrogram comparison - Layer {first_layer}",
"font": {"size": 24},
"y": 0.98,
"yanchor": "top",
},
template="plotly_white",
height=750,
margin=dict(t=140, b=260),
updatemenus=fig_a.layout.updatemenus,
sliders=fig_a.layout.sliders,
)
left_xaxis = fig_a.layout.xaxis.to_plotly_json()
right_xaxis = fig_b.layout.xaxis.to_plotly_json()
right_xaxis["matches"] = None
right_xaxis["anchor"] = "y2"
combined.update_layout(xaxis=left_xaxis, xaxis2=right_xaxis)
combined.update_xaxes(tickangle=-45, automargin=True)
combined.update_yaxes(
title_text=fig_a.layout.yaxis.title.text,
range=[0.0, max_y],
automargin=True,
)
return combined
def _render_dendrogram_analysis(
store: Store,
mask_strategy: MaskStrategy,
) -> None:
variants = available_variants(store, mask_strategy)
if not variants:
st.info("No variants with saved vectors for this model.")
return
with st.expander("Variant selection", expanded=True):
col1, col2 = st.columns(2)
default_a = "biography" if "biography" in variants else variants[0]
default_b_idx = (
variants.index("templated")
if "templated" in variants
else min(1, len(variants) - 1)
)
with col1:
variant_a = st.selectbox(
"Variant A",
options=variants,
index=variants.index(default_a),
format_func=prompt_variant_label,
key=widget_key("load", "dendro_variant_a", store_id(store)),
)
with col2:
variant_b = st.selectbox(
"Variant B",
options=variants,
index=default_b_idx,
format_func=prompt_variant_label,
key=widget_key("load", "dendro_variant_b", store_id(store)),
)
shared_variants = list(dict.fromkeys([variant_a, variant_b]))
select_specific = st.toggle(
"Select specific personas",
value=False,
key=widget_key("load", "dendro_select_mode", store_id(store)),
help="Search and select specific personas instead of using the first N.",
)
if select_specific:
empty_message = _personas_empty_message(shared_variants)
options = _load_persona_options(
store,
shared_variants,
mask_strategy,
empty_message=empty_message,
)
if options is None:
st.session_state.pop(
_persona_names_state_key(f"dendro:{store_id(store)}"), None
)
return
persona_ids = _render_persona_select_controls(
options,
widget_scope=f"dendro:{store_id(store)}",
max_selections=_MAX_PERSONA_COUNTS["dendro"],
)
if not persona_ids:
return
else:
persona_ids = _select_artifact_personas(
store,
shared_variants,
mask_strategy,
widget_scope=f"dendro:{store_id(store)}",
remember_key=_LAST_DENDRO_PERSONAS_KEY,
default_count_limit=_DEFAULT_PERSONA_LIMITS["dendro"],
max_count_limit=_MAX_PERSONA_COUNTS["dendro"],
)
if not persona_ids:
return
col_opts1, col_opts2 = st.columns(2)
with col_opts1:
layered_mode = st.toggle(
"Per-layer animated",
value=False,
key=widget_key("load", "dendro_layered", store_id(store)),
help="Animated dendrogram with one frame per layer instead of averaging all layers.",
)
with col_opts2:
linkage = st.selectbox(
"Linkage",
options=_DENDRO_LINKAGE_OPTIONS,
index=0,
key=widget_key("load", "dendro_linkage", store_id(store)),
)
selected_layers: list[int] | None = None
if layered_mode:
source, location, model_name = store_cache_parts(store)
layer_options = store_layers_cached(
source,
location,
model_name,
mask_strategy.value,
tuple(shared_variants),
tuple(persona_ids),
)
if not layer_options:
st.info("No shared layers are available for the selected personas.")
return
selected_layers = _render_layer_frame_controls(store, "dendro", layer_options)
persona_key = personas_fingerprint(persona_ids)
fig_key = widget_key(
"load",
"dendro_fig_state",
store_id(store),
store.model_name,
mask_strategy.value,
variant_a,
variant_b,
persona_key,
str(layered_mode),
linkage,
"_".join(map(str, selected_layers or [])),
)
_clear_old_figure_states(fig_key)
if st.button(
"Generate dendrograms",
type="primary",
key=widget_key(
"load", "dendro_btn", store_id(store), variant_a, variant_b, persona_key
),
):
progress = st.progress(0, text="Loading first variant vectors…")
try:
progress.progress(15, text="Loading variant vectors…")
by_variant = _load_variant_vectors(
store,
shared_variants,
mask_strategy,
persona_ids,
)
samples_a = by_variant[variant_a]
progress.progress(40, text="Building first dendrogram…")
fig_a = plot_persona_dendrogram(
samples_a,
layered=layered_mode,
layers=selected_layers,
linkage=linkage,
title=f"Dendrogram — {prompt_variant_label(variant_a)}",
)
fig_a.update_layout(height=750)
fig_b = None
if variant_a != variant_b:
progress.progress(60, text="Building second dendrogram…")
samples_b = by_variant[variant_b]
progress.progress(75, text="Building second dendrogram…")
fig_b = plot_persona_dendrogram(
samples_b,
layered=layered_mode,
layers=selected_layers,
linkage=linkage,
title=f"Dendrogram — {prompt_variant_label(variant_b)}",
)
fig_b.update_layout(height=750)
del samples_b
del samples_a
comparison_fig = None
if fig_b is not None and layered_mode:
comparison_fig = _comparison_dendrogram_figure(
fig_a,
fig_b,
title_a=prompt_variant_label(variant_a),
title_b=prompt_variant_label(variant_b),
)
progress.progress(90, text="Storing figure state…")
_store_figure_state(
fig_key,
(
None if comparison_fig is not None else fig_a,
None if comparison_fig is not None else fig_b,
comparison_fig,
len(persona_ids),
variant_a,
variant_b,
),
)
progress.progress(100, text="Done.")
except Exception as exc:
st.error(f"Could not build dendrogram: {exc}")
st.session_state.pop(fig_key, None)
finally:
gc.collect()
progress.empty()
if fig_key in st.session_state:
saved = st.session_state[fig_key]
fig_a, fig_b, comparison_fig, n_personas, va, vb = saved
if comparison_fig is not None:
_plotly_chart(comparison_fig)
elif fig_b is not None:
col_a, col_b = st.columns(2)
with col_a:
st.subheader(prompt_variant_label(va))
_plotly_chart(fig_a)
with col_b:
st.subheader(prompt_variant_label(vb))
_plotly_chart(fig_b)
else:
_plotly_chart(fig_a)
figs = (
[comparison_fig]
if comparison_fig is not None
else [fig_a] + ([fig_b] if fig_b else [])
)
filenames = (
[_filename("dendro_compare", store.model_name, mask_strategy.value, va, vb)]
if comparison_fig is not None
else [
_filename("dendro", store.model_name, mask_strategy.value, va),
*(
[_filename("dendro", store.model_name, mask_strategy.value, vb)]
if fig_b
else []
),
]
)
_render_save_buttons(figs, filenames, "dendro")
st.success(f"Generated dendrogram(s) for {n_personas} persona(s).")