persona-ui / utils /source_controls.py
Jac-Zac
Big refactoring
b279884
from __future__ import annotations
from pathlib import Path
import streamlit as st
from persona_data.environment import get_artifacts_dir
from persona_vectors.extraction import MaskStrategy
from utils.analysis_sources import (
DEFAULT_COMPARE_MODEL,
DEFAULT_HUB_REPO,
SOURCE_HUB,
SOURCE_LOCAL,
SOURCES,
Store,
activation_store_cached,
hub_models_by_mask_strategy,
local_model_matches,
local_model_options_cached,
)
from utils.helpers import widget_key
from utils.selection_controls import remembered_segmented_control
_SHARED_SOURCE_KEY = "source:last_source"
_SHARED_HUB_REPO_KEY = "source:hub_repo"
_SHARED_HUB_MODEL_KEY = "source:hub_model"
_SHARED_LOCAL_ROOT_KEY = "source:local_root"
_SHARED_LOCAL_MODEL_KEY = "source:local_model"
def render_source_select(
*,
widget_scope: str,
last_source_key: str | None = None,
) -> str:
key = widget_key(widget_scope, "source")
if last_source_key is not None and last_source_key not in st.session_state:
shared_source = st.session_state.get(_SHARED_SOURCE_KEY)
if shared_source is not None:
st.session_state[last_source_key] = shared_source
selected = remembered_segmented_control(
"Source",
options=SOURCES,
key=key,
remember_key=last_source_key or _SHARED_SOURCE_KEY,
default=SOURCE_HUB,
label_visibility="collapsed",
)
st.session_state[_SHARED_SOURCE_KEY] = selected
if last_source_key is not None:
st.session_state[last_source_key] = selected
return selected
def _render_hub_model_select(
*,
state_prefix: str,
widget_scope: str,
repo_id: str,
mask_strategy: MaskStrategy,
model_label: str,
fallback_help: str,
selection_help: str,
) -> str:
fallback_key = f"{state_prefix}:hub_model_fallback"
fallback_model = st.session_state.get(
fallback_key,
st.session_state.get(_SHARED_HUB_MODEL_KEY, DEFAULT_COMPARE_MODEL),
)
try:
models_by_strategy = hub_models_by_mask_strategy(repo_id)
except Exception as exc:
st.warning(f"Could not load Hub configs for `{repo_id}`: {exc}")
model = st.text_input(
model_label,
value=fallback_model,
key=fallback_key,
help=fallback_help,
)
st.session_state[_SHARED_HUB_MODEL_KEY] = model
return model
model_options = models_by_strategy.get(mask_strategy, [])
if not model_options:
st.warning(
f"No Hub vector configs found for `{mask_strategy.value}` in `{repo_id}`."
)
model = st.text_input(
model_label,
value=fallback_model,
key=fallback_key,
help=fallback_help,
)
st.session_state[_SHARED_HUB_MODEL_KEY] = model
return model
select_key = widget_key(widget_scope, "hub_model", repo_id, mask_strategy.value)
previous_model = st.session_state.get(
select_key,
st.session_state.get(_SHARED_HUB_MODEL_KEY, fallback_model),
)
default_model = (
previous_model if previous_model in model_options else model_options[0]
)
selected = st.selectbox(
model_label,
options=model_options,
index=model_options.index(default_model),
key=select_key,
help=selection_help,
)
st.session_state[fallback_key] = selected
st.session_state[_SHARED_HUB_MODEL_KEY] = selected
return selected
def _render_local_model_select(
*,
state_prefix: str,
artifacts_root: str,
mask_strategy: MaskStrategy,
allow_custom_toggle: bool,
model_label: str,
) -> str:
fallback_key = f"{state_prefix}:local_model"
fallback_model = st.session_state.get(
fallback_key,
st.session_state.get(_SHARED_LOCAL_MODEL_KEY, DEFAULT_COMPARE_MODEL),
)
model_options = local_model_options_cached(artifacts_root, mask_strategy.value)
if not model_options:
model = st.text_input(model_label, value=fallback_model, key=fallback_key)
st.session_state[_SHARED_LOCAL_MODEL_KEY] = model
return model
if allow_custom_toggle:
custom = st.toggle(
"Custom local model",
value=False,
key=f"{state_prefix}:local_model_custom_enabled",
help="Enter a model id/path manually instead of choosing from activation directories.",
)
if custom:
model = st.text_input("Local model", value=fallback_model, key=fallback_key)
st.session_state[_SHARED_LOCAL_MODEL_KEY] = model
return model
select_key = f"{state_prefix}:local_model_select"
previous_model = st.session_state.get(
select_key,
st.session_state.get(_SHARED_LOCAL_MODEL_KEY, fallback_model),
)
if not any(local_model_matches(previous_model, option) for option in model_options):
previous_model = fallback_model
default_model = next(
(
option
for option in model_options
if local_model_matches(option, previous_model)
),
model_options[0],
)
selected = st.selectbox(
model_label,
options=model_options,
index=model_options.index(default_model),
key=select_key,
help="Models discovered under the selected artifacts root.",
)
st.session_state[fallback_key] = selected
st.session_state[_SHARED_LOCAL_MODEL_KEY] = selected
return selected
def render_store_select(
source: str,
mask_strategy: MaskStrategy,
*,
state_prefix: str,
widget_scope: str,
artifacts_root_key: str,
model_label: str = "Model",
local_model_label: str = "Model",
allow_custom_local_model: bool = False,
repo_help: str | None = None,
fallback_help: str = "Model id to use if Hub config discovery is unavailable.",
) -> Store:
if source == SOURCE_HUB:
repo_key = f"{state_prefix}:hub_repo"
repo = st.text_input(
"Hub repo",
value=st.session_state.get(
repo_key,
st.session_state.get(_SHARED_HUB_REPO_KEY, DEFAULT_HUB_REPO),
),
key=repo_key,
help=repo_help,
)
st.session_state[_SHARED_HUB_REPO_KEY] = repo
model_name = _render_hub_model_select(
state_prefix=state_prefix,
widget_scope=widget_scope,
repo_id=repo,
mask_strategy=mask_strategy,
model_label=model_label,
fallback_help=fallback_help,
selection_help="Models with vectors in the selected Hub repo and mask strategy.",
)
return activation_store_cached(
SOURCE_HUB, repo, model_name, mask_strategy.value
)
root = st.text_input(
"Artifacts root",
value=st.session_state.get(
_SHARED_LOCAL_ROOT_KEY,
str(get_artifacts_dir() / "activations"),
),
key=artifacts_root_key,
)
root = str(Path(root).expanduser())
st.session_state[_SHARED_LOCAL_ROOT_KEY] = root
model_name = _render_local_model_select(
state_prefix=state_prefix,
artifacts_root=root,
mask_strategy=mask_strategy,
allow_custom_toggle=allow_custom_local_model,
model_label=local_model_label,
)
return activation_store_cached(SOURCE_LOCAL, root, model_name, mask_strategy.value)