import os from dataclasses import dataclass import streamlit as st from dotenv import load_dotenv from utils.analysis_sources import DEFAULT_COMPARE_MODEL, DEFAULT_HUB_REPO, SOURCE_HUB from utils.helpers import DATASET_SOURCES, session_key, widget_key from utils.preload import preload_once from utils.runtime import configured_ndif_api_key, list_remote_models from utils.theme import active_base, install_catppuccin_theme load_dotenv() DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it") REMOTE_DEFAULT_MODEL = os.environ.get("REMOTE_DEFAULT_MODEL", "google/gemma-2-9b-it") _LAST_LOCAL_MODEL_KEY = session_key("sidebar", "last_local_model") _LAST_REMOTE_MODEL_KEY = session_key("sidebar", "last_remote_model") _SIDEBAR_ACTIVE_TAB_KEY = session_key("sidebar", "active_tab") _SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY = session_key( "sidebar", "remote_model_custom_value" ) _SIDEBAR_REMOTE_MODEL_CUSTOM_ENABLED_KEY = session_key( "sidebar", "remote_model_custom_enabled" ) _SIDEBAR_REMOTE_MODEL_KEY = session_key("sidebar", "remote_model") _SIDEBAR_LOCAL_MODEL_KEY = session_key("sidebar", "local_model") _SIDEBAR_REMOTE_KEY = session_key("sidebar", "remote") _SIDEBAR_DATASET_SOURCE_KEY = session_key("sidebar", "dataset_source") _SIDEBAR_NDIF_API_KEY = session_key("sidebar", "ndif_api_key") NDIF_REGISTRATION_URL = "https://login.ndif.us/" _TABS = ["Chat", "Analysis", "Probing", "Extract"] _TAB_ICONS = [ ":material/chat:", ":material/search:", ":material/biotech:", ":material/tune:", ] _TAB_PRELOAD_MODULES = { "Chat": ("tabs.analysis_core", "tabs.extract", "tabs.compare_chat", "tabs.probe"), "Analysis": ("tabs.chat", "tabs.extract", "tabs.probe"), "Probing": ("tabs.chat", "tabs.analysis_core", "tabs.extract"), "Extract": ("tabs.chat", "tabs.analysis_core", "tabs.probe"), } _TAB_PRELOAD_FUNCTIONS = { "Chat": ("utils.analysis_metadata:synth_persona_attribute_names",), "Probing": ("utils.analysis_metadata:synth_persona_attribute_names",), "Extract": ("utils.analysis_metadata:synth_persona_attribute_names",), } def _hub_metadata_preload_calls() -> tuple[ tuple[str, tuple[str, str, str, str | None]], ... ]: calls: list[tuple[str, tuple[str, str, str, str | None]]] = [] def add(repo: str, model: str, mask_strategy: str, variant: str | None) -> None: calls.append( ( "utils.analysis_sources:prefetch_hub_metadata", (repo, model, mask_strategy, variant), ) ) shared_source = st.session_state.get("source:last_source", SOURCE_HUB) shared_mask_strategy = st.session_state.get( "source:last_mask_strategy", "answer_mean" ) analysis_source = st.session_state.get("analysis:last_source", shared_source) if analysis_source == SOURCE_HUB: repo = st.session_state.get( "analysis:hub_repo", st.session_state.get("source:hub_repo", DEFAULT_HUB_REPO), ) mask_strategy = st.session_state.get( "analysis:last_mask_strategy", shared_mask_strategy, ) model = st.session_state.get( widget_key("load", "hub_model", repo, mask_strategy), st.session_state.get( "analysis:hub_model_fallback", st.session_state.get("source:hub_model", DEFAULT_COMPARE_MODEL), ), ) variant = st.session_state.get( "analysis:last_projection_variant", st.session_state.get("analysis:last_similarity_variant"), ) add(repo, model, mask_strategy, variant) probe_source = st.session_state.get(widget_key("probe", "source"), shared_source) if probe_source == SOURCE_HUB: repo = st.session_state.get( "probe:hub_repo", st.session_state.get("source:hub_repo", DEFAULT_HUB_REPO), ) mask_strategy = st.session_state.get( "probe:last_mask_strategy", shared_mask_strategy, ) model = st.session_state.get( widget_key("probe", "hub_model", repo, mask_strategy), st.session_state.get( "probe:hub_model_fallback", st.session_state.get("source:hub_model", DEFAULT_COMPARE_MODEL), ), ) add(repo, model, mask_strategy, st.session_state.get("probe:variant")) deduped: dict[tuple[str, tuple[str, str, str, str | None]], None] = {} for call in calls: deduped[call] = None return tuple(deduped) @dataclass(frozen=True) class SidebarState: remote: bool model_name: str dataset_source: str active_tab: str def _remote_model_input(remote_models: list[str]) -> str: """Return the active remote model id, picking from running NDIF deployments or a custom value.""" last_remote = st.session_state.get(_LAST_REMOTE_MODEL_KEY, REMOTE_DEFAULT_MODEL) if not remote_models: st.warning("No running NDIF models found.") model_name = st.text_input( "Model", value=st.session_state.get( _SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY, last_remote ), key=_SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY, help="NDIF model id. Use this to cold-load a remote model.", ) st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name return model_name custom = st.toggle( "Custom remote model", value=False, key=_SIDEBAR_REMOTE_MODEL_CUSTOM_ENABLED_KEY, help="Enter any NDIF-loadable model id, even if it is not currently running.", ) if custom: model_name = st.text_input( "Model", value=st.session_state.get( _SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY, last_remote ), key=_SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY, help="NDIF model id. Example: openai/gpt-oss-20b", ) st.caption( f"{len(remote_models)} running NDIF model(s) detected. " "Custom model ids can cold-load if your NDIF account allows it." ) else: default_model = st.session_state.get(_SIDEBAR_REMOTE_MODEL_KEY, last_remote) if default_model not in remote_models: default_model = ( REMOTE_DEFAULT_MODEL if REMOTE_DEFAULT_MODEL in remote_models else remote_models[0] ) model_name = st.selectbox( "Model", options=remote_models, index=remote_models.index(default_model), key=_SIDEBAR_REMOTE_MODEL_KEY, help="Running NDIF model.", ) st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name return model_name def _ndif_api_key_input() -> None: """Prompt for a per-session NDIF API key.""" if configured_ndif_api_key(): st.caption("Using NDIF API key from environment.") return api_key = st.text_input( "NDIF API key", type="password", key=_SIDEBAR_NDIF_API_KEY, help=f"Required for remote (NDIF) execution. Register at {NDIF_REGISTRATION_URL}", ) if not api_key: st.caption(f"No NDIF API key found. [Get one]({NDIF_REGISTRATION_URL}).") def _sidebar_controls() -> SidebarState: with st.sidebar: st.markdown("## Persona UI") if _SIDEBAR_ACTIVE_TAB_KEY not in st.session_state: st.session_state[_SIDEBAR_ACTIVE_TAB_KEY] = "Chat" active_tab = st.session_state[_SIDEBAR_ACTIVE_TAB_KEY] for tab_name, icon in zip(_TABS, _TAB_ICONS, strict=True): is_selected = tab_name == active_tab if st.button( tab_name, key=f"sidebar__tab__{tab_name.lower()}", width="stretch", type="primary" if is_selected else "secondary", icon=icon, ): st.session_state[_SIDEBAR_ACTIVE_TAB_KEY] = tab_name st.rerun() if active_tab in {"Analysis", "Probing"}: # These tabs select their own model in-tab. The global sidebar # only carries over the last local model id for breadcrumbs. model_name = st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL) dataset_source = st.session_state.get( _SIDEBAR_DATASET_SOURCE_KEY, DATASET_SOURCES[0], ) return SidebarState( remote=False, model_name=model_name, dataset_source=dataset_source, active_tab=active_tab, ) st.divider() st.caption("Runtime") _ndif_api_key_input() remote = st.toggle("Remote (NDIF)", value=False, key=_SIDEBAR_REMOTE_KEY) if remote: model_name = _remote_model_input(list_remote_models()) else: model_name = st.text_input( "Model", value=st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL), key=_SIDEBAR_LOCAL_MODEL_KEY, help="Local model id or path.", ) st.session_state[_LAST_LOCAL_MODEL_KEY] = model_name st.caption("Data") dataset_source = st.selectbox( "Source", DATASET_SOURCES, key=_SIDEBAR_DATASET_SOURCE_KEY, help="Dataset for Chat and Extract.", ) return SidebarState( remote=remote, model_name=model_name, dataset_source=dataset_source, active_tab=active_tab, ) def main() -> None: """Run the Streamlit app.""" st.set_page_config(page_title="Persona UI", layout="wide") install_catppuccin_theme(active_base()) sidebar = _sidebar_controls() if sidebar.active_tab == "Extract": from tabs.extract import render_extract_tab render_extract_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source) elif sidebar.active_tab == "Analysis": from tabs.analysis_core import render_analysis_tab render_analysis_tab() elif sidebar.active_tab == "Probing": from tabs.probe import render_probing_tab render_probing_tab() else: from tabs.chat import render_chat_tab render_chat_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source) preload_once( f"after-{sidebar.active_tab.lower()}", modules=_TAB_PRELOAD_MODULES.get(sidebar.active_tab, ()), functions=_TAB_PRELOAD_FUNCTIONS.get(sidebar.active_tab, ()), calls=_hub_metadata_preload_calls(), ) if __name__ == "__main__": main()