File size: 9,013 Bytes
656f91e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c50b87
656f91e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44ce4f7
656f91e
44ce4f7
 
 
 
 
656f91e
44ce4f7
656f91e
44ce4f7
 
656f91e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""Model registry + cached loader for serving real models behind the app (F006).

The Gradio app injects its policy through ``build_demo(policy_factory=...)``,
called once per ask. This module is the real-model side of that seam: a small
registry of selectable models and a cached loader that returns a ready ``Policy``.

FORMAT PARITY (non-negotiable): serving reuses ``ModelPolicy`` verbatim, which
routes every formatting decision through the single source of truth
``server/tooling.py`` β€” the same system prompt (``get_system_prompt``), the same
tool JSON schema (``get_tool_definitions``), the same chat-template rendering, and
the same ``<tool_call>{...}</tool_call>`` parser (``parse_action``). There is NO
second parser and NO alternate prompt here. Each model's ``enable_thinking`` is set
from its training config so ``/no_think`` matches training (all current: False).

CACHING (the pitfall): the factory is called once PER ASK. We must NOT reload
multi-GB weights per question. ``_LOADED`` caches ``(model, tokenizer)`` per key
(loaded at most once per process); ``get_policy`` returns a FRESH ``ModelPolicy``
each call wrapping the shared model β€” fresh per-episode transcript, cheap to build.

ZeroGPU: the load happens in the PARENT process (``preload_available`` at startup,
or lazily on first ask) so a per-call ``@spaces.GPU`` fork inherits the cache;
``_maybe_to_cuda`` is the only torch touch in the load path (see the serve plan's
ZeroGPU fork-and-cache note).

Dep-light at import: torch/transformers are pulled only when a real model is
actually loaded (``_load_model_and_tokenizer`` / ``_maybe_to_cuda`` import them
lazily), so importing this module β€” and building the selector UI β€” stays headless.
"""

from __future__ import annotations

from dataclasses import dataclass
import logging
from typing import Any

try:  # package import (canonical) then flat-layout / direct-run fallback
    from ..evaluation.model_policy import ModelPolicy
    from ..evaluation.policies import Policy
except ImportError:  # pragma: no cover - flat-layout fallback
    from evaluation.model_policy import ModelPolicy  # type: ignore[no-redef]
    from evaluation.policies import Policy  # type: ignore[no-redef]

logger = logging.getLogger(__name__)

# The no-model deterministic default. Handled by the app (it owns _DemoScriptPolicy +
# the active_db_id plumbing); never routed through get_policy (avoids a serving<->app_ui
# import cycle).
DEMO_KEY = "demo"


class ModelUnavailableError(RuntimeError):
    """A registered model is selected but not yet published/loadable."""


@dataclass(frozen=True)
class ModelSpec:
    """One selectable model.

    ``source`` is an HF repo id or a local dir (empty for the demo entry).
    ``enable_thinking`` MUST match how the model was trained (parity). ``available``
    is False for a model not yet on the Hub; flip it once F008 pushes it.
    """

    label: str
    source: str
    enable_thinking: bool = False
    revision: str | None = None
    is_demo: bool = False
    available: bool = True


# Ordered: the demo is first (the default). The fine-tuned slot is wired but disabled
# until F008 pushes it to the Hub β€” flipping ``available=True`` lights up the dropdown
# entry with no other change. All current models trained with enable_thinking=False
# (configs/modal_1p7b_fullft_v2.json: "enable_thinking": false), so /no_think matches.
MODEL_REGISTRY: dict[str, ModelSpec] = {
    DEMO_KEY: ModelSpec(
        label="Demo β€” scripted, no model (instant)",
        source="",
        is_demo=True,
    ),
    "qwen3-0.6b": ModelSpec(
        label="Qwen3-0.6B β€” vanilla (not fine-tuned)",
        source="Qwen/Qwen3-0.6B",
    ),
    "qwen3-1.7b": ModelSpec(
        label="Qwen3-1.7B β€” vanilla (not fine-tuned)",
        source="Qwen/Qwen3-1.7B",
    ),
    "sqlenv-1.7b-grpo-v2": ModelSpec(
        label="Qwen3-1.7B β€” fine-tuned on SQLEnv (GRPO v2)",
        source="hjerpe/sqlenv-qwen3-1.7b-grpo-v2",
        available=True,  # published 2026-06-14 to the Hub (public)
    ),
}

# Module-level cache: key -> (model, tokenizer). Populated lazily on first real request
# (or eagerly by preload_available). Never holds the demo key.
_LOADED: dict[str, tuple[Any, Any]] = {}


def default_model_key() -> str:
    """The dropdown's default selection β€” the instant, offline demo."""
    return DEMO_KEY


def get_spec(key: str) -> ModelSpec | None:
    """The ``ModelSpec`` for ``key`` (None if unknown)."""
    return MODEL_REGISTRY.get(key)


def dropdown_choices(*, include_unavailable: bool = True) -> list[tuple[str, str]]:
    """``(label, key)`` pairs for ``gr.Dropdown(choices=...)``.

    Demo first, then available models. Unavailable (not-yet-pushed) models are
    shown with a "coming soon" suffix when ``include_unavailable`` so the roadmap is
    visible; selecting one is gated in the app (and ``get_policy`` raises
    ``ModelUnavailableError``).
    """
    choices: list[tuple[str, str]] = []
    for key, spec in MODEL_REGISTRY.items():
        if spec.is_demo or spec.available:
            choices.append((spec.label, key))
        elif include_unavailable:
            choices.append((f"{spec.label} β€” coming soon", key))
    return choices


def _load_model_and_tokenizer(source: str, revision: str | None) -> tuple[Any, Any]:
    """Lazy wrapper over the shared loader (keeps training imports off the hot path)."""
    try:
        from ..training.data_loading import load_model_and_tokenizer
    except ImportError:  # pragma: no cover - flat-layout fallback
        from training.data_loading import (  # type: ignore[no-redef]
            load_model_and_tokenizer,
        )
    # "auto" loads in the checkpoint's native dtype (bf16 for our Qwen3 models) β€”
    # halves memory vs the fp32 default, which matters on the ZeroGPU parent that
    # holds every preloaded model. Tool-call FORMAT parity is unaffected by dtype.
    return load_model_and_tokenizer(source, revision=revision, torch_dtype="auto")


def _maybe_to_cuda(model: Any) -> Any:
    """Place the model on CUDA β€” the GPU-placement step of the load path.

    On a ZeroGPU Space, ``import spaces`` (done first in ``app.py``) enables a CUDA
    emulation, so moving to ``cuda`` at load time is the REQUIRED pattern; the model
    is materialized on the real GPU inside ``@spaces.GPU``. ``cuda.is_available()`` is
    False on the startup container, so we must NOT guard on it. On a CPU-only box
    ``.to("cuda")`` raises, so we fall back to CPU.
    """
    try:
        return model.to("cuda")
    except Exception:  # CPU fallback when no CUDA is available (e.g. local dev)
        return model


def _ensure_loaded(key: str) -> tuple[Any, Any]:
    """Load ``(model, tokenizer)`` for ``key`` once; cache and return it."""
    cached = _LOADED.get(key)
    if cached is not None:
        return cached
    spec = MODEL_REGISTRY[key]
    logger.info("Loading model %r from %s ...", key, spec.source)
    model, tokenizer = _load_model_and_tokenizer(spec.source, spec.revision)
    model = _maybe_to_cuda(model)
    model.eval()
    _LOADED[key] = (model, tokenizer)
    return _LOADED[key]


def get_policy(key: str) -> Policy:
    """Return a FRESH ``ModelPolicy`` for ``key`` (loads weights once, then caches).

    Cheap per call: ensures the shared ``(model, tokenizer)`` is loaded, then wraps
    it in a new ``ModelPolicy`` (fresh per-episode transcript). Raises for the demo
    key (built by the app) and for unavailable models.
    """
    spec = MODEL_REGISTRY.get(key)
    if spec is None:
        raise KeyError(f"Unknown model key: {key!r}")
    if spec.is_demo:
        raise ValueError("the demo policy is built by the app, not serving.get_policy")
    if not spec.available:
        raise ModelUnavailableError(
            f"Model {key!r} ({spec.source}) is not published yet β€” "
            "push it to the Hub (F008) and set available=True."
        )
    model, tokenizer = _ensure_loaded(key)
    return ModelPolicy(model, tokenizer, enable_thinking=spec.enable_thinking)


def preload_available() -> list[str]:
    """Eager-load every available real model into the cache (call at Space startup).

    Runs in the PARENT process so a later per-ask ``@spaces.GPU`` fork inherits the
    loaded weights (ZeroGPU does not persist state created inside the GPU call). A
    failed preload is logged and skipped β€” that model simply loads lazily on first
    ask. Returns the keys successfully preloaded.
    """
    loaded: list[str] = []
    for key, spec in MODEL_REGISTRY.items():
        if spec.is_demo or not spec.available:
            continue
        try:
            _ensure_loaded(key)
            loaded.append(key)
        except Exception:  # never let a bad preload crash startup
            logger.warning(
                "Preload failed for %r (%s); it will load lazily on first ask.",
                key,
                spec.source,
                exc_info=True,
            )
    return loaded