"""HuggingFace transformers runtime — implements matter.engine.Runtime. Loads Gemma 4 lazily on first inference (so cold Spaces serve the demo-mode path without ever paying the load cost) and wraps inference in @spaces.GPU so the Space's ZeroGPU pool only spins up while we're actually generating. Picks Gemma 4 E2B (5B, any-to-any, instruction-tuned) by default. Override via the MATTER_MODEL_ID Space secret. """ from __future__ import annotations import os import threading from pathlib import Path from typing import Literal import torch from PIL import Image try: import spaces # type: ignore HAS_SPACES = True except ImportError: HAS_SPACES = False DEFAULT_MODEL_ID = os.environ.get("MATTER_MODEL_ID", "google/gemma-4-E2B-it") DEFAULT_MAX_NEW_TOKENS = int(os.environ.get("MATTER_MAX_NEW_TOKENS", "1024")) DEFAULT_LORA_ID = os.environ.get("MATTER_LORA_ID", "").strip() or None # Module-level init lock (must NOT be an instance attribute — `self` gets # pickled across the ZeroGPU process boundary, and threading.Lock can't # pickle). Modules are imported per-process so this lock is per-process, # which is exactly the granularity we want. _LOAD_LOCK = threading.Lock() def _gpu_decorator(fn): """No-op when running locally (no `spaces` module), real decorator on HF.""" if HAS_SPACES: return spaces.GPU(duration=90)(fn) return fn class TransformersRuntime: """Implements matter.engine.Runtime over HF transformers + Gemma 4.""" # Passport schema's provenance.runtime enum doesn't include "transformers" # — report as "other" and surface the actual stack via model_id. name: Literal["other"] = "other" def __init__( self, model: str = DEFAULT_MODEL_ID, max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, lora_id: str | None = DEFAULT_LORA_ID, ): self.model_id = model self.lora_id = lora_id self.max_new_tokens = max_new_tokens self._model = None self._processor = None def _ensure_loaded(self) -> None: # Fast path: already loaded, no lock needed. if self._model is not None: return # Module-level lock guards against concurrent first-call races. Two # users hitting a cold Space simultaneously could both enter # from_pretrained without this lock and double-allocate, OOM'ing CUDA. with _LOAD_LOCK: # Double-checked locking: another thread may have completed the # load while we were waiting for the lock. if self._model is not None: return from transformers import AutoModelForImageTextToText, AutoProcessor dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 device = "cuda" if torch.cuda.is_available() else "cpu" processor = AutoProcessor.from_pretrained(self.model_id) model = AutoModelForImageTextToText.from_pretrained( self.model_id, torch_dtype=dtype, device_map=device, ) if self.lora_id: try: from peft import PeftModel model = PeftModel.from_pretrained(model, self.lora_id) except Exception as e: print(f"[TransformersRuntime] LoRA load failed ({self.lora_id}): {e}") model.eval() # Publish atomically — readers without the lock should never see a # half-initialized state. self._processor = processor self._model = model def infer(self, prompt: str, image: Path | None) -> str: return self._infer_gpu(prompt, str(image) if image is not None else None) @_gpu_decorator def _infer_gpu(self, prompt: str, image_path: str | None) -> str: self._ensure_loaded() proc = self._processor model = self._model # Image first, then text — per the official google/gemma-4-E2B-it usage. content: list[dict] = [] if image_path: content.append({"type": "image", "image": Image.open(image_path).convert("RGB")}) content.append({"type": "text", "text": prompt}) messages = [{"role": "user", "content": content}] inputs = proc.apply_chat_template( messages, tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True, ).to(model.device) input_len = inputs["input_ids"].shape[-1] with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=self.max_new_tokens, do_sample=False, ) # Per Gemma 4 docs: decode with special tokens, then let the processor # parse them out cleanly via parse_response(). raw = proc.decode(outputs[0][input_len:], skip_special_tokens=False) if hasattr(proc, "parse_response"): parsed = proc.parse_response(raw) if isinstance(parsed, str): return parsed if isinstance(parsed, dict) and "content" in parsed: return parsed["content"] if isinstance(parsed["content"], str) else str(parsed["content"]) return str(parsed) return proc.decode(outputs[0][input_len:], skip_special_tokens=True) __all__ = ["TransformersRuntime"]