Gemma 4 Vision Encoder (27-layer ViT with 2D RoPE)
Standalone extraction of the vision encoder from Google's Gemma 4 31B multimodal model. This is a 569.6M parameter Vision Transformer with learned 2D positional embeddings, RoPE, QK-norms, and gated MLP — a significant upgrade from the SigLIP encoder used in Gemma 3.
License: Apache 2.0 (inherited from Gemma 4 — no restrictions)
Architecture
| Property | Value |
|---|---|
| Total parameters | 569.6M |
| Architecture | ViT with 2D RoPE + learned positional embeddings |
| Hidden dimension | 1152 |
| Encoder layers | 27 |
| Attention heads | 16 (72 dim per head) |
| KV heads | 16 (full MHA, no GQA) |
| MLP | Gated (gate_proj + up_proj + down_proj) |
| MLP intermediate | 4304 |
| Activation | GELU (pytorch_tanh variant) |
| Normalization | RMSNorm (eps=1e-6) |
| Patch size | 16×16 |
| Pooling | 3×3 kernel (reduces token count by 9×) |
| Position embeddings | Learned 2D table (2, 10240, 1152) + RoPE (theta=100) |
| Q/K norms | Yes |
| Default output tokens | 280 |
| Configurable token budgets | 70, 140, 280, 560, 1120 |
| Input | Pre-patchified: (batch, num_patches, 768) where 768 = 3×16×16 |
| Output | (num_valid_tokens, 1152) after pooling + standardization |
What's New vs Gemma 3 (SigLIP)
| Gemma 3 Vision | Gemma 4 Vision (this model) | |
|---|---|---|
| Architecture | SigLIP (ViT-SO400M) | Custom ViT with 2D RoPE |
| Layers | 27 | 27 |
| Hidden dim | 1152 | 1152 |
| Position encoding | Learned 1D | Learned 2D + RoPE |
| Attention | Standard | QK-normed |
| MLP | Standard (fc1 + fc2) | Gated (gate + up + down) |
| Aspect ratio | Fixed square (896×896) | Variable aspect ratio |
| Token budget | Fixed 256 | Configurable (70–1120) |
| Pooling | 4×4 average | 3×3 |
Not Shared with E2B/E4B
Unlike the audio encoder (which is identical across E2B and E4B), the vision encoders differ:
| E2B/E4B | 31B (this extraction) | |
|---|---|---|
| Layers | 16 | 27 |
| Parameters | ~340M | 569.6M |
Usage
import torch
from transformers import Gemma4VisionModel, Gemma4ImageProcessor
from PIL import Image
# Load vision encoder directly from this repo
vision_model = Gemma4VisionModel.from_pretrained(
"rnagabh/gemma4-vision-encoder",
torch_dtype=torch.bfloat16,
)
vision_model.to("cuda")
vision_model.eval()
# Load image processor (saved in this repo)
image_processor = Gemma4ImageProcessor.from_pretrained("rnagabh/gemma4-vision-encoder")
# Process an image
img = Image.open("your_image.jpg")
processed = image_processor(images=[img], return_tensors="pt")
pixel_values = processed["pixel_values"].to(dtype=torch.bfloat16, device="cuda")
position_ids = processed["image_position_ids"].to(device="cuda")
tokens_per_image = processed["num_soft_tokens_per_image"] # for splitting batch output
with torch.no_grad():
output = vision_model(pixel_values=pixel_values, pixel_position_ids=position_ids)
embeddings = output.last_hidden_state # (num_tokens, 1152)
# Mean-pool for a single image vector
image_embedding = embeddings.float().mean(dim=0) # (1152,)
Important: Always use the
Gemma4ImageProcessorincluded in this repo for preprocessing. It handles resizing, patchification, position ID generation, and pixel normalization. Manual patchification without this processor will produce significantly degraded results.
Benchmark Results (frozen 1152-dim embeddings, linear probe)
CIFAR-10 Classification
| Metric | Value |
|---|---|
| Linear probe accuracy | 94.0% |
| Random baseline | 10.0% |
| Improvement over chance | 9.4× |
| Dataset | CIFAR-10 test set (1000 samples, 100 per class) |
| Probe | Logistic regression on L2-normalized mean-pooled embeddings |
Strong performance across all classes: airplane (0.98 F1), ship (0.98 F1), truck (0.97 F1), automobile (0.97 F1). Weakest class is cat (0.86 F1) — a fine-grained category that is inherently harder.
Files in This Repo
| File | Description | Size |
|---|---|---|
config.json |
Vision encoder config (Gemma4VisionConfig) | <1 KB |
model.safetensors |
Vision encoder weights (569.6M params, BF16) | 1,139 MB |
preprocessor_config.json |
Image processor config (Gemma4ImageProcessor) | <1 KB |
embed_vision.safetensors |
Vision→text embedding projection (1152→5376) | 12.4 MB |
Limitations
- End-to-end trained for LLM decoding: The encoder was trained to produce features for Gemma 4's text decoder. The 1152-dim output is the pure vision representation; the
embed_visionprojection maps to the 31B's text hidden space (5376-dim). - Requires image processor: Use the
Gemma4ImageProcessorincluded in this repo for preprocessing. The model expects pre-patchified(B, num_patches, 768)tensors with explicit 2D position IDs — the processor handles this automatically. - Variable aspect ratio support: The 2D position embeddings enable non-square images. The processor generates correct position IDs for any aspect ratio.
- Output shape note: The pooler strips padding and collapses the batch dimension, returning
(num_valid_tokens, 1152). For batched inference, usenum_soft_tokens_per_imagefrom the processor to split the output back into per-image embeddings.
Extraction Details
- Extracted from
google/gemma-4-31B-itby downloading only the shard containing vision tower weights (model-00001-of-00002.safetensors) - No full model load required — targeted tensor extraction
- Weights loaded with
strict=True— perfect match - Forward pass verified: 864×864 image → (324, 1152) output
- All architecture specs verified against the live model config
References
- Downloads last month
- 80
Model tree for rnagabh/gemma4-vision-encoder
Base model
google/gemma-4-31B-itEvaluation results
- Linear Probe Accuracy on CIFAR-10test set self-reported94.000