import gc from copy import deepcopy import plotly.graph_objects as go import streamlit as st from persona_vectors.extraction import MaskStrategy from persona_vectors.plots import plot_persona_dendrogram from plotly.subplots import make_subplots from tabs.analysis._shared import ( _load_persona_options, _load_variant_vectors, _plotly_chart, _render_layer_frame_controls, _render_persona_select_controls, _render_save_buttons, _select_artifact_personas, ) from tabs.analysis._state import ( _DEFAULT_PERSONA_LIMITS, _MAX_PERSONA_COUNTS, _clear_old_figure_states, _filename, _persona_names_state_key, _personas_empty_message, _store_figure_state, ) from utils.analysis_sources import ( Store, available_variants, store_cache_parts, store_id, store_layers_cached, ) from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key _LAST_DENDRO_PERSONAS_KEY = "analysis:last_personas:dendro" _DENDRO_LINKAGE_OPTIONS = ["ward", "complete", "average", "single"] def _comparison_dendrogram_figure( fig_a: go.Figure, fig_b: go.Figure, *, title_a: str, title_b: str, ) -> go.Figure: """Merge two layered dendrograms so one slider drives both panels.""" combined = make_subplots( rows=1, cols=2, subplot_titles=(title_a, title_b), shared_yaxes=True, horizontal_spacing=0.05, ) for trace in fig_a.data: combined.add_trace(deepcopy(trace), row=1, col=1) for trace in fig_b.data: combined.add_trace(deepcopy(trace), row=1, col=2) frames: list[go.Frame] = [] for frame_a, frame_b in zip(fig_a.frames, fig_b.frames, strict=True): right_data = [] for trace in frame_b.data: copied = deepcopy(trace) copied.update(xaxis="x2", yaxis="y2") right_data.append(copied) frame_xaxis = frame_a.layout.xaxis.to_plotly_json() frame_xaxis2 = frame_b.layout.xaxis.to_plotly_json() frame_xaxis2["matches"] = None frame_xaxis2["anchor"] = "y2" frame_yaxis = frame_a.layout.yaxis.to_plotly_json() frame_yaxis2 = frame_b.layout.yaxis.to_plotly_json() frame_yaxis2["matches"] = "y" frame_yaxis2["anchor"] = "x2" frames.append( go.Frame( name=frame_a.name, data=[*deepcopy(frame_a.data), *right_data], layout={ "title": {"text": f"Dendrogram comparison - Layer {frame_a.name}"}, "xaxis": frame_xaxis, "xaxis2": frame_xaxis2, "yaxis": frame_yaxis, "yaxis2": frame_yaxis2, }, ) ) y_ranges = [ fig_a.layout.yaxis.range, fig_b.layout.yaxis.range, ] max_y = max(float(axis_range[1]) for axis_range in y_ranges if axis_range) first_layer = fig_a.frames[0].name if fig_a.frames else "" combined.frames = frames combined.update_layout( title={ "text": f"Dendrogram comparison - Layer {first_layer}", "font": {"size": 24}, "y": 0.98, "yanchor": "top", }, template="plotly_white", height=750, margin=dict(t=140, b=260), updatemenus=fig_a.layout.updatemenus, sliders=fig_a.layout.sliders, ) left_xaxis = fig_a.layout.xaxis.to_plotly_json() right_xaxis = fig_b.layout.xaxis.to_plotly_json() right_xaxis["matches"] = None right_xaxis["anchor"] = "y2" combined.update_layout(xaxis=left_xaxis, xaxis2=right_xaxis) combined.update_xaxes(tickangle=-45, automargin=True) combined.update_yaxes( title_text=fig_a.layout.yaxis.title.text, range=[0.0, max_y], automargin=True, ) return combined def _render_dendrogram_analysis( store: Store, mask_strategy: MaskStrategy, ) -> None: variants = available_variants(store, mask_strategy) if not variants: st.info("No variants with saved vectors for this model.") return with st.expander("Variant selection", expanded=True): col1, col2 = st.columns(2) default_a = "biography" if "biography" in variants else variants[0] default_b_idx = ( variants.index("templated") if "templated" in variants else min(1, len(variants) - 1) ) with col1: variant_a = st.selectbox( "Variant A", options=variants, index=variants.index(default_a), format_func=prompt_variant_label, key=widget_key("load", "dendro_variant_a", store_id(store)), ) with col2: variant_b = st.selectbox( "Variant B", options=variants, index=default_b_idx, format_func=prompt_variant_label, key=widget_key("load", "dendro_variant_b", store_id(store)), ) shared_variants = list(dict.fromkeys([variant_a, variant_b])) select_specific = st.toggle( "Select specific personas", value=False, key=widget_key("load", "dendro_select_mode", store_id(store)), help="Search and select specific personas instead of using the first N.", ) if select_specific: empty_message = _personas_empty_message(shared_variants) options = _load_persona_options( store, shared_variants, mask_strategy, empty_message=empty_message, ) if options is None: st.session_state.pop( _persona_names_state_key(f"dendro:{store_id(store)}"), None ) return persona_ids = _render_persona_select_controls( options, widget_scope=f"dendro:{store_id(store)}", max_selections=_MAX_PERSONA_COUNTS["dendro"], ) if not persona_ids: return else: persona_ids = _select_artifact_personas( store, shared_variants, mask_strategy, widget_scope=f"dendro:{store_id(store)}", remember_key=_LAST_DENDRO_PERSONAS_KEY, default_count_limit=_DEFAULT_PERSONA_LIMITS["dendro"], max_count_limit=_MAX_PERSONA_COUNTS["dendro"], ) if not persona_ids: return col_opts1, col_opts2 = st.columns(2) with col_opts1: layered_mode = st.toggle( "Per-layer animated", value=False, key=widget_key("load", "dendro_layered", store_id(store)), help="Animated dendrogram with one frame per layer instead of averaging all layers.", ) with col_opts2: linkage = st.selectbox( "Linkage", options=_DENDRO_LINKAGE_OPTIONS, index=0, key=widget_key("load", "dendro_linkage", store_id(store)), ) selected_layers: list[int] | None = None if layered_mode: source, location, model_name = store_cache_parts(store) layer_options = store_layers_cached( source, location, model_name, mask_strategy.value, tuple(shared_variants), tuple(persona_ids), ) if not layer_options: st.info("No shared layers are available for the selected personas.") return selected_layers = _render_layer_frame_controls(store, "dendro", layer_options) persona_key = personas_fingerprint(persona_ids) fig_key = widget_key( "load", "dendro_fig_state", store_id(store), store.model_name, mask_strategy.value, variant_a, variant_b, persona_key, str(layered_mode), linkage, "_".join(map(str, selected_layers or [])), ) _clear_old_figure_states(fig_key) if st.button( "Generate dendrograms", type="primary", key=widget_key( "load", "dendro_btn", store_id(store), variant_a, variant_b, persona_key ), ): progress = st.progress(0, text="Loading first variant vectors…") try: progress.progress(15, text="Loading variant vectors…") by_variant = _load_variant_vectors( store, shared_variants, mask_strategy, persona_ids, ) samples_a = by_variant[variant_a] progress.progress(40, text="Building first dendrogram…") fig_a = plot_persona_dendrogram( samples_a, layered=layered_mode, layers=selected_layers, linkage=linkage, title=f"Dendrogram — {prompt_variant_label(variant_a)}", ) fig_a.update_layout(height=750) fig_b = None if variant_a != variant_b: progress.progress(60, text="Building second dendrogram…") samples_b = by_variant[variant_b] progress.progress(75, text="Building second dendrogram…") fig_b = plot_persona_dendrogram( samples_b, layered=layered_mode, layers=selected_layers, linkage=linkage, title=f"Dendrogram — {prompt_variant_label(variant_b)}", ) fig_b.update_layout(height=750) del samples_b del samples_a comparison_fig = None if fig_b is not None and layered_mode: comparison_fig = _comparison_dendrogram_figure( fig_a, fig_b, title_a=prompt_variant_label(variant_a), title_b=prompt_variant_label(variant_b), ) progress.progress(90, text="Storing figure state…") _store_figure_state( fig_key, ( None if comparison_fig is not None else fig_a, None if comparison_fig is not None else fig_b, comparison_fig, len(persona_ids), variant_a, variant_b, ), ) progress.progress(100, text="Done.") except Exception as exc: st.error(f"Could not build dendrogram: {exc}") st.session_state.pop(fig_key, None) finally: gc.collect() progress.empty() if fig_key in st.session_state: saved = st.session_state[fig_key] fig_a, fig_b, comparison_fig, n_personas, va, vb = saved if comparison_fig is not None: _plotly_chart(comparison_fig) elif fig_b is not None: col_a, col_b = st.columns(2) with col_a: st.subheader(prompt_variant_label(va)) _plotly_chart(fig_a) with col_b: st.subheader(prompt_variant_label(vb)) _plotly_chart(fig_b) else: _plotly_chart(fig_a) figs = ( [comparison_fig] if comparison_fig is not None else [fig_a] + ([fig_b] if fig_b else []) ) filenames = ( [_filename("dendro_compare", store.model_name, mask_strategy.value, va, vb)] if comparison_fig is not None else [ _filename("dendro", store.model_name, mask_strategy.value, va), *( [_filename("dendro", store.model_name, mask_strategy.value, vb)] if fig_b else [] ), ] ) _render_save_buttons(figs, filenames, "dendro") st.success(f"Generated dendrogram(s) for {n_personas} persona(s).")