| import os |
| from contextlib import contextmanager |
|
|
| import streamlit as st |
| from persona_vectors.analysis import ( |
| AnalysisDataset, |
| LayeredSamples, |
| load_analysis_dataset, |
| ) |
| from persona_vectors.artifacts import ( |
| HFPersonaVectorStore, |
| PersonaVectorStore, |
| discover_activation_models, |
| model_dir_name, |
| ) |
| from persona_vectors.extraction import MaskStrategy |
| from persona_vectors.hub import list_hub_vector_models |
| from persona_vectors.plots import ( |
| LayeredProjectionData, |
| prepare_kmeans_groups, |
| prepare_layered_projection_data, |
| ) |
|
|
| from utils.helpers import env_int |
|
|
| Store = PersonaVectorStore | HFPersonaVectorStore |
|
|
| DEFAULT_HUB_REPO = os.environ.get( |
| "PERSONA_VECTORS_HUB_REPO", |
| "implicit-personalization/synth-persona-vectors", |
| ) |
| DEFAULT_COMPARE_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it") |
| SOURCE_HUB = "Hugging Face Hub" |
| SOURCE_LOCAL = "Local artifacts" |
| SOURCES = (SOURCE_HUB, SOURCE_LOCAL) |
|
|
|
|
| _STORE_CACHE_ENTRIES = env_int("PERSONA_UI_STORE_CACHE_ENTRIES", 4) |
| _VECTOR_CACHE_ENTRIES = env_int("PERSONA_UI_VECTOR_CACHE_ENTRIES", 4) |
| _PREPARED_CACHE_ENTRIES = env_int("PERSONA_UI_PREPARED_CACHE_ENTRIES", 8) |
|
|
|
|
| def _hub_variants_pending(store: Store, variants: tuple[str, ...]) -> tuple[str, ...]: |
| """Return Hub variants that have not yet been opened by this store instance.""" |
|
|
| if not isinstance(store, HFPersonaVectorStore): |
| return () |
| return tuple(variant for variant in variants if variant not in store._datasets) |
|
|
|
|
| @contextmanager |
| def _hub_vector_notice(store: Store, variants: tuple[str, ...]): |
| """Show a transient, honest cold-load note for Hub-backed vector data.""" |
|
|
| pending = _hub_variants_pending(store, variants) |
| if not pending: |
| yield |
| return |
|
|
| notice = st.empty() |
| notice.warning( |
| "Loading persona vectors from Hugging Face. " |
| "On a cold cache, this may download Hub dataset files." |
| ) |
| try: |
| yield |
| finally: |
| notice.empty() |
|
|
|
|
| @st.cache_resource(show_spinner=False, max_entries=_STORE_CACHE_ENTRIES) |
| def activation_store_cached( |
| source: str, |
| location: str, |
| model_name: str, |
| mask_strategy_value: str, |
| ) -> Store: |
| mask_strategy = MaskStrategy(mask_strategy_value) |
| if source == SOURCE_HUB: |
| return HFPersonaVectorStore(location, model_name, mask_strategy=mask_strategy) |
| return PersonaVectorStore(model_name, location, mask_strategy=mask_strategy) |
|
|
|
|
| @st.cache_data(show_spinner=False) |
| def available_variants_cached( |
| source: str, |
| location: str, |
| model_name: str, |
| mask_strategy_value: str, |
| ) -> list[str]: |
| return activation_store_cached( |
| source, location, model_name, mask_strategy_value |
| ).available_variants() |
|
|
|
|
| @st.cache_data(show_spinner=False) |
| def personas_cached( |
| source: str, |
| location: str, |
| model_name: str, |
| mask_strategy_value: str, |
| variants: tuple[str, ...], |
| *, |
| include_baseline: bool = False, |
| ) -> list[str]: |
| store = activation_store_cached(source, location, model_name, mask_strategy_value) |
| with _hub_vector_notice(store, variants): |
| return store.list_personas(list(variants), include_baseline=include_baseline) |
|
|
|
|
| @st.cache_data(show_spinner=False) |
| def persona_names_cached( |
| source: str, |
| location: str, |
| model_name: str, |
| mask_strategy_value: str, |
| variants: tuple[str, ...], |
| persona_ids: tuple[str, ...], |
| ) -> dict[str, str]: |
| store = activation_store_cached(source, location, model_name, mask_strategy_value) |
| with _hub_vector_notice(store, variants): |
| names = store.persona_names(list(persona_ids), variants=list(variants)) |
| |
| return {pid: names.get(pid, pid) for pid in persona_ids} |
|
|
|
|
| @st.cache_data(show_spinner=False) |
| def store_layers_cached( |
| source: str, |
| location: str, |
| model_name: str, |
| mask_strategy_value: str, |
| variants: tuple[str, ...], |
| persona_ids: tuple[str, ...], |
| ) -> list[int]: |
| store = activation_store_cached(source, location, model_name, mask_strategy_value) |
| with _hub_vector_notice(store, variants): |
| return store.list_layers(list(variants), list(persona_ids)) |
|
|
|
|
| @st.cache_data(show_spinner=False) |
| def local_model_options_cached( |
| artifacts_root: str, mask_strategy_value: str |
| ) -> list[str]: |
| return discover_activation_models(artifacts_root, mask_strategy_value) |
|
|
|
|
| @st.cache_data(show_spinner=False) |
| def hub_models_by_mask_strategy(repo_id: str) -> dict[MaskStrategy, list[str]]: |
| valid = {strategy.value for strategy in MaskStrategy} |
| return { |
| MaskStrategy(strategy_value): models |
| for strategy_value, models in list_hub_vector_models(repo_id).items() |
| if strategy_value in valid |
| } |
|
|
|
|
| def store_cache_parts(store: Store) -> tuple[str, str, str]: |
| if isinstance(store, HFPersonaVectorStore): |
| return SOURCE_HUB, store.repo_id, store.model_name |
| return SOURCE_LOCAL, str(store.root_dir), store.model_name |
|
|
|
|
| def store_id(store: Store) -> str: |
| if isinstance(store, HFPersonaVectorStore): |
| return f"hub:{store.repo_id}" |
| return f"local:{store.root_dir}" |
|
|
|
|
| def available_variants(store: Store, mask_strategy: MaskStrategy) -> list[str]: |
| source, location, model_name = store_cache_parts(store) |
| return available_variants_cached(source, location, model_name, mask_strategy.value) |
|
|
|
|
| def local_model_matches(left: str, right: str) -> bool: |
| return model_dir_name(left) == model_dir_name(right) |
|
|
|
|
| @st.cache_resource(show_spinner=False, max_entries=_VECTOR_CACHE_ENTRIES) |
| def load_analysis_dataset_cached( |
| source: str, |
| location: str, |
| model_name: str, |
| mask_strategy_value: str, |
| variants: tuple[str, ...], |
| persona_ids: tuple[str, ...], |
| ) -> AnalysisDataset: |
| store = activation_store_cached(source, location, model_name, mask_strategy_value) |
| with _hub_vector_notice(store, variants): |
| return load_analysis_dataset( |
| store, |
| variants, |
| mask_strategy=MaskStrategy(mask_strategy_value), |
| persona_ids=persona_ids, |
| ) |
|
|
|
|
| def load_persona_vectors_cached( |
| source: str, |
| location: str, |
| model_name: str, |
| mask_strategy_value: str, |
| variant: str, |
| persona_ids: tuple[str, ...], |
| ) -> LayeredSamples: |
| return load_analysis_dataset_cached( |
| source, |
| location, |
| model_name, |
| mask_strategy_value, |
| (variant,), |
| persona_ids, |
| ).samples(variant) |
|
|
|
|
| def load_variant_vectors_cached( |
| source: str, |
| location: str, |
| model_name: str, |
| mask_strategy_value: str, |
| variants: tuple[str, ...], |
| persona_ids: tuple[str, ...], |
| ) -> dict[str, LayeredSamples]: |
| return load_analysis_dataset_cached( |
| source, |
| location, |
| model_name, |
| mask_strategy_value, |
| variants, |
| persona_ids, |
| ).samples_by_variant |
|
|
|
|
| @st.cache_resource(show_spinner=False, max_entries=_PREPARED_CACHE_ENTRIES) |
| def projection_data_cached( |
| source: str, |
| location: str, |
| model_name: str, |
| mask_strategy_value: str, |
| variant: str, |
| persona_ids: tuple[str, ...], |
| layers: tuple[int, ...], |
| kind: str, |
| n_components: int, |
| graph_overlay: bool, |
| graph_n_neighbors: int, |
| ) -> LayeredProjectionData: |
| samples = load_persona_vectors_cached( |
| source, location, model_name, mask_strategy_value, variant, persona_ids |
| ) |
| return prepare_layered_projection_data( |
| samples, |
| kind, |
| layers=list(layers), |
| n_components=n_components, |
| graph_overlay=graph_overlay, |
| graph_n_neighbors=graph_n_neighbors, |
| ) |
|
|
|
|
| @st.cache_resource(show_spinner=False, max_entries=_PREPARED_CACHE_ENTRIES) |
| def kmeans_groups_cached( |
| source: str, |
| location: str, |
| model_name: str, |
| mask_strategy_value: str, |
| variant: str, |
| persona_ids: tuple[str, ...], |
| layers: tuple[int, ...], |
| n_clusters: int, |
| cluster_mode: str, |
| ) -> list[str] | dict[int, list[str]]: |
| samples = load_persona_vectors_cached( |
| source, location, model_name, mask_strategy_value, variant, persona_ids |
| ) |
| return prepare_kmeans_groups( |
| samples, |
| layers=list(layers), |
| n_clusters=n_clusters, |
| cluster_mode=cluster_mode, |
| ) |
|
|
|
|
| def prefetch_hub_metadata( |
| repo_id: str, |
| model_name: str, |
| mask_strategy_value: str, |
| variant: str | None = None, |
| ) -> None: |
| """Warm small Hub metadata caches without loading full activation tensors.""" |
| if not repo_id or not model_name or not mask_strategy_value: |
| return |
| hub_models_by_mask_strategy(repo_id) |
| available_variants_cached( |
| SOURCE_HUB, |
| repo_id, |
| model_name, |
| mask_strategy_value, |
| ) |
| if variant: |
| personas_cached( |
| SOURCE_HUB, |
| repo_id, |
| model_name, |
| mask_strategy_value, |
| (variant,), |
| ) |
|
|