import gc from collections.abc import Callable import plotly.graph_objects as go import streamlit as st from persona_vectors.attributes import attribute_color_kwargs, attribute_display_label from persona_vectors.extraction import MaskStrategy from persona_vectors.plots import ( build_layered_figure, build_pair_similarity_figure, build_similarity_figures, ) from tabs.analysis._shared import ( _gray_out_unselected_personas, _load_persona_vectors, _plotly_chart, _render_save_buttons, _select_single_variant_samples, ) from tabs.analysis._state import ( _CLUSTER_MODES, _DEFAULT_GRAPH_NEIGHBORS, _LAST_PROJECTION_ATTRIBUTE_KEY, _LAST_PROJECTION_CLUSTER_K_KEY, _LAST_PROJECTION_CLUSTER_MODE_KEY, _LAST_PROJECTION_COLOR_MODE_KEY, _LAST_PROJECTION_HIGHLIGHTS_KEY, _LAST_PROJECTION_NORMALIZE_KEY, _LAST_PROJECTION_PERSONAS_KEY, _LAST_PROJECTION_VARIANT_KEY, _LAST_SIMILARITY_VARIANT_KEY, _MAX_ATTRIBUTE_CATEGORIES, _MAX_PAIR_TRAJECTORY_TRACES, _MAX_SIMILARITY_CELLS, _PROJECTION_COLOR_MODES, _PROJECTION_KINDS, LayeredFigureStateKeys, ProjectionColorConfig, _clear_old_figure_states, _clear_old_prepared_states, _highlight_persona_groups, _persona_display_label, _persona_names_state_key, _remember_multiselect, _remembered_selectbox, _store_figure_state, ) from utils.analysis_metadata import ( synth_persona_attribute_names, synth_persona_dataset_cached, ) from utils.analysis_sources import ( Store, kmeans_groups_cached, projection_data_cached, store_cache_parts, store_id, ) from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key def _render_pair_trajectory_control( *, enabled: bool, persona_count: int, scope: str, store: Store, ) -> bool: if not enabled: return False pair_count = persona_count * (persona_count - 1) // 2 if pair_count > _MAX_PAIR_TRAJECTORY_TRACES: st.caption( "Pair trajectories hidden because this selection would create " f"{pair_count:,} Plotly traces." ) return False return st.checkbox( "Pair trajectories", value=False, key=widget_key("load", "pair_trajectories", scope, store_id(store)), help="Adds one line per persona pair. Keep this off for larger selections.", ) def _validate_layered_figure_size( figure_kind: str, persona_count: int, selected_layers: list[int], ) -> bool: if figure_kind != "similarity": return True similarity_cells = persona_count * persona_count * len(selected_layers) if similarity_cells <= _MAX_SIMILARITY_CELLS: return True st.error( "Reduce personas or layer frames before generating the similarity " f"matrix ({similarity_cells:,} cells selected)." ) return False def _render_projection_color_config( store: Store, scope: str, persona_ids: list[str], ) -> ProjectionColorConfig | None: widget_scope = f"{scope}:{store_id(store)}" persona_key = personas_fingerprint(persona_ids) persona_names = st.session_state.get( _persona_names_state_key(widget_scope), {}, ) color_mode_key = widget_key("load", "color_mode", scope, store_id(store)) color_mode = _remembered_selectbox( "Color by", key=color_mode_key, remember_key=_LAST_PROJECTION_COLOR_MODE_KEY, options=_PROJECTION_COLOR_MODES, default="Persona attribute", ) if color_mode == "K-means clusters": max_clusters = min(10, len(persona_ids)) if max_clusters < 2: st.info("Select at least two personas to use K-means coloring.") return None cluster_key = widget_key("load", "cluster_k", scope, store_id(store)) default_clusters = min(3, len(persona_ids)) if cluster_key not in st.session_state: st.session_state[cluster_key] = min( max( int( st.session_state.get( _LAST_PROJECTION_CLUSTER_K_KEY, default_clusters, ) ), 2, ), max_clusters, ) n_clusters = st.slider( "K (clusters)", min_value=2, max_value=max_clusters, key=cluster_key, ) mode_key = widget_key("load", "cluster_mode", scope, store_id(store)) mode_options = list(_CLUSTER_MODES) mode_label = _remembered_selectbox( "Cluster fit", key=mode_key, remember_key=_LAST_PROJECTION_CLUSTER_MODE_KEY, options=mode_options, default=mode_options[0], help=( "Mean across layers is the previous behavior. First selected " "layer keeps one fixed clustering from the first frame. Per layer " "recomputes clustering for each animation frame." ), ) st.session_state[_LAST_PROJECTION_CLUSTER_K_KEY] = n_clusters return ProjectionColorConfig( color_mode=color_mode, n_clusters=n_clusters, cluster_mode=_CLUSTER_MODES[mode_label], ) if color_mode == "Persona attribute": persona_dataset = synth_persona_dataset_cached() attribute_options = list(synth_persona_attribute_names()) if not attribute_options: st.info("No persona attributes are available for this dataset.") return None default_attribute = ( attribute_options.index("sex") if "sex" in attribute_options else 0 ) attribute_key = widget_key("load", "attribute", scope, store_id(store)) attribute_name = _remembered_selectbox( "Attribute", key=attribute_key, remember_key=_LAST_PROJECTION_ATTRIBUTE_KEY, options=attribute_options, default=attribute_options[default_attribute], format_func=lambda name: attribute_display_label(persona_dataset, name), ) info = persona_dataset.attribute_info(attribute_name) if info.get("high_cardinality"): st.caption( "High-cardinality categorical attributes are grouped to the " f"top {_MAX_ATTRIBUTE_CATEGORIES} values plus Other." ) return ProjectionColorConfig( color_mode=color_mode, attribute_name=attribute_name, ) highlight_persona_ids: tuple[str, ...] = () if persona_ids: highlight_key = widget_key( "load", "persona_highlight", scope, store_id(store), persona_key ) highlighted = st.multiselect( "Highlight personas", options=persona_ids, default=_remember_multiselect( key=highlight_key, remember_key=_LAST_PROJECTION_HIGHLIGHTS_KEY, options=persona_ids, ), format_func=lambda persona_id: _persona_display_label( persona_names, persona_id ), key=highlight_key, help=( "Select a few personas to keep their default colors while the rest " "are grayed out." ), ) highlight_persona_ids = tuple(highlighted) st.session_state[_LAST_PROJECTION_HIGHLIGHTS_KEY] = list(highlighted) highlight_persona_key = ( personas_fingerprint(highlight_persona_ids) if highlight_persona_ids else "" ) return ProjectionColorConfig( color_mode=color_mode, highlight_persona_ids=highlight_persona_ids, highlight_persona_key=highlight_persona_key, ) def _render_projection_normalize_control(scope: str, store: Store) -> bool: key = widget_key("load", "projection_normalize", scope, store_id(store)) if key not in st.session_state: st.session_state[key] = bool( st.session_state.get(_LAST_PROJECTION_NORMALIZE_KEY, True) ) normalize = st.checkbox( "Normalize vectors", key=key, help=("Center and L2-normalize persona vectors before PCA/UMAP projection."), ) st.session_state[_LAST_PROJECTION_NORMALIZE_KEY] = normalize return normalize def _layered_figure_state_keys( store: Store, mask_strategy: MaskStrategy, *, scope: str, figure_kind: str, n_components: int, projection_normalize: bool, color_config: ProjectionColorConfig, variant: str, persona_key: str, selected_layers: list[int], pair_trajectories: bool, ) -> LayeredFigureStateKeys: layer_key = "_".join(map(str, selected_layers)) figure_key = widget_key( "load", f"{scope}_fig_state", store_id(store), store.model_name, mask_strategy.value, figure_kind, str(n_components), str(projection_normalize), color_config.color_mode, str(color_config.attribute_name), str(color_config.n_clusters), str(color_config.cluster_mode), str(color_config.highlight_persona_key), variant, "persona_vector", persona_key, layer_key, str(pair_trajectories), ) if figure_kind not in _PROJECTION_KINDS: return LayeredFigureStateKeys(figure=figure_key) prepared_key = widget_key( "load", f"{scope}_projection_ready", store_id(store), store.model_name, mask_strategy.value, figure_kind, str(n_components), str(projection_normalize), str(figure_kind == "isomap"), str(_DEFAULT_GRAPH_NEIGHBORS), variant, persona_key, layer_key, ) return LayeredFigureStateKeys(figure=figure_key, prepared=prepared_key) def _projection_build_kwargs( *, store: Store, mask_strategy: MaskStrategy, variant: str, figure_kind: str, selected_layers: list[int], n_components: int, projection_normalize: bool, color_config: ProjectionColorConfig, persona_ids: list[str], persona_names: dict[str, str], ) -> dict: if figure_kind not in _PROJECTION_KINDS: return {} graph_overlay = figure_kind == "isomap" build_kwargs = { "n_components": n_components, "projection_normalize": projection_normalize, "graph_overlay": graph_overlay, "graph_n_neighbors": _DEFAULT_GRAPH_NEIGHBORS, } source, location, model_name = store_cache_parts(store) cache_args = ( source, location, model_name, mask_strategy.value, variant, tuple(persona_ids), tuple(selected_layers), ) build_kwargs["projection_data"] = projection_data_cached( *cache_args, figure_kind, n_components, projection_normalize, graph_overlay, _DEFAULT_GRAPH_NEIGHBORS, ) if color_config.n_clusters is not None: build_kwargs["groups"] = kmeans_groups_cached( *cache_args, color_config.n_clusters, color_config.cluster_mode or "mean_across_layers", ) if color_config.attribute_name is not None: build_kwargs.update( attribute_color_kwargs( synth_persona_dataset_cached(), color_config.attribute_name, persona_ids, max_categories=_MAX_ATTRIBUTE_CATEGORIES, ) ) if color_config.color_mode == "Persona" and color_config.highlight_persona_ids: groups = _highlight_persona_groups( persona_ids, persona_names, color_config.highlight_persona_ids, ) if groups is not None: build_kwargs["groups"] = groups return build_kwargs def _build_layered_analysis_figures( samples, *, figure_kind: str, selected_layers: list[int], variant: str, title_fn: Callable[[str], str], pair_trajectories: bool, build_kwargs: dict, ) -> tuple[go.Figure, go.Figure | None]: if figure_kind == "similarity" and pair_trajectories: return build_similarity_figures( samples, layers=selected_layers, title=title_fn(variant), pair_title=( "Pair similarity trajectories - " f"{prompt_variant_label(variant)} - persona vectors" ), ) main_fig = build_layered_figure( samples, figure_kind, layers=selected_layers, title=title_fn(variant), **build_kwargs, ) if figure_kind == "isomap": _add_isomap_connection_toggle(main_fig) if figure_kind in _PROJECTION_KINDS: main_fig.update_layout(height=700) extra_fig = ( build_pair_similarity_figure( samples, layers=selected_layers, title=( "Pair similarity trajectories - " f"{prompt_variant_label(variant)} - persona vectors" ), ) if pair_trajectories else None ) return main_fig, extra_fig def _add_isomap_connection_toggle(fig: go.Figure) -> None: """Add an in-plot control for the Isomap kNN graph trace.""" if not fig.data or fig.data[0].name != "kNN graph": return existing_menus = tuple(fig.layout.updatemenus or ()) fig.update_layout( updatemenus=existing_menus + ( dict( type="buttons", direction="left", active=0, showactive=False, x=0, xanchor="left", y=1.16, yanchor="top", pad=dict(t=0, r=10), buttons=[ dict( label="Show connections", method="restyle", args=[{"visible": True}, [0]], ), dict( label="Hide connections", method="restyle", args=[{"visible": False}, [0]], ), ], ), ), ) def _render_layered_figure_analysis( store: Store, mask_strategy: MaskStrategy, *, scope: str, figure_kind: str, button_label: str, title_fn: Callable[[str], str], include_pair_trajectories: bool = False, n_components: int = 2, remember_key: str = _LAST_PROJECTION_PERSONAS_KEY, default_count_limit: int = 500, max_count_limit: int | None = None, allow_specific_personas: bool = False, ) -> None: """Render a single-variant layered analysis: select → button → figure(s). Used for similarity matrix, PCA, and UMAP. Set ``include_pair_trajectories`` to add the pair-similarity-trajectory figure (similarity matrix only). """ selected = _select_single_variant_samples( store, mask_strategy, scope, remember_key=remember_key, variant_remember_key=( _LAST_PROJECTION_VARIANT_KEY if figure_kind in _PROJECTION_KINDS else _LAST_SIMILARITY_VARIANT_KEY ), default_count_limit=default_count_limit, max_count_limit=max_count_limit, allow_specific_personas=allow_specific_personas, ) if selected is None: return variant, persona_ids, persona_key, selected_layers = selected pair_trajectories = _render_pair_trajectory_control( enabled=include_pair_trajectories, persona_count=len(persona_ids), scope=scope, store=store, ) if not _validate_layered_figure_size( figure_kind, len(persona_ids), selected_layers ): return color_config = ProjectionColorConfig() if figure_kind in _PROJECTION_KINDS: color_config = _render_projection_color_config(store, scope, persona_ids) if color_config is None: return if figure_kind in {"pca", "umap"}: projection_normalize = _render_projection_normalize_control(scope, store) elif figure_kind in _PROJECTION_KINDS: projection_normalize = True else: projection_normalize = False state_keys = _layered_figure_state_keys( store, mask_strategy, scope=scope, figure_kind=figure_kind, n_components=n_components, projection_normalize=projection_normalize, color_config=color_config, variant=variant, persona_key=persona_key, selected_layers=selected_layers, pair_trajectories=pair_trajectories, ) filename = scope _clear_old_figure_states(state_keys.figure) persona_names = st.session_state.get( _persona_names_state_key(f"{scope}:{store_id(store)}"), {}, ) build_clicked = st.button(button_label, type="primary") recolor_from_warm_projection = ( state_keys.prepared is not None and bool(st.session_state.get(state_keys.prepared)) and state_keys.figure not in st.session_state ) if build_clicked or recolor_from_warm_projection: build_label = { "umap": "Computing UMAP projections…", "pca": "Computing PCA projections…", "isomap": "Computing Isomap projections…", "similarity": "Computing similarity matrices…", }.get(figure_kind, "Building figure…") progress = st.progress(0, text="Loading activation vectors…") try: progress.progress(15, text="Loading activation vectors…") samples = _load_persona_vectors( store, variant, mask_strategy, persona_ids, ) progress.progress(55, text=build_label) build_kwargs = _projection_build_kwargs( store=store, mask_strategy=mask_strategy, variant=variant, figure_kind=figure_kind, selected_layers=selected_layers, n_components=n_components, projection_normalize=projection_normalize, color_config=color_config, persona_ids=persona_ids, persona_names=persona_names, ) main_fig, extra_fig = _build_layered_analysis_figures( samples, figure_kind=figure_kind, selected_layers=selected_layers, variant=variant, title_fn=title_fn, pair_trajectories=pair_trajectories, build_kwargs=build_kwargs, ) if ( color_config.color_mode == "Persona" and color_config.highlight_persona_ids ): _gray_out_unselected_personas(main_fig) progress.progress(90, text="Storing figure state…") n_samples = samples.vectors.shape[0] del samples _store_figure_state(state_keys.figure, (main_fig, extra_fig, n_samples)) if state_keys.prepared is not None: _clear_old_prepared_states(state_keys.prepared) st.session_state[state_keys.prepared] = True progress.progress(100, text="Done.") except Exception as exc: st.error(f"Could not build figure: {exc}") st.session_state.pop(state_keys.figure, None) finally: gc.collect() progress.empty() if state_keys.figure in st.session_state: main_fig, extra_fig, n_samples = st.session_state[state_keys.figure] _plotly_chart(main_fig) figs = [main_fig] filenames = [filename] if extra_fig is not None: st.subheader("Pair trajectories") _plotly_chart(extra_fig) figs.append(extra_fig) filenames.append(f"{filename}__pair_trajectories") _render_save_buttons(figs, filenames, scope) st.success(f"Loaded {n_samples} samples.")