File size: 6,503 Bytes
93f12f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5388aa0
 
 
 
 
 
 
 
 
 
 
 
93f12f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a8691c
5388aa0
93f12f3
721fa09
 
 
 
 
 
93f12f3
 
3a8691c
 
5388aa0
 
 
 
 
 
 
 
93f12f3
 
5388aa0
 
721fa09
5388aa0
 
93f12f3
 
3a8691c
5388aa0
721fa09
5388aa0
 
 
 
 
 
 
 
 
 
 
 
 
 
93f12f3
 
 
 
 
 
 
3a8691c
93f12f3
 
 
 
 
3a8691c
721fa09
 
93f12f3
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
---
language:
- en
- multilingual
license: apache-2.0
library_name: transformers
tags:
- feature-extraction
- image-feature-extraction
- vision
- vit
- gemma4
- google
- safetensors
pipeline_tag: image-feature-extraction
base_model: google/gemma-4-31B-it
model-index:
- name: gemma4-vision-encoder
  results:
  - task:
      type: image-classification
      name: CIFAR-10 (10-class)
    dataset:
      name: CIFAR-10
      type: cifar10
      split: test
    metrics:
    - type: accuracy
      value: 94.0
      name: Linear Probe Accuracy
---

# Gemma 4 Vision Encoder (27-layer ViT with 2D RoPE)

Standalone extraction of the vision encoder from Google's [Gemma 4 31B](https://huggingface.co/google/gemma-4-31B-it) multimodal model. This is a 569.6M parameter Vision Transformer with learned 2D positional embeddings, RoPE, QK-norms, and gated MLP — a significant upgrade from the SigLIP encoder used in Gemma 3.

**License:** Apache 2.0 (inherited from Gemma 4 — no restrictions)

## Architecture

| Property | Value |
|---|---|
| Total parameters | 569.6M |
| Architecture | ViT with 2D RoPE + learned positional embeddings |
| Hidden dimension | 1152 |
| Encoder layers | 27 |
| Attention heads | 16 (72 dim per head) |
| KV heads | 16 (full MHA, no GQA) |
| MLP | Gated (gate_proj + up_proj + down_proj) |
| MLP intermediate | 4304 |
| Activation | GELU (pytorch_tanh variant) |
| Normalization | RMSNorm (eps=1e-6) |
| Patch size | 16×16 |
| Pooling | 3×3 kernel (reduces token count by 9×) |
| Position embeddings | Learned 2D table (2, 10240, 1152) + RoPE (theta=100) |
| Q/K norms | Yes |
| Default output tokens | 280 |
| Configurable token budgets | 70, 140, 280, 560, 1120 |
| Input | Pre-patchified: `(batch, num_patches, 768)` where 768 = 3×16×16 |
| Output | `(num_valid_tokens, 1152)` after pooling + standardization |

### What's New vs Gemma 3 (SigLIP)

| | Gemma 3 Vision | Gemma 4 Vision (this model) |
|---|---|---|
| Architecture | SigLIP (ViT-SO400M) | Custom ViT with 2D RoPE |
| Layers | 27 | 27 |
| Hidden dim | 1152 | 1152 |
| Position encoding | Learned 1D | **Learned 2D + RoPE** |
| Attention | Standard | **QK-normed** |
| MLP | Standard (fc1 + fc2) | **Gated (gate + up + down)** |
| Aspect ratio | Fixed square (896×896) | **Variable aspect ratio** |
| Token budget | Fixed 256 | **Configurable (70–1120)** |
| Pooling | 4×4 average | **3×3** |

### Not Shared with E2B/E4B

Unlike the audio encoder (which is identical across E2B and E4B), the vision encoders differ:

| | E2B/E4B | 31B (this extraction) |
|---|---|---|
| Layers | 16 | **27** |
| Parameters | ~340M | **569.6M** |

## Usage

