Spaces:
Sleeping
Sleeping
| """ | |
| Gemma 4 E2B IT vision-language model loader. | |
| Requirements | |
| ------------ | |
| - transformers >= 5.5.0 (NOT compatible with the academic env) | |
| - HF_TOKEN env variable set to an account that has accepted the | |
| model licence at https://huggingface.co/google/gemma-4-E2B-it | |
| VRAM: ~2.8 GB at float16 β fits on 8 GB GPU alongside Depth + YOLO. | |
| Usage:: | |
| from src.models.gemma4 import Gemma4VLM | |
| vlm = Gemma4VLM() | |
| answer = vlm.query_vlm(pil_image, "Describe this scene.") | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from typing import Optional | |
| import torch | |
| from PIL import Image | |
| from ..config import GEMMA4_ID, GEMMA4_MAX_NEW_TOKENS | |
| class Gemma4VLM: | |
| """Gemma 4 E2B IT vision-language model. | |
| Loads the model once and caches it. All inference runs under | |
| ``torch.inference_mode()`` for speed. | |
| Args: | |
| device: Target device string, e.g. ``"cuda"`` or ``"cpu"``. | |
| Defaults to ``"cuda"`` when available. | |
| dtype: Torch dtype for model weights. ``torch.float16`` uses | |
| ~2.8 GB on an 8 GB GPU. Switch to ``torch.bfloat16`` if | |
| your hardware supports it (Ampere/Ada/Hopper). | |
| """ | |
| def __init__( | |
| self, | |
| device: str = "cpu", | |
| dtype: torch.dtype = torch.float16, | |
| ) -> None: | |
| self.device = device if torch.cuda.is_available() else "cpu" | |
| self.dtype = dtype | |
| self.model = None | |
| self.processor = None | |
| self._load() | |
| # ββ Model loading βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load(self) -> None: | |
| """Download (first run) and load Gemma 4 E2B IT. | |
| Raises: | |
| ImportError: If transformers < 5.5.0 is installed. | |
| OSError: If HF_TOKEN is not set and the model is still gated. | |
| """ | |
| try: | |
| from transformers import AutoProcessor # noqa: F401 β probe version | |
| import transformers as _tf | |
| ver = tuple(int(x) for x in _tf.__version__.split(".")[:2] if x.isdigit()) | |
| if ver < (5, 5): | |
| raise ImportError( | |
| f"Gemma 4 requires transformers >= 5.5.0, " | |
| f"found {_tf.__version__}. " | |
| "Create a new env: pip install transformers>=5.5.0" | |
| ) | |
| except ImportError as exc: | |
| raise ImportError(str(exc)) from exc | |
| from transformers import AutoProcessor, Gemma4ForConditionalGeneration # type: ignore[attr-defined] | |
| hf_token: Optional[str] = os.environ.get("HF_TOKEN") | |
| print(f"Loading {GEMMA4_ID} ({self.dtype}, device={self.device})...") | |
| self.processor = AutoProcessor.from_pretrained( | |
| GEMMA4_ID, | |
| token=hf_token, | |
| ) | |
| self.model = Gemma4ForConditionalGeneration.from_pretrained( | |
| GEMMA4_ID, | |
| device_map={"": self.device}, | |
| torch_dtype=self.dtype, | |
| token=hf_token, | |
| ) | |
| self.model.eval() | |
| if torch.cuda.is_available(): | |
| alloc_mb = torch.cuda.memory_allocated() / 1024 ** 2 | |
| print(f" GPU memory allocated: {alloc_mb:.0f} MB") | |
| # ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def query_vlm(self, image: Image.Image, question: str) -> str: | |
| """Query Gemma 4 with an image and a text prompt. | |
| Builds a single-turn user message, applies the chat template, | |
| runs ``model.generate()``, and strips the input tokens from the | |
| output before decoding. | |
| Args: | |
| image: PIL Image to analyse. | |
| question: Text question or prompt (may include the depth preamble). | |
| Returns: | |
| Generated answer text with input tokens stripped. | |
| """ | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": question}, | |
| ], | |
| } | |
| ] | |
| inputs = self.processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| n_input_tokens = inputs["input_ids"].shape[1] | |
| with torch.inference_mode(): | |
| output_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=GEMMA4_MAX_NEW_TOKENS, | |
| do_sample=False, | |
| ) | |
| # Slice off the prompt tokens so we only decode the generated part. | |
| new_token_ids = output_ids[0][n_input_tokens:] | |
| return self.processor.decode(new_token_ids, skip_special_tokens=True).strip() | |