rnagabh's picture
Initial upload: Gemma 4 vision encoder (569.6M, 27-layer ViT with 2D RoPE)
3a8691c verified
metadata
language:
  - en
  - multilingual
license: apache-2.0
library_name: transformers
tags:
  - feature-extraction
  - image-feature-extraction
  - vision
  - vit
  - gemma4
  - google
  - safetensors
pipeline_tag: image-feature-extraction
base_model: google/gemma-4-31B-it
model-index:
  - name: gemma4-vision-encoder
    results:
      - task:
          type: image-classification
          name: CIFAR-10 (10-class)
        dataset:
          name: CIFAR-10
          type: cifar10
          split: test
        metrics:
          - type: accuracy
            value: 94
            name: Linear Probe Accuracy

Gemma 4 Vision Encoder (27-layer ViT with 2D RoPE)

Standalone extraction of the vision encoder from Google's Gemma 4 31B multimodal model. This is a 569.6M parameter Vision Transformer with learned 2D positional embeddings, RoPE, QK-norms, and gated MLP — a significant upgrade from the SigLIP encoder used in Gemma 3.

License: Apache 2.0 (inherited from Gemma 4 — no restrictions)

Architecture

Property Value
Total parameters 569.6M
Architecture ViT with 2D RoPE + learned positional embeddings
Hidden dimension 1152
Encoder layers 27
Attention heads 16 (72 dim per head)
KV heads 16 (full MHA, no GQA)
MLP Gated (gate_proj + up_proj + down_proj)
MLP intermediate 4304
Activation GELU (pytorch_tanh variant)
Normalization RMSNorm (eps=1e-6)
Patch size 16×16
Pooling 3×3 kernel (reduces token count by 9×)
Position embeddings Learned 2D table (2, 10240, 1152) + RoPE (theta=100)
Q/K norms Yes
Default output tokens 280
Configurable token budgets 70, 140, 280, 560, 1120
Input Pre-patchified: (batch, num_patches, 768) where 768 = 3×16×16
Output (num_valid_tokens, 1152) after pooling + standardization

What's New vs Gemma 3 (SigLIP)

Gemma 3 Vision Gemma 4 Vision (this model)
Architecture SigLIP (ViT-SO400M) Custom ViT with 2D RoPE
Layers 27 27
Hidden dim 1152 1152
Position encoding Learned 1D Learned 2D + RoPE
Attention Standard QK-normed
MLP Standard (fc1 + fc2) Gated (gate + up + down)
Aspect ratio Fixed square (896×896) Variable aspect ratio
Token budget Fixed 256 Configurable (70–1120)
Pooling 4×4 average 3×3

Not Shared with E2B/E4B

Unlike the audio encoder (which is identical across E2B and E4B), the vision encoders differ:

E2B/E4B 31B (this extraction)
Layers 16 27
Parameters ~340M 569.6M

Usage

import torch
from transformers import Gemma4VisionModel, Gemma4ImageProcessor
from PIL import Image

# Load vision encoder directly from this repo
vision_model = Gemma4VisionModel.from_pretrained(
    "rnagabh/gemma4-vision-encoder",
    torch_dtype=torch.bfloat16,
)
vision_model.to("cuda")
vision_model.eval()

# Load image processor (saved in this repo)
image_processor = Gemma4ImageProcessor.from_pretrained("rnagabh/gemma4-vision-encoder")

# Process an image
img = Image.open("your_image.jpg")
processed = image_processor(images=[img], return_tensors="pt")

pixel_values = processed["pixel_values"].to(dtype=torch.bfloat16, device="cuda")
position_ids = processed["image_position_ids"].to(device="cuda")
tokens_per_image = processed["num_soft_tokens_per_image"]  # for splitting batch output

with torch.no_grad():
    output = vision_model(pixel_values=pixel_values, pixel_position_ids=position_ids)
    embeddings = output.last_hidden_state  # (num_tokens, 1152)

    # Mean-pool for a single image vector
    image_embedding = embeddings.float().mean(dim=0)  # (1152,)

Important: Always use the Gemma4ImageProcessor included in this repo for preprocessing. It handles resizing, patchification, position ID generation, and pixel normalization. Manual patchification without this processor will produce significantly degraded results.

Benchmark Results (frozen 1152-dim embeddings, linear probe)

CIFAR-10 Classification

Metric Value
Linear probe accuracy 94.0%
Random baseline 10.0%
Improvement over chance 9.4×
Dataset CIFAR-10 test set (1000 samples, 100 per class)
Probe Logistic regression on L2-normalized mean-pooled embeddings

Strong performance across all classes: airplane (0.98 F1), ship (0.98 F1), truck (0.97 F1), automobile (0.97 F1). Weakest class is cat (0.86 F1) — a fine-grained category that is inherently harder.

Files in This Repo

File Description Size
config.json Vision encoder config (Gemma4VisionConfig) <1 KB
model.safetensors Vision encoder weights (569.6M params, BF16) 1,139 MB
preprocessor_config.json Image processor config (Gemma4ImageProcessor) <1 KB
embed_vision.safetensors Vision→text embedding projection (1152→5376) 12.4 MB

Limitations

  • End-to-end trained for LLM decoding: The encoder was trained to produce features for Gemma 4's text decoder. The 1152-dim output is the pure vision representation; the embed_vision projection maps to the 31B's text hidden space (5376-dim).
  • Requires image processor: Use the Gemma4ImageProcessor included in this repo for preprocessing. The model expects pre-patchified (B, num_patches, 768) tensors with explicit 2D position IDs — the processor handles this automatically.
  • Variable aspect ratio support: The 2D position embeddings enable non-square images. The processor generates correct position IDs for any aspect ratio.
  • Output shape note: The pooler strips padding and collapses the batch dimension, returning (num_valid_tokens, 1152). For batched inference, use num_soft_tokens_per_image from the processor to split the output back into per-image embeddings.

Extraction Details

  • Extracted from google/gemma-4-31B-it by downloading only the shard containing vision tower weights (model-00001-of-00002.safetensors)
  • No full model load required — targeted tensor extraction
  • Weights loaded with strict=True — perfect match
  • Forward pass verified: 864×864 image → (324, 1152) output
  • All architecture specs verified against the live model config

References