File size: 5,437 Bytes
0b33900
 
 
 
 
 
 
 
 
 
 
 
 
e8956ff
0b33900
 
 
 
 
 
 
 
 
 
 
 
 
5713521
0b33900
 
726cb1c
 
 
 
 
 
0b33900
 
 
 
 
 
 
 
 
 
 
7bb1554
 
 
0b33900
 
 
 
 
 
 
 
 
 
 
 
 
 
e8956ff
0b33900
 
726cb1c
 
 
 
e8956ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b33900
 
 
 
 
 
 
 
 
 
cce2b04
0b33900
 
 
 
 
 
 
 
 
 
 
cce2b04
0b33900
cce2b04
0b33900
 
cce2b04
0b33900
 
 
 
 
cce2b04
 
 
 
 
 
 
 
 
 
 
0b33900
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""HuggingFace transformers runtime — implements matter.engine.Runtime.

Loads Gemma 4 lazily on first inference (so cold Spaces serve the demo-mode path
without ever paying the load cost) and wraps inference in @spaces.GPU so the
Space's ZeroGPU pool only spins up while we're actually generating.

Picks Gemma 4 E2B (5B, any-to-any, instruction-tuned) by default. Override via
the MATTER_MODEL_ID Space secret.
"""

from __future__ import annotations

import os
import threading
from pathlib import Path
from typing import Literal

import torch
from PIL import Image

try:
    import spaces  # type: ignore
    HAS_SPACES = True
except ImportError:
    HAS_SPACES = False

DEFAULT_MODEL_ID = os.environ.get("MATTER_MODEL_ID", "google/gemma-4-E2B-it")
DEFAULT_MAX_NEW_TOKENS = int(os.environ.get("MATTER_MAX_NEW_TOKENS", "1024"))
DEFAULT_LORA_ID = os.environ.get("MATTER_LORA_ID", "").strip() or None

# Module-level init lock (must NOT be an instance attribute — `self` gets
# pickled across the ZeroGPU process boundary, and threading.Lock can't
# pickle). Modules are imported per-process so this lock is per-process,
# which is exactly the granularity we want.
_LOAD_LOCK = threading.Lock()


def _gpu_decorator(fn):
    """No-op when running locally (no `spaces` module), real decorator on HF."""
    if HAS_SPACES:
        return spaces.GPU(duration=90)(fn)
    return fn


class TransformersRuntime:
    """Implements matter.engine.Runtime over HF transformers + Gemma 4."""

    # Passport schema's provenance.runtime enum doesn't include "transformers"
    # — report as "other" and surface the actual stack via model_id.
    name: Literal["other"] = "other"

    def __init__(
        self,
        model: str = DEFAULT_MODEL_ID,
        max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
        lora_id: str | None = DEFAULT_LORA_ID,
    ):
        self.model_id = model
        self.lora_id = lora_id
        self.max_new_tokens = max_new_tokens
        self._model = None
        self._processor = None

    def _ensure_loaded(self) -> None:
        # Fast path: already loaded, no lock needed.
        if self._model is not None:
            return
        # Module-level lock guards against concurrent first-call races. Two
        # users hitting a cold Space simultaneously could both enter
        # from_pretrained without this lock and double-allocate, OOM'ing CUDA.
        with _LOAD_LOCK:
            # Double-checked locking: another thread may have completed the
            # load while we were waiting for the lock.
            if self._model is not None:
                return
            from transformers import AutoModelForImageTextToText, AutoProcessor

            dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
            device = "cuda" if torch.cuda.is_available() else "cpu"

            processor = AutoProcessor.from_pretrained(self.model_id)
            model = AutoModelForImageTextToText.from_pretrained(
                self.model_id,
                torch_dtype=dtype,
                device_map=device,
            )
            if self.lora_id:
                try:
                    from peft import PeftModel
                    model = PeftModel.from_pretrained(model, self.lora_id)
                except Exception as e:
                    print(f"[TransformersRuntime] LoRA load failed ({self.lora_id}): {e}")
            model.eval()
            # Publish atomically — readers without the lock should never see a
            # half-initialized state.
            self._processor = processor
            self._model = model

    def infer(self, prompt: str, image: Path | None) -> str:
        return self._infer_gpu(prompt, str(image) if image is not None else None)

    @_gpu_decorator
    def _infer_gpu(self, prompt: str, image_path: str | None) -> str:
        self._ensure_loaded()
        proc = self._processor
        model = self._model

        # Image first, then text — per the official google/gemma-4-E2B-it usage.
        content: list[dict] = []
        if image_path:
            content.append({"type": "image", "image": Image.open(image_path).convert("RGB")})
        content.append({"type": "text", "text": prompt})
        messages = [{"role": "user", "content": content}]

        inputs = proc.apply_chat_template(
            messages,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
            add_generation_prompt=True,
        ).to(model.device)
        input_len = inputs["input_ids"].shape[-1]

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                do_sample=False,
            )

        # Per Gemma 4 docs: decode with special tokens, then let the processor
        # parse them out cleanly via parse_response().
        raw = proc.decode(outputs[0][input_len:], skip_special_tokens=False)
        if hasattr(proc, "parse_response"):
            parsed = proc.parse_response(raw)
            if isinstance(parsed, str):
                return parsed
            if isinstance(parsed, dict) and "content" in parsed:
                return parsed["content"] if isinstance(parsed["content"], str) else str(parsed["content"])
            return str(parsed)
        return proc.decode(outputs[0][input_len:], skip_special_tokens=True)


__all__ = ["TransformersRuntime"]