| import json |
| import logging |
| import os |
| from collections.abc import Iterable |
|
|
| import streamlit as st |
|
|
| from utils.helpers import env_int, session_key |
|
|
| logger = logging.getLogger(__name__) |
| _LANGUAGE_MODEL_CLASSES = {"LanguageModel", "StandardizedTransformer"} |
| _EXPECTED_NDIF_STATES = {"RUNNING", "NOT DEPLOYED", "DEPLOYING", "DELETING"} |
| _MODEL_CACHE_ENTRIES = env_int("PERSONA_UI_MODEL_CACHE_ENTRIES", 1) |
| _SESSION_NDIF_API_KEY = session_key("sidebar", "ndif_api_key") |
|
|
|
|
| def _iter_deployments(raw: object) -> Iterable[dict]: |
| if not isinstance(raw, dict): |
| return () |
| deployments = raw.get("deployments", {}) |
| if not isinstance(deployments, dict): |
| return () |
| return (value for value in deployments.values() if isinstance(value, dict)) |
|
|
|
|
| def _is_visible_deployment(deployment: dict) -> bool: |
| return deployment.get("deployment_level") in {"HOT", "WARM"} or ( |
| "schedule" in deployment |
| ) |
|
|
|
|
| def _repo_id_from_model_key(model_key: str) -> str: |
| try: |
| repo_id = json.loads(model_key.split(":", 1)[-1]).get("repo_id") |
| except Exception: |
| return model_key |
| return repo_id if isinstance(repo_id, str) else model_key |
|
|
|
|
| def _running_language_model(deployment: dict) -> str | None: |
| if not _is_visible_deployment(deployment): |
| return None |
|
|
| model_key = deployment.get("model_key", "") |
| model_class = model_key.split(":", 1)[0].split(".")[-1] |
| if model_class not in _LANGUAGE_MODEL_CLASSES: |
| return None |
| if deployment.get("application_state", "NOT DEPLOYED") != "RUNNING": |
| return None |
| return _repo_id_from_model_key(model_key) |
|
|
|
|
| def _unexpected_state(deployment: dict) -> tuple[str, str] | None: |
| state = deployment.get("application_state", "NOT DEPLOYED") |
| if state in _EXPECTED_NDIF_STATES: |
| return None |
| model_key = deployment.get("model_key", "") |
| return _repo_id_from_model_key(model_key), state |
|
|
|
|
| @st.cache_data(show_spinner=False, ttl=30) |
| def list_remote_models() -> list[str]: |
| """Return the NDIF language models that are currently running. |
| |
| Parses the raw NDIF response directly instead of going through the formatted |
| ``nnsight.ndif.status()`` response because formatting crashes whenever NDIF reports |
| any deployment with an ``application_state`` that isn't in nnsight's |
| ``ModelStatus`` enum (e.g. ``UNHEALTHY``) — one bad deployment poisons |
| the whole response. See nnsight 0.6.3 ``ndif.py::status``. |
| """ |
|
|
| from nnsight.ndif import status |
|
|
| try: |
| raw = status(raw=True) |
| except Exception: |
| logger.warning("Failed to fetch NDIF status", exc_info=True) |
| return [] |
|
|
| model_names: list[str] = [] |
| bad_states: list[tuple[str, str]] = [] |
|
|
| for deployment in _iter_deployments(raw): |
| if bad_state := _unexpected_state(deployment): |
| bad_states.append(bad_state) |
| if model_name := _running_language_model(deployment): |
| model_names.append(model_name) |
|
|
| if bad_states: |
| logger.warning( |
| "NDIF reported deployments with unexpected application_state values " |
| "(nnsight's ModelStatus enum may not know about these): %s", |
| bad_states, |
| ) |
|
|
| return sorted(set(model_names)) |
|
|
|
|
| def session_ndif_api_key() -> str | None: |
| """Return this visitor's NDIF key without touching process globals.""" |
|
|
| value = st.session_state.get(_SESSION_NDIF_API_KEY) |
| return value if isinstance(value, str) and value else None |
|
|
|
|
| def configured_ndif_api_key() -> str | None: |
| """Return an app-level NDIF key configured through the environment, if any.""" |
|
|
| value = os.environ.get("NDIF_API_KEY") |
| return value if value else None |
|
|
|
|
| def remote_backend(model: object, api_key: str | None = None, *, on_status=None): |
| """Build an NDIF backend with credentials bound to one browser session.""" |
|
|
| from nnsight.intervention.backends.remote import JobStatusDisplay, RemoteBackend |
|
|
| active_key = api_key or session_ndif_api_key() or configured_ndif_api_key() |
| if not active_key: |
| raise RuntimeError("Enter your NDIF API key before using remote execution.") |
|
|
| backend = RemoteBackend(model.to_model_key(), api_key=active_key) |
| backend.CONNECT_TIMEOUT = 300.0 |
| if on_status is None: |
| return backend |
|
|
| class _CallbackJobStatusDisplay(JobStatusDisplay): |
| def update( |
| self, |
| job_id: str = "", |
| status_name: str = "", |
| description: str = "", |
| ): |
| super().update(job_id, status_name, description) |
| if status_name: |
| on_status(job_id, status_name, description) |
|
|
| backend.status_display = _CallbackJobStatusDisplay( |
| enabled=True, |
| verbose=backend.verbose, |
| ) |
| return backend |
|
|
|
|
| @st.cache_resource(show_spinner=False, max_entries=_MODEL_CACHE_ENTRIES) |
| def cached_model(model_name: str): |
| """Load and cache a standardized nnterp model. |
| |
| Streamlit reruns this app on every interaction, so caching keeps one loaded |
| model instance instead of reloading weights on every widget change. |
| ``remote`` is intentionally not part of the cache key: it matters at |
| generation/trace time, but the current ``StandardizedTransformer`` |
| constructor ignores it, and excluding it avoids loading duplicate local |
| model objects when toggling NDIF. The cache defaults to one model to avoid |
| keeping multiple large models in RAM. |
| """ |
|
|
| import torch |
| from nnterp import StandardizedTransformer |
|
|
| torch.set_grad_enabled(False) |
|
|
| return StandardizedTransformer(model_name) |
|
|