| import os |
| from dataclasses import dataclass |
|
|
| import streamlit as st |
| from dotenv import load_dotenv |
|
|
| from utils.analysis_sources import DEFAULT_COMPARE_MODEL, DEFAULT_HUB_REPO, SOURCE_HUB |
| from utils.helpers import DATASET_SOURCES, session_key, widget_key |
| from utils.preload import preload_once |
| from utils.runtime import configured_ndif_api_key, list_remote_models |
| from utils.theme import active_base, install_catppuccin_theme |
|
|
| load_dotenv() |
| DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it") |
| REMOTE_DEFAULT_MODEL = os.environ.get("REMOTE_DEFAULT_MODEL", "google/gemma-2-9b-it") |
| _LAST_LOCAL_MODEL_KEY = session_key("sidebar", "last_local_model") |
| _LAST_REMOTE_MODEL_KEY = session_key("sidebar", "last_remote_model") |
| _SIDEBAR_ACTIVE_TAB_KEY = session_key("sidebar", "active_tab") |
| _SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY = session_key( |
| "sidebar", "remote_model_custom_value" |
| ) |
| _SIDEBAR_REMOTE_MODEL_CUSTOM_ENABLED_KEY = session_key( |
| "sidebar", "remote_model_custom_enabled" |
| ) |
| _SIDEBAR_REMOTE_MODEL_KEY = session_key("sidebar", "remote_model") |
| _SIDEBAR_LOCAL_MODEL_KEY = session_key("sidebar", "local_model") |
| _SIDEBAR_REMOTE_KEY = session_key("sidebar", "remote") |
| _SIDEBAR_DATASET_SOURCE_KEY = session_key("sidebar", "dataset_source") |
| _SIDEBAR_NDIF_API_KEY = session_key("sidebar", "ndif_api_key") |
| NDIF_REGISTRATION_URL = "https://login.ndif.us/" |
|
|
|
|
| _TABS = ["Chat", "Analysis", "Probing", "Extract"] |
| _TAB_ICONS = [ |
| ":material/chat:", |
| ":material/search:", |
| ":material/biotech:", |
| ":material/tune:", |
| ] |
| _TAB_PRELOAD_MODULES = { |
| "Chat": ("tabs.analysis_core", "tabs.extract", "tabs.compare_chat", "tabs.probe"), |
| "Analysis": ("tabs.chat", "tabs.extract", "tabs.probe"), |
| "Probing": ("tabs.chat", "tabs.analysis_core", "tabs.extract"), |
| "Extract": ("tabs.chat", "tabs.analysis_core", "tabs.probe"), |
| } |
| _TAB_PRELOAD_FUNCTIONS = { |
| "Chat": ("utils.analysis_metadata:synth_persona_attribute_names",), |
| "Probing": ("utils.analysis_metadata:synth_persona_attribute_names",), |
| "Extract": ("utils.analysis_metadata:synth_persona_attribute_names",), |
| } |
|
|
|
|
| def _hub_metadata_preload_calls() -> tuple[ |
| tuple[str, tuple[str, str, str, str | None]], ... |
| ]: |
| calls: list[tuple[str, tuple[str, str, str, str | None]]] = [] |
|
|
| def add(repo: str, model: str, mask_strategy: str, variant: str | None) -> None: |
| calls.append( |
| ( |
| "utils.analysis_sources:prefetch_hub_metadata", |
| (repo, model, mask_strategy, variant), |
| ) |
| ) |
|
|
| shared_source = st.session_state.get("source:last_source", SOURCE_HUB) |
| shared_mask_strategy = st.session_state.get( |
| "source:last_mask_strategy", "answer_mean" |
| ) |
|
|
| analysis_source = st.session_state.get("analysis:last_source", shared_source) |
| if analysis_source == SOURCE_HUB: |
| repo = st.session_state.get( |
| "analysis:hub_repo", |
| st.session_state.get("source:hub_repo", DEFAULT_HUB_REPO), |
| ) |
| mask_strategy = st.session_state.get( |
| "analysis:last_mask_strategy", |
| shared_mask_strategy, |
| ) |
| model = st.session_state.get( |
| widget_key("load", "hub_model", repo, mask_strategy), |
| st.session_state.get( |
| "analysis:hub_model_fallback", |
| st.session_state.get("source:hub_model", DEFAULT_COMPARE_MODEL), |
| ), |
| ) |
| variant = st.session_state.get( |
| "analysis:last_projection_variant", |
| st.session_state.get("analysis:last_similarity_variant"), |
| ) |
| add(repo, model, mask_strategy, variant) |
|
|
| probe_source = st.session_state.get(widget_key("probe", "source"), shared_source) |
| if probe_source == SOURCE_HUB: |
| repo = st.session_state.get( |
| "probe:hub_repo", |
| st.session_state.get("source:hub_repo", DEFAULT_HUB_REPO), |
| ) |
| mask_strategy = st.session_state.get( |
| "probe:last_mask_strategy", |
| shared_mask_strategy, |
| ) |
| model = st.session_state.get( |
| widget_key("probe", "hub_model", repo, mask_strategy), |
| st.session_state.get( |
| "probe:hub_model_fallback", |
| st.session_state.get("source:hub_model", DEFAULT_COMPARE_MODEL), |
| ), |
| ) |
| add(repo, model, mask_strategy, st.session_state.get("probe:variant")) |
|
|
| deduped: dict[tuple[str, tuple[str, str, str, str | None]], None] = {} |
| for call in calls: |
| deduped[call] = None |
| return tuple(deduped) |
|
|
|
|
| @dataclass(frozen=True) |
| class SidebarState: |
| remote: bool |
| model_name: str |
| dataset_source: str |
| active_tab: str |
|
|
|
|
| def _remote_model_input(remote_models: list[str]) -> str: |
| """Return the active remote model id, picking from running NDIF deployments or a custom value.""" |
|
|
| last_remote = st.session_state.get(_LAST_REMOTE_MODEL_KEY, REMOTE_DEFAULT_MODEL) |
|
|
| if not remote_models: |
| st.warning("No running NDIF models found.") |
| model_name = st.text_input( |
| "Model", |
| value=st.session_state.get( |
| _SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY, last_remote |
| ), |
| key=_SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY, |
| help="NDIF model id. Use this to cold-load a remote model.", |
| ) |
| st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name |
| return model_name |
|
|
| custom = st.toggle( |
| "Custom remote model", |
| value=False, |
| key=_SIDEBAR_REMOTE_MODEL_CUSTOM_ENABLED_KEY, |
| help="Enter any NDIF-loadable model id, even if it is not currently running.", |
| ) |
| if custom: |
| model_name = st.text_input( |
| "Model", |
| value=st.session_state.get( |
| _SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY, last_remote |
| ), |
| key=_SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY, |
| help="NDIF model id. Example: openai/gpt-oss-20b", |
| ) |
| st.caption( |
| f"{len(remote_models)} running NDIF model(s) detected. " |
| "Custom model ids can cold-load if your NDIF account allows it." |
| ) |
| else: |
| default_model = st.session_state.get(_SIDEBAR_REMOTE_MODEL_KEY, last_remote) |
| if default_model not in remote_models: |
| default_model = ( |
| REMOTE_DEFAULT_MODEL |
| if REMOTE_DEFAULT_MODEL in remote_models |
| else remote_models[0] |
| ) |
| model_name = st.selectbox( |
| "Model", |
| options=remote_models, |
| index=remote_models.index(default_model), |
| key=_SIDEBAR_REMOTE_MODEL_KEY, |
| help="Running NDIF model.", |
| ) |
| st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name |
| return model_name |
|
|
|
|
| def _ndif_api_key_input() -> None: |
| """Prompt for a per-session NDIF API key.""" |
|
|
| if configured_ndif_api_key(): |
| st.caption("Using NDIF API key from environment.") |
| return |
|
|
| api_key = st.text_input( |
| "NDIF API key", |
| type="password", |
| key=_SIDEBAR_NDIF_API_KEY, |
| help=f"Required for remote (NDIF) execution. Register at {NDIF_REGISTRATION_URL}", |
| ) |
| if not api_key: |
| st.caption(f"No NDIF API key found. [Get one]({NDIF_REGISTRATION_URL}).") |
|
|
|
|
| def _sidebar_controls() -> SidebarState: |
| with st.sidebar: |
| st.markdown("## Persona UI") |
|
|
| if _SIDEBAR_ACTIVE_TAB_KEY not in st.session_state: |
| st.session_state[_SIDEBAR_ACTIVE_TAB_KEY] = "Chat" |
|
|
| active_tab = st.session_state[_SIDEBAR_ACTIVE_TAB_KEY] |
| for tab_name, icon in zip(_TABS, _TAB_ICONS, strict=True): |
| is_selected = tab_name == active_tab |
| if st.button( |
| tab_name, |
| key=f"sidebar__tab__{tab_name.lower()}", |
| width="stretch", |
| type="primary" if is_selected else "secondary", |
| icon=icon, |
| ): |
| st.session_state[_SIDEBAR_ACTIVE_TAB_KEY] = tab_name |
| st.rerun() |
|
|
| if active_tab in {"Analysis", "Probing"}: |
| |
| |
| model_name = st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL) |
| dataset_source = st.session_state.get( |
| _SIDEBAR_DATASET_SOURCE_KEY, |
| DATASET_SOURCES[0], |
| ) |
| return SidebarState( |
| remote=False, |
| model_name=model_name, |
| dataset_source=dataset_source, |
| active_tab=active_tab, |
| ) |
|
|
| st.divider() |
| st.caption("Runtime") |
| _ndif_api_key_input() |
| remote = st.toggle("Remote (NDIF)", value=False, key=_SIDEBAR_REMOTE_KEY) |
|
|
| if remote: |
| model_name = _remote_model_input(list_remote_models()) |
| else: |
| model_name = st.text_input( |
| "Model", |
| value=st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL), |
| key=_SIDEBAR_LOCAL_MODEL_KEY, |
| help="Local model id or path.", |
| ) |
| st.session_state[_LAST_LOCAL_MODEL_KEY] = model_name |
|
|
| st.caption("Data") |
| dataset_source = st.selectbox( |
| "Source", |
| DATASET_SOURCES, |
| key=_SIDEBAR_DATASET_SOURCE_KEY, |
| help="Dataset for Chat and Extract.", |
| ) |
|
|
| return SidebarState( |
| remote=remote, |
| model_name=model_name, |
| dataset_source=dataset_source, |
| active_tab=active_tab, |
| ) |
|
|
|
|
| def main() -> None: |
| """Run the Streamlit app.""" |
|
|
| st.set_page_config(page_title="Persona UI", layout="wide") |
| install_catppuccin_theme(active_base()) |
|
|
| sidebar = _sidebar_controls() |
|
|
| if sidebar.active_tab == "Extract": |
| from tabs.extract import render_extract_tab |
|
|
| render_extract_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source) |
| elif sidebar.active_tab == "Analysis": |
| from tabs.analysis_core import render_analysis_tab |
|
|
| render_analysis_tab() |
| elif sidebar.active_tab == "Probing": |
| from tabs.probe import render_probing_tab |
|
|
| render_probing_tab() |
| else: |
| from tabs.chat import render_chat_tab |
|
|
| render_chat_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source) |
|
|
| preload_once( |
| f"after-{sidebar.active_tab.lower()}", |
| modules=_TAB_PRELOAD_MODULES.get(sidebar.active_tab, ()), |
| functions=_TAB_PRELOAD_FUNCTIONS.get(sidebar.active_tab, ()), |
| calls=_hub_metadata_preload_calls(), |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|