| --- |
| language: |
| - en |
| - multilingual |
| license: apache-2.0 |
| library_name: transformers |
| tags: |
| - feature-extraction |
| - image-feature-extraction |
| - vision |
| - vit |
| - gemma4 |
| - google |
| - safetensors |
| pipeline_tag: image-feature-extraction |
| base_model: google/gemma-4-31B-it |
| model-index: |
| - name: gemma4-vision-encoder |
| results: |
| - task: |
| type: image-classification |
| name: CIFAR-10 (10-class) |
| dataset: |
| name: CIFAR-10 |
| type: cifar10 |
| split: test |
| metrics: |
| - type: accuracy |
| value: 94.0 |
| name: Linear Probe Accuracy |
| --- |
| |
| # Gemma 4 Vision Encoder (27-layer ViT with 2D RoPE) |
|
|
| Standalone extraction of the vision encoder from Google's [Gemma 4 31B](https://huggingface.co/google/gemma-4-31B-it) multimodal model. This is a 569.6M parameter Vision Transformer with learned 2D positional embeddings, RoPE, QK-norms, and gated MLP — a significant upgrade from the SigLIP encoder used in Gemma 3. |
|
|
| **License:** Apache 2.0 (inherited from Gemma 4 — no restrictions) |
|
|
| ## Architecture |
|
|
| | Property | Value | |
| |---|---| |
| | Total parameters | 569.6M | |
| | Architecture | ViT with 2D RoPE + learned positional embeddings | |
| | Hidden dimension | 1152 | |
| | Encoder layers | 27 | |
| | Attention heads | 16 (72 dim per head) | |
| | KV heads | 16 (full MHA, no GQA) | |
| | MLP | Gated (gate_proj + up_proj + down_proj) | |
| | MLP intermediate | 4304 | |
| | Activation | GELU (pytorch_tanh variant) | |
| | Normalization | RMSNorm (eps=1e-6) | |
| | Patch size | 16×16 | |
| | Pooling | 3×3 kernel (reduces token count by 9×) | |
| | Position embeddings | Learned 2D table (2, 10240, 1152) + RoPE (theta=100) | |
| | Q/K norms | Yes | |
| | Default output tokens | 280 | |
| | Configurable token budgets | 70, 140, 280, 560, 1120 | |
| | Input | Pre-patchified: `(batch, num_patches, 768)` where 768 = 3×16×16 | |
| | Output | `(num_valid_tokens, 1152)` after pooling + standardization | |
|
|
| ### What's New vs Gemma 3 (SigLIP) |
|
|
| | | Gemma 3 Vision | Gemma 4 Vision (this model) | |
| |---|---|---| |
| | Architecture | SigLIP (ViT-SO400M) | Custom ViT with 2D RoPE | |
| | Layers | 27 | 27 | |
| | Hidden dim | 1152 | 1152 | |
| | Position encoding | Learned 1D | **Learned 2D + RoPE** | |
| | Attention | Standard | **QK-normed** | |
| | MLP | Standard (fc1 + fc2) | **Gated (gate + up + down)** | |
| | Aspect ratio | Fixed square (896×896) | **Variable aspect ratio** | |
| | Token budget | Fixed 256 | **Configurable (70–1120)** | |
| | Pooling | 4×4 average | **3×3** | |
|
|
| ### Not Shared with E2B/E4B |
|
|
| Unlike the audio encoder (which is identical across E2B and E4B), the vision encoders differ: |
|
|
| | | E2B/E4B | 31B (this extraction) | |
| |---|---|---| |
| | Layers | 16 | **27** | |
| | Parameters | ~340M | **569.6M** | |
|
|
| ## Usage |
|
|
| ```python |
| import torch |
| from transformers import Gemma4VisionModel, Gemma4ImageProcessor |
| from PIL import Image |
| |
| # Load vision encoder directly from this repo |
| vision_model = Gemma4VisionModel.from_pretrained( |
| "rnagabh/gemma4-vision-encoder", |
| torch_dtype=torch.bfloat16, |
| ) |
| vision_model.to("cuda") |
| vision_model.eval() |
| |
| # Load image processor (saved in this repo) |
| image_processor = Gemma4ImageProcessor.from_pretrained("rnagabh/gemma4-vision-encoder") |
| |
| # Process an image |
| img = Image.open("your_image.jpg") |
| processed = image_processor(images=[img], return_tensors="pt") |
| |
| pixel_values = processed["pixel_values"].to(dtype=torch.bfloat16, device="cuda") |
| position_ids = processed["image_position_ids"].to(device="cuda") |
| tokens_per_image = processed["num_soft_tokens_per_image"] # for splitting batch output |
| |
| with torch.no_grad(): |
| output = vision_model(pixel_values=pixel_values, pixel_position_ids=position_ids) |
| embeddings = output.last_hidden_state # (num_tokens, 1152) |
| |
| # Mean-pool for a single image vector |
| image_embedding = embeddings.float().mean(dim=0) # (1152,) |
| ``` |
|
|
| > **Important:** Always use the `Gemma4ImageProcessor` included in this repo for preprocessing. |
| > It handles resizing, patchification, position ID generation, and pixel normalization. |
| > Manual patchification without this processor will produce significantly degraded results. |
|
|
| ## Benchmark Results (frozen 1152-dim embeddings, linear probe) |
|
|
| ### CIFAR-10 Classification |
|
|
| | Metric | Value | |
| |---|---| |
| | Linear probe accuracy | **94.0%** | |
| | Random baseline | 10.0% | |
| | Improvement over chance | **9.4×** | |
| | Dataset | CIFAR-10 test set (1000 samples, 100 per class) | |
| | Probe | Logistic regression on L2-normalized mean-pooled embeddings | |
|
|
| Strong performance across all classes: airplane (0.98 F1), ship (0.98 F1), truck (0.97 F1), automobile (0.97 F1). Weakest class is cat (0.86 F1) — a fine-grained category that is inherently harder. |
|
|
| ## Files in This Repo |
|
|
| | File | Description | Size | |
| |---|---|---| |
| | `config.json` | Vision encoder config (Gemma4VisionConfig) | <1 KB | |
| | `model.safetensors` | Vision encoder weights (569.6M params, BF16) | 1,139 MB | |
| | `preprocessor_config.json` | Image processor config (Gemma4ImageProcessor) | <1 KB | |
| | `embed_vision.safetensors` | Vision→text embedding projection (1152→5376) | 12.4 MB | |
|
|
| ## Limitations |
|
|
| - **End-to-end trained for LLM decoding:** The encoder was trained to produce features for Gemma 4's text decoder. The 1152-dim output is the pure vision representation; the `embed_vision` projection maps to the 31B's text hidden space (5376-dim). |
| - **Requires image processor:** Use the `Gemma4ImageProcessor` included in this repo for preprocessing. The model expects pre-patchified `(B, num_patches, 768)` tensors with explicit 2D position IDs — the processor handles this automatically. |
| - **Variable aspect ratio support:** The 2D position embeddings enable non-square images. The processor generates correct position IDs for any aspect ratio. |
| - **Output shape note:** The pooler strips padding and collapses the batch dimension, returning `(num_valid_tokens, 1152)`. For batched inference, use `num_soft_tokens_per_image` from the processor to split the output back into per-image embeddings. |
|
|
| ## Extraction Details |
|
|
| - Extracted from `google/gemma-4-31B-it` by downloading only the shard containing vision tower weights (`model-00001-of-00002.safetensors`) |
| - No full model load required — targeted tensor extraction |
| - Weights loaded with `strict=True` — perfect match |
| - Forward pass verified: 864×864 image → (324, 1152) output |
| - All architecture specs verified against the live model config |
|
|
| ## References |
|
|
| - [Gemma 4 on HuggingFace](https://huggingface.co/google/gemma-4-31B-it) |
| - [Gemma 4 Blog Post](https://huggingface.co/blog/gemma4) |
| - [Gemma 4 Architecture Comparison](https://g4.si5.pl/) |