multi-agent-lab / src /models /local_provider.py
agharsallah
feat(media): introduce MediaRouter and stubs for image and speech generation
8400d8c
Raw
History Blame Contribute Delete
18.3 kB
"""In-process transformers provider β€” the local-GPU transport for the ``local`` backend.
This is the *serving* side of the local backend whose catalogue lives in
:mod:`src.models.local_catalogue`. Where :class:`~src.models.litellm_provider.LiteLLMProvider`
calls a model over an OpenAI-compatible HTTP endpoint, this provider runs a small
``transformers`` model **in the same process, on the host's own GPU**, behind a
``@spaces.GPU`` function β€” so a Hugging Face Space serves the cast on its own hardware
with no endpoint to deploy and no token to hold.
It is hardware-agnostic (ADR-0033). ``@spaces.GPU`` is **effect-free off ZeroGPU**, so the
one decorated ``_generate`` covers every flavour:
* **ZeroGPU** β€” the decorator allocates a GPU for the call and releases it after.
* **Dedicated GPU / local CUDA** β€” the decorator is a passthrough; the model runs on
the persistent GPU.
**Two phases, split across the ZeroGPU fork.** ZeroGPU grants a real GPU *only* for the
duration of a ``@spaces.GPU`` call (each call runs in a forked worker); the parent process
never gets one, and any low-level CUDA init outside such a call kills the process. So the
work is split:
* **Parent β€” download only, never CUDA** (:func:`_ensure_downloaded`): fetch the repo's
weights to the on-disk HF cache with ``snapshot_download``. This pays the network cost
once, in the resilient parent, so the short GPU window never spends its budget pulling
gigabytes. It deliberately does **not** materialise the model in host RAM β€” a cast of
four 8–12B models would otherwise pin ~60GB of parent RAM for the whole show.
* **Worker β€” load straight onto the GPU** (:func:`_ensure_loaded_on_device`): inside the
granted window, ``from_pretrained(device_map={"": 0}, local_files_only=True)`` lets
transformers + accelerate **materialise and place** every weight, tied head, and
non-persistent buffer directly on the device in one atomic step, then caches the
device-resident model per repo (a reused worker β€” and any dedicated GPU β€” keeps it
resident across calls).
**Why ``device_map`` and not a manual ``.to("cuda")``.** transformers 5.x always builds the
model on the ``meta`` device and streams the checkpoint onto the target. A bare
``from_pretrained(...).to("cuda")`` leaves a model whose non-persistent buffers (e.g. a
rotary ``inv_freq``) or a tied/"missing" head can still sit on ``meta``, and the later
``.to("cuda")`` then dies with *"Cannot copy out of meta tensor; no data!"*
(transformers#41038/#30703) β€” and ``low_cpu_mem_usage`` no longer changes this (5.x drops
the kwarg outright). Handing transformers the device via ``device_map`` is the supported
path: ``_move_missing_keys_from_meta_to_device`` places the buffers and missing keys on the
mapped device and ``initialize_weights``/``tie_weights`` run there, so **nothing is ever
left on meta** and there is no fragile post-hoc move. This needs ``accelerate`` (a declared
dep); the kwarg-only fallbacks keep older transformers working.
Heavy imports (``torch`` / ``transformers``) are lazy β€” confined to the functions that
need them β€” so importing this module never initialises CUDA (which would trip ZeroGPU's
fork guard) and the offline path never pays for them. ``spaces`` itself is import-safe
everywhere. ``complete`` returns the failure sentinel on any error (never raises), exactly
like the HTTP provider, so the conductor's resilient loop treats a local-inference hiccup
the same as a flaky endpoint.
"""
from __future__ import annotations
from dataclasses import dataclass, field
import spaces # import-safe everywhere (effect-free off ZeroGPU); needed for @spaces.GPU
from src import observability as obs
from src.models.openai_compat import OpenAICompatProvider
from src.models.provider import ModelProvider, estimate_tokens, model_error
# Device-resident models, keyed by repo id: ``repo_id -> (tokenizer, model)``. Populated in
# the forked ``@spaces.GPU`` worker by :func:`_ensure_loaded_on_device`, so a reused worker
# (and any dedicated GPU) keeps the model resident across calls. Module-level so the cache
# survives across provider instances and across ticks of a show.
_LOADED: dict[str, tuple] = {}
# Repo ids whose weights have been fetched to the on-disk HF cache by :func:`_ensure_downloaded`
# in the *parent*. A set, not a model cache β€” the parent holds no weights in RAM (see the
# module docstring); it only records "this repo is on disk" so we skip the network re-check.
_DOWNLOADED: set[str] = set()
def _always_true(*_args, **_kwargs) -> bool:
return True
# v4-era capability predicates that transformers 5.x removed but Hub ``trust_remote_code``
# modelling files still import (e.g. MiniCPM's modeling_minicpm.py does
# ``from transformers.utils.import_utils import is_torch_fx_available``). All of these are
# unconditionally True at this project's torch>=2.8 floor β€” exactly the value the
# transformers maintainers say is now correct (transformers#44561) β€” so back-filling them
# lets such remote code import instead of crashing with ``cannot import name '…'``.
_REMOVED_TORCH_PREDICATES = ("is_torch_fx_available", "is_torch_sdpa_available")
def _ensure_transformers_v4_symbols() -> None:
"""Restore removed v4-era predicates onto ``transformers.utils`` so older Hub remote
code (loaded via ``trust_remote_code``) imports cleanly. Idempotent β€” only fills a name
that is genuinely absent, so it never shadows a function transformers still ships."""
try:
import transformers.utils as tu
from transformers.utils import import_utils
except Exception: # pragma: no cover - transformers absent β†’ offline path, nothing to do
return
for mod in (import_utils, tu):
for name in _REMOVED_TORCH_PREDICATES:
if not hasattr(mod, name):
setattr(mod, name, _always_true)
def _ensure_downloaded(repo_id: str, trust_remote_code: bool) -> None:
"""Fetch *repo_id*'s files to the on-disk HF cache **in the parent**, without CUDA.
Called from :meth:`LocalTransformersProvider.complete` in the parent process. It pulls
the weights (and, for ``trust_remote_code`` repos, the modelling ``.py`` files) over the
network *once*, so the later ``@spaces.GPU`` window β€” where the GPU budget is scarce β€”
loads from a warm local cache instead of downloading. It deliberately **never touches
CUDA** (under ZeroGPU the parent gets no GPU) and **never materialises the model** in
host RAM: a multi-model cast would otherwise pin tens of GB of parent RAM for the whole
show. Cached via :data:`_DOWNLOADED` so repeated turns skip the network revision check.
Errors propagate to :meth:`complete`, which turns them into the resilient failure
sentinel β€” and crucially we never spend a GPU window on a model whose weights are
missing.
"""
if repo_id in _DOWNLOADED:
return
from huggingface_hub import snapshot_download
# Honour the same gate transformers does (HF_TOKEN for gated repos like Aya); the Space
# sets it in the environment. snapshot_download is a no-op-ish revision check once cached.
snapshot_download(repo_id)
_DOWNLOADED.add(repo_id)
def _ensure_loaded_on_device(repo_id: str, trust_remote_code: bool) -> tuple:
"""Load (once, cached) the tokenizer + model **directly onto the GPU** for *repo_id*.
Runs inside the decorated :func:`_generate`, where ZeroGPU has granted a real device.
``device_map={"": 0}`` hands transformers the placement so it **materialises and places**
every weight, tied head and non-persistent buffer on the device in one step β€” the
supported path that leaves nothing on the ``meta`` device for a later move to choke on
(see the module docstring). ``local_files_only=True`` keeps the GPU window off the
network: the parent already fetched the repo (:func:`_ensure_downloaded`), so a missing
file fails fast here rather than burning the budget on a download. ``dtype="auto"`` keeps
the checkpoint's native precision (falling back to the legacy ``torch_dtype`` kwarg name
on older transformers).
Off a GPU (a misconfigured call β€” :func:`_generate` is normally gated behind a device)
it degrades to a plain CPU load so the provider still answers rather than crashing.
"""
if repo_id in _LOADED:
return _LOADED[repo_id]
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Back-fill v4-era symbols removed in transformers 5.x before any trust_remote_code
# modelling file is imported (tokenizer or model), or it crashes at import time.
_ensure_transformers_v4_symbols()
tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=trust_remote_code, local_files_only=True)
# device_map places the model on the granted GPU; CPU is the degenerate off-GPU fallback.
device_map = {"": 0} if torch.cuda.is_available() else {"": "cpu"}
load_kwargs = dict(device_map=device_map, trust_remote_code=trust_remote_code, local_files_only=True)
try:
model = AutoModelForCausalLM.from_pretrained(repo_id, dtype="auto", **load_kwargs)
except TypeError: # pragma: no cover - older transformers use the torch_dtype kwarg name
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype="auto", **load_kwargs)
model.eval()
_LOADED[repo_id] = (tokenizer, model)
return _LOADED[repo_id]
def _gpu_duration(repo_id, trust_remote_code, use_cache, system, prompt, max_new_tokens, temperature, top_p) -> int:
"""Dynamic ``@spaces.GPU`` duration (seconds) for one generation.
Scales with the token budget and stays bounded so the Space keeps reasonable queue
priority on ZeroGPU (shorter declared durations are prioritised). The base covers a cold
device load: the first call for a model in a freshly forked worker materialises the
weights onto the GPU (from the parent-warmed disk cache), and that must finish inside the
granted window. Subsequent calls hit the resident cache and use only the forward-pass tail.
"""
return min(120, 60 + int(max_new_tokens) // 4)
@spaces.GPU(duration=_gpu_duration)
def _generate(repo_id, trust_remote_code, use_cache, system, prompt, max_new_tokens, temperature, top_p):
"""Run one chat completion on the GPU; return ``(text, prompt_tokens, completion_tokens)``.
Module-level and decorated so ZeroGPU registers it and grants a GPU for the call. The
model is loaded straight onto the device via :func:`_ensure_loaded_on_device` (cached
per repo — a disk→device materialise on first use, a no-op on later calls), so the
forward pass runs entirely on the granted GPU with no post-hoc device move. Input tensors
are built and placed on the model's own device.
"""
import torch
tokenizer, model = _ensure_loaded_on_device(repo_id, trust_remote_code)
device = next(model.parameters()).device
messages = [{"role": "system", "content": system}, {"role": "user", "content": prompt}]
# return_dict=True yields a BatchEncoding (input_ids + attention_mask). This is the
# default in transformers 5.x and we request it explicitly so the call is robust across
# versions: a bare-tensor return (older default) would be passed positionally into
# generate() as `inputs`, and 5.x's generate() then does inputs.shape[0] on the dict β†’
# AttributeError. Unpacking with ** feeds input_ids AND the attention mask correctly.
# enable_thinking=False: a reasoning model (e.g. MiniCPM5) otherwise opens a <think>
# block and can spend the whole token budget reasoning, leaving an empty spoken line once
# the engine strips the trace (src/core/structured.py). We want a direct line; the kwarg
# is forwarded to the chat template and harmlessly ignored by non-reasoning templates.
inputs = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt", return_dict=True, enable_thinking=False
).to(device)
input_len = int(inputs["input_ids"].shape[-1])
do_sample = temperature and float(temperature) > 0
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=bool(do_sample),
temperature=float(temperature) if do_sample else None,
top_p=float(top_p) if do_sample else None,
# Per-model: False for repos whose custom code mishandles the 5.x KV cache.
use_cache=bool(use_cache),
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
)
generated = output[0][input_len:]
text = tokenizer.decode(generated, skip_special_tokens=True).strip()
return text, input_len, int(generated.shape[-1])
@dataclass
class LocalTransformersProvider(ModelProvider):
"""Serve one logical profile by running a ``transformers`` model on the host GPU.
``model`` is the bare ``transformers`` repo id (e.g. ``"openbmb/MiniCPM5-1B"``) β€”
the same string :func:`src.models.local_catalogue.binding_for` returns. Decoding
(``temperature`` / ``top_p`` / ``max_tokens``) comes from the router's per-profile
spec. ``trust_remote_code`` is resolved from the catalogue for the repo (default
``False`` for an off-catalogue id).
"""
model: str
temperature: float = 0.7
top_p: float = 0.95
max_tokens: int = 256
_last_usage: dict = field(default_factory=dict, init=False, repr=False)
def complete(self, role: str, prompt: str) -> str:
span_attrs = {
"gen_ai.system": "transformers-local",
"gen_ai.request.model": self.model,
"gen_ai.request.temperature": self.temperature,
"gen_ai.request.max_tokens": self.max_tokens,
"mal.role": role,
}
with obs.span("llm.call", **span_attrs):
try:
# Fetch the weights to disk in the PARENT (no CUDA, no RAM materialise) so the
# forked @spaces.GPU call below loads from a warm cache (see module docstring).
_ensure_downloaded(self.model, self._trust_remote_code())
system = OpenAICompatProvider._system_for_role(role)
text, prompt_tokens, completion_tokens = _generate(
self.model,
self._trust_remote_code(),
self._use_cache(),
system,
prompt,
self.max_tokens,
self.temperature,
self.top_p,
)
self._record_usage(prompt_tokens, completion_tokens, prompt, text)
self._emit_telemetry(role, prompt, text)
return text
except Exception as exc:
self._zero_usage()
obs.log("llm.error", level="warning", model=self.model, role=role, error=str(exc))
return model_error(exc)
# ── internals ───────────────────────────────────────────────────────────────
def _trust_remote_code(self) -> bool:
"""Whether the catalogue marks this repo as needing custom modelling code.
Looked up by repo id; an id not in the catalogue (a hand-pinned repo) defaults to
``False`` β€” the safe choice, and the Lab only ever offers catalogue models.
"""
from src.models import local_catalogue
entry = local_catalogue.model_by_key(self.model)
return bool(entry.trust_remote_code) if entry is not None else False
def _use_cache(self) -> bool:
"""Whether to use the generation KV cache for this repo (from the catalogue).
Defaults to True (the fast path); the catalogue can set it False for a custom-code
repo whose attention mishandles transformers 5.x's cache API. The current cast is
all native-arch so none do; an off-catalogue id likewise keeps the cache on.
"""
from src.models import local_catalogue
entry = local_catalogue.model_by_key(self.model)
return bool(entry.use_cache) if entry is not None else True
def _record_usage(self, prompt_tokens: int, completion_tokens: int, prompt: str, text: str) -> None:
# Generation returns exact token counts; fall back to an estimate only if a count
# came back as zero (e.g. an empty decode), so the Governor always sees a budget hit.
prompt_tokens = int(prompt_tokens) or estimate_tokens(prompt)
completion_tokens = int(completion_tokens) or estimate_tokens(text)
self._last_usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
def _zero_usage(self) -> None:
self._last_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
def _emit_telemetry(self, role: str, prompt: str, text: str) -> None:
usage = self._last_usage or {}
prompt_tokens = int(usage.get("prompt_tokens", 0) or 0)
completion_tokens = int(usage.get("completion_tokens", 0) or 0)
obs.add_span_attrs(
**{
"gen_ai.usage.input_tokens": prompt_tokens,
"gen_ai.usage.output_tokens": completion_tokens,
"llm.cost_usd": 0.0, # local inference has no per-call price (GPU is the cost)
"llm.structured": False,
"llm.prompt": prompt,
"llm.completion": text,
}
)
obs.record_llm_call(self.model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, cost_usd=0.0)
obs.log(
"llm.call",
role=role,
model=self.model,
structured=False,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cost_usd=0.0,
)
obs.log("llm.exchange", level="debug", role=role, model=self.model, prompt=prompt, completion=text)