ncylich commited on
Commit
ea203cb
Β·
verified Β·
1 Parent(s): fc43e86

Upload gemma4.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. gemma4.py +652 -0
gemma4.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemma 4 E2B β€” clean PyTorch forward pass (text model only).
3
+
4
+ Architecture:
5
+ - 35 decoder layers, hidden_size=1536, vocab=262144
6
+ - 8 Q heads, 1 KV head (MQA)
7
+ - Sliding attention layers (0-3, 5-8, 10-13, 15-18, 20-23, 25-28, 30-33):
8
+ head_dim=256, sliding_window=512, rope_theta=10000
9
+ - Full attention layers (every 5th: 4,9,14,19,24,29,34):
10
+ head_dim=512, partial_rotary_factor=0.25 (only first 128 of 512 dims rotated),
11
+ rope_theta=1000000
12
+ - MLP (all layers): GeGLU, intermediate_size=6144
13
+ - Per-layer auxiliary stream (full details below)
14
+ - layer_scalar: per-layer learned scalar multiplied onto residual contributions
15
+ - QK RMSNorm before RoPE, attn_scale=1.0
16
+ - Final: RMSNorm + tied lm_head + logit softcapping at 30.0
17
+
18
+ Per-layer auxiliary stream:
19
+ Model-level (computed once, before all layers):
20
+ 1. embed_tokens_per_layer(input_ids) β†’ [B, T, 35*256] (vocab lookup)
21
+ 2. per_layer_model_projection(x_embed) → [B, T, 35*256] (project hidden→aux)
22
+ scaled by hidden_size**-0.5
23
+ 3. per_layer_projection_norm (RMSNorm(256)) on the projection slice per layer
24
+ 4. Combine: per_layer_inputs = (embed_aux + proj_aux) * (1/sqrt(2))
25
+ reshaped to [B, T, 35, 256]
26
+
27
+ Per-layer (at layer i):
28
+ per_layer_input_i = per_layer_inputs[:, :, i, :] # [B, T, 256]
29
+ x_normed = input_layernorm(x)
30
+ gate = sigmoid(per_layer_input_gate(x_normed)) # [B, T, 256]
31
+ gated = gate * per_layer_input_i # [B, T, 256]
32
+ out = per_layer_projection(gated) # [B, T, 1536] (256β†’1536)
33
+ x = x + post_per_layer_input_norm(out)
34
+
35
+ Weight shapes in checkpoint:
36
+ per_layer_model_projection.weight : [8960, 1536] (Linear 1536β†’8960)
37
+ per_layer_projection_norm.weight : [256] (RMSNorm on 256-dim slices)
38
+ layers.i.per_layer_input_gate.weight : [256, 1536] (Linear 1536β†’256)
39
+ layers.i.per_layer_projection.weight : [1536, 256] (Linear 256β†’1536)
40
+ layers.i.post_per_layer_input_norm.weight : [1536] (RMSNorm on 1536-dim output)
41
+ """
42
+
43
+ import math
44
+ import os
45
+ from pathlib import Path
46
+
47
+ import torch
48
+ import torch.nn as nn
49
+ import torch.nn.functional as F
50
+ from safetensors import safe_open
51
+ from transformers import AutoTokenizer
52
+
53
+ # ── device / dtype ────────────────────────────────────────────────────────────
54
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
55
+ DTYPE = torch.bfloat16
56
+
57
+ # ── model path ────────────────────────────────────────────────────────────────
58
+ # Try known HF repo caches in order; first one that exists wins. Override with
59
+ # $GEMMA4_HF_REPO to point at an arbitrary repo cache (e.g., "google/gemma-4-e2b-it").
60
+ _HUB_ROOT = Path(os.path.expanduser("~/.cache/huggingface/hub"))
61
+ _REPO_CANDIDATES = (
62
+ os.environ.get("GEMMA4_HF_REPO", ""),
63
+ "gg-hf-gg/gemma-4-E2B",
64
+ "google/gemma-4-e2b-it",
65
+ )
66
+
67
+
68
+ def _resolve_model_paths():
69
+ """Return (snapshot_dir, safetensors_path). Picks first available repo+snapshot
70
+ that actually contains a .safetensors file. Iterates ALL snapshots per repo
71
+ before moving to the next repo β€” iterdir() order is not deterministic and HF
72
+ may keep multiple snapshots where only one has weights blob-resolved.
73
+ """
74
+ for repo in _REPO_CANDIDATES:
75
+ if not repo:
76
+ continue
77
+ repo_cache = _HUB_ROOT / ("models--" + repo.replace("/", "--"))
78
+ snap_root = repo_cache / "snapshots"
79
+ if not snap_root.is_dir():
80
+ continue
81
+ for snap in sorted(p for p in snap_root.iterdir() if p.is_dir()):
82
+ # Prefer model.safetensors (single-file) else any .safetensors
83
+ sft = snap / "model.safetensors"
84
+ if not sft.exists():
85
+ candidates = sorted(snap.glob("*.safetensors"))
86
+ if not candidates:
87
+ continue
88
+ sft = candidates[0]
89
+ return snap, sft
90
+ raise FileNotFoundError(
91
+ "No Gemma-4 E2B HF cache found. Tried: " + ", ".join(r for r in _REPO_CANDIDATES if r)
92
+ + ". Run `hf download google/gemma-4-e2b-it` or set GEMMA4_HF_REPO."
93
+ )
94
+
95
+
96
+ MODEL_DIR, SAFETENSORS_BLOB = _resolve_model_paths()
97
+
98
+ # ── architecture constants ────────────────────────────────────────────────────
99
+ N_LAYERS = 35
100
+ HIDDEN_SIZE = 1536
101
+ VOCAB_SIZE = 262144
102
+ N_Q_HEADS = 8
103
+ N_KV_HEADS = 1
104
+ HEAD_DIM_SLIDE = 256 # sliding attention head dim
105
+ HEAD_DIM_FULL = 512 # full attention head dim
106
+ PER_LAYER_DIM = 256 # per-layer auxiliary stream width per layer
107
+ INTERMEDIATE = 6144 # MLP intermediate size (layers 0-14)
108
+ INTERMEDIATE_WIDE = 12288 # double-wide MLP intermediate size (layers 15-34)
109
+ # Layers 15-34 use double-wide MLP (use_double_wide_mlp=True in config)
110
+ DOUBLE_WIDE_START = 15
111
+ SLIDING_WINDOW = 512
112
+ ROPE_THETA_SLIDE = 10_000.0
113
+ ROPE_THETA_FULL = 1_000_000.0
114
+ PARTIAL_ROT_FULL = 0.25 # only first floor(512*0.25)=128 dims get RoPE
115
+ RMS_EPS = 1e-6
116
+ LOGIT_CAP = 30.0
117
+ ATTN_SCALE = 1.0 # QK are RMSNorm'd, so no sqrt(d) scaling needed
118
+
119
+ # Per-layer projection scale: hidden_size**-0.5 (applied to per_layer_model_projection output)
120
+ PER_LAYER_PROJ_SCALE = HIDDEN_SIZE ** -0.5
121
+ # Input combination scale: 1/sqrt(2) (mix embed aux + model projection)
122
+ PER_LAYER_INPUT_SCALE = math.sqrt(0.5) # = 1/sqrt(2)
123
+
124
+ # Full-attention layers: every 5th layer (0-indexed: 4,9,14,19,24,29,34)
125
+ FULL_ATTN_LAYERS = frozenset(range(4, N_LAYERS, 5))
126
+
127
+
128
+ def is_full_attention(layer_idx: int) -> bool:
129
+ """Return True if layer_idx uses full (global) attention."""
130
+ return layer_idx in FULL_ATTN_LAYERS
131
+
132
+
133
+ # ── RMSNorm ───────────────────────────────────────────────────────────────────
134
+
135
+ class RMSNorm(nn.Module):
136
+ """RMSNorm with weight * normed, weight initialized to ones."""
137
+
138
+ def __init__(self, dim: int):
139
+ super().__init__()
140
+ self.weight = nn.Parameter(torch.ones(dim))
141
+
142
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
143
+ x_f32 = x.float()
144
+ normed = x_f32 * torch.rsqrt(x_f32.pow(2).mean(-1, keepdim=True) + RMS_EPS)
145
+ return (normed * self.weight.float()).to(x.dtype)
146
+
147
+
148
+ # ── RoPE ─────────────────────────────────────────────────────────────────────
149
+
150
+ def build_rope_freqs(
151
+ head_dim: int,
152
+ max_seq: int,
153
+ theta: float,
154
+ device: torch.device,
155
+ n_rot_pairs: int | None = None,
156
+ ) -> tuple[torch.Tensor, torch.Tensor]:
157
+ """
158
+ Build cos/sin tables of shape [max_seq, head_dim].
159
+
160
+ For full-attention layers with partial rotation, only the first
161
+ n_rot_pairs*2 positions carry actual frequencies; the rest are zeros
162
+ (NoPE β€” no positional encoding for those dims).
163
+
164
+ Args:
165
+ head_dim: total head dimension
166
+ max_seq: maximum sequence length to precompute
167
+ theta: RoPE base frequency
168
+ device: target device
169
+ n_rot_pairs: if set, only compute real freqs for this many pairs;
170
+ remaining dims get freq=0 (cos=1, sin=0 β†’ identity).
171
+ """
172
+ half = head_dim // 2
173
+ if n_rot_pairs is None:
174
+ n_rot_pairs = half
175
+
176
+ # Build frequencies only for the pairs that actually rotate
177
+ inv_freq = 1.0 / (theta ** (
178
+ torch.arange(0, n_rot_pairs, device=device).float() / half
179
+ )) # shape [n_rot_pairs]
180
+
181
+ # Pad with zeros for the remaining pairs (NoPE: cos=1, sin=0)
182
+ if n_rot_pairs < half:
183
+ inv_freq = torch.cat([
184
+ inv_freq,
185
+ torch.zeros(half - n_rot_pairs, device=device),
186
+ ]) # [half]
187
+
188
+ t = torch.arange(max_seq, device=device).float()
189
+ freqs = torch.outer(t, inv_freq) # [T, half]
190
+ freqs = torch.cat([freqs, freqs], dim=-1) # [T, head_dim]
191
+ return freqs.cos(), freqs.sin()
192
+
193
+
194
+ def apply_rope(
195
+ x: torch.Tensor,
196
+ cos: torch.Tensor,
197
+ sin: torch.Tensor,
198
+ ) -> torch.Tensor:
199
+ """
200
+ Apply rotary embeddings.
201
+
202
+ Args:
203
+ x: [B, H, T, head_dim]
204
+ cos: [T, head_dim] (broadcastable)
205
+ sin: [T, head_dim]
206
+ """
207
+ half = x.shape[-1] // 2
208
+ x1, x2 = x[..., :half], x[..., half:]
209
+ rotated = torch.cat([-x2, x1], dim=-1)
210
+ T = x.shape[2]
211
+ cos_ = cos[:T].unsqueeze(0).unsqueeze(0).to(x.dtype) # [1,1,T,D]
212
+ sin_ = sin[:T].unsqueeze(0).unsqueeze(0).to(x.dtype)
213
+ return x * cos_ + rotated * sin_
214
+
215
+
216
+ # ── Attention ─────────────────────────────────────────────────────────────────
217
+
218
+ class Attention(nn.Module):
219
+ """
220
+ Multi-query attention (8 Q heads, 1 KV head).
221
+
222
+ Sliding layers: head_dim=256, local window=512.
223
+ Full layers: head_dim=512, causal (no window restriction).
224
+ """
225
+
226
+ def __init__(self, layer_idx: int):
227
+ super().__init__()
228
+ self.layer_idx = layer_idx
229
+ self.full_attn = is_full_attention(layer_idx)
230
+ self.head_dim = HEAD_DIM_FULL if self.full_attn else HEAD_DIM_SLIDE
231
+ hd = self.head_dim
232
+
233
+ self.q_proj = nn.Linear(HIDDEN_SIZE, N_Q_HEADS * hd, bias=False)
234
+ self.k_proj = nn.Linear(HIDDEN_SIZE, N_KV_HEADS * hd, bias=False)
235
+ self.v_proj = nn.Linear(HIDDEN_SIZE, N_KV_HEADS * hd, bias=False)
236
+ self.o_proj = nn.Linear(N_Q_HEADS * hd, HIDDEN_SIZE, bias=False)
237
+
238
+ self.q_norm = RMSNorm(hd)
239
+ self.k_norm = RMSNorm(hd)
240
+
241
+ def forward(
242
+ self,
243
+ x: torch.Tensor, # [B, T, D]
244
+ cos: torch.Tensor, # [T, head_dim]
245
+ sin: torch.Tensor,
246
+ ) -> torch.Tensor:
247
+ B, T, _ = x.shape
248
+ hd = self.head_dim
249
+
250
+ q = self.q_proj(x).view(B, T, N_Q_HEADS, hd).transpose(1, 2) # [B,Hq,T,hd]
251
+ k = self.k_proj(x).view(B, T, N_KV_HEADS, hd).transpose(1, 2) # [B,1,T,hd]
252
+ v = self.v_proj(x).view(B, T, N_KV_HEADS, hd).transpose(1, 2)
253
+
254
+ # Per-head QK normalisation (before RoPE)
255
+ q = self.q_norm(q)
256
+ k = self.k_norm(k)
257
+
258
+ # Rotary position embeddings
259
+ q = apply_rope(q, cos, sin)
260
+ k = apply_rope(k, cos, sin)
261
+
262
+ # Expand KV to match Q heads (MQA)
263
+ k = k.expand(B, N_Q_HEADS, T, hd)
264
+ v = v.expand(B, N_Q_HEADS, T, hd)
265
+
266
+ if self.full_attn:
267
+ # Standard causal attention, no window restriction
268
+ out = F.scaled_dot_product_attention(
269
+ q, k, v,
270
+ is_causal=True,
271
+ scale=ATTN_SCALE,
272
+ )
273
+ else:
274
+ # Sliding window causal attention.
275
+ # attn_mask[i, j] = True means query-position i CAN attend to key-position j.
276
+ # Causal: j <= i (can only attend to past/current positions)
277
+ # Window: i - j < SLIDING_WINDOW
278
+ idx = torch.arange(T, device=x.device)
279
+ # idx.unsqueeze(0) = [1, T] broadcast as j (key) axis
280
+ # idx.unsqueeze(1) = [T, 1] broadcast as i (query) axis
281
+ # mask[i, j] = True iff j <= i AND i - j < SLIDING_WINDOW
282
+ attn_mask = (
283
+ (idx.unsqueeze(0) <= idx.unsqueeze(1)) & # j <= i (causal)
284
+ (idx.unsqueeze(1) - idx.unsqueeze(0) < SLIDING_WINDOW) # i - j < W
285
+ ) # [T_q, T_k]
286
+ out = F.scaled_dot_product_attention(
287
+ q, k, v,
288
+ attn_mask=attn_mask,
289
+ scale=ATTN_SCALE,
290
+ )
291
+
292
+ out = out.transpose(1, 2).contiguous().view(B, T, N_Q_HEADS * hd)
293
+ return self.o_proj(out)
294
+
295
+
296
+ # ── MLP (GeGLU) ───────────────────────────────────────────────────────────────
297
+
298
+ class MLP(nn.Module):
299
+ """
300
+ GeGLU feed-forward network.
301
+
302
+ Layers 0-14: intermediate_size=6144
303
+ Layers 15-34: intermediate_size=12288 (double-wide)
304
+ """
305
+
306
+ def __init__(self, layer_idx: int):
307
+ super().__init__()
308
+ inter = INTERMEDIATE_WIDE if layer_idx >= DOUBLE_WIDE_START else INTERMEDIATE
309
+ self.gate_proj = nn.Linear(HIDDEN_SIZE, inter, bias=False)
310
+ self.up_proj = nn.Linear(HIDDEN_SIZE, inter, bias=False)
311
+ self.down_proj = nn.Linear(inter, HIDDEN_SIZE, bias=False)
312
+
313
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
314
+ gate = F.gelu(self.gate_proj(x), approximate="tanh")
315
+ return self.down_proj(gate * self.up_proj(x))
316
+
317
+
318
+ # ── Decoder layer ─────────────────────────────────────────────────────────────
319
+
320
+ class Gemma4TextLayer(nn.Module):
321
+ """
322
+ Single Gemma 4 decoder layer.
323
+
324
+ Execution order (per forward call):
325
+ 1. Per-layer auxiliary stream injection
326
+ 2. Self-attention block (pre/post norm, residual scaled by layer_scalar)
327
+ 3. MLP block (pre/post norm, residual scaled by layer_scalar)
328
+
329
+ Per-layer auxiliary stream injection:
330
+ Receives per_layer_input [B,T,256] = combined embed+projection for this layer.
331
+ x_normed = input_layernorm(x)
332
+ gate = sigmoid(per_layer_input_gate(x_normed)) # [B,T,256]
333
+ gated = gate * per_layer_input # [B,T,256]
334
+ out_1536 = per_layer_projection(gated) # [B,T,1536]
335
+ x = x + post_per_layer_input_norm(out_1536)
336
+ """
337
+
338
+ def __init__(self, layer_idx: int):
339
+ super().__init__()
340
+ self.layer_idx = layer_idx
341
+
342
+ # Attention
343
+ self.self_attn = Attention(layer_idx)
344
+
345
+ # MLP (double-wide for layers 15+)
346
+ self.mlp = MLP(layer_idx)
347
+
348
+ # Layer norms
349
+ self.input_layernorm = RMSNorm(HIDDEN_SIZE)
350
+ self.post_attention_layernorm = RMSNorm(HIDDEN_SIZE)
351
+ self.pre_feedforward_layernorm = RMSNorm(HIDDEN_SIZE)
352
+ self.post_feedforward_layernorm = RMSNorm(HIDDEN_SIZE)
353
+ self.post_per_layer_input_norm = RMSNorm(HIDDEN_SIZE)
354
+
355
+ # Per-layer auxiliary stream weights:
356
+ # per_layer_input_gate: Linear(1536β†’256), weight=[256, 1536]
357
+ # per_layer_projection: Linear(256β†’1536), weight=[1536, 256]
358
+ self.per_layer_input_gate = nn.Linear(HIDDEN_SIZE, PER_LAYER_DIM, bias=False)
359
+ self.per_layer_projection = nn.Linear(PER_LAYER_DIM, HIDDEN_SIZE, bias=False)
360
+
361
+ # Scalar multiplier for attention and MLP residual contributions
362
+ self.layer_scalar = nn.Parameter(torch.ones(1))
363
+
364
+ def forward(
365
+ self,
366
+ x: torch.Tensor, # [B, T, D]
367
+ cos: torch.Tensor, # RoPE tables for this layer type
368
+ sin: torch.Tensor,
369
+ per_layer_input: torch.Tensor, # [B, T, 256] combined embed+projection for this layer
370
+ ) -> torch.Tensor:
371
+
372
+ scalar = self.layer_scalar.to(x.dtype)
373
+
374
+ # ── 1. Per-layer auxiliary stream injection ──────────────────────────
375
+ # Gate uses the model's hidden activation (gelu_pytorch_tanh), matching
376
+ # the Gemma3n reference implementation.
377
+ # The layer_scalar multiplies all residual contributions (per-layer, attn, MLP).
378
+ x_normed = self.input_layernorm(x)
379
+ gate = F.gelu(self.per_layer_input_gate(x_normed), approximate="tanh") # [B,T,256]
380
+ gated = gate * per_layer_input # [B,T,256]
381
+ out = self.per_layer_projection(gated) # [B,T,1536]
382
+ x = x + scalar * self.post_per_layer_input_norm(out)
383
+
384
+ # ── 2. Self-attention ────────────────────────────────────────────────
385
+ # Apply input_layernorm again after the per-layer injection
386
+ h = self.self_attn(self.input_layernorm(x), cos, sin)
387
+ x = x + scalar * self.post_attention_layernorm(h)
388
+
389
+ # ── 3. MLP ───────────────────────────────────────────────────────────
390
+ h = self.mlp(self.pre_feedforward_layernorm(x))
391
+ x = x + scalar * self.post_feedforward_layernorm(h)
392
+
393
+ return x
394
+
395
+
396
+ # ── Full model ─────────────────────────────────────────────────────────────────
397
+
398
+ class Gemma4ForCausalLM(nn.Module):
399
+ """
400
+ Gemma 4 E2B text model (causal LM head, no vision/audio).
401
+
402
+ Tied embeddings: embed_tokens.weight is shared with lm_head.
403
+ Output logits are softcapped: 30 * tanh(logits / 30).
404
+
405
+ Per-layer auxiliary stream is computed model-level before layer iteration:
406
+ - embed_tokens_per_layer lookup: [B,T,35*256]
407
+ - per_layer_model_projection: Linear(1536β†’35*256)
408
+ - per_layer_projection_norm: RMSNorm(256) per layer-slice
409
+ - combine: per_layer_inputs = (embed_aux + proj_scaled) * (1/sqrt(2))
410
+ """
411
+
412
+ def __init__(self):
413
+ super().__init__()
414
+
415
+ # Token embeddings
416
+ self.embed_tokens = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
417
+ self.embed_tokens_per_layer = nn.Embedding(VOCAB_SIZE, N_LAYERS * PER_LAYER_DIM)
418
+
419
+ # Final norm
420
+ self.norm = RMSNorm(HIDDEN_SIZE)
421
+
422
+ # Transformer layers
423
+ self.layers = nn.ModuleList([Gemma4TextLayer(i) for i in range(N_LAYERS)])
424
+
425
+ # Model-level per-layer projection (hidden β†’ all layer aux dims at once)
426
+ # weight shape: [35*256, 1536] = [8960, 1536]
427
+ self.per_layer_model_projection = nn.Linear(
428
+ HIDDEN_SIZE, N_LAYERS * PER_LAYER_DIM, bias=False
429
+ )
430
+ # Norm applied to per-layer projection slices [256]
431
+ self.per_layer_projection_norm = RMSNorm(PER_LAYER_DIM)
432
+
433
+ # RoPE tables (computed lazily)
434
+ self._rope_slide_cos: torch.Tensor | None = None
435
+ self._rope_slide_sin: torch.Tensor | None = None
436
+ self._rope_full_cos: torch.Tensor | None = None
437
+ self._rope_full_sin: torch.Tensor | None = None
438
+ self._rope_seq: int = 0
439
+
440
+ @staticmethod
441
+ def is_full_attention(layer_idx: int) -> bool:
442
+ return is_full_attention(layer_idx)
443
+
444
+ def _ensure_rope(self, seq_len: int, device: torch.device) -> None:
445
+ """Precompute (or extend) RoPE tables on demand."""
446
+ if self._rope_slide_cos is not None and self._rope_seq >= seq_len:
447
+ return
448
+ max_seq = max(seq_len, 2048)
449
+
450
+ # Sliding layers: head_dim=256, full rotation
451
+ cs, sn = build_rope_freqs(HEAD_DIM_SLIDE, max_seq, ROPE_THETA_SLIDE, device)
452
+ self._rope_slide_cos = cs
453
+ self._rope_slide_sin = sn
454
+
455
+ # Full-attention layers: head_dim=512, partial_rotary_factor=0.25.
456
+ # 512 * 0.25 = 128 dims rotated = 64 rotation pairs (half=256, 64 of 256 pairs).
457
+ n_rot = int(HEAD_DIM_FULL * PARTIAL_ROT_FULL) // 2 # = 64
458
+ cf, sf = build_rope_freqs(
459
+ HEAD_DIM_FULL, max_seq, ROPE_THETA_FULL, device, n_rot_pairs=n_rot
460
+ )
461
+ self._rope_full_cos = cf
462
+ self._rope_full_sin = sf
463
+ self._rope_seq = max_seq
464
+
465
+ def _compute_per_layer_inputs(
466
+ self, input_ids: torch.Tensor, x_embed: torch.Tensor
467
+ ) -> torch.Tensor:
468
+ """
469
+ Precompute per-layer auxiliary inputs for all 35 layers.
470
+
471
+ Returns:
472
+ per_layer_inputs: [B, T, N_LAYERS, PER_LAYER_DIM]
473
+ """
474
+ B, T = input_ids.shape
475
+
476
+ # 1. Token-based per-layer embeddings (vocabulary lookup)
477
+ # Scaled by sqrt(PER_LAYER_DIM)=16, matching Gemma3n's ScaledWordEmbedding convention
478
+ embed_aux = self.embed_tokens_per_layer(input_ids).to(x_embed.dtype)
479
+ embed_aux = embed_aux * math.sqrt(PER_LAYER_DIM) # scale by sqrt(256)=16
480
+ # embed_aux: [B, T, 35*256] reshape β†’ [B, T, 35, 256]
481
+ embed_aux = embed_aux.view(B, T, N_LAYERS, PER_LAYER_DIM)
482
+
483
+ # 2. Hidden-state projection: project x_embed to [B, T, 35*256]
484
+ proj_all = self.per_layer_model_projection(x_embed) # [B, T, 35*256]
485
+ proj_all = proj_all * PER_LAYER_PROJ_SCALE # scale by 1/sqrt(hidden)
486
+ proj_all = proj_all.view(B, T, N_LAYERS, PER_LAYER_DIM)
487
+ # Apply RMSNorm(256) to each layer slice
488
+ proj_all = self.per_layer_projection_norm(proj_all) # broadcast over [B,T,N]
489
+
490
+ # 3. Combine: (embed_aux + proj_normed) * (1/sqrt(2))
491
+ per_layer_inputs = (embed_aux + proj_all) * PER_LAYER_INPUT_SCALE
492
+
493
+ return per_layer_inputs # [B, T, 35, 256]
494
+
495
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
496
+ """
497
+ Args:
498
+ input_ids: [B, T] long tensor
499
+
500
+ Returns:
501
+ logits: [B, T, vocab_size] with softcapping applied
502
+ """
503
+ B, T = input_ids.shape
504
+ self._ensure_rope(T, input_ids.device)
505
+
506
+ # Token embeddings scaled by sqrt(hidden_size)
507
+ x = self.embed_tokens(input_ids) * math.sqrt(HIDDEN_SIZE) # [B,T,D]
508
+
509
+ # Compute per-layer auxiliary inputs (uses unmodified x_embed)
510
+ per_layer_inputs = self._compute_per_layer_inputs(input_ids, x)
511
+
512
+ for i, layer in enumerate(self.layers):
513
+ per_layer_i = per_layer_inputs[:, :, i, :] # [B, T, 256]
514
+
515
+ if is_full_attention(i):
516
+ cos, sin = self._rope_full_cos, self._rope_full_sin
517
+ else:
518
+ cos, sin = self._rope_slide_cos, self._rope_slide_sin
519
+
520
+ x = layer(x, cos, sin, per_layer_i)
521
+
522
+ x = self.norm(x)
523
+
524
+ # Tied lm_head: F.linear(x, embed_tokens.weight)
525
+ logits = F.linear(x, self.embed_tokens.weight.to(x.dtype)) # [B,T,V]
526
+
527
+ # Logit softcapping
528
+ logits = LOGIT_CAP * torch.tanh(logits / LOGIT_CAP)
529
+ return logits
530
+
531
+ @classmethod
532
+ def load_weights(
533
+ cls,
534
+ safetensors_path: str | Path,
535
+ device: str = "cpu",
536
+ ) -> "Gemma4ForCausalLM":
537
+ """
538
+ Load from the safetensors checkpoint.
539
+
540
+ Weight names in the file follow the pattern:
541
+ model.language_model.X β†’ self.X
542
+ """
543
+ model = cls()
544
+ path = str(safetensors_path)
545
+ prefix = "model.language_model."
546
+ state = {}
547
+
548
+ with safe_open(path, framework="pt", device=device) as f:
549
+ for key in f.keys():
550
+ if not key.startswith(prefix):
551
+ continue
552
+ local_key = key[len(prefix):] # strip "model.language_model."
553
+ state[local_key] = f.get_tensor(key)
554
+
555
+ missing, unexpected = model.load_state_dict(state, strict=False)
556
+ if missing:
557
+ print(f"[load_weights] {len(missing)} missing keys (first 5): {missing[:5]}")
558
+ if unexpected:
559
+ print(f"[load_weights] {len(unexpected)} unexpected keys (first 5): {unexpected[:5]}")
560
+
561
+ model = model.to(dtype=DTYPE)
562
+ return model
563
+
564
+
565
+ # ── Convenience loader ─────────────────────────────────────────────────────────
566
+
567
+ def load_gemma4(
568
+ device: str | None = None,
569
+ ) -> tuple[Gemma4ForCausalLM, AutoTokenizer]:
570
+ """
571
+ Load the Gemma 4 E2B model and tokenizer.
572
+
573
+ Returns:
574
+ (model, tokenizer) β€” model is in eval mode on `device`.
575
+ """
576
+ if device is None:
577
+ device = DEVICE
578
+
579
+ print(f"Loading Gemma 4 E2B from {SAFETENSORS_BLOB} ...")
580
+ model = Gemma4ForCausalLM.load_weights(SAFETENSORS_BLOB, device=device)
581
+ model = model.to(device).eval()
582
+
583
+ print(f"Loading tokenizer from {MODEL_DIR} ...")
584
+ tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR), local_files_only=True)
585
+
586
+ return model, tokenizer
587
+
588
+
589
+ # ── PPL evaluation ─────────────────────────────────────────────────────────────
590
+
591
+ def ppl_on_text(
592
+ model: Gemma4ForCausalLM,
593
+ tokenizer: AutoTokenizer,
594
+ text: str,
595
+ device: str | None = None,
596
+ max_length: int = 1024,
597
+ ) -> float:
598
+ """
599
+ Compute token-level perplexity on `text`.
600
+
601
+ Args:
602
+ model: Gemma4ForCausalLM in eval mode
603
+ tokenizer: matching AutoTokenizer
604
+ text: input string
605
+ device: device for inference (defaults to DEVICE)
606
+ max_length: truncate to this many tokens
607
+
608
+ Returns:
609
+ perplexity (float)
610
+ """
611
+ if device is None:
612
+ device = DEVICE
613
+
614
+ enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
615
+ input_ids = enc["input_ids"].to(device)
616
+
617
+ with torch.no_grad():
618
+ logits = model(input_ids) # [1, T, V]
619
+
620
+ # Shift: predict token t+1 from position t
621
+ shift_logits = logits[0, :-1, :] # [T-1, V]
622
+ shift_labels = input_ids[0, 1:] # [T-1]
623
+
624
+ log_probs = F.log_softmax(shift_logits.float(), dim=-1)
625
+ nll = -log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1).mean()
626
+ return nll.exp().item()
627
+
628
+
629
+ # ── main ──────────────────────────────────────────────────────────────────────
630
+
631
+ if __name__ == "__main__":
632
+ _WIKI_TEXT = (
633
+ "The transformer architecture was introduced in the paper "
634
+ "'Attention Is All You Need' by Vaswani et al. in 2017. "
635
+ "It relies entirely on self-attention mechanisms, dispensing with "
636
+ "recurrence and convolutions entirely. Transformers have since become "
637
+ "the dominant architecture for natural language processing, powering "
638
+ "models such as BERT, GPT, T5, and the Gemma family. "
639
+ "The key innovation is the multi-head attention mechanism, which allows "
640
+ "the model to jointly attend to information from different representation "
641
+ "subspaces at different positions. This is complemented by position-wise "
642
+ "feed-forward networks and residual connections with layer normalisation. "
643
+ "Large language models built on this architecture are trained on massive "
644
+ "corpora using next-token prediction (autoregressive language modelling) "
645
+ "or masked language modelling. They exhibit emergent capabilities such as "
646
+ "few-shot and zero-shot generalisation across a wide variety of tasks."
647
+ )
648
+
649
+ model, tokenizer = load_gemma4()
650
+
651
+ ppl = ppl_on_text(model, tokenizer, _WIKI_TEXT)
652
+ print(f"\nPerplexity on sample text: {ppl:.2f} (target: ~17–18 for bfloat16)")