persona-ui / tabs /analysis /layered.py
Jac-Zac
Updated to latest persona-vector versino
7460198
import gc
from collections.abc import Callable
import plotly.graph_objects as go
import streamlit as st
from persona_vectors.attributes import attribute_color_kwargs, attribute_display_label
from persona_vectors.extraction import MaskStrategy
from persona_vectors.plots import (
build_layered_figure,
build_pair_similarity_figure,
build_similarity_figures,
)
from tabs.analysis._shared import (
_gray_out_unselected_personas,
_load_persona_vectors,
_plotly_chart,
_render_save_buttons,
_select_single_variant_samples,
)
from tabs.analysis._state import (
_CLUSTER_MODES,
_DEFAULT_GRAPH_NEIGHBORS,
_LAST_PROJECTION_ATTRIBUTE_KEY,
_LAST_PROJECTION_CLUSTER_K_KEY,
_LAST_PROJECTION_CLUSTER_MODE_KEY,
_LAST_PROJECTION_COLOR_MODE_KEY,
_LAST_PROJECTION_HIGHLIGHTS_KEY,
_LAST_PROJECTION_NORMALIZE_KEY,
_LAST_PROJECTION_PERSONAS_KEY,
_LAST_PROJECTION_VARIANT_KEY,
_LAST_SIMILARITY_VARIANT_KEY,
_MAX_ATTRIBUTE_CATEGORIES,
_MAX_PAIR_TRAJECTORY_TRACES,
_MAX_SIMILARITY_CELLS,
_PROJECTION_COLOR_MODES,
_PROJECTION_KINDS,
LayeredFigureStateKeys,
ProjectionColorConfig,
_clear_old_figure_states,
_clear_old_prepared_states,
_highlight_persona_groups,
_persona_display_label,
_persona_names_state_key,
_remember_multiselect,
_remembered_selectbox,
_store_figure_state,
)
from utils.analysis_metadata import (
synth_persona_attribute_names,
synth_persona_dataset_cached,
)
from utils.analysis_sources import (
Store,
kmeans_groups_cached,
projection_data_cached,
store_cache_parts,
store_id,
)
from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
def _render_pair_trajectory_control(
*,
enabled: bool,
persona_count: int,
scope: str,
store: Store,
) -> bool:
if not enabled:
return False
pair_count = persona_count * (persona_count - 1) // 2
if pair_count > _MAX_PAIR_TRAJECTORY_TRACES:
st.caption(
"Pair trajectories hidden because this selection would create "
f"{pair_count:,} Plotly traces."
)
return False
return st.checkbox(
"Pair trajectories",
value=False,
key=widget_key("load", "pair_trajectories", scope, store_id(store)),
help="Adds one line per persona pair. Keep this off for larger selections.",
)
def _validate_layered_figure_size(
figure_kind: str,
persona_count: int,
selected_layers: list[int],
) -> bool:
if figure_kind != "similarity":
return True
similarity_cells = persona_count * persona_count * len(selected_layers)
if similarity_cells <= _MAX_SIMILARITY_CELLS:
return True
st.error(
"Reduce personas or layer frames before generating the similarity "
f"matrix ({similarity_cells:,} cells selected)."
)
return False
def _render_projection_color_config(
store: Store,
scope: str,
persona_ids: list[str],
) -> ProjectionColorConfig | None:
widget_scope = f"{scope}:{store_id(store)}"
persona_key = personas_fingerprint(persona_ids)
persona_names = st.session_state.get(
_persona_names_state_key(widget_scope),
{},
)
color_mode_key = widget_key("load", "color_mode", scope, store_id(store))
color_mode = _remembered_selectbox(
"Color by",
key=color_mode_key,
remember_key=_LAST_PROJECTION_COLOR_MODE_KEY,
options=_PROJECTION_COLOR_MODES,
default="Persona attribute",
)
if color_mode == "K-means clusters":
max_clusters = min(10, len(persona_ids))
if max_clusters < 2:
st.info("Select at least two personas to use K-means coloring.")
return None
cluster_key = widget_key("load", "cluster_k", scope, store_id(store))
default_clusters = min(3, len(persona_ids))
if cluster_key not in st.session_state:
st.session_state[cluster_key] = min(
max(
int(
st.session_state.get(
_LAST_PROJECTION_CLUSTER_K_KEY,
default_clusters,
)
),
2,
),
max_clusters,
)
n_clusters = st.slider(
"K (clusters)",
min_value=2,
max_value=max_clusters,
key=cluster_key,
)
mode_key = widget_key("load", "cluster_mode", scope, store_id(store))
mode_options = list(_CLUSTER_MODES)
mode_label = _remembered_selectbox(
"Cluster fit",
key=mode_key,
remember_key=_LAST_PROJECTION_CLUSTER_MODE_KEY,
options=mode_options,
default=mode_options[0],
help=(
"Mean across layers is the previous behavior. First selected "
"layer keeps one fixed clustering from the first frame. Per layer "
"recomputes clustering for each animation frame."
),
)
st.session_state[_LAST_PROJECTION_CLUSTER_K_KEY] = n_clusters
return ProjectionColorConfig(
color_mode=color_mode,
n_clusters=n_clusters,
cluster_mode=_CLUSTER_MODES[mode_label],
)
if color_mode == "Persona attribute":
persona_dataset = synth_persona_dataset_cached()
attribute_options = list(synth_persona_attribute_names())
if not attribute_options:
st.info("No persona attributes are available for this dataset.")
return None
default_attribute = (
attribute_options.index("sex") if "sex" in attribute_options else 0
)
attribute_key = widget_key("load", "attribute", scope, store_id(store))
attribute_name = _remembered_selectbox(
"Attribute",
key=attribute_key,
remember_key=_LAST_PROJECTION_ATTRIBUTE_KEY,
options=attribute_options,
default=attribute_options[default_attribute],
format_func=lambda name: attribute_display_label(persona_dataset, name),
)
info = persona_dataset.attribute_info(attribute_name)
if info.get("high_cardinality"):
st.caption(
"High-cardinality categorical attributes are grouped to the "
f"top {_MAX_ATTRIBUTE_CATEGORIES} values plus Other."
)
return ProjectionColorConfig(
color_mode=color_mode,
attribute_name=attribute_name,
)
highlight_persona_ids: tuple[str, ...] = ()
if persona_ids:
highlight_key = widget_key(
"load", "persona_highlight", scope, store_id(store), persona_key
)
highlighted = st.multiselect(
"Highlight personas",
options=persona_ids,
default=_remember_multiselect(
key=highlight_key,
remember_key=_LAST_PROJECTION_HIGHLIGHTS_KEY,
options=persona_ids,
),
format_func=lambda persona_id: _persona_display_label(
persona_names, persona_id
),
key=highlight_key,
help=(
"Select a few personas to keep their default colors while the rest "
"are grayed out."
),
)
highlight_persona_ids = tuple(highlighted)
st.session_state[_LAST_PROJECTION_HIGHLIGHTS_KEY] = list(highlighted)
highlight_persona_key = (
personas_fingerprint(highlight_persona_ids) if highlight_persona_ids else ""
)
return ProjectionColorConfig(
color_mode=color_mode,
highlight_persona_ids=highlight_persona_ids,
highlight_persona_key=highlight_persona_key,
)
def _render_projection_normalize_control(scope: str, store: Store) -> bool:
key = widget_key("load", "projection_normalize", scope, store_id(store))
if key not in st.session_state:
st.session_state[key] = bool(
st.session_state.get(_LAST_PROJECTION_NORMALIZE_KEY, True)
)
normalize = st.checkbox(
"Normalize vectors",
key=key,
help=("Center and L2-normalize persona vectors before PCA/UMAP projection."),
)
st.session_state[_LAST_PROJECTION_NORMALIZE_KEY] = normalize
return normalize
def _layered_figure_state_keys(
store: Store,
mask_strategy: MaskStrategy,
*,
scope: str,
figure_kind: str,
n_components: int,
projection_normalize: bool,
color_config: ProjectionColorConfig,
variant: str,
persona_key: str,
selected_layers: list[int],
pair_trajectories: bool,
) -> LayeredFigureStateKeys:
layer_key = "_".join(map(str, selected_layers))
figure_key = widget_key(
"load",
f"{scope}_fig_state",
store_id(store),
store.model_name,
mask_strategy.value,
figure_kind,
str(n_components),
str(projection_normalize),
color_config.color_mode,
str(color_config.attribute_name),
str(color_config.n_clusters),
str(color_config.cluster_mode),
str(color_config.highlight_persona_key),
variant,
"persona_vector",
persona_key,
layer_key,
str(pair_trajectories),
)
if figure_kind not in _PROJECTION_KINDS:
return LayeredFigureStateKeys(figure=figure_key)
prepared_key = widget_key(
"load",
f"{scope}_projection_ready",
store_id(store),
store.model_name,
mask_strategy.value,
figure_kind,
str(n_components),
str(projection_normalize),
str(figure_kind == "isomap"),
str(_DEFAULT_GRAPH_NEIGHBORS),
variant,
persona_key,
layer_key,
)
return LayeredFigureStateKeys(figure=figure_key, prepared=prepared_key)
def _projection_build_kwargs(
*,
store: Store,
mask_strategy: MaskStrategy,
variant: str,
figure_kind: str,
selected_layers: list[int],
n_components: int,
projection_normalize: bool,
color_config: ProjectionColorConfig,
persona_ids: list[str],
persona_names: dict[str, str],
) -> dict:
if figure_kind not in _PROJECTION_KINDS:
return {}
graph_overlay = figure_kind == "isomap"
build_kwargs = {
"n_components": n_components,
"projection_normalize": projection_normalize,
"graph_overlay": graph_overlay,
"graph_n_neighbors": _DEFAULT_GRAPH_NEIGHBORS,
}
source, location, model_name = store_cache_parts(store)
cache_args = (
source,
location,
model_name,
mask_strategy.value,
variant,
tuple(persona_ids),
tuple(selected_layers),
)
build_kwargs["projection_data"] = projection_data_cached(
*cache_args,
figure_kind,
n_components,
projection_normalize,
graph_overlay,
_DEFAULT_GRAPH_NEIGHBORS,
)
if color_config.n_clusters is not None:
build_kwargs["groups"] = kmeans_groups_cached(
*cache_args,
color_config.n_clusters,
color_config.cluster_mode or "mean_across_layers",
)
if color_config.attribute_name is not None:
build_kwargs.update(
attribute_color_kwargs(
synth_persona_dataset_cached(),
color_config.attribute_name,
persona_ids,
max_categories=_MAX_ATTRIBUTE_CATEGORIES,
)
)
if color_config.color_mode == "Persona" and color_config.highlight_persona_ids:
groups = _highlight_persona_groups(
persona_ids,
persona_names,
color_config.highlight_persona_ids,
)
if groups is not None:
build_kwargs["groups"] = groups
return build_kwargs
def _build_layered_analysis_figures(
samples,
*,
figure_kind: str,
selected_layers: list[int],
variant: str,
title_fn: Callable[[str], str],
pair_trajectories: bool,
build_kwargs: dict,
) -> tuple[go.Figure, go.Figure | None]:
if figure_kind == "similarity" and pair_trajectories:
return build_similarity_figures(
samples,
layers=selected_layers,
title=title_fn(variant),
pair_title=(
"Pair similarity trajectories - "
f"{prompt_variant_label(variant)} - persona vectors"
),
)
main_fig = build_layered_figure(
samples,
figure_kind,
layers=selected_layers,
title=title_fn(variant),
**build_kwargs,
)
if figure_kind == "isomap":
_add_isomap_connection_toggle(main_fig)
if figure_kind in _PROJECTION_KINDS:
main_fig.update_layout(height=700)
extra_fig = (
build_pair_similarity_figure(
samples,
layers=selected_layers,
title=(
"Pair similarity trajectories - "
f"{prompt_variant_label(variant)} - persona vectors"
),
)
if pair_trajectories
else None
)
return main_fig, extra_fig
def _add_isomap_connection_toggle(fig: go.Figure) -> None:
"""Add an in-plot control for the Isomap kNN graph trace."""
if not fig.data or fig.data[0].name != "kNN graph":
return
existing_menus = tuple(fig.layout.updatemenus or ())
fig.update_layout(
updatemenus=existing_menus
+ (
dict(
type="buttons",
direction="left",
active=0,
showactive=False,
x=0,
xanchor="left",
y=1.16,
yanchor="top",
pad=dict(t=0, r=10),
buttons=[
dict(
label="Show connections",
method="restyle",
args=[{"visible": True}, [0]],
),
dict(
label="Hide connections",
method="restyle",
args=[{"visible": False}, [0]],
),
],
),
),
)
def _render_layered_figure_analysis(
store: Store,
mask_strategy: MaskStrategy,
*,
scope: str,
figure_kind: str,
button_label: str,
title_fn: Callable[[str], str],
include_pair_trajectories: bool = False,
n_components: int = 2,
remember_key: str = _LAST_PROJECTION_PERSONAS_KEY,
default_count_limit: int = 500,
max_count_limit: int | None = None,
allow_specific_personas: bool = False,
) -> None:
"""Render a single-variant layered analysis: select → button → figure(s).
Used for similarity matrix, PCA, and UMAP. Set ``include_pair_trajectories``
to add the pair-similarity-trajectory figure (similarity matrix only).
"""
selected = _select_single_variant_samples(
store,
mask_strategy,
scope,
remember_key=remember_key,
variant_remember_key=(
_LAST_PROJECTION_VARIANT_KEY
if figure_kind in _PROJECTION_KINDS
else _LAST_SIMILARITY_VARIANT_KEY
),
default_count_limit=default_count_limit,
max_count_limit=max_count_limit,
allow_specific_personas=allow_specific_personas,
)
if selected is None:
return
variant, persona_ids, persona_key, selected_layers = selected
pair_trajectories = _render_pair_trajectory_control(
enabled=include_pair_trajectories,
persona_count=len(persona_ids),
scope=scope,
store=store,
)
if not _validate_layered_figure_size(
figure_kind, len(persona_ids), selected_layers
):
return
color_config = ProjectionColorConfig()
if figure_kind in _PROJECTION_KINDS:
color_config = _render_projection_color_config(store, scope, persona_ids)
if color_config is None:
return
if figure_kind in {"pca", "umap"}:
projection_normalize = _render_projection_normalize_control(scope, store)
elif figure_kind in _PROJECTION_KINDS:
projection_normalize = True
else:
projection_normalize = False
state_keys = _layered_figure_state_keys(
store,
mask_strategy,
scope=scope,
figure_kind=figure_kind,
n_components=n_components,
projection_normalize=projection_normalize,
color_config=color_config,
variant=variant,
persona_key=persona_key,
selected_layers=selected_layers,
pair_trajectories=pair_trajectories,
)
filename = scope
_clear_old_figure_states(state_keys.figure)
persona_names = st.session_state.get(
_persona_names_state_key(f"{scope}:{store_id(store)}"),
{},
)
build_clicked = st.button(button_label, type="primary")
recolor_from_warm_projection = (
state_keys.prepared is not None
and bool(st.session_state.get(state_keys.prepared))
and state_keys.figure not in st.session_state
)
if build_clicked or recolor_from_warm_projection:
build_label = {
"umap": "Computing UMAP projections…",
"pca": "Computing PCA projections…",
"isomap": "Computing Isomap projections…",
"similarity": "Computing similarity matrices…",
}.get(figure_kind, "Building figure…")
progress = st.progress(0, text="Loading activation vectors…")
try:
progress.progress(15, text="Loading activation vectors…")
samples = _load_persona_vectors(
store,
variant,
mask_strategy,
persona_ids,
)
progress.progress(55, text=build_label)
build_kwargs = _projection_build_kwargs(
store=store,
mask_strategy=mask_strategy,
variant=variant,
figure_kind=figure_kind,
selected_layers=selected_layers,
n_components=n_components,
projection_normalize=projection_normalize,
color_config=color_config,
persona_ids=persona_ids,
persona_names=persona_names,
)
main_fig, extra_fig = _build_layered_analysis_figures(
samples,
figure_kind=figure_kind,
selected_layers=selected_layers,
variant=variant,
title_fn=title_fn,
pair_trajectories=pair_trajectories,
build_kwargs=build_kwargs,
)
if (
color_config.color_mode == "Persona"
and color_config.highlight_persona_ids
):
_gray_out_unselected_personas(main_fig)
progress.progress(90, text="Storing figure state…")
n_samples = samples.vectors.shape[0]
del samples
_store_figure_state(state_keys.figure, (main_fig, extra_fig, n_samples))
if state_keys.prepared is not None:
_clear_old_prepared_states(state_keys.prepared)
st.session_state[state_keys.prepared] = True
progress.progress(100, text="Done.")
except Exception as exc:
st.error(f"Could not build figure: {exc}")
st.session_state.pop(state_keys.figure, None)
finally:
gc.collect()
progress.empty()
if state_keys.figure in st.session_state:
main_fig, extra_fig, n_samples = st.session_state[state_keys.figure]
_plotly_chart(main_fig)
figs = [main_fig]
filenames = [filename]
if extra_fig is not None:
st.subheader("Pair trajectories")
_plotly_chart(extra_fig)
figs.append(extra_fig)
filenames.append(f"{filename}__pair_trajectories")
_render_save_buttons(figs, filenames, scope)
st.success(f"Loaded {n_samples} samples.")