```python
import torch
from transformers import Gemma4VisionModel, Gemma4ImageProcessor
from PIL import Image

# Load vision encoder directly from this repo
vision_model = Gemma4VisionModel.from_pretrained(
    "rnagabh/gemma4-vision-encoder",
    torch_dtype=torch.bfloat16,
)
vision_model.to("cuda")
vision_model.eval()

# Load image processor (saved in this repo)
image_processor = Gemma4ImageProcessor.from_pretrained("rnagabh/gemma4-vision-encoder")

# Process an image
img = Image.open("your_image.jpg")
processed = image_processor(images=[img], return_tensors="pt")

pixel_values = processed["pixel_values"].to(dtype=torch.bfloat16, device="cuda")
position_ids = processed["image_position_ids"].to(device="cuda")
tokens_per_image = processed["num_soft_tokens_per_image"]  # for splitting batch output

with torch.no_grad():
    output = vision_model(pixel_values=pixel_values, pixel_position_ids=position_ids)
    embeddings = output.last_hidden_state  # (num_tokens, 1152)

    # Mean-pool for a single image vector
    image_embedding = embeddings.float().mean(dim=0)  # (1152,)
```

> **Important:** Always use the `Gemma4ImageProcessor` included in this repo for preprocessing.
> It handles resizing, patchification, position ID generation, and pixel normalization.
> Manual patchification without this processor will produce significantly degraded results.

## Benchmark Results (frozen 1152-dim embeddings, linear probe)

### CIFAR-10 Classification

| Metric | Value |
|---|---|
| Linear probe accuracy | **94.0%** |
| Random baseline | 10.0% |
| Improvement over chance | **9.4×** |
| Dataset | CIFAR-10 test set (1000 samples, 100 per class) |
| Probe | Logistic regression on L2-normalized mean-pooled embeddings |

Strong performance across all classes: airplane (0.98 F1), ship (0.98 F1), truck (0.97 F1), automobile (0.97 F1). Weakest class is cat (0.86 F1) — a fine-grained category that is inherently harder.

## Files in This Repo

| File | Description | Size |
|---|---|---|
| `config.json` | Vision encoder config (Gemma4VisionConfig) | <1 KB |
| `model.safetensors` | Vision encoder weights (569.6M params, BF16) | 1,139 MB |
| `preprocessor_config.json` | Image processor config (Gemma4ImageProcessor) | <1 KB |
| `embed_vision.safetensors` | Vision→text embedding projection (1152→5376) | 12.4 MB |

## Limitations

- **End-to-end trained for LLM decoding:** The encoder was trained to produce features for Gemma 4's text decoder. The 1152-dim output is the pure vision representation; the `embed_vision` projection maps to the 31B's text hidden space (5376-dim).
- **Requires image processor:** Use the `Gemma4ImageProcessor` included in this repo for preprocessing. The model expects pre-patchified `(B, num_patches, 768)` tensors with explicit 2D position IDs — the processor handles this automatically.
- **Variable aspect ratio support:** The 2D position embeddings enable non-square images. The processor generates correct position IDs for any aspect ratio.
- **Output shape note:** The pooler strips padding and collapses the batch dimension, returning `(num_valid_tokens, 1152)`. For batched inference, use `num_soft_tokens_per_image` from the processor to split the output back into per-image embeddings.

## Extraction Details

- Extracted from `google/gemma-4-31B-it` by downloading only the shard containing vision tower weights (`model-00001-of-00002.safetensors`)
- No full model load required — targeted tensor extraction
- Weights loaded with `strict=True` — perfect match
- Forward pass verified: 864×864 image → (324, 1152) output
- All architecture specs verified against the live model config

## References

- [Gemma 4 on HuggingFace](https://huggingface.co/google/gemma-4-31B-it)
- [Gemma 4 Blog Post](https://huggingface.co/blog/gemma4)
- [Gemma 4 Architecture Comparison](https://g4.si5.pl/)