rnagabh commited on
Commit
93f12f3
·
verified ·
1 Parent(s): f1ed3d8

Initial upload: Gemma 4 vision encoder (569.6M, 27-layer ViT with 2D RoPE)

Browse files
Files changed (4) hide show
  1. README.md +153 -0
  2. config.json +49 -0
  3. embed_vision.safetensors +3 -0
  4. model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ - multilingual
5
+ license: apache-2.0
6
+ library_name: transformers
7
+ tags:
8
+ - feature-extraction
9
+ - image-feature-extraction
10
+ - vision
11
+ - vit
12
+ - gemma4
13
+ - google
14
+ - safetensors
15
+ pipeline_tag: image-feature-extraction
16
+ base_model: google/gemma-4-31B-it
17
+ model-index:
18
+ - name: gemma4-vision-encoder
19
+ results: []
20
+ ---
21
+
22
+ # Gemma 4 Vision Encoder (27-layer ViT with 2D RoPE)
23
+
24
+ Standalone extraction of the vision encoder from Google's [Gemma 4 31B](https://huggingface.co/google/gemma-4-31B-it) multimodal model. This is a 569.6M parameter Vision Transformer with learned 2D positional embeddings, RoPE, QK-norms, and gated MLP — a significant upgrade from the SigLIP encoder used in Gemma 3.
25
+
26
+ **License:** Apache 2.0 (inherited from Gemma 4 — no restrictions)
27
+
28
+ ## Architecture
29
+
30
+ | Property | Value |
31
+ |---|---|
32
+ | Total parameters | 569.6M |
33
+ | Architecture | ViT with 2D RoPE + learned positional embeddings |
34
+ | Hidden dimension | 1152 |
35
+ | Encoder layers | 27 |
36
+ | Attention heads | 16 (72 dim per head) |
37
+ | KV heads | 16 (full MHA, no GQA) |
38
+ | MLP | Gated (gate_proj + up_proj + down_proj) |
39
+ | MLP intermediate | 4304 |
40
+ | Activation | GELU (pytorch_tanh variant) |
41
+ | Normalization | RMSNorm (eps=1e-6) |
42
+ | Patch size | 16×16 |
43
+ | Pooling | 3×3 kernel (reduces token count by 9×) |
44
+ | Position embeddings | Learned 2D table (2, 10240, 1152) + RoPE (theta=100) |
45
+ | Q/K norms | Yes |
46
+ | Default output tokens | 280 |
47
+ | Configurable token budgets | 70, 140, 280, 560, 1120 |
48
+ | Input | Pre-patchified: `(batch, num_patches, 768)` where 768 = 3×16×16 |
49
+ | Output | `(num_valid_tokens, 1152)` after pooling + standardization |
50
+
51
+ ### What's New vs Gemma 3 (SigLIP)
52
+
53
+ | | Gemma 3 Vision | Gemma 4 Vision (this model) |
54
+ |---|---|---|
55
+ | Architecture | SigLIP (ViT-SO400M) | Custom ViT with 2D RoPE |
56
+ | Layers | 27 | 27 |
57
+ | Hidden dim | 1152 | 1152 |
58
+ | Position encoding | Learned 1D | **Learned 2D + RoPE** |
59
+ | Attention | Standard | **QK-normed** |
60
+ | MLP | Standard (fc1 + fc2) | **Gated (gate + up + down)** |
61
+ | Aspect ratio | Fixed square (896×896) | **Variable aspect ratio** |
62
+ | Token budget | Fixed 256 | **Configurable (70–1120)** |
63
+ | Pooling | 4×4 average | **3×3** |
64
+
65
+ ### Not Shared with E2B/E4B
66
+
67
+ Unlike the audio encoder (which is identical across E2B and E4B), the vision encoders differ:
68
+
69
+ | | E2B/E4B | 31B (this extraction) |
70
+ |---|---|---|
71
+ | Layers | 16 | **27** |
72
+ | Parameters | ~340M | **569.6M** |
73
+
74
+ ## Usage
75
+
76
+ ```python
77
+ import torch
78
+ from transformers import Gemma4VisionModel, Gemma4VisionConfig
79
+ from safetensors.torch import load_file
80
+
81
+ # Load vision encoder from this repo
82
+ cfg = Gemma4VisionConfig.from_pretrained("rnagabh/gemma4-vision-encoder")
83
+ vision_model = Gemma4VisionModel(cfg)
84
+ state_dict = load_file("path/to/model.safetensors") # or download from repo
85
+ vision_model.load_state_dict(state_dict, strict=True)
86
+ vision_model = vision_model.to(dtype=torch.bfloat16, device="cuda")
87
+ vision_model.eval()
88
+
89
+ # Prepare image: patchify and create position IDs
90
+ # Image must have sides divisible by patch_size (16) AND
91
+ # num_patches must be divisible by pooling_kernel^2 (9)
92
+ # Good sizes: 864 (54 patches/side), 768 (48), 576 (36)
93
+ P = 16
94
+ img_size = 864
95
+ patches_per_side = img_size // P # 54
96
+
97
+ # Patchify: (B, C, H, W) → (B, num_patches, C*P*P)
98
+ img = torch.randn(1, 3, img_size, img_size, dtype=torch.bfloat16, device="cuda")
99
+ patches = img.unfold(2, P, P).unfold(3, P, P)
100
+ patches = patches.contiguous().view(1, 3, -1, P, P)
101
+ patches = patches.permute(0, 2, 1, 3, 4)
102
+ patches = patches.reshape(1, -1, 3 * P * P) # (1, 2916, 768)
103
+
104
+ # Position IDs: (batch, num_patches, 2) as (x, y) coordinates
105
+ ys, xs = torch.meshgrid(
106
+ torch.arange(patches_per_side),
107
+ torch.arange(patches_per_side),
108
+ indexing="ij",
109
+ )
110
+ position_ids = torch.stack([xs.flatten(), ys.flatten()], dim=-1)
111
+ position_ids = position_ids.unsqueeze(0).to(device="cuda") # (1, 2916, 2)
112
+
113
+ with torch.no_grad():
114
+ output = vision_model(pixel_values=patches, pixel_position_ids=position_ids)
115
+ embeddings = output.last_hidden_state # (324, 1152) — pooled tokens
116
+ ```
117
+
118
+ > **Image size constraints:** The number of patches must be divisible by the pooling kernel² (9).
119
+ > This means each image dimension divided by patch_size (16) must be divisible by 3.
120
+ > Valid image sizes include: 576, 768, 864, 960, 1152, etc.
121
+
122
+ > **Output shape:** The batch dimension is collapsed — the pooler strips padding and returns
123
+ > a flat `(num_valid_tokens, hidden_dim)` tensor. For a single 864×864 image, you get
124
+ > `(324, 1152)` — 324 pooled visual tokens at 1152 dimensions.
125
+
126
+ ## Files in This Repo
127
+
128
+ | File | Description | Size |
129
+ |---|---|---|
130
+ | `config.json` | Vision encoder config (Gemma4VisionConfig) | <1 KB |
131
+ | `model.safetensors` | Vision encoder weights (569.6M params, BF16) | 1,139 MB |
132
+ | `embed_vision.safetensors` | Vision→text embedding projection (1152→5376) | 12.4 MB |
133
+
134
+ ## Limitations
135
+
136
+ - **End-to-end trained for LLM decoding:** The encoder was trained to produce features for Gemma 4's text decoder. The 1152-dim output is the pure vision representation; the `embed_vision` projection maps to the 31B's text hidden space (5376-dim).
137
+ - **Requires pre-patchified input:** Unlike standard ViT models that accept raw `(B, C, H, W)` images, this model expects pre-patchified `(B, num_patches, 768)` tensors with explicit position IDs.
138
+ - **Variable aspect ratio support:** The 2D position embeddings enable non-square images, but you must provide correct `pixel_position_ids` for each patch.
139
+ - **No built-in image preprocessing:** You need to handle resizing, normalization (the model does `2*(x-0.5)` internally), and patchification yourself, or use the parent model's processor.
140
+
141
+ ## Extraction Details
142
+
143
+ - Extracted from `google/gemma-4-31B-it` by downloading only the shard containing vision tower weights (`model-00001-of-00002.safetensors`)
144
+ - No full model load required — targeted tensor extraction
145
+ - Weights loaded with `strict=True` — perfect match
146
+ - Forward pass verified: 864×864 image → (324, 1152) output
147
+ - All architecture specs verified against the live model config
148
+
149
+ ## References
150
+
151
+ - [Gemma 4 on HuggingFace](https://huggingface.co/google/gemma-4-31B-it)
152
+ - [Gemma 4 Blog Post](https://huggingface.co/blog/gemma4)
153
+ - [Gemma 4 Architecture Comparison](https://g4.si5.pl/)
config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "",
3
+ "architectures": [
4
+ "Gemma4VisionModel"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "chunk_size_feed_forward": 0,
9
+ "default_output_length": 280,
10
+ "dtype": "bfloat16",
11
+ "global_head_dim": 72,
12
+ "head_dim": 72,
13
+ "hidden_activation": "gelu_pytorch_tanh",
14
+ "hidden_size": 1152,
15
+ "id2label": {
16
+ "0": "LABEL_0",
17
+ "1": "LABEL_1"
18
+ },
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 4304,
21
+ "is_encoder_decoder": false,
22
+ "label2id": {
23
+ "LABEL_0": 0,
24
+ "LABEL_1": 1
25
+ },
26
+ "max_position_embeddings": 131072,
27
+ "model_type": "gemma4_vision",
28
+ "num_attention_heads": 16,
29
+ "num_hidden_layers": 27,
30
+ "num_key_value_heads": 16,
31
+ "output_attentions": false,
32
+ "output_hidden_states": false,
33
+ "patch_size": 16,
34
+ "pooling_kernel_size": 3,
35
+ "position_embedding_size": 10240,
36
+ "problem_type": null,
37
+ "return_dict": true,
38
+ "rms_norm_eps": 1e-06,
39
+ "rope_parameters": {
40
+ "rope_theta": 100.0,
41
+ "rope_type": "default"
42
+ },
43
+ "standardize": true,
44
+ "use_clipped_linears": false,
45
+ "torch_dtype": "bfloat16",
46
+ "_source_model": "google/gemma-4-31B-it",
47
+ "_extraction_note": "Vision tower extracted from 31B model",
48
+ "_verified_total_params": 569550384
49
+ }
embed_vision.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7de065bb74191a84c11c7a436eba29278ae1934b92c386f6fcfa515b2d5c0c7f
3
+ size 12386408
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:408421234ecaab33c9b641efe86d369d8398252d812c23faccfbd1cb1d744ccb
3
+ size 1139143360