File size: 6,503 Bytes
93f12f3 5388aa0 93f12f3 3a8691c 5388aa0 93f12f3 721fa09 93f12f3 3a8691c 5388aa0 93f12f3 5388aa0 721fa09 5388aa0 93f12f3 3a8691c 5388aa0 721fa09 5388aa0 93f12f3 3a8691c 93f12f3 3a8691c 721fa09 93f12f3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | ---
language:
- en
- multilingual
license: apache-2.0
library_name: transformers
tags:
- feature-extraction
- image-feature-extraction
- vision
- vit
- gemma4
- google
- safetensors
pipeline_tag: image-feature-extraction
base_model: google/gemma-4-31B-it
model-index:
- name: gemma4-vision-encoder
results:
- task:
type: image-classification
name: CIFAR-10 (10-class)
dataset:
name: CIFAR-10
type: cifar10
split: test
metrics:
- type: accuracy
value: 94.0
name: Linear Probe Accuracy
---
# Gemma 4 Vision Encoder (27-layer ViT with 2D RoPE)
Standalone extraction of the vision encoder from Google's [Gemma 4 31B](https://huggingface.co/google/gemma-4-31B-it) multimodal model. This is a 569.6M parameter Vision Transformer with learned 2D positional embeddings, RoPE, QK-norms, and gated MLP — a significant upgrade from the SigLIP encoder used in Gemma 3.
**License:** Apache 2.0 (inherited from Gemma 4 — no restrictions)
## Architecture
| Property | Value |
|---|---|
| Total parameters | 569.6M |
| Architecture | ViT with 2D RoPE + learned positional embeddings |
| Hidden dimension | 1152 |
| Encoder layers | 27 |
| Attention heads | 16 (72 dim per head) |
| KV heads | 16 (full MHA, no GQA) |
| MLP | Gated (gate_proj + up_proj + down_proj) |
| MLP intermediate | 4304 |
| Activation | GELU (pytorch_tanh variant) |
| Normalization | RMSNorm (eps=1e-6) |
| Patch size | 16×16 |
| Pooling | 3×3 kernel (reduces token count by 9×) |
| Position embeddings | Learned 2D table (2, 10240, 1152) + RoPE (theta=100) |
| Q/K norms | Yes |
| Default output tokens | 280 |
| Configurable token budgets | 70, 140, 280, 560, 1120 |
| Input | Pre-patchified: `(batch, num_patches, 768)` where 768 = 3×16×16 |
| Output | `(num_valid_tokens, 1152)` after pooling + standardization |
### What's New vs Gemma 3 (SigLIP)
| | Gemma 3 Vision | Gemma 4 Vision (this model) |
|---|---|---|
| Architecture | SigLIP (ViT-SO400M) | Custom ViT with 2D RoPE |
| Layers | 27 | 27 |
| Hidden dim | 1152 | 1152 |
| Position encoding | Learned 1D | **Learned 2D + RoPE** |
| Attention | Standard | **QK-normed** |
| MLP | Standard (fc1 + fc2) | **Gated (gate + up + down)** |
| Aspect ratio | Fixed square (896×896) | **Variable aspect ratio** |
| Token budget | Fixed 256 | **Configurable (70–1120)** |
| Pooling | 4×4 average | **3×3** |
### Not Shared with E2B/E4B
Unlike the audio encoder (which is identical across E2B and E4B), the vision encoders differ:
| | E2B/E4B | 31B (this extraction) |
|---|---|---|
| Layers | 16 | **27** |
| Parameters | ~340M | **569.6M** |
## Usage
```python
import torch
from transformers import Gemma4VisionModel, Gemma4ImageProcessor
from PIL import Image
# Load vision encoder directly from this repo
vision_model = Gemma4VisionModel.from_pretrained(
"rnagabh/gemma4-vision-encoder",
torch_dtype=torch.bfloat16,
)
vision_model.to("cuda")
vision_model.eval()
# Load image processor (saved in this repo)
image_processor = Gemma4ImageProcessor.from_pretrained("rnagabh/gemma4-vision-encoder")
# Process an image
img = Image.open("your_image.jpg")
processed = image_processor(images=[img], return_tensors="pt")
pixel_values = processed["pixel_values"].to(dtype=torch.bfloat16, device="cuda")
position_ids = processed["image_position_ids"].to(device="cuda")
tokens_per_image = processed["num_soft_tokens_per_image"] # for splitting batch output
with torch.no_grad():
output = vision_model(pixel_values=pixel_values, pixel_position_ids=position_ids)
embeddings = output.last_hidden_state # (num_tokens, 1152)
# Mean-pool for a single image vector
image_embedding = embeddings.float().mean(dim=0) # (1152,)
```
> **Important:** Always use the `Gemma4ImageProcessor` included in this repo for preprocessing.
> It handles resizing, patchification, position ID generation, and pixel normalization.
> Manual patchification without this processor will produce significantly degraded results.
## Benchmark Results (frozen 1152-dim embeddings, linear probe)
### CIFAR-10 Classification
| Metric | Value |
|---|---|
| Linear probe accuracy | **94.0%** |
| Random baseline | 10.0% |
| Improvement over chance | **9.4×** |
| Dataset | CIFAR-10 test set (1000 samples, 100 per class) |
| Probe | Logistic regression on L2-normalized mean-pooled embeddings |
Strong performance across all classes: airplane (0.98 F1), ship (0.98 F1), truck (0.97 F1), automobile (0.97 F1). Weakest class is cat (0.86 F1) — a fine-grained category that is inherently harder.
## Files in This Repo
| File | Description | Size |
|---|---|---|
| `config.json` | Vision encoder config (Gemma4VisionConfig) | <1 KB |
| `model.safetensors` | Vision encoder weights (569.6M params, BF16) | 1,139 MB |
| `preprocessor_config.json` | Image processor config (Gemma4ImageProcessor) | <1 KB |
| `embed_vision.safetensors` | Vision→text embedding projection (1152→5376) | 12.4 MB |
## Limitations
- **End-to-end trained for LLM decoding:** The encoder was trained to produce features for Gemma 4's text decoder. The 1152-dim output is the pure vision representation; the `embed_vision` projection maps to the 31B's text hidden space (5376-dim).
- **Requires image processor:** Use the `Gemma4ImageProcessor` included in this repo for preprocessing. The model expects pre-patchified `(B, num_patches, 768)` tensors with explicit 2D position IDs — the processor handles this automatically.
- **Variable aspect ratio support:** The 2D position embeddings enable non-square images. The processor generates correct position IDs for any aspect ratio.
- **Output shape note:** The pooler strips padding and collapses the batch dimension, returning `(num_valid_tokens, 1152)`. For batched inference, use `num_soft_tokens_per_image` from the processor to split the output back into per-image embeddings.
## Extraction Details
- Extracted from `google/gemma-4-31B-it` by downloading only the shard containing vision tower weights (`model-00001-of-00002.safetensors`)
- No full model load required — targeted tensor extraction
- Weights loaded with `strict=True` — perfect match
- Forward pass verified: 864×864 image → (324, 1152) output
- All architecture specs verified against the live model config
## References
- [Gemma 4 on HuggingFace](https://huggingface.co/google/gemma-4-31B-it)
- [Gemma 4 Blog Post](https://huggingface.co/blog/gemma4)
- [Gemma 4 Architecture Comparison](https://g4.si5.pl/) |