persona-ui / utils /analysis_sources.py
Jac-Zac
add session-scoped NDIF execution and improve cold-load UX
ae347c6
import os
from contextlib import contextmanager
import streamlit as st
from persona_vectors.analysis import (
AnalysisDataset,
LayeredSamples,
load_analysis_dataset,
)
from persona_vectors.artifacts import (
HFPersonaVectorStore,
PersonaVectorStore,
discover_activation_models,
model_dir_name,
)
from persona_vectors.extraction import MaskStrategy
from persona_vectors.hub import list_hub_vector_models
from persona_vectors.plots import (
LayeredProjectionData,
prepare_kmeans_groups,
prepare_layered_projection_data,
)
from utils.helpers import env_int
Store = PersonaVectorStore | HFPersonaVectorStore
DEFAULT_HUB_REPO = os.environ.get(
"PERSONA_VECTORS_HUB_REPO",
"implicit-personalization/synth-persona-vectors",
)
DEFAULT_COMPARE_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
SOURCE_HUB = "Hugging Face Hub"
SOURCE_LOCAL = "Local artifacts"
SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
_STORE_CACHE_ENTRIES = env_int("PERSONA_UI_STORE_CACHE_ENTRIES", 4)
_VECTOR_CACHE_ENTRIES = env_int("PERSONA_UI_VECTOR_CACHE_ENTRIES", 4)
_PREPARED_CACHE_ENTRIES = env_int("PERSONA_UI_PREPARED_CACHE_ENTRIES", 8)
def _hub_variants_pending(store: Store, variants: tuple[str, ...]) -> tuple[str, ...]:
"""Return Hub variants that have not yet been opened by this store instance."""
if not isinstance(store, HFPersonaVectorStore):
return ()
return tuple(variant for variant in variants if variant not in store._datasets)
@contextmanager
def _hub_vector_notice(store: Store, variants: tuple[str, ...]):
"""Show a transient, honest cold-load note for Hub-backed vector data."""
pending = _hub_variants_pending(store, variants)
if not pending:
yield
return
notice = st.empty()
notice.warning(
"Loading persona vectors from Hugging Face. "
"On a cold cache, this may download Hub dataset files."
)
try:
yield
finally:
notice.empty()
@st.cache_resource(show_spinner=False, max_entries=_STORE_CACHE_ENTRIES)
def activation_store_cached(
source: str,
location: str,
model_name: str,
mask_strategy_value: str,
) -> Store:
mask_strategy = MaskStrategy(mask_strategy_value)
if source == SOURCE_HUB:
return HFPersonaVectorStore(location, model_name, mask_strategy=mask_strategy)
return PersonaVectorStore(model_name, location, mask_strategy=mask_strategy)
@st.cache_data(show_spinner=False)
def available_variants_cached(
source: str,
location: str,
model_name: str,
mask_strategy_value: str,
) -> list[str]:
return activation_store_cached(
source, location, model_name, mask_strategy_value
).available_variants()
@st.cache_data(show_spinner=False)
def personas_cached(
source: str,
location: str,
model_name: str,
mask_strategy_value: str,
variants: tuple[str, ...],
*,
include_baseline: bool = False,
) -> list[str]:
store = activation_store_cached(source, location, model_name, mask_strategy_value)
with _hub_vector_notice(store, variants):
return store.list_personas(list(variants), include_baseline=include_baseline)
@st.cache_data(show_spinner=False)
def persona_names_cached(
source: str,
location: str,
model_name: str,
mask_strategy_value: str,
variants: tuple[str, ...],
persona_ids: tuple[str, ...],
) -> dict[str, str]:
store = activation_store_cached(source, location, model_name, mask_strategy_value)
with _hub_vector_notice(store, variants):
names = store.persona_names(list(persona_ids), variants=list(variants))
# Preserve input order, fall back to the id when the row has no display name.
return {pid: names.get(pid, pid) for pid in persona_ids}
@st.cache_data(show_spinner=False)
def store_layers_cached(
source: str,
location: str,
model_name: str,
mask_strategy_value: str,
variants: tuple[str, ...],
persona_ids: tuple[str, ...],
) -> list[int]:
store = activation_store_cached(source, location, model_name, mask_strategy_value)
with _hub_vector_notice(store, variants):
return store.list_layers(list(variants), list(persona_ids))
@st.cache_data(show_spinner=False)
def local_model_options_cached(
artifacts_root: str, mask_strategy_value: str
) -> list[str]:
return discover_activation_models(artifacts_root, mask_strategy_value)
@st.cache_data(show_spinner=False)
def hub_models_by_mask_strategy(repo_id: str) -> dict[MaskStrategy, list[str]]:
valid = {strategy.value for strategy in MaskStrategy}
return {
MaskStrategy(strategy_value): models
for strategy_value, models in list_hub_vector_models(repo_id).items()
if strategy_value in valid
}
def store_cache_parts(store: Store) -> tuple[str, str, str]:
if isinstance(store, HFPersonaVectorStore):
return SOURCE_HUB, store.repo_id, store.model_name
return SOURCE_LOCAL, str(store.root_dir), store.model_name
def store_id(store: Store) -> str:
if isinstance(store, HFPersonaVectorStore):
return f"hub:{store.repo_id}"
return f"local:{store.root_dir}"
def available_variants(store: Store, mask_strategy: MaskStrategy) -> list[str]:
source, location, model_name = store_cache_parts(store)
return available_variants_cached(source, location, model_name, mask_strategy.value)
def local_model_matches(left: str, right: str) -> bool:
return model_dir_name(left) == model_dir_name(right)
@st.cache_resource(show_spinner=False, max_entries=_VECTOR_CACHE_ENTRIES)
def load_analysis_dataset_cached(
source: str,
location: str,
model_name: str,
mask_strategy_value: str,
variants: tuple[str, ...],
persona_ids: tuple[str, ...],
) -> AnalysisDataset:
store = activation_store_cached(source, location, model_name, mask_strategy_value)
with _hub_vector_notice(store, variants):
return load_analysis_dataset(
store,
variants,
mask_strategy=MaskStrategy(mask_strategy_value),
persona_ids=persona_ids,
)
def load_persona_vectors_cached(
source: str,
location: str,
model_name: str,
mask_strategy_value: str,
variant: str,
persona_ids: tuple[str, ...],
) -> LayeredSamples:
return load_analysis_dataset_cached(
source,
location,
model_name,
mask_strategy_value,
(variant,),
persona_ids,
).samples(variant)
def load_variant_vectors_cached(
source: str,
location: str,
model_name: str,
mask_strategy_value: str,
variants: tuple[str, ...],
persona_ids: tuple[str, ...],
) -> dict[str, LayeredSamples]:
return load_analysis_dataset_cached(
source,
location,
model_name,
mask_strategy_value,
variants,
persona_ids,
).samples_by_variant
@st.cache_resource(show_spinner=False, max_entries=_PREPARED_CACHE_ENTRIES)
def projection_data_cached(
source: str,
location: str,
model_name: str,
mask_strategy_value: str,
variant: str,
persona_ids: tuple[str, ...],
layers: tuple[int, ...],
kind: str,
n_components: int,
graph_overlay: bool,
graph_n_neighbors: int,
) -> LayeredProjectionData:
samples = load_persona_vectors_cached(
source, location, model_name, mask_strategy_value, variant, persona_ids
)
return prepare_layered_projection_data(
samples,
kind,
layers=list(layers),
n_components=n_components,
graph_overlay=graph_overlay,
graph_n_neighbors=graph_n_neighbors,
)
@st.cache_resource(show_spinner=False, max_entries=_PREPARED_CACHE_ENTRIES)
def kmeans_groups_cached(
source: str,
location: str,
model_name: str,
mask_strategy_value: str,
variant: str,
persona_ids: tuple[str, ...],
layers: tuple[int, ...],
n_clusters: int,
cluster_mode: str,
) -> list[str] | dict[int, list[str]]:
samples = load_persona_vectors_cached(
source, location, model_name, mask_strategy_value, variant, persona_ids
)
return prepare_kmeans_groups(
samples,
layers=list(layers),
n_clusters=n_clusters,
cluster_mode=cluster_mode,
)
def prefetch_hub_metadata(
repo_id: str,
model_name: str,
mask_strategy_value: str,
variant: str | None = None,
) -> None:
"""Warm small Hub metadata caches without loading full activation tensors."""
if not repo_id or not model_name or not mask_strategy_value:
return
hub_models_by_mask_strategy(repo_id)
available_variants_cached(
SOURCE_HUB,
repo_id,
model_name,
mask_strategy_value,
)
if variant:
personas_cached(
SOURCE_HUB,
repo_id,
model_name,
mask_strategy_value,
(variant,),
)