Spaces:
Sleeping
Sleeping
| import json | |
| import logging | |
| from collections.abc import Iterable | |
| import streamlit as st | |
| logger = logging.getLogger(__name__) | |
| _LANGUAGE_MODEL_CLASSES = {"LanguageModel", "StandardizedTransformer"} | |
| _EXPECTED_NDIF_STATES = {"RUNNING", "NOT DEPLOYED", "DEPLOYING", "DELETING"} | |
| def _iter_deployments(raw: object) -> Iterable[dict]: | |
| if not isinstance(raw, dict): | |
| return () | |
| deployments = raw.get("deployments", {}) | |
| if not isinstance(deployments, dict): | |
| return () | |
| return (value for value in deployments.values() if isinstance(value, dict)) | |
| def _is_visible_deployment(deployment: dict) -> bool: | |
| return deployment.get("deployment_level") in {"HOT", "WARM"} or ( | |
| "schedule" in deployment | |
| ) | |
| def _repo_id_from_model_key(model_key: str) -> str: | |
| try: | |
| repo_id = json.loads(model_key.split(":", 1)[-1]).get("repo_id") | |
| except Exception: | |
| return model_key | |
| return repo_id if isinstance(repo_id, str) else model_key | |
| def _running_language_model(deployment: dict) -> str | None: | |
| if not _is_visible_deployment(deployment): | |
| return None | |
| model_key = deployment.get("model_key", "") | |
| model_class = model_key.split(":", 1)[0].split(".")[-1] | |
| if model_class not in _LANGUAGE_MODEL_CLASSES: | |
| return None | |
| if deployment.get("application_state", "NOT DEPLOYED") != "RUNNING": | |
| return None | |
| return _repo_id_from_model_key(model_key) | |
| def _unexpected_state(deployment: dict) -> tuple[str, str] | None: | |
| state = deployment.get("application_state", "NOT DEPLOYED") | |
| if state in _EXPECTED_NDIF_STATES: | |
| return None | |
| model_key = deployment.get("model_key", "") | |
| return _repo_id_from_model_key(model_key), state | |
| def list_remote_models() -> list[str]: | |
| """Return the NDIF language models that are currently running. | |
| Parses the raw NDIF response directly instead of going through | |
| ``nnsight.ndif_status()`` because that call crashes whenever NDIF reports | |
| any deployment with an ``application_state`` that isn't in nnsight's | |
| ``ModelStatus`` enum (e.g. ``UNHEALTHY``) — one bad deployment poisons | |
| the whole response. See nnsight 0.6.3 ``ndif.py::status``. | |
| """ | |
| import nnsight | |
| try: | |
| raw = nnsight.ndif_status(raw=True) | |
| except Exception: | |
| logger.warning("Failed to fetch NDIF status", exc_info=True) | |
| return [] | |
| model_names: list[str] = [] | |
| bad_states: list[tuple[str, str]] = [] # (repo_id_or_key, application_state) | |
| for deployment in _iter_deployments(raw): | |
| if bad_state := _unexpected_state(deployment): | |
| bad_states.append(bad_state) | |
| if model_name := _running_language_model(deployment): | |
| model_names.append(model_name) | |
| if bad_states: | |
| logger.warning( | |
| "NDIF reported deployments with unexpected application_state values " | |
| "(nnsight's ModelStatus enum may not know about these): %s", | |
| bad_states, | |
| ) | |
| return sorted(set(model_names)) | |
| def cached_model(model_name: str): | |
| """Load and cache a standardized nnterp model. | |
| Streamlit reruns this app on every interaction, so caching keeps one loaded | |
| model instance per model name instead of reloading weights on every widget | |
| change. ``remote`` is intentionally not part of the cache key: it matters | |
| at generation/trace time, but the current ``StandardizedTransformer`` | |
| constructor ignores it, and excluding it avoids loading duplicate local | |
| model objects when toggling NDIF. | |
| """ | |
| from nnterp import StandardizedTransformer | |
| return StandardizedTransformer(model_name) | |