persona-ui / tabs /analysis /_state.py
Jac-Zac
Big refactoring
b279884
from dataclasses import dataclass
import streamlit as st
from persona_data.synth_persona import BASELINE_PERSONA_ID
from persona_vectors.attributes import DEFAULT_MAX_ATTRIBUTE_CATEGORIES
from utils.helpers import env_int, slugify, widget_key
def _filename(*parts: str) -> str:
return "__".join(slugify(part) for part in parts if part)
# Keep analysis-tab selection state separate so projection defaults do not
# overwrite cosine similarity defaults.
_LAST_COSINE_PERSONAS_KEY = "analysis:last_personas:cosine"
_LAST_PROJECTION_PERSONAS_KEY = "analysis:last_personas:projection"
_LAST_SIMILARITY_PERSONAS_KEY = "analysis:last_personas:similarity"
_LAST_MASK_STRATEGY_KEY = "analysis:last_mask_strategy"
_LAST_SOURCE_KEY = "analysis:last_source"
_LAST_PROJECTION_VARIANT_KEY = "analysis:last_projection_variant"
_LAST_SIMILARITY_VARIANT_KEY = "analysis:last_similarity_variant"
_LAST_PROJECTION_COLOR_MODE_KEY = "analysis:last_projection_color_mode"
_LAST_PROJECTION_ATTRIBUTE_KEY = "analysis:last_projection_attribute"
_LAST_PROJECTION_CLUSTER_K_KEY = "analysis:last_projection_cluster_k"
_LAST_PROJECTION_CLUSTER_MODE_KEY = "analysis:last_projection_cluster_mode"
_LAST_PROJECTION_HIGHLIGHTS_KEY = "analysis:last_projection_highlights"
_LAST_PROJECTION_DIMS_KEY = "analysis:last_projection_dims"
_LAST_LAYER_FRAMES_KEY = "analysis:last_layer_frames"
_DEFAULT_LAYER_FRAMES = 16
_DEFAULT_PERSONA_LIMITS = {
"similarity": 20,
"pca": 500,
"umap": 500,
"isomap": 500,
"dendro": 20,
}
_MAX_PERSONA_COUNTS = {
"similarity": 100,
"dendro": 100,
}
_MAX_SIMILARITY_CELLS = 4_000_000
_MAX_PAIR_TRAJECTORY_TRACES = 500
_DEFAULT_GRAPH_NEIGHBORS = 5
_PROJECTION_KINDS = {"pca", "umap", "isomap"}
_CLUSTER_MODES = {
"Mean across layers": "mean_across_layers",
"First selected layer": "first_layer",
"Per layer": "per_layer",
}
_PROJECTION_COLOR_MODES = ["Persona attribute", "Persona", "K-means clusters"]
_MAX_ATTRIBUTE_CATEGORIES = DEFAULT_MAX_ATTRIBUTE_CATEGORIES
def _is_assistant_persona(persona_id: str, persona_name: str | None = None) -> bool:
persona_id_normalized = persona_id.strip().lower()
persona_name_normalized = (persona_name or "").strip().lower()
return (
persona_id_normalized in {"assistant", BASELINE_PERSONA_ID.lower()}
or persona_name_normalized == "assistant"
)
@dataclass(frozen=True)
class CosineSelection:
variants: list[str]
variant_a: str
variant_b: str
persona_ids: list[str]
persona_key: str
@dataclass(frozen=True)
class PersonaOptions:
regular_ids: list[str]
assistant_id: str | None
persona_names: dict[str, str]
@dataclass(frozen=True)
class ProjectionColorConfig:
color_mode: str = "Persona"
n_clusters: int | None = None
cluster_mode: str | None = None
attribute_name: str | None = None
highlight_persona_ids: tuple[str, ...] = ()
highlight_persona_key: str = ""
@dataclass(frozen=True)
class LayeredFigureStateKeys:
figure: str
prepared: str | None = None
_HIGHLIGHT_OTHER_LABEL = "Other"
_HIGHLIGHT_OTHER_COLOR = "rgba(148, 163, 184, 0.35)"
def _persona_names_state_key(widget_scope: str) -> str:
return widget_key("load", "persona_names", widget_scope)
def _persona_display_label(persona_names: dict[str, str], persona_id: str) -> str:
name = persona_names.get(persona_id, persona_id)
return f"{name} ({persona_id})" if name != persona_id else persona_id
def _highlight_persona_groups(
persona_ids: list[str],
persona_names: dict[str, str],
highlight_persona_ids: tuple[str, ...],
) -> list[str] | None:
if not highlight_persona_ids:
return None
highlighted = set(highlight_persona_ids)
return [
(
_persona_display_label(persona_names, persona_id)
if persona_id in highlighted
else _HIGHLIGHT_OTHER_LABEL
)
for persona_id in persona_ids
]
def _sequence_to_list(value: object) -> list[object] | None:
if value is None or isinstance(value, (str, bytes)):
return None
if isinstance(value, list):
return value
if isinstance(value, tuple):
return list(value)
try:
return list(value)
except TypeError:
return None
_TRACKED_STATE_KEYS_KEY = "analysis:_tracked_state_keys"
_FIGURE_STATE_ENTRIES = env_int("PERSONA_UI_FIGURE_STATE_ENTRIES", 2)
_PREPARED_STATE_ENTRIES = env_int("PERSONA_UI_PREPARED_STATE_ENTRIES", 4)
def _touch_load_state(current_key: str, suffix: str, *, max_entries: int) -> None:
# Keep a tiny MRU window of heavy state instead of scanning all of
# session_state or retaining every figure forever. This makes nearby
# method-switching feel warm while still giving RAM a hard ceiling.
tracked: dict[str, list[str]] = st.session_state.setdefault(
_TRACKED_STATE_KEYS_KEY, {}
)
keys = [key for key in tracked.get(suffix, []) if key != current_key]
keys.append(current_key)
while len(keys) > max(1, max_entries):
st.session_state.pop(keys.pop(0), None)
tracked[suffix] = keys
def _clear_old_figure_states(current_key: str) -> None:
_touch_load_state(
current_key,
"_fig_state",
max_entries=_FIGURE_STATE_ENTRIES,
)
def _clear_old_prepared_states(current_key: str) -> None:
_touch_load_state(
current_key,
"_projection_ready",
max_entries=_PREPARED_STATE_ENTRIES,
)
def _store_figure_state(key: str, value: object) -> None:
_clear_old_figure_states(key)
st.session_state[key] = value
def _seed_selectbox_key(
*,
key: str,
remember_key: str,
options: list[str],
default: str,
) -> str:
value = st.session_state.get(key, st.session_state.get(remember_key, default))
if value not in options:
value = default
return value
def _remembered_selectbox(
label: str,
*,
key: str,
remember_key: str,
options: list[str],
default: str,
**selectbox_kwargs: object,
) -> str:
selected = _seed_selectbox_key(
key=key,
remember_key=remember_key,
options=options,
default=default,
)
choice = st.selectbox(
label,
options=options,
index=options.index(selected),
key=key,
**selectbox_kwargs,
)
st.session_state[remember_key] = choice
return choice
def _personas_empty_message(variants: list[str]) -> str:
if len(variants) > 1:
return (
"No personas have vectors for all selected variants. "
"Pick a single variant or change the source."
)
return "No personas found for this model and variant."
def _remember_multiselect(
*,
key: str,
remember_key: str,
options: list[str],
) -> list[str]:
remembered = st.session_state.get(key, st.session_state.get(remember_key, []))
if not isinstance(remembered, list):
remembered = []
return [value for value in remembered if value in options]