persona-ui / utils /runtime.py
Jac-Zac
add session-scoped NDIF execution and improve cold-load UX
ae347c6
import json
import logging
import os
from collections.abc import Iterable
import streamlit as st
from utils.helpers import env_int, session_key
logger = logging.getLogger(__name__)
_LANGUAGE_MODEL_CLASSES = {"LanguageModel", "StandardizedTransformer"}
_EXPECTED_NDIF_STATES = {"RUNNING", "NOT DEPLOYED", "DEPLOYING", "DELETING"}
_MODEL_CACHE_ENTRIES = env_int("PERSONA_UI_MODEL_CACHE_ENTRIES", 1)
_SESSION_NDIF_API_KEY = session_key("sidebar", "ndif_api_key")
def _iter_deployments(raw: object) -> Iterable[dict]:
if not isinstance(raw, dict):
return ()
deployments = raw.get("deployments", {})
if not isinstance(deployments, dict):
return ()
return (value for value in deployments.values() if isinstance(value, dict))
def _is_visible_deployment(deployment: dict) -> bool:
return deployment.get("deployment_level") in {"HOT", "WARM"} or (
"schedule" in deployment
)
def _repo_id_from_model_key(model_key: str) -> str:
try:
repo_id = json.loads(model_key.split(":", 1)[-1]).get("repo_id")
except Exception:
return model_key
return repo_id if isinstance(repo_id, str) else model_key
def _running_language_model(deployment: dict) -> str | None:
if not _is_visible_deployment(deployment):
return None
model_key = deployment.get("model_key", "")
model_class = model_key.split(":", 1)[0].split(".")[-1]
if model_class not in _LANGUAGE_MODEL_CLASSES:
return None
if deployment.get("application_state", "NOT DEPLOYED") != "RUNNING":
return None
return _repo_id_from_model_key(model_key)
def _unexpected_state(deployment: dict) -> tuple[str, str] | None:
state = deployment.get("application_state", "NOT DEPLOYED")
if state in _EXPECTED_NDIF_STATES:
return None
model_key = deployment.get("model_key", "")
return _repo_id_from_model_key(model_key), state
@st.cache_data(show_spinner=False, ttl=30)
def list_remote_models() -> list[str]:
"""Return the NDIF language models that are currently running.
Parses the raw NDIF response directly instead of going through the formatted
``nnsight.ndif.status()`` response because formatting crashes whenever NDIF reports
any deployment with an ``application_state`` that isn't in nnsight's
``ModelStatus`` enum (e.g. ``UNHEALTHY``) — one bad deployment poisons
the whole response. See nnsight 0.6.3 ``ndif.py::status``.
"""
from nnsight.ndif import status
try:
raw = status(raw=True)
except Exception:
logger.warning("Failed to fetch NDIF status", exc_info=True)
return []
model_names: list[str] = []
bad_states: list[tuple[str, str]] = [] # (repo_id_or_key, application_state)
for deployment in _iter_deployments(raw):
if bad_state := _unexpected_state(deployment):
bad_states.append(bad_state)
if model_name := _running_language_model(deployment):
model_names.append(model_name)
if bad_states:
logger.warning(
"NDIF reported deployments with unexpected application_state values "
"(nnsight's ModelStatus enum may not know about these): %s",
bad_states,
)
return sorted(set(model_names))
def session_ndif_api_key() -> str | None:
"""Return this visitor's NDIF key without touching process globals."""
value = st.session_state.get(_SESSION_NDIF_API_KEY)
return value if isinstance(value, str) and value else None
def configured_ndif_api_key() -> str | None:
"""Return an app-level NDIF key configured through the environment, if any."""
value = os.environ.get("NDIF_API_KEY")
return value if value else None
def remote_backend(model: object, api_key: str | None = None, *, on_status=None):
"""Build an NDIF backend with credentials bound to one browser session."""
from nnsight.intervention.backends.remote import JobStatusDisplay, RemoteBackend
active_key = api_key or session_ndif_api_key() or configured_ndif_api_key()
if not active_key:
raise RuntimeError("Enter your NDIF API key before using remote execution.")
backend = RemoteBackend(model.to_model_key(), api_key=active_key)
backend.CONNECT_TIMEOUT = 300.0
if on_status is None:
return backend
class _CallbackJobStatusDisplay(JobStatusDisplay):
def update(
self,
job_id: str = "",
status_name: str = "",
description: str = "",
):
super().update(job_id, status_name, description)
if status_name:
on_status(job_id, status_name, description)
backend.status_display = _CallbackJobStatusDisplay(
enabled=True,
verbose=backend.verbose,
)
return backend
@st.cache_resource(show_spinner=False, max_entries=_MODEL_CACHE_ENTRIES)
def cached_model(model_name: str):
"""Load and cache a standardized nnterp model.
Streamlit reruns this app on every interaction, so caching keeps one loaded
model instance instead of reloading weights on every widget change.
``remote`` is intentionally not part of the cache key: it matters at
generation/trace time, but the current ``StandardizedTransformer``
constructor ignores it, and excluding it avoids loading duplicate local
model objects when toggling NDIF. The cache defaults to one model to avoid
keeping multiple large models in RAM.
"""
import torch
from nnterp import StandardizedTransformer
torch.set_grad_enabled(False)
return StandardizedTransformer(model_name)