DepthLens / src /models /gemma4.py
Rishabh Jain
Initial upload β€” depth-aware scene description system
5412d82
"""
Gemma 4 E2B IT vision-language model loader.
Requirements
------------
- transformers >= 5.5.0 (NOT compatible with the academic env)
- HF_TOKEN env variable set to an account that has accepted the
model licence at https://huggingface.co/google/gemma-4-E2B-it
VRAM: ~2.8 GB at float16 β€” fits on 8 GB GPU alongside Depth + YOLO.
Usage::
from src.models.gemma4 import Gemma4VLM
vlm = Gemma4VLM()
answer = vlm.query_vlm(pil_image, "Describe this scene.")
"""
from __future__ import annotations
import os
from typing import Optional
import torch
from PIL import Image
from ..config import GEMMA4_ID, GEMMA4_MAX_NEW_TOKENS
class Gemma4VLM:
"""Gemma 4 E2B IT vision-language model.
Loads the model once and caches it. All inference runs under
``torch.inference_mode()`` for speed.
Args:
device: Target device string, e.g. ``"cuda"`` or ``"cpu"``.
Defaults to ``"cuda"`` when available.
dtype: Torch dtype for model weights. ``torch.float16`` uses
~2.8 GB on an 8 GB GPU. Switch to ``torch.bfloat16`` if
your hardware supports it (Ampere/Ada/Hopper).
"""
def __init__(
self,
device: str = "cpu",
dtype: torch.dtype = torch.float16,
) -> None:
self.device = device if torch.cuda.is_available() else "cpu"
self.dtype = dtype
self.model = None
self.processor = None
self._load()
# ── Model loading ─────────────────────────────────────────────────────────
def _load(self) -> None:
"""Download (first run) and load Gemma 4 E2B IT.
Raises:
ImportError: If transformers < 5.5.0 is installed.
OSError: If HF_TOKEN is not set and the model is still gated.
"""
try:
from transformers import AutoProcessor # noqa: F401 β€” probe version
import transformers as _tf
ver = tuple(int(x) for x in _tf.__version__.split(".")[:2] if x.isdigit())
if ver < (5, 5):
raise ImportError(
f"Gemma 4 requires transformers >= 5.5.0, "
f"found {_tf.__version__}. "
"Create a new env: pip install transformers>=5.5.0"
)
except ImportError as exc:
raise ImportError(str(exc)) from exc
from transformers import AutoProcessor, Gemma4ForConditionalGeneration # type: ignore[attr-defined]
hf_token: Optional[str] = os.environ.get("HF_TOKEN")
print(f"Loading {GEMMA4_ID} ({self.dtype}, device={self.device})...")
self.processor = AutoProcessor.from_pretrained(
GEMMA4_ID,
token=hf_token,
)
self.model = Gemma4ForConditionalGeneration.from_pretrained(
GEMMA4_ID,
device_map={"": self.device},
torch_dtype=self.dtype,
token=hf_token,
)
self.model.eval()
if torch.cuda.is_available():
alloc_mb = torch.cuda.memory_allocated() / 1024 ** 2
print(f" GPU memory allocated: {alloc_mb:.0f} MB")
# ── Inference ─────────────────────────────────────────────────────────────
def query_vlm(self, image: Image.Image, question: str) -> str:
"""Query Gemma 4 with an image and a text prompt.
Builds a single-turn user message, applies the chat template,
runs ``model.generate()``, and strips the input tokens from the
output before decoding.
Args:
image: PIL Image to analyse.
question: Text question or prompt (may include the depth preamble).
Returns:
Generated answer text with input tokens stripped.
"""
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": question},
],
}
]
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.device)
n_input_tokens = inputs["input_ids"].shape[1]
with torch.inference_mode():
output_ids = self.model.generate(
**inputs,
max_new_tokens=GEMMA4_MAX_NEW_TOKENS,
do_sample=False,
)
# Slice off the prompt tokens so we only decode the generated part.
new_token_ids = output_ids[0][n_input_tokens:]
return self.processor.decode(new_token_ids, skip_special_tokens=True).strip()