| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import streamlit as st |
| from persona_data.environment import get_artifacts_dir |
| from persona_vectors.extraction import MaskStrategy |
|
|
| from utils.analysis_sources import ( |
| DEFAULT_COMPARE_MODEL, |
| DEFAULT_HUB_REPO, |
| SOURCE_HUB, |
| SOURCE_LOCAL, |
| SOURCES, |
| Store, |
| activation_store_cached, |
| hub_models_by_mask_strategy, |
| local_model_matches, |
| local_model_options_cached, |
| ) |
| from utils.helpers import widget_key |
| from utils.selection_controls import remembered_segmented_control |
|
|
| _SHARED_SOURCE_KEY = "source:last_source" |
| _SHARED_HUB_REPO_KEY = "source:hub_repo" |
| _SHARED_HUB_MODEL_KEY = "source:hub_model" |
| _SHARED_LOCAL_ROOT_KEY = "source:local_root" |
| _SHARED_LOCAL_MODEL_KEY = "source:local_model" |
|
|
|
|
| def render_source_select( |
| *, |
| widget_scope: str, |
| last_source_key: str | None = None, |
| ) -> str: |
| key = widget_key(widget_scope, "source") |
| if last_source_key is not None and last_source_key not in st.session_state: |
| shared_source = st.session_state.get(_SHARED_SOURCE_KEY) |
| if shared_source is not None: |
| st.session_state[last_source_key] = shared_source |
| selected = remembered_segmented_control( |
| "Source", |
| options=SOURCES, |
| key=key, |
| remember_key=last_source_key or _SHARED_SOURCE_KEY, |
| default=SOURCE_HUB, |
| label_visibility="collapsed", |
| ) |
| st.session_state[_SHARED_SOURCE_KEY] = selected |
| if last_source_key is not None: |
| st.session_state[last_source_key] = selected |
| return selected |
|
|
|
|
| def _render_hub_model_select( |
| *, |
| state_prefix: str, |
| widget_scope: str, |
| repo_id: str, |
| mask_strategy: MaskStrategy, |
| model_label: str, |
| fallback_help: str, |
| selection_help: str, |
| ) -> str: |
| fallback_key = f"{state_prefix}:hub_model_fallback" |
| fallback_model = st.session_state.get( |
| fallback_key, |
| st.session_state.get(_SHARED_HUB_MODEL_KEY, DEFAULT_COMPARE_MODEL), |
| ) |
| try: |
| models_by_strategy = hub_models_by_mask_strategy(repo_id) |
| except Exception as exc: |
| st.warning(f"Could not load Hub configs for `{repo_id}`: {exc}") |
| model = st.text_input( |
| model_label, |
| value=fallback_model, |
| key=fallback_key, |
| help=fallback_help, |
| ) |
| st.session_state[_SHARED_HUB_MODEL_KEY] = model |
| return model |
|
|
| model_options = models_by_strategy.get(mask_strategy, []) |
| if not model_options: |
| st.warning( |
| f"No Hub vector configs found for `{mask_strategy.value}` in `{repo_id}`." |
| ) |
| model = st.text_input( |
| model_label, |
| value=fallback_model, |
| key=fallback_key, |
| help=fallback_help, |
| ) |
| st.session_state[_SHARED_HUB_MODEL_KEY] = model |
| return model |
|
|
| select_key = widget_key(widget_scope, "hub_model", repo_id, mask_strategy.value) |
| previous_model = st.session_state.get( |
| select_key, |
| st.session_state.get(_SHARED_HUB_MODEL_KEY, fallback_model), |
| ) |
| default_model = ( |
| previous_model if previous_model in model_options else model_options[0] |
| ) |
| selected = st.selectbox( |
| model_label, |
| options=model_options, |
| index=model_options.index(default_model), |
| key=select_key, |
| help=selection_help, |
| ) |
| st.session_state[fallback_key] = selected |
| st.session_state[_SHARED_HUB_MODEL_KEY] = selected |
| return selected |
|
|
|
|
| def _render_local_model_select( |
| *, |
| state_prefix: str, |
| artifacts_root: str, |
| mask_strategy: MaskStrategy, |
| allow_custom_toggle: bool, |
| model_label: str, |
| ) -> str: |
| fallback_key = f"{state_prefix}:local_model" |
| fallback_model = st.session_state.get( |
| fallback_key, |
| st.session_state.get(_SHARED_LOCAL_MODEL_KEY, DEFAULT_COMPARE_MODEL), |
| ) |
| model_options = local_model_options_cached(artifacts_root, mask_strategy.value) |
| if not model_options: |
| model = st.text_input(model_label, value=fallback_model, key=fallback_key) |
| st.session_state[_SHARED_LOCAL_MODEL_KEY] = model |
| return model |
|
|
| if allow_custom_toggle: |
| custom = st.toggle( |
| "Custom local model", |
| value=False, |
| key=f"{state_prefix}:local_model_custom_enabled", |
| help="Enter a model id/path manually instead of choosing from activation directories.", |
| ) |
| if custom: |
| model = st.text_input("Local model", value=fallback_model, key=fallback_key) |
| st.session_state[_SHARED_LOCAL_MODEL_KEY] = model |
| return model |
|
|
| select_key = f"{state_prefix}:local_model_select" |
| previous_model = st.session_state.get( |
| select_key, |
| st.session_state.get(_SHARED_LOCAL_MODEL_KEY, fallback_model), |
| ) |
| if not any(local_model_matches(previous_model, option) for option in model_options): |
| previous_model = fallback_model |
| default_model = next( |
| ( |
| option |
| for option in model_options |
| if local_model_matches(option, previous_model) |
| ), |
| model_options[0], |
| ) |
| selected = st.selectbox( |
| model_label, |
| options=model_options, |
| index=model_options.index(default_model), |
| key=select_key, |
| help="Models discovered under the selected artifacts root.", |
| ) |
| st.session_state[fallback_key] = selected |
| st.session_state[_SHARED_LOCAL_MODEL_KEY] = selected |
| return selected |
|
|
|
|
| def render_store_select( |
| source: str, |
| mask_strategy: MaskStrategy, |
| *, |
| state_prefix: str, |
| widget_scope: str, |
| artifacts_root_key: str, |
| model_label: str = "Model", |
| local_model_label: str = "Model", |
| allow_custom_local_model: bool = False, |
| repo_help: str | None = None, |
| fallback_help: str = "Model id to use if Hub config discovery is unavailable.", |
| ) -> Store: |
| if source == SOURCE_HUB: |
| repo_key = f"{state_prefix}:hub_repo" |
| repo = st.text_input( |
| "Hub repo", |
| value=st.session_state.get( |
| repo_key, |
| st.session_state.get(_SHARED_HUB_REPO_KEY, DEFAULT_HUB_REPO), |
| ), |
| key=repo_key, |
| help=repo_help, |
| ) |
| st.session_state[_SHARED_HUB_REPO_KEY] = repo |
| model_name = _render_hub_model_select( |
| state_prefix=state_prefix, |
| widget_scope=widget_scope, |
| repo_id=repo, |
| mask_strategy=mask_strategy, |
| model_label=model_label, |
| fallback_help=fallback_help, |
| selection_help="Models with vectors in the selected Hub repo and mask strategy.", |
| ) |
| return activation_store_cached( |
| SOURCE_HUB, repo, model_name, mask_strategy.value |
| ) |
|
|
| root = st.text_input( |
| "Artifacts root", |
| value=st.session_state.get( |
| _SHARED_LOCAL_ROOT_KEY, |
| str(get_artifacts_dir() / "activations"), |
| ), |
| key=artifacts_root_key, |
| ) |
| root = str(Path(root).expanduser()) |
| st.session_state[_SHARED_LOCAL_ROOT_KEY] = root |
| model_name = _render_local_model_select( |
| state_prefix=state_prefix, |
| artifacts_root=root, |
| mask_strategy=mask_strategy, |
| allow_custom_toggle=allow_custom_local_model, |
| model_label=local_model_label, |
| ) |
| return activation_store_cached(SOURCE_LOCAL, root, model_name, mask_strategy.value) |
|
|