File size: 5,305 Bytes
9ba2da4 a89a7f1 ae347c6 9ba2da4 a89a7f1 ae347c6 c607869 a89a7f1 9ba2da4 c607869 ae347c6 9ba2da4 a89a7f1 a9950fb ae347c6 a9950fb ae347c6 a89a7f1 ae347c6 a89a7f1 a9950fb a89a7f1 9ba2da4 a89a7f1 a9950fb a89a7f1 ae347c6 a1b8512 ae347c6 a1b8512 ae347c6 c607869 9ba2da4 a89a7f1 c607869 9ba2da4 c607869 a89a7f1 99c28ab a89a7f1 99c28ab a89a7f1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | import json
import logging
import os
from collections.abc import Iterable
import streamlit as st
from utils.helpers import env_int, session_key
logger = logging.getLogger(__name__)
_LANGUAGE_MODEL_CLASSES = {"LanguageModel", "StandardizedTransformer"}
_EXPECTED_NDIF_STATES = {"RUNNING", "NOT DEPLOYED", "DEPLOYING", "DELETING"}
_MODEL_CACHE_ENTRIES = env_int("PERSONA_UI_MODEL_CACHE_ENTRIES", 1)
_SESSION_NDIF_API_KEY = session_key("sidebar", "ndif_api_key")
def _iter_deployments(raw: object) -> Iterable[dict]:
if not isinstance(raw, dict):
return ()
deployments = raw.get("deployments", {})
if not isinstance(deployments, dict):
return ()
return (value for value in deployments.values() if isinstance(value, dict))
def _is_visible_deployment(deployment: dict) -> bool:
return deployment.get("deployment_level") in {"HOT", "WARM"} or (
"schedule" in deployment
)
def _repo_id_from_model_key(model_key: str) -> str:
try:
repo_id = json.loads(model_key.split(":", 1)[-1]).get("repo_id")
except Exception:
return model_key
return repo_id if isinstance(repo_id, str) else model_key
def _running_language_model(deployment: dict) -> str | None:
if not _is_visible_deployment(deployment):
return None
model_key = deployment.get("model_key", "")
model_class = model_key.split(":", 1)[0].split(".")[-1]
if model_class not in _LANGUAGE_MODEL_CLASSES:
return None
if deployment.get("application_state", "NOT DEPLOYED") != "RUNNING":
return None
return _repo_id_from_model_key(model_key)
def _unexpected_state(deployment: dict) -> tuple[str, str] | None:
state = deployment.get("application_state", "NOT DEPLOYED")
if state in _EXPECTED_NDIF_STATES:
return None
model_key = deployment.get("model_key", "")
return _repo_id_from_model_key(model_key), state
@st.cache_data(show_spinner=False, ttl=30)
def list_remote_models() -> list[str]:
"""Return the NDIF language models that are currently running.
Parses the raw NDIF response directly instead of going through the formatted
``nnsight.ndif.status()`` response because formatting crashes whenever NDIF reports
any deployment with an ``application_state`` that isn't in nnsight's
``ModelStatus`` enum (e.g. ``UNHEALTHY``) — one bad deployment poisons
the whole response. See nnsight 0.6.3 ``ndif.py::status``.
"""
from nnsight.ndif import status
try:
raw = status(raw=True)
except Exception:
logger.warning("Failed to fetch NDIF status", exc_info=True)
return []
model_names: list[str] = []
bad_states: list[tuple[str, str]] = [] # (repo_id_or_key, application_state)
for deployment in _iter_deployments(raw):
if bad_state := _unexpected_state(deployment):
bad_states.append(bad_state)
if model_name := _running_language_model(deployment):
model_names.append(model_name)
if bad_states:
logger.warning(
"NDIF reported deployments with unexpected application_state values "
"(nnsight's ModelStatus enum may not know about these): %s",
bad_states,
)
return sorted(set(model_names))
def session_ndif_api_key() -> str | None:
"""Return this visitor's NDIF key without touching process globals."""
value = st.session_state.get(_SESSION_NDIF_API_KEY)
return value if isinstance(value, str) and value else None
def configured_ndif_api_key() -> str | None:
"""Return an app-level NDIF key configured through the environment, if any."""
value = os.environ.get("NDIF_API_KEY")
return value if value else None
def remote_backend(model: object, api_key: str | None = None, *, on_status=None):
"""Build an NDIF backend with credentials bound to one browser session."""
from nnsight.intervention.backends.remote import RemoteBackend
from persona_vectors.activations import CallbackJobStatusDisplay
active_key = api_key or session_ndif_api_key() or configured_ndif_api_key()
if not active_key:
raise RuntimeError("Enter your NDIF API key before using remote execution.")
backend = RemoteBackend(model.to_model_key(), api_key=active_key)
backend.CONNECT_TIMEOUT = 300.0
if on_status is not None:
backend.status_display = CallbackJobStatusDisplay(
on_status, enabled=True, verbose=backend.verbose
)
return backend
@st.cache_resource(show_spinner=False, max_entries=_MODEL_CACHE_ENTRIES)
def cached_model(model_name: str):
"""Load and cache a standardized nnterp model.
Streamlit reruns this app on every interaction, so caching keeps one loaded
model instance instead of reloading weights on every widget change.
``remote`` is intentionally not part of the cache key: it matters at
generation/trace time, but the current ``StandardizedTransformer``
constructor ignores it, and excluding it avoids loading duplicate local
model objects when toggling NDIF. The cache defaults to one model to avoid
keeping multiple large models in RAM.
"""
import torch
from nnterp import StandardizedTransformer
torch.set_grad_enabled(False)
return StandardizedTransformer(model_name)
|