analyst-buddy / server /serving.py
hjerpe's picture
Deploy analyst-buddy (Gradio app + serving)
6c50b87 verified
Raw
History Blame Contribute Delete
9.01 kB
"""Model registry + cached loader for serving real models behind the app (F006).
The Gradio app injects its policy through ``build_demo(policy_factory=...)``,
called once per ask. This module is the real-model side of that seam: a small
registry of selectable models and a cached loader that returns a ready ``Policy``.
FORMAT PARITY (non-negotiable): serving reuses ``ModelPolicy`` verbatim, which
routes every formatting decision through the single source of truth
``server/tooling.py`` β€” the same system prompt (``get_system_prompt``), the same
tool JSON schema (``get_tool_definitions``), the same chat-template rendering, and
the same ``<tool_call>{...}</tool_call>`` parser (``parse_action``). There is NO
second parser and NO alternate prompt here. Each model's ``enable_thinking`` is set
from its training config so ``/no_think`` matches training (all current: False).
CACHING (the pitfall): the factory is called once PER ASK. We must NOT reload
multi-GB weights per question. ``_LOADED`` caches ``(model, tokenizer)`` per key
(loaded at most once per process); ``get_policy`` returns a FRESH ``ModelPolicy``
each call wrapping the shared model β€” fresh per-episode transcript, cheap to build.
ZeroGPU: the load happens in the PARENT process (``preload_available`` at startup,
or lazily on first ask) so a per-call ``@spaces.GPU`` fork inherits the cache;
``_maybe_to_cuda`` is the only torch touch in the load path (see the serve plan's
ZeroGPU fork-and-cache note).
Dep-light at import: torch/transformers are pulled only when a real model is
actually loaded (``_load_model_and_tokenizer`` / ``_maybe_to_cuda`` import them
lazily), so importing this module β€” and building the selector UI β€” stays headless.
"""
from __future__ import annotations
from dataclasses import dataclass
import logging
from typing import Any
try: # package import (canonical) then flat-layout / direct-run fallback
from ..evaluation.model_policy import ModelPolicy
from ..evaluation.policies import Policy
except ImportError: # pragma: no cover - flat-layout fallback
from evaluation.model_policy import ModelPolicy # type: ignore[no-redef]
from evaluation.policies import Policy # type: ignore[no-redef]
logger = logging.getLogger(__name__)
# The no-model deterministic default. Handled by the app (it owns _DemoScriptPolicy +
# the active_db_id plumbing); never routed through get_policy (avoids a serving<->app_ui
# import cycle).
DEMO_KEY = "demo"
class ModelUnavailableError(RuntimeError):
"""A registered model is selected but not yet published/loadable."""
@dataclass(frozen=True)
class ModelSpec:
"""One selectable model.
``source`` is an HF repo id or a local dir (empty for the demo entry).
``enable_thinking`` MUST match how the model was trained (parity). ``available``
is False for a model not yet on the Hub; flip it once F008 pushes it.
"""
label: str
source: str
enable_thinking: bool = False
revision: str | None = None
is_demo: bool = False
available: bool = True
# Ordered: the demo is first (the default). The fine-tuned slot is wired but disabled
# until F008 pushes it to the Hub β€” flipping ``available=True`` lights up the dropdown
# entry with no other change. All current models trained with enable_thinking=False
# (configs/modal_1p7b_fullft_v2.json: "enable_thinking": false), so /no_think matches.
MODEL_REGISTRY: dict[str, ModelSpec] = {
DEMO_KEY: ModelSpec(
label="Demo β€” scripted, no model (instant)",
source="",
is_demo=True,
),
"qwen3-0.6b": ModelSpec(
label="Qwen3-0.6B β€” vanilla (not fine-tuned)",
source="Qwen/Qwen3-0.6B",
),
"qwen3-1.7b": ModelSpec(
label="Qwen3-1.7B β€” vanilla (not fine-tuned)",
source="Qwen/Qwen3-1.7B",
),
"sqlenv-1.7b-grpo-v2": ModelSpec(
label="Qwen3-1.7B β€” fine-tuned on SQLEnv (GRPO v2)",
source="hjerpe/sqlenv-qwen3-1.7b-grpo-v2",
available=True, # published 2026-06-14 to the Hub (public)
),
}
# Module-level cache: key -> (model, tokenizer). Populated lazily on first real request
# (or eagerly by preload_available). Never holds the demo key.
_LOADED: dict[str, tuple[Any, Any]] = {}
def default_model_key() -> str:
"""The dropdown's default selection β€” the instant, offline demo."""
return DEMO_KEY
def get_spec(key: str) -> ModelSpec | None:
"""The ``ModelSpec`` for ``key`` (None if unknown)."""
return MODEL_REGISTRY.get(key)
def dropdown_choices(*, include_unavailable: bool = True) -> list[tuple[str, str]]:
"""``(label, key)`` pairs for ``gr.Dropdown(choices=...)``.
Demo first, then available models. Unavailable (not-yet-pushed) models are
shown with a "coming soon" suffix when ``include_unavailable`` so the roadmap is
visible; selecting one is gated in the app (and ``get_policy`` raises
``ModelUnavailableError``).
"""
choices: list[tuple[str, str]] = []
for key, spec in MODEL_REGISTRY.items():
if spec.is_demo or spec.available:
choices.append((spec.label, key))
elif include_unavailable:
choices.append((f"{spec.label} β€” coming soon", key))
return choices
def _load_model_and_tokenizer(source: str, revision: str | None) -> tuple[Any, Any]:
"""Lazy wrapper over the shared loader (keeps training imports off the hot path)."""
try:
from ..training.data_loading import load_model_and_tokenizer
except ImportError: # pragma: no cover - flat-layout fallback
from training.data_loading import ( # type: ignore[no-redef]
load_model_and_tokenizer,
)
# "auto" loads in the checkpoint's native dtype (bf16 for our Qwen3 models) β€”
# halves memory vs the fp32 default, which matters on the ZeroGPU parent that
# holds every preloaded model. Tool-call FORMAT parity is unaffected by dtype.
return load_model_and_tokenizer(source, revision=revision, torch_dtype="auto")
def _maybe_to_cuda(model: Any) -> Any:
"""Place the model on CUDA β€” the GPU-placement step of the load path.
On a ZeroGPU Space, ``import spaces`` (done first in ``app.py``) enables a CUDA
emulation, so moving to ``cuda`` at load time is the REQUIRED pattern; the model
is materialized on the real GPU inside ``@spaces.GPU``. ``cuda.is_available()`` is
False on the startup container, so we must NOT guard on it. On a CPU-only box
``.to("cuda")`` raises, so we fall back to CPU.
"""
try:
return model.to("cuda")
except Exception: # CPU fallback when no CUDA is available (e.g. local dev)
return model
def _ensure_loaded(key: str) -> tuple[Any, Any]:
"""Load ``(model, tokenizer)`` for ``key`` once; cache and return it."""
cached = _LOADED.get(key)
if cached is not None:
return cached
spec = MODEL_REGISTRY[key]
logger.info("Loading model %r from %s ...", key, spec.source)
model, tokenizer = _load_model_and_tokenizer(spec.source, spec.revision)
model = _maybe_to_cuda(model)
model.eval()
_LOADED[key] = (model, tokenizer)
return _LOADED[key]
def get_policy(key: str) -> Policy:
"""Return a FRESH ``ModelPolicy`` for ``key`` (loads weights once, then caches).
Cheap per call: ensures the shared ``(model, tokenizer)`` is loaded, then wraps
it in a new ``ModelPolicy`` (fresh per-episode transcript). Raises for the demo
key (built by the app) and for unavailable models.
"""
spec = MODEL_REGISTRY.get(key)
if spec is None:
raise KeyError(f"Unknown model key: {key!r}")
if spec.is_demo:
raise ValueError("the demo policy is built by the app, not serving.get_policy")
if not spec.available:
raise ModelUnavailableError(
f"Model {key!r} ({spec.source}) is not published yet β€” "
"push it to the Hub (F008) and set available=True."
)
model, tokenizer = _ensure_loaded(key)
return ModelPolicy(model, tokenizer, enable_thinking=spec.enable_thinking)
def preload_available() -> list[str]:
"""Eager-load every available real model into the cache (call at Space startup).
Runs in the PARENT process so a later per-ask ``@spaces.GPU`` fork inherits the
loaded weights (ZeroGPU does not persist state created inside the GPU call). A
failed preload is logged and skipped β€” that model simply loads lazily on first
ask. Returns the keys successfully preloaded.
"""
loaded: list[str] = []
for key, spec in MODEL_REGISTRY.items():
if spec.is_demo or not spec.available:
continue
try:
_ensure_loaded(key)
loaded.append(key)
except Exception: # never let a bad preload crash startup
logger.warning(
"Preload failed for %r (%s); it will load lazily on first ask.",
key,
spec.source,
exc_info=True,
)
return loaded