""" Gemma 4 E2B IT vision-language model loader. Requirements ------------ - transformers >= 5.5.0 (NOT compatible with the academic env) - HF_TOKEN env variable set to an account that has accepted the model licence at https://huggingface.co/google/gemma-4-E2B-it VRAM: ~2.8 GB at float16 — fits on 8 GB GPU alongside Depth + YOLO. Usage:: from src.models.gemma4 import Gemma4VLM vlm = Gemma4VLM() answer = vlm.query_vlm(pil_image, "Describe this scene.") """ from __future__ import annotations import os from typing import Optional import torch from PIL import Image from ..config import GEMMA4_ID, GEMMA4_MAX_NEW_TOKENS class Gemma4VLM: """Gemma 4 E2B IT vision-language model. Loads the model once and caches it. All inference runs under ``torch.inference_mode()`` for speed. Args: device: Target device string, e.g. ``"cuda"`` or ``"cpu"``. Defaults to ``"cuda"`` when available. dtype: Torch dtype for model weights. ``torch.float16`` uses ~2.8 GB on an 8 GB GPU. Switch to ``torch.bfloat16`` if your hardware supports it (Ampere/Ada/Hopper). """ def __init__( self, device: str = "cpu", dtype: torch.dtype = torch.float16, ) -> None: self.device = device if torch.cuda.is_available() else "cpu" self.dtype = dtype self.model = None self.processor = None self._load() # ── Model loading ───────────────────────────────────────────────────────── def _load(self) -> None: """Download (first run) and load Gemma 4 E2B IT. Raises: ImportError: If transformers < 5.5.0 is installed. OSError: If HF_TOKEN is not set and the model is still gated. """ try: from transformers import AutoProcessor # noqa: F401 — probe version import transformers as _tf ver = tuple(int(x) for x in _tf.__version__.split(".")[:2] if x.isdigit()) if ver < (5, 5): raise ImportError( f"Gemma 4 requires transformers >= 5.5.0, " f"found {_tf.__version__}. " "Create a new env: pip install transformers>=5.5.0" ) except ImportError as exc: raise ImportError(str(exc)) from exc from transformers import AutoProcessor, Gemma4ForConditionalGeneration # type: ignore[attr-defined] hf_token: Optional[str] = os.environ.get("HF_TOKEN") print(f"Loading {GEMMA4_ID} ({self.dtype}, device={self.device})...") self.processor = AutoProcessor.from_pretrained( GEMMA4_ID, token=hf_token, ) self.model = Gemma4ForConditionalGeneration.from_pretrained( GEMMA4_ID, device_map={"": self.device}, torch_dtype=self.dtype, token=hf_token, ) self.model.eval() if torch.cuda.is_available(): alloc_mb = torch.cuda.memory_allocated() / 1024 ** 2 print(f" GPU memory allocated: {alloc_mb:.0f} MB") # ── Inference ───────────────────────────────────────────────────────────── def query_vlm(self, image: Image.Image, question: str) -> str: """Query Gemma 4 with an image and a text prompt. Builds a single-turn user message, applies the chat template, runs ``model.generate()``, and strips the input tokens from the output before decoding. Args: image: PIL Image to analyse. question: Text question or prompt (may include the depth preamble). Returns: Generated answer text with input tokens stripped. """ messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": question}, ], } ] inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(self.device) n_input_tokens = inputs["input_ids"].shape[1] with torch.inference_mode(): output_ids = self.model.generate( **inputs, max_new_tokens=GEMMA4_MAX_NEW_TOKENS, do_sample=False, ) # Slice off the prompt tokens so we only decode the generated part. new_token_ids = output_ids[0][n_input_tokens:] return self.processor.decode(new_token_ids, skip_special_tokens=True).strip()