File size: 5,437 Bytes
0b33900 e8956ff 0b33900 5713521 0b33900 726cb1c 0b33900 7bb1554 0b33900 e8956ff 0b33900 726cb1c e8956ff 0b33900 cce2b04 0b33900 cce2b04 0b33900 cce2b04 0b33900 cce2b04 0b33900 cce2b04 0b33900 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | """HuggingFace transformers runtime — implements matter.engine.Runtime.
Loads Gemma 4 lazily on first inference (so cold Spaces serve the demo-mode path
without ever paying the load cost) and wraps inference in @spaces.GPU so the
Space's ZeroGPU pool only spins up while we're actually generating.
Picks Gemma 4 E2B (5B, any-to-any, instruction-tuned) by default. Override via
the MATTER_MODEL_ID Space secret.
"""
from __future__ import annotations
import os
import threading
from pathlib import Path
from typing import Literal
import torch
from PIL import Image
try:
import spaces # type: ignore
HAS_SPACES = True
except ImportError:
HAS_SPACES = False
DEFAULT_MODEL_ID = os.environ.get("MATTER_MODEL_ID", "google/gemma-4-E2B-it")
DEFAULT_MAX_NEW_TOKENS = int(os.environ.get("MATTER_MAX_NEW_TOKENS", "1024"))
DEFAULT_LORA_ID = os.environ.get("MATTER_LORA_ID", "").strip() or None
# Module-level init lock (must NOT be an instance attribute — `self` gets
# pickled across the ZeroGPU process boundary, and threading.Lock can't
# pickle). Modules are imported per-process so this lock is per-process,
# which is exactly the granularity we want.
_LOAD_LOCK = threading.Lock()
def _gpu_decorator(fn):
"""No-op when running locally (no `spaces` module), real decorator on HF."""
if HAS_SPACES:
return spaces.GPU(duration=90)(fn)
return fn
class TransformersRuntime:
"""Implements matter.engine.Runtime over HF transformers + Gemma 4."""
# Passport schema's provenance.runtime enum doesn't include "transformers"
# — report as "other" and surface the actual stack via model_id.
name: Literal["other"] = "other"
def __init__(
self,
model: str = DEFAULT_MODEL_ID,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
lora_id: str | None = DEFAULT_LORA_ID,
):
self.model_id = model
self.lora_id = lora_id
self.max_new_tokens = max_new_tokens
self._model = None
self._processor = None
def _ensure_loaded(self) -> None:
# Fast path: already loaded, no lock needed.
if self._model is not None:
return
# Module-level lock guards against concurrent first-call races. Two
# users hitting a cold Space simultaneously could both enter
# from_pretrained without this lock and double-allocate, OOM'ing CUDA.
with _LOAD_LOCK:
# Double-checked locking: another thread may have completed the
# load while we were waiting for the lock.
if self._model is not None:
return
from transformers import AutoModelForImageTextToText, AutoProcessor
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(self.model_id)
model = AutoModelForImageTextToText.from_pretrained(
self.model_id,
torch_dtype=dtype,
device_map=device,
)
if self.lora_id:
try:
from peft import PeftModel
model = PeftModel.from_pretrained(model, self.lora_id)
except Exception as e:
print(f"[TransformersRuntime] LoRA load failed ({self.lora_id}): {e}")
model.eval()
# Publish atomically — readers without the lock should never see a
# half-initialized state.
self._processor = processor
self._model = model
def infer(self, prompt: str, image: Path | None) -> str:
return self._infer_gpu(prompt, str(image) if image is not None else None)
@_gpu_decorator
def _infer_gpu(self, prompt: str, image_path: str | None) -> str:
self._ensure_loaded()
proc = self._processor
model = self._model
# Image first, then text — per the official google/gemma-4-E2B-it usage.
content: list[dict] = []
if image_path:
content.append({"type": "image", "image": Image.open(image_path).convert("RGB")})
content.append({"type": "text", "text": prompt})
messages = [{"role": "user", "content": content}]
inputs = proc.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=False,
)
# Per Gemma 4 docs: decode with special tokens, then let the processor
# parse them out cleanly via parse_response().
raw = proc.decode(outputs[0][input_len:], skip_special_tokens=False)
if hasattr(proc, "parse_response"):
parsed = proc.parse_response(raw)
if isinstance(parsed, str):
return parsed
if isinstance(parsed, dict) and "content" in parsed:
return parsed["content"] if isinstance(parsed["content"], str) else str(parsed["content"])
return str(parsed)
return proc.decode(outputs[0][input_len:], skip_special_tokens=True)
__all__ = ["TransformersRuntime"]
|