--- language: - en - multilingual license: apache-2.0 library_name: transformers tags: - feature-extraction - image-feature-extraction - vision - vit - gemma4 - google - safetensors pipeline_tag: image-feature-extraction base_model: google/gemma-4-31B-it model-index: - name: gemma4-vision-encoder results: - task: type: image-classification name: CIFAR-10 (10-class) dataset: name: CIFAR-10 type: cifar10 split: test metrics: - type: accuracy value: 94.0 name: Linear Probe Accuracy --- # Gemma 4 Vision Encoder (27-layer ViT with 2D RoPE) Standalone extraction of the vision encoder from Google's [Gemma 4 31B](https://huggingface.co/google/gemma-4-31B-it) multimodal model. This is a 569.6M parameter Vision Transformer with learned 2D positional embeddings, RoPE, QK-norms, and gated MLP — a significant upgrade from the SigLIP encoder used in Gemma 3. **License:** Apache 2.0 (inherited from Gemma 4 — no restrictions) ## Architecture | Property | Value | |---|---| | Total parameters | 569.6M | | Architecture | ViT with 2D RoPE + learned positional embeddings | | Hidden dimension | 1152 | | Encoder layers | 27 | | Attention heads | 16 (72 dim per head) | | KV heads | 16 (full MHA, no GQA) | | MLP | Gated (gate_proj + up_proj + down_proj) | | MLP intermediate | 4304 | | Activation | GELU (pytorch_tanh variant) | | Normalization | RMSNorm (eps=1e-6) | | Patch size | 16×16 | | Pooling | 3×3 kernel (reduces token count by 9×) | | Position embeddings | Learned 2D table (2, 10240, 1152) + RoPE (theta=100) | | Q/K norms | Yes | | Default output tokens | 280 | | Configurable token budgets | 70, 140, 280, 560, 1120 | | Input | Pre-patchified: `(batch, num_patches, 768)` where 768 = 3×16×16 | | Output | `(num_valid_tokens, 1152)` after pooling + standardization | ### What's New vs Gemma 3 (SigLIP) | | Gemma 3 Vision | Gemma 4 Vision (this model) | |---|---|---| | Architecture | SigLIP (ViT-SO400M) | Custom ViT with 2D RoPE | | Layers | 27 | 27 | | Hidden dim | 1152 | 1152 | | Position encoding | Learned 1D | **Learned 2D + RoPE** | | Attention | Standard | **QK-normed** | | MLP | Standard (fc1 + fc2) | **Gated (gate + up + down)** | | Aspect ratio | Fixed square (896×896) | **Variable aspect ratio** | | Token budget | Fixed 256 | **Configurable (70–1120)** | | Pooling | 4×4 average | **3×3** | ### Not Shared with E2B/E4B Unlike the audio encoder (which is identical across E2B and E4B), the vision encoders differ: | | E2B/E4B | 31B (this extraction) | |---|---|---| | Layers | 16 | **27** | | Parameters | ~340M | **569.6M** | ## Usage ```python import torch from transformers import Gemma4VisionModel, Gemma4ImageProcessor from PIL import Image # Load vision encoder directly from this repo vision_model = Gemma4VisionModel.from_pretrained( "rnagabh/gemma4-vision-encoder", torch_dtype=torch.bfloat16, ) vision_model.to("cuda") vision_model.eval() # Load image processor (saved in this repo) image_processor = Gemma4ImageProcessor.from_pretrained("rnagabh/gemma4-vision-encoder") # Process an image img = Image.open("your_image.jpg") processed = image_processor(images=[img], return_tensors="pt") pixel_values = processed["pixel_values"].to(dtype=torch.bfloat16, device="cuda") position_ids = processed["image_position_ids"].to(device="cuda") tokens_per_image = processed["num_soft_tokens_per_image"] # for splitting batch output with torch.no_grad(): output = vision_model(pixel_values=pixel_values, pixel_position_ids=position_ids) embeddings = output.last_hidden_state # (num_tokens, 1152) # Mean-pool for a single image vector image_embedding = embeddings.float().mean(dim=0) # (1152,) ``` > **Important:** Always use the `Gemma4ImageProcessor` included in this repo for preprocessing. > It handles resizing, patchification, position ID generation, and pixel normalization. > Manual patchification without this processor will produce significantly degraded results. ## Benchmark Results (frozen 1152-dim embeddings, linear probe) ### CIFAR-10 Classification | Metric | Value | |---|---| | Linear probe accuracy | **94.0%** | | Random baseline | 10.0% | | Improvement over chance | **9.4×** | | Dataset | CIFAR-10 test set (1000 samples, 100 per class) | | Probe | Logistic regression on L2-normalized mean-pooled embeddings | Strong performance across all classes: airplane (0.98 F1), ship (0.98 F1), truck (0.97 F1), automobile (0.97 F1). Weakest class is cat (0.86 F1) — a fine-grained category that is inherently harder. ## Files in This Repo | File | Description | Size | |---|---|---| | `config.json` | Vision encoder config (Gemma4VisionConfig) | <1 KB | | `model.safetensors` | Vision encoder weights (569.6M params, BF16) | 1,139 MB | | `preprocessor_config.json` | Image processor config (Gemma4ImageProcessor) | <1 KB | | `embed_vision.safetensors` | Vision→text embedding projection (1152→5376) | 12.4 MB | ## Limitations - **End-to-end trained for LLM decoding:** The encoder was trained to produce features for Gemma 4's text decoder. The 1152-dim output is the pure vision representation; the `embed_vision` projection maps to the 31B's text hidden space (5376-dim). - **Requires image processor:** Use the `Gemma4ImageProcessor` included in this repo for preprocessing. The model expects pre-patchified `(B, num_patches, 768)` tensors with explicit 2D position IDs — the processor handles this automatically. - **Variable aspect ratio support:** The 2D position embeddings enable non-square images. The processor generates correct position IDs for any aspect ratio. - **Output shape note:** The pooler strips padding and collapses the batch dimension, returning `(num_valid_tokens, 1152)`. For batched inference, use `num_soft_tokens_per_image` from the processor to split the output back into per-image embeddings. ## Extraction Details - Extracted from `google/gemma-4-31B-it` by downloading only the shard containing vision tower weights (`model-00001-of-00002.safetensors`) - No full model load required — targeted tensor extraction - Weights loaded with `strict=True` — perfect match - Forward pass verified: 864×864 image → (324, 1152) output - All architecture specs verified against the live model config ## References - [Gemma 4 on HuggingFace](https://huggingface.co/google/gemma-4-31B-it) - [Gemma 4 Blog Post](https://huggingface.co/blog/gemma4) - [Gemma 4 Architecture Comparison](https://g4.si5.pl/)