rnagabh's picture
Initial upload: Gemma 4 vision encoder (569.6M, 27-layer ViT with 2D RoPE)
3a8691c verified
---
language:
- en
- multilingual
license: apache-2.0
library_name: transformers
tags:
- feature-extraction
- image-feature-extraction
- vision
- vit
- gemma4
- google
- safetensors
pipeline_tag: image-feature-extraction
base_model: google/gemma-4-31B-it
model-index:
- name: gemma4-vision-encoder
results:
- task:
type: image-classification
name: CIFAR-10 (10-class)
dataset:
name: CIFAR-10
type: cifar10
split: test
metrics:
- type: accuracy
value: 94.0
name: Linear Probe Accuracy
---
# Gemma 4 Vision Encoder (27-layer ViT with 2D RoPE)
Standalone extraction of the vision encoder from Google's [Gemma 4 31B](https://huggingface.co/google/gemma-4-31B-it) multimodal model. This is a 569.6M parameter Vision Transformer with learned 2D positional embeddings, RoPE, QK-norms, and gated MLP — a significant upgrade from the SigLIP encoder used in Gemma 3.
**License:** Apache 2.0 (inherited from Gemma 4 — no restrictions)
## Architecture
| Property | Value |
|---|---|
| Total parameters | 569.6M |
| Architecture | ViT with 2D RoPE + learned positional embeddings |
| Hidden dimension | 1152 |
| Encoder layers | 27 |
| Attention heads | 16 (72 dim per head) |
| KV heads | 16 (full MHA, no GQA) |
| MLP | Gated (gate_proj + up_proj + down_proj) |
| MLP intermediate | 4304 |
| Activation | GELU (pytorch_tanh variant) |
| Normalization | RMSNorm (eps=1e-6) |
| Patch size | 16×16 |
| Pooling | 3×3 kernel (reduces token count by 9×) |
| Position embeddings | Learned 2D table (2, 10240, 1152) + RoPE (theta=100) |
| Q/K norms | Yes |
| Default output tokens | 280 |
| Configurable token budgets | 70, 140, 280, 560, 1120 |
| Input | Pre-patchified: `(batch, num_patches, 768)` where 768 = 3×16×16 |
| Output | `(num_valid_tokens, 1152)` after pooling + standardization |
### What's New vs Gemma 3 (SigLIP)
| | Gemma 3 Vision | Gemma 4 Vision (this model) |
|---|---|---|
| Architecture | SigLIP (ViT-SO400M) | Custom ViT with 2D RoPE |
| Layers | 27 | 27 |
| Hidden dim | 1152 | 1152 |
| Position encoding | Learned 1D | **Learned 2D + RoPE** |
| Attention | Standard | **QK-normed** |
| MLP | Standard (fc1 + fc2) | **Gated (gate + up + down)** |
| Aspect ratio | Fixed square (896×896) | **Variable aspect ratio** |
| Token budget | Fixed 256 | **Configurable (70–1120)** |
| Pooling | 4×4 average | **3×3** |
### Not Shared with E2B/E4B
Unlike the audio encoder (which is identical across E2B and E4B), the vision encoders differ:
| | E2B/E4B | 31B (this extraction) |
|---|---|---|
| Layers | 16 | **27** |
| Parameters | ~340M | **569.6M** |
## Usage
```python
import torch
from transformers import Gemma4VisionModel, Gemma4ImageProcessor
from PIL import Image
# Load vision encoder directly from this repo
vision_model = Gemma4VisionModel.from_pretrained(
"rnagabh/gemma4-vision-encoder",
torch_dtype=torch.bfloat16,
)
vision_model.to("cuda")
vision_model.eval()
# Load image processor (saved in this repo)
image_processor = Gemma4ImageProcessor.from_pretrained("rnagabh/gemma4-vision-encoder")
# Process an image
img = Image.open("your_image.jpg")
processed = image_processor(images=[img], return_tensors="pt")
pixel_values = processed["pixel_values"].to(dtype=torch.bfloat16, device="cuda")
position_ids = processed["image_position_ids"].to(device="cuda")
tokens_per_image = processed["num_soft_tokens_per_image"] # for splitting batch output
with torch.no_grad():
output = vision_model(pixel_values=pixel_values, pixel_position_ids=position_ids)
embeddings = output.last_hidden_state # (num_tokens, 1152)
# Mean-pool for a single image vector
image_embedding = embeddings.float().mean(dim=0) # (1152,)
```
> **Important:** Always use the `Gemma4ImageProcessor` included in this repo for preprocessing.
> It handles resizing, patchification, position ID generation, and pixel normalization.
> Manual patchification without this processor will produce significantly degraded results.
## Benchmark Results (frozen 1152-dim embeddings, linear probe)
### CIFAR-10 Classification
| Metric | Value |
|---|---|
| Linear probe accuracy | **94.0%** |
| Random baseline | 10.0% |
| Improvement over chance | **9.4×** |
| Dataset | CIFAR-10 test set (1000 samples, 100 per class) |
| Probe | Logistic regression on L2-normalized mean-pooled embeddings |
Strong performance across all classes: airplane (0.98 F1), ship (0.98 F1), truck (0.97 F1), automobile (0.97 F1). Weakest class is cat (0.86 F1) — a fine-grained category that is inherently harder.
## Files in This Repo
| File | Description | Size |
|---|---|---|
| `config.json` | Vision encoder config (Gemma4VisionConfig) | <1 KB |
| `model.safetensors` | Vision encoder weights (569.6M params, BF16) | 1,139 MB |
| `preprocessor_config.json` | Image processor config (Gemma4ImageProcessor) | <1 KB |
| `embed_vision.safetensors` | Vision→text embedding projection (1152→5376) | 12.4 MB |
## Limitations
- **End-to-end trained for LLM decoding:** The encoder was trained to produce features for Gemma 4's text decoder. The 1152-dim output is the pure vision representation; the `embed_vision` projection maps to the 31B's text hidden space (5376-dim).
- **Requires image processor:** Use the `Gemma4ImageProcessor` included in this repo for preprocessing. The model expects pre-patchified `(B, num_patches, 768)` tensors with explicit 2D position IDs — the processor handles this automatically.
- **Variable aspect ratio support:** The 2D position embeddings enable non-square images. The processor generates correct position IDs for any aspect ratio.
- **Output shape note:** The pooler strips padding and collapses the batch dimension, returning `(num_valid_tokens, 1152)`. For batched inference, use `num_soft_tokens_per_image` from the processor to split the output back into per-image embeddings.
## Extraction Details
- Extracted from `google/gemma-4-31B-it` by downloading only the shard containing vision tower weights (`model-00001-of-00002.safetensors`)
- No full model load required — targeted tensor extraction
- Weights loaded with `strict=True` — perfect match
- Forward pass verified: 864×864 image → (324, 1152) output
- All architecture specs verified against the live model config
## References
- [Gemma 4 on HuggingFace](https://huggingface.co/google/gemma-4-31B-it)
- [Gemma 4 Blog Post](https://huggingface.co/blog/gemma4)
- [Gemma 4 Architecture Comparison](https://g4.si5.pl/)