File size: 5,305 Bytes
9ba2da4
a89a7f1
ae347c6
9ba2da4
a89a7f1
 
 
ae347c6
c607869
a89a7f1
9ba2da4
 
c607869
ae347c6
9ba2da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a89a7f1
 
 
 
a9950fb
 
ae347c6
 
a9950fb
 
 
 
 
ae347c6
a89a7f1
 
ae347c6
a89a7f1
 
 
 
 
a9950fb
a89a7f1
9ba2da4
 
 
 
 
a89a7f1
a9950fb
 
 
 
 
 
 
a89a7f1
 
 
ae347c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1b8512
 
ae347c6
 
 
 
 
 
 
a1b8512
 
 
 
ae347c6
 
 
c607869
9ba2da4
a89a7f1
 
 
c607869
 
 
9ba2da4
c607869
 
a89a7f1
 
99c28ab
a89a7f1
 
99c28ab
 
a89a7f1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import json
import logging
import os
from collections.abc import Iterable

import streamlit as st

from utils.helpers import env_int, session_key

logger = logging.getLogger(__name__)
_LANGUAGE_MODEL_CLASSES = {"LanguageModel", "StandardizedTransformer"}
_EXPECTED_NDIF_STATES = {"RUNNING", "NOT DEPLOYED", "DEPLOYING", "DELETING"}
_MODEL_CACHE_ENTRIES = env_int("PERSONA_UI_MODEL_CACHE_ENTRIES", 1)
_SESSION_NDIF_API_KEY = session_key("sidebar", "ndif_api_key")


def _iter_deployments(raw: object) -> Iterable[dict]:
    if not isinstance(raw, dict):
        return ()
    deployments = raw.get("deployments", {})
    if not isinstance(deployments, dict):
        return ()
    return (value for value in deployments.values() if isinstance(value, dict))


def _is_visible_deployment(deployment: dict) -> bool:
    return deployment.get("deployment_level") in {"HOT", "WARM"} or (
        "schedule" in deployment
    )


def _repo_id_from_model_key(model_key: str) -> str:
    try:
        repo_id = json.loads(model_key.split(":", 1)[-1]).get("repo_id")
    except Exception:
        return model_key
    return repo_id if isinstance(repo_id, str) else model_key


def _running_language_model(deployment: dict) -> str | None:
    if not _is_visible_deployment(deployment):
        return None

    model_key = deployment.get("model_key", "")
    model_class = model_key.split(":", 1)[0].split(".")[-1]
    if model_class not in _LANGUAGE_MODEL_CLASSES:
        return None
    if deployment.get("application_state", "NOT DEPLOYED") != "RUNNING":
        return None
    return _repo_id_from_model_key(model_key)


def _unexpected_state(deployment: dict) -> tuple[str, str] | None:
    state = deployment.get("application_state", "NOT DEPLOYED")
    if state in _EXPECTED_NDIF_STATES:
        return None
    model_key = deployment.get("model_key", "")
    return _repo_id_from_model_key(model_key), state


@st.cache_data(show_spinner=False, ttl=30)
def list_remote_models() -> list[str]:
    """Return the NDIF language models that are currently running.

    Parses the raw NDIF response directly instead of going through the formatted
    ``nnsight.ndif.status()`` response because formatting crashes whenever NDIF reports
    any deployment with an ``application_state`` that isn't in nnsight's
    ``ModelStatus`` enum (e.g. ``UNHEALTHY``) — one bad deployment poisons
    the whole response. See nnsight 0.6.3 ``ndif.py::status``.
    """

    from nnsight.ndif import status

    try:
        raw = status(raw=True)
    except Exception:
        logger.warning("Failed to fetch NDIF status", exc_info=True)
        return []

    model_names: list[str] = []
    bad_states: list[tuple[str, str]] = []  # (repo_id_or_key, application_state)

    for deployment in _iter_deployments(raw):
        if bad_state := _unexpected_state(deployment):
            bad_states.append(bad_state)
        if model_name := _running_language_model(deployment):
            model_names.append(model_name)

    if bad_states:
        logger.warning(
            "NDIF reported deployments with unexpected application_state values "
            "(nnsight's ModelStatus enum may not know about these): %s",
            bad_states,
        )

    return sorted(set(model_names))


def session_ndif_api_key() -> str | None:
    """Return this visitor's NDIF key without touching process globals."""

    value = st.session_state.get(_SESSION_NDIF_API_KEY)
    return value if isinstance(value, str) and value else None


def configured_ndif_api_key() -> str | None:
    """Return an app-level NDIF key configured through the environment, if any."""

    value = os.environ.get("NDIF_API_KEY")
    return value if value else None


def remote_backend(model: object, api_key: str | None = None, *, on_status=None):
    """Build an NDIF backend with credentials bound to one browser session."""

    from nnsight.intervention.backends.remote import RemoteBackend
    from persona_vectors.activations import CallbackJobStatusDisplay

    active_key = api_key or session_ndif_api_key() or configured_ndif_api_key()
    if not active_key:
        raise RuntimeError("Enter your NDIF API key before using remote execution.")

    backend = RemoteBackend(model.to_model_key(), api_key=active_key)
    backend.CONNECT_TIMEOUT = 300.0
    if on_status is not None:
        backend.status_display = CallbackJobStatusDisplay(
            on_status, enabled=True, verbose=backend.verbose
        )
    return backend


@st.cache_resource(show_spinner=False, max_entries=_MODEL_CACHE_ENTRIES)
def cached_model(model_name: str):
    """Load and cache a standardized nnterp model.

    Streamlit reruns this app on every interaction, so caching keeps one loaded
    model instance instead of reloading weights on every widget change.
    ``remote`` is intentionally not part of the cache key: it matters at
    generation/trace time, but the current ``StandardizedTransformer``
    constructor ignores it, and excluding it avoids loading duplicate local
    model objects when toggling NDIF. The cache defaults to one model to avoid
    keeping multiple large models in RAM.
    """

    import torch
    from nnterp import StandardizedTransformer

    torch.set_grad_enabled(False)

    return StandardizedTransformer(model_name)