"""Model registry + cached loader for serving real models behind the app (F006). The Gradio app injects its policy through ``build_demo(policy_factory=...)``, called once per ask. This module is the real-model side of that seam: a small registry of selectable models and a cached loader that returns a ready ``Policy``. FORMAT PARITY (non-negotiable): serving reuses ``ModelPolicy`` verbatim, which routes every formatting decision through the single source of truth ``server/tooling.py`` — the same system prompt (``get_system_prompt``), the same tool JSON schema (``get_tool_definitions``), the same chat-template rendering, and the same ``{...}`` parser (``parse_action``). There is NO second parser and NO alternate prompt here. Each model's ``enable_thinking`` is set from its training config so ``/no_think`` matches training (all current: False). CACHING (the pitfall): the factory is called once PER ASK. We must NOT reload multi-GB weights per question. ``_LOADED`` caches ``(model, tokenizer)`` per key (loaded at most once per process); ``get_policy`` returns a FRESH ``ModelPolicy`` each call wrapping the shared model — fresh per-episode transcript, cheap to build. ZeroGPU: the load happens in the PARENT process (``preload_available`` at startup, or lazily on first ask) so a per-call ``@spaces.GPU`` fork inherits the cache; ``_maybe_to_cuda`` is the only torch touch in the load path (see the serve plan's ZeroGPU fork-and-cache note). Dep-light at import: torch/transformers are pulled only when a real model is actually loaded (``_load_model_and_tokenizer`` / ``_maybe_to_cuda`` import them lazily), so importing this module — and building the selector UI — stays headless. """ from __future__ import annotations from dataclasses import dataclass import logging from typing import Any try: # package import (canonical) then flat-layout / direct-run fallback from ..evaluation.model_policy import ModelPolicy from ..evaluation.policies import Policy except ImportError: # pragma: no cover - flat-layout fallback from evaluation.model_policy import ModelPolicy # type: ignore[no-redef] from evaluation.policies import Policy # type: ignore[no-redef] logger = logging.getLogger(__name__) # The no-model deterministic default. Handled by the app (it owns _DemoScriptPolicy + # the active_db_id plumbing); never routed through get_policy (avoids a serving<->app_ui # import cycle). DEMO_KEY = "demo" class ModelUnavailableError(RuntimeError): """A registered model is selected but not yet published/loadable.""" @dataclass(frozen=True) class ModelSpec: """One selectable model. ``source`` is an HF repo id or a local dir (empty for the demo entry). ``enable_thinking`` MUST match how the model was trained (parity). ``available`` is False for a model not yet on the Hub; flip it once F008 pushes it. """ label: str source: str enable_thinking: bool = False revision: str | None = None is_demo: bool = False available: bool = True # Ordered: the demo is first (the default). The fine-tuned slot is wired but disabled # until F008 pushes it to the Hub — flipping ``available=True`` lights up the dropdown # entry with no other change. All current models trained with enable_thinking=False # (configs/modal_1p7b_fullft_v2.json: "enable_thinking": false), so /no_think matches. MODEL_REGISTRY: dict[str, ModelSpec] = { DEMO_KEY: ModelSpec( label="Demo — scripted, no model (instant)", source="", is_demo=True, ), "qwen3-0.6b": ModelSpec( label="Qwen3-0.6B — vanilla (not fine-tuned)", source="Qwen/Qwen3-0.6B", ), "qwen3-1.7b": ModelSpec( label="Qwen3-1.7B — vanilla (not fine-tuned)", source="Qwen/Qwen3-1.7B", ), "sqlenv-1.7b-grpo-v2": ModelSpec( label="Qwen3-1.7B — fine-tuned on SQLEnv (GRPO v2)", source="hjerpe/sqlenv-qwen3-1.7b-grpo-v2", available=True, # published 2026-06-14 to the Hub (public) ), } # Module-level cache: key -> (model, tokenizer). Populated lazily on first real request # (or eagerly by preload_available). Never holds the demo key. _LOADED: dict[str, tuple[Any, Any]] = {} def default_model_key() -> str: """The dropdown's default selection — the instant, offline demo.""" return DEMO_KEY def get_spec(key: str) -> ModelSpec | None: """The ``ModelSpec`` for ``key`` (None if unknown).""" return MODEL_REGISTRY.get(key) def dropdown_choices(*, include_unavailable: bool = True) -> list[tuple[str, str]]: """``(label, key)`` pairs for ``gr.Dropdown(choices=...)``. Demo first, then available models. Unavailable (not-yet-pushed) models are shown with a "coming soon" suffix when ``include_unavailable`` so the roadmap is visible; selecting one is gated in the app (and ``get_policy`` raises ``ModelUnavailableError``). """ choices: list[tuple[str, str]] = [] for key, spec in MODEL_REGISTRY.items(): if spec.is_demo or spec.available: choices.append((spec.label, key)) elif include_unavailable: choices.append((f"{spec.label} — coming soon", key)) return choices def _load_model_and_tokenizer(source: str, revision: str | None) -> tuple[Any, Any]: """Lazy wrapper over the shared loader (keeps training imports off the hot path).""" try: from ..training.data_loading import load_model_and_tokenizer except ImportError: # pragma: no cover - flat-layout fallback from training.data_loading import ( # type: ignore[no-redef] load_model_and_tokenizer, ) # "auto" loads in the checkpoint's native dtype (bf16 for our Qwen3 models) — # halves memory vs the fp32 default, which matters on the ZeroGPU parent that # holds every preloaded model. Tool-call FORMAT parity is unaffected by dtype. return load_model_and_tokenizer(source, revision=revision, torch_dtype="auto") def _maybe_to_cuda(model: Any) -> Any: """Place the model on CUDA — the GPU-placement step of the load path. On a ZeroGPU Space, ``import spaces`` (done first in ``app.py``) enables a CUDA emulation, so moving to ``cuda`` at load time is the REQUIRED pattern; the model is materialized on the real GPU inside ``@spaces.GPU``. ``cuda.is_available()`` is False on the startup container, so we must NOT guard on it. On a CPU-only box ``.to("cuda")`` raises, so we fall back to CPU. """ try: return model.to("cuda") except Exception: # CPU fallback when no CUDA is available (e.g. local dev) return model def _ensure_loaded(key: str) -> tuple[Any, Any]: """Load ``(model, tokenizer)`` for ``key`` once; cache and return it.""" cached = _LOADED.get(key) if cached is not None: return cached spec = MODEL_REGISTRY[key] logger.info("Loading model %r from %s ...", key, spec.source) model, tokenizer = _load_model_and_tokenizer(spec.source, spec.revision) model = _maybe_to_cuda(model) model.eval() _LOADED[key] = (model, tokenizer) return _LOADED[key] def get_policy(key: str) -> Policy: """Return a FRESH ``ModelPolicy`` for ``key`` (loads weights once, then caches). Cheap per call: ensures the shared ``(model, tokenizer)`` is loaded, then wraps it in a new ``ModelPolicy`` (fresh per-episode transcript). Raises for the demo key (built by the app) and for unavailable models. """ spec = MODEL_REGISTRY.get(key) if spec is None: raise KeyError(f"Unknown model key: {key!r}") if spec.is_demo: raise ValueError("the demo policy is built by the app, not serving.get_policy") if not spec.available: raise ModelUnavailableError( f"Model {key!r} ({spec.source}) is not published yet — " "push it to the Hub (F008) and set available=True." ) model, tokenizer = _ensure_loaded(key) return ModelPolicy(model, tokenizer, enable_thinking=spec.enable_thinking) def preload_available() -> list[str]: """Eager-load every available real model into the cache (call at Space startup). Runs in the PARENT process so a later per-ask ``@spaces.GPU`` fork inherits the loaded weights (ZeroGPU does not persist state created inside the GPU call). A failed preload is logged and skipped — that model simply loads lazily on first ask. Returns the keys successfully preloaded. """ loaded: list[str] = [] for key, spec in MODEL_REGISTRY.items(): if spec.is_demo or not spec.available: continue try: _ensure_loaded(key) loaded.append(key) except Exception: # never let a bad preload crash startup logger.warning( "Preload failed for %r (%s); it will load lazily on first ask.", key, spec.source, exc_info=True, ) return loaded