File size: 27,227 Bytes
ea203cb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 | """
Gemma 4 E2B β clean PyTorch forward pass (text model only).
Architecture:
- 35 decoder layers, hidden_size=1536, vocab=262144
- 8 Q heads, 1 KV head (MQA)
- Sliding attention layers (0-3, 5-8, 10-13, 15-18, 20-23, 25-28, 30-33):
head_dim=256, sliding_window=512, rope_theta=10000
- Full attention layers (every 5th: 4,9,14,19,24,29,34):
head_dim=512, partial_rotary_factor=0.25 (only first 128 of 512 dims rotated),
rope_theta=1000000
- MLP (all layers): GeGLU, intermediate_size=6144
- Per-layer auxiliary stream (full details below)
- layer_scalar: per-layer learned scalar multiplied onto residual contributions
- QK RMSNorm before RoPE, attn_scale=1.0
- Final: RMSNorm + tied lm_head + logit softcapping at 30.0
Per-layer auxiliary stream:
Model-level (computed once, before all layers):
1. embed_tokens_per_layer(input_ids) β [B, T, 35*256] (vocab lookup)
2. per_layer_model_projection(x_embed) β [B, T, 35*256] (project hiddenβaux)
scaled by hidden_size**-0.5
3. per_layer_projection_norm (RMSNorm(256)) on the projection slice per layer
4. Combine: per_layer_inputs = (embed_aux + proj_aux) * (1/sqrt(2))
reshaped to [B, T, 35, 256]
Per-layer (at layer i):
per_layer_input_i = per_layer_inputs[:, :, i, :] # [B, T, 256]
x_normed = input_layernorm(x)
gate = sigmoid(per_layer_input_gate(x_normed)) # [B, T, 256]
gated = gate * per_layer_input_i # [B, T, 256]
out = per_layer_projection(gated) # [B, T, 1536] (256β1536)
x = x + post_per_layer_input_norm(out)
Weight shapes in checkpoint:
per_layer_model_projection.weight : [8960, 1536] (Linear 1536β8960)
per_layer_projection_norm.weight : [256] (RMSNorm on 256-dim slices)
layers.i.per_layer_input_gate.weight : [256, 1536] (Linear 1536β256)
layers.i.per_layer_projection.weight : [1536, 256] (Linear 256β1536)
layers.i.post_per_layer_input_norm.weight : [1536] (RMSNorm on 1536-dim output)
"""
import math
import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors import safe_open
from transformers import AutoTokenizer
# ββ device / dtype ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16
# ββ model path ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Try known HF repo caches in order; first one that exists wins. Override with
# $GEMMA4_HF_REPO to point at an arbitrary repo cache (e.g., "google/gemma-4-e2b-it").
_HUB_ROOT = Path(os.path.expanduser("~/.cache/huggingface/hub"))
_REPO_CANDIDATES = (
os.environ.get("GEMMA4_HF_REPO", ""),
"gg-hf-gg/gemma-4-E2B",
"google/gemma-4-e2b-it",
)
def _resolve_model_paths():
"""Return (snapshot_dir, safetensors_path). Picks first available repo+snapshot
that actually contains a .safetensors file. Iterates ALL snapshots per repo
before moving to the next repo β iterdir() order is not deterministic and HF
may keep multiple snapshots where only one has weights blob-resolved.
"""
for repo in _REPO_CANDIDATES:
if not repo:
continue
repo_cache = _HUB_ROOT / ("models--" + repo.replace("/", "--"))
snap_root = repo_cache / "snapshots"
if not snap_root.is_dir():
continue
for snap in sorted(p for p in snap_root.iterdir() if p.is_dir()):
# Prefer model.safetensors (single-file) else any .safetensors
sft = snap / "model.safetensors"
if not sft.exists():
candidates = sorted(snap.glob("*.safetensors"))
if not candidates:
continue
sft = candidates[0]
return snap, sft
raise FileNotFoundError(
"No Gemma-4 E2B HF cache found. Tried: " + ", ".join(r for r in _REPO_CANDIDATES if r)
+ ". Run `hf download google/gemma-4-e2b-it` or set GEMMA4_HF_REPO."
)
MODEL_DIR, SAFETENSORS_BLOB = _resolve_model_paths()
# ββ architecture constants ββββββββββββββββββββββββββββββββββββββββββββββββββββ
N_LAYERS = 35
HIDDEN_SIZE = 1536
VOCAB_SIZE = 262144
N_Q_HEADS = 8
N_KV_HEADS = 1
HEAD_DIM_SLIDE = 256 # sliding attention head dim
HEAD_DIM_FULL = 512 # full attention head dim
PER_LAYER_DIM = 256 # per-layer auxiliary stream width per layer
INTERMEDIATE = 6144 # MLP intermediate size (layers 0-14)
INTERMEDIATE_WIDE = 12288 # double-wide MLP intermediate size (layers 15-34)
# Layers 15-34 use double-wide MLP (use_double_wide_mlp=True in config)
DOUBLE_WIDE_START = 15
SLIDING_WINDOW = 512
ROPE_THETA_SLIDE = 10_000.0
ROPE_THETA_FULL = 1_000_000.0
PARTIAL_ROT_FULL = 0.25 # only first floor(512*0.25)=128 dims get RoPE
RMS_EPS = 1e-6
LOGIT_CAP = 30.0
ATTN_SCALE = 1.0 # QK are RMSNorm'd, so no sqrt(d) scaling needed
# Per-layer projection scale: hidden_size**-0.5 (applied to per_layer_model_projection output)
PER_LAYER_PROJ_SCALE = HIDDEN_SIZE ** -0.5
# Input combination scale: 1/sqrt(2) (mix embed aux + model projection)
PER_LAYER_INPUT_SCALE = math.sqrt(0.5) # = 1/sqrt(2)
# Full-attention layers: every 5th layer (0-indexed: 4,9,14,19,24,29,34)
FULL_ATTN_LAYERS = frozenset(range(4, N_LAYERS, 5))
def is_full_attention(layer_idx: int) -> bool:
"""Return True if layer_idx uses full (global) attention."""
return layer_idx in FULL_ATTN_LAYERS
# ββ RMSNorm βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class RMSNorm(nn.Module):
"""RMSNorm with weight * normed, weight initialized to ones."""
def __init__(self, dim: int):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_f32 = x.float()
normed = x_f32 * torch.rsqrt(x_f32.pow(2).mean(-1, keepdim=True) + RMS_EPS)
return (normed * self.weight.float()).to(x.dtype)
# ββ RoPE βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def build_rope_freqs(
head_dim: int,
max_seq: int,
theta: float,
device: torch.device,
n_rot_pairs: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Build cos/sin tables of shape [max_seq, head_dim].
For full-attention layers with partial rotation, only the first
n_rot_pairs*2 positions carry actual frequencies; the rest are zeros
(NoPE β no positional encoding for those dims).
Args:
head_dim: total head dimension
max_seq: maximum sequence length to precompute
theta: RoPE base frequency
device: target device
n_rot_pairs: if set, only compute real freqs for this many pairs;
remaining dims get freq=0 (cos=1, sin=0 β identity).
"""
half = head_dim // 2
if n_rot_pairs is None:
n_rot_pairs = half
# Build frequencies only for the pairs that actually rotate
inv_freq = 1.0 / (theta ** (
torch.arange(0, n_rot_pairs, device=device).float() / half
)) # shape [n_rot_pairs]
# Pad with zeros for the remaining pairs (NoPE: cos=1, sin=0)
if n_rot_pairs < half:
inv_freq = torch.cat([
inv_freq,
torch.zeros(half - n_rot_pairs, device=device),
]) # [half]
t = torch.arange(max_seq, device=device).float()
freqs = torch.outer(t, inv_freq) # [T, half]
freqs = torch.cat([freqs, freqs], dim=-1) # [T, head_dim]
return freqs.cos(), freqs.sin()
def apply_rope(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings.
Args:
x: [B, H, T, head_dim]
cos: [T, head_dim] (broadcastable)
sin: [T, head_dim]
"""
half = x.shape[-1] // 2
x1, x2 = x[..., :half], x[..., half:]
rotated = torch.cat([-x2, x1], dim=-1)
T = x.shape[2]
cos_ = cos[:T].unsqueeze(0).unsqueeze(0).to(x.dtype) # [1,1,T,D]
sin_ = sin[:T].unsqueeze(0).unsqueeze(0).to(x.dtype)
return x * cos_ + rotated * sin_
# ββ Attention βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class Attention(nn.Module):
"""
Multi-query attention (8 Q heads, 1 KV head).
Sliding layers: head_dim=256, local window=512.
Full layers: head_dim=512, causal (no window restriction).
"""
def __init__(self, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.full_attn = is_full_attention(layer_idx)
self.head_dim = HEAD_DIM_FULL if self.full_attn else HEAD_DIM_SLIDE
hd = self.head_dim
self.q_proj = nn.Linear(HIDDEN_SIZE, N_Q_HEADS * hd, bias=False)
self.k_proj = nn.Linear(HIDDEN_SIZE, N_KV_HEADS * hd, bias=False)
self.v_proj = nn.Linear(HIDDEN_SIZE, N_KV_HEADS * hd, bias=False)
self.o_proj = nn.Linear(N_Q_HEADS * hd, HIDDEN_SIZE, bias=False)
self.q_norm = RMSNorm(hd)
self.k_norm = RMSNorm(hd)
def forward(
self,
x: torch.Tensor, # [B, T, D]
cos: torch.Tensor, # [T, head_dim]
sin: torch.Tensor,
) -> torch.Tensor:
B, T, _ = x.shape
hd = self.head_dim
q = self.q_proj(x).view(B, T, N_Q_HEADS, hd).transpose(1, 2) # [B,Hq,T,hd]
k = self.k_proj(x).view(B, T, N_KV_HEADS, hd).transpose(1, 2) # [B,1,T,hd]
v = self.v_proj(x).view(B, T, N_KV_HEADS, hd).transpose(1, 2)
# Per-head QK normalisation (before RoPE)
q = self.q_norm(q)
k = self.k_norm(k)
# Rotary position embeddings
q = apply_rope(q, cos, sin)
k = apply_rope(k, cos, sin)
# Expand KV to match Q heads (MQA)
k = k.expand(B, N_Q_HEADS, T, hd)
v = v.expand(B, N_Q_HEADS, T, hd)
if self.full_attn:
# Standard causal attention, no window restriction
out = F.scaled_dot_product_attention(
q, k, v,
is_causal=True,
scale=ATTN_SCALE,
)
else:
# Sliding window causal attention.
# attn_mask[i, j] = True means query-position i CAN attend to key-position j.
# Causal: j <= i (can only attend to past/current positions)
# Window: i - j < SLIDING_WINDOW
idx = torch.arange(T, device=x.device)
# idx.unsqueeze(0) = [1, T] broadcast as j (key) axis
# idx.unsqueeze(1) = [T, 1] broadcast as i (query) axis
# mask[i, j] = True iff j <= i AND i - j < SLIDING_WINDOW
attn_mask = (
(idx.unsqueeze(0) <= idx.unsqueeze(1)) & # j <= i (causal)
(idx.unsqueeze(1) - idx.unsqueeze(0) < SLIDING_WINDOW) # i - j < W
) # [T_q, T_k]
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
scale=ATTN_SCALE,
)
out = out.transpose(1, 2).contiguous().view(B, T, N_Q_HEADS * hd)
return self.o_proj(out)
# ββ MLP (GeGLU) βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class MLP(nn.Module):
"""
GeGLU feed-forward network.
Layers 0-14: intermediate_size=6144
Layers 15-34: intermediate_size=12288 (double-wide)
"""
def __init__(self, layer_idx: int):
super().__init__()
inter = INTERMEDIATE_WIDE if layer_idx >= DOUBLE_WIDE_START else INTERMEDIATE
self.gate_proj = nn.Linear(HIDDEN_SIZE, inter, bias=False)
self.up_proj = nn.Linear(HIDDEN_SIZE, inter, bias=False)
self.down_proj = nn.Linear(inter, HIDDEN_SIZE, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = F.gelu(self.gate_proj(x), approximate="tanh")
return self.down_proj(gate * self.up_proj(x))
# ββ Decoder layer βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class Gemma4TextLayer(nn.Module):
"""
Single Gemma 4 decoder layer.
Execution order (per forward call):
1. Per-layer auxiliary stream injection
2. Self-attention block (pre/post norm, residual scaled by layer_scalar)
3. MLP block (pre/post norm, residual scaled by layer_scalar)
Per-layer auxiliary stream injection:
Receives per_layer_input [B,T,256] = combined embed+projection for this layer.
x_normed = input_layernorm(x)
gate = sigmoid(per_layer_input_gate(x_normed)) # [B,T,256]
gated = gate * per_layer_input # [B,T,256]
out_1536 = per_layer_projection(gated) # [B,T,1536]
x = x + post_per_layer_input_norm(out_1536)
"""
def __init__(self, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
# Attention
self.self_attn = Attention(layer_idx)
# MLP (double-wide for layers 15+)
self.mlp = MLP(layer_idx)
# Layer norms
self.input_layernorm = RMSNorm(HIDDEN_SIZE)
self.post_attention_layernorm = RMSNorm(HIDDEN_SIZE)
self.pre_feedforward_layernorm = RMSNorm(HIDDEN_SIZE)
self.post_feedforward_layernorm = RMSNorm(HIDDEN_SIZE)
self.post_per_layer_input_norm = RMSNorm(HIDDEN_SIZE)
# Per-layer auxiliary stream weights:
# per_layer_input_gate: Linear(1536β256), weight=[256, 1536]
# per_layer_projection: Linear(256β1536), weight=[1536, 256]
self.per_layer_input_gate = nn.Linear(HIDDEN_SIZE, PER_LAYER_DIM, bias=False)
self.per_layer_projection = nn.Linear(PER_LAYER_DIM, HIDDEN_SIZE, bias=False)
# Scalar multiplier for attention and MLP residual contributions
self.layer_scalar = nn.Parameter(torch.ones(1))
def forward(
self,
x: torch.Tensor, # [B, T, D]
cos: torch.Tensor, # RoPE tables for this layer type
sin: torch.Tensor,
per_layer_input: torch.Tensor, # [B, T, 256] combined embed+projection for this layer
) -> torch.Tensor:
scalar = self.layer_scalar.to(x.dtype)
# ββ 1. Per-layer auxiliary stream injection ββββββββββββββββββββββββββ
# Gate uses the model's hidden activation (gelu_pytorch_tanh), matching
# the Gemma3n reference implementation.
# The layer_scalar multiplies all residual contributions (per-layer, attn, MLP).
x_normed = self.input_layernorm(x)
gate = F.gelu(self.per_layer_input_gate(x_normed), approximate="tanh") # [B,T,256]
gated = gate * per_layer_input # [B,T,256]
out = self.per_layer_projection(gated) # [B,T,1536]
x = x + scalar * self.post_per_layer_input_norm(out)
# ββ 2. Self-attention ββββββββββββββββββββββββββββββββββββββββββββββββ
# Apply input_layernorm again after the per-layer injection
h = self.self_attn(self.input_layernorm(x), cos, sin)
x = x + scalar * self.post_attention_layernorm(h)
# ββ 3. MLP βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
h = self.mlp(self.pre_feedforward_layernorm(x))
x = x + scalar * self.post_feedforward_layernorm(h)
return x
# ββ Full model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class Gemma4ForCausalLM(nn.Module):
"""
Gemma 4 E2B text model (causal LM head, no vision/audio).
Tied embeddings: embed_tokens.weight is shared with lm_head.
Output logits are softcapped: 30 * tanh(logits / 30).
Per-layer auxiliary stream is computed model-level before layer iteration:
- embed_tokens_per_layer lookup: [B,T,35*256]
- per_layer_model_projection: Linear(1536β35*256)
- per_layer_projection_norm: RMSNorm(256) per layer-slice
- combine: per_layer_inputs = (embed_aux + proj_scaled) * (1/sqrt(2))
"""
def __init__(self):
super().__init__()
# Token embeddings
self.embed_tokens = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
self.embed_tokens_per_layer = nn.Embedding(VOCAB_SIZE, N_LAYERS * PER_LAYER_DIM)
# Final norm
self.norm = RMSNorm(HIDDEN_SIZE)
# Transformer layers
self.layers = nn.ModuleList([Gemma4TextLayer(i) for i in range(N_LAYERS)])
# Model-level per-layer projection (hidden β all layer aux dims at once)
# weight shape: [35*256, 1536] = [8960, 1536]
self.per_layer_model_projection = nn.Linear(
HIDDEN_SIZE, N_LAYERS * PER_LAYER_DIM, bias=False
)
# Norm applied to per-layer projection slices [256]
self.per_layer_projection_norm = RMSNorm(PER_LAYER_DIM)
# RoPE tables (computed lazily)
self._rope_slide_cos: torch.Tensor | None = None
self._rope_slide_sin: torch.Tensor | None = None
self._rope_full_cos: torch.Tensor | None = None
self._rope_full_sin: torch.Tensor | None = None
self._rope_seq: int = 0
@staticmethod
def is_full_attention(layer_idx: int) -> bool:
return is_full_attention(layer_idx)
def _ensure_rope(self, seq_len: int, device: torch.device) -> None:
"""Precompute (or extend) RoPE tables on demand."""
if self._rope_slide_cos is not None and self._rope_seq >= seq_len:
return
max_seq = max(seq_len, 2048)
# Sliding layers: head_dim=256, full rotation
cs, sn = build_rope_freqs(HEAD_DIM_SLIDE, max_seq, ROPE_THETA_SLIDE, device)
self._rope_slide_cos = cs
self._rope_slide_sin = sn
# Full-attention layers: head_dim=512, partial_rotary_factor=0.25.
# 512 * 0.25 = 128 dims rotated = 64 rotation pairs (half=256, 64 of 256 pairs).
n_rot = int(HEAD_DIM_FULL * PARTIAL_ROT_FULL) // 2 # = 64
cf, sf = build_rope_freqs(
HEAD_DIM_FULL, max_seq, ROPE_THETA_FULL, device, n_rot_pairs=n_rot
)
self._rope_full_cos = cf
self._rope_full_sin = sf
self._rope_seq = max_seq
def _compute_per_layer_inputs(
self, input_ids: torch.Tensor, x_embed: torch.Tensor
) -> torch.Tensor:
"""
Precompute per-layer auxiliary inputs for all 35 layers.
Returns:
per_layer_inputs: [B, T, N_LAYERS, PER_LAYER_DIM]
"""
B, T = input_ids.shape
# 1. Token-based per-layer embeddings (vocabulary lookup)
# Scaled by sqrt(PER_LAYER_DIM)=16, matching Gemma3n's ScaledWordEmbedding convention
embed_aux = self.embed_tokens_per_layer(input_ids).to(x_embed.dtype)
embed_aux = embed_aux * math.sqrt(PER_LAYER_DIM) # scale by sqrt(256)=16
# embed_aux: [B, T, 35*256] reshape β [B, T, 35, 256]
embed_aux = embed_aux.view(B, T, N_LAYERS, PER_LAYER_DIM)
# 2. Hidden-state projection: project x_embed to [B, T, 35*256]
proj_all = self.per_layer_model_projection(x_embed) # [B, T, 35*256]
proj_all = proj_all * PER_LAYER_PROJ_SCALE # scale by 1/sqrt(hidden)
proj_all = proj_all.view(B, T, N_LAYERS, PER_LAYER_DIM)
# Apply RMSNorm(256) to each layer slice
proj_all = self.per_layer_projection_norm(proj_all) # broadcast over [B,T,N]
# 3. Combine: (embed_aux + proj_normed) * (1/sqrt(2))
per_layer_inputs = (embed_aux + proj_all) * PER_LAYER_INPUT_SCALE
return per_layer_inputs # [B, T, 35, 256]
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Args:
input_ids: [B, T] long tensor
Returns:
logits: [B, T, vocab_size] with softcapping applied
"""
B, T = input_ids.shape
self._ensure_rope(T, input_ids.device)
# Token embeddings scaled by sqrt(hidden_size)
x = self.embed_tokens(input_ids) * math.sqrt(HIDDEN_SIZE) # [B,T,D]
# Compute per-layer auxiliary inputs (uses unmodified x_embed)
per_layer_inputs = self._compute_per_layer_inputs(input_ids, x)
for i, layer in enumerate(self.layers):
per_layer_i = per_layer_inputs[:, :, i, :] # [B, T, 256]
if is_full_attention(i):
cos, sin = self._rope_full_cos, self._rope_full_sin
else:
cos, sin = self._rope_slide_cos, self._rope_slide_sin
x = layer(x, cos, sin, per_layer_i)
x = self.norm(x)
# Tied lm_head: F.linear(x, embed_tokens.weight)
logits = F.linear(x, self.embed_tokens.weight.to(x.dtype)) # [B,T,V]
# Logit softcapping
logits = LOGIT_CAP * torch.tanh(logits / LOGIT_CAP)
return logits
@classmethod
def load_weights(
cls,
safetensors_path: str | Path,
device: str = "cpu",
) -> "Gemma4ForCausalLM":
"""
Load from the safetensors checkpoint.
Weight names in the file follow the pattern:
model.language_model.X β self.X
"""
model = cls()
path = str(safetensors_path)
prefix = "model.language_model."
state = {}
with safe_open(path, framework="pt", device=device) as f:
for key in f.keys():
if not key.startswith(prefix):
continue
local_key = key[len(prefix):] # strip "model.language_model."
state[local_key] = f.get_tensor(key)
missing, unexpected = model.load_state_dict(state, strict=False)
if missing:
print(f"[load_weights] {len(missing)} missing keys (first 5): {missing[:5]}")
if unexpected:
print(f"[load_weights] {len(unexpected)} unexpected keys (first 5): {unexpected[:5]}")
model = model.to(dtype=DTYPE)
return model
# ββ Convenience loader βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def load_gemma4(
device: str | None = None,
) -> tuple[Gemma4ForCausalLM, AutoTokenizer]:
"""
Load the Gemma 4 E2B model and tokenizer.
Returns:
(model, tokenizer) β model is in eval mode on `device`.
"""
if device is None:
device = DEVICE
print(f"Loading Gemma 4 E2B from {SAFETENSORS_BLOB} ...")
model = Gemma4ForCausalLM.load_weights(SAFETENSORS_BLOB, device=device)
model = model.to(device).eval()
print(f"Loading tokenizer from {MODEL_DIR} ...")
tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR), local_files_only=True)
return model, tokenizer
# ββ PPL evaluation βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def ppl_on_text(
model: Gemma4ForCausalLM,
tokenizer: AutoTokenizer,
text: str,
device: str | None = None,
max_length: int = 1024,
) -> float:
"""
Compute token-level perplexity on `text`.
Args:
model: Gemma4ForCausalLM in eval mode
tokenizer: matching AutoTokenizer
text: input string
device: device for inference (defaults to DEVICE)
max_length: truncate to this many tokens
Returns:
perplexity (float)
"""
if device is None:
device = DEVICE
enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
input_ids = enc["input_ids"].to(device)
with torch.no_grad():
logits = model(input_ids) # [1, T, V]
# Shift: predict token t+1 from position t
shift_logits = logits[0, :-1, :] # [T-1, V]
shift_labels = input_ids[0, 1:] # [T-1]
log_probs = F.log_softmax(shift_logits.float(), dim=-1)
nll = -log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1).mean()
return nll.exp().item()
# ββ main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
_WIKI_TEXT = (
"The transformer architecture was introduced in the paper "
"'Attention Is All You Need' by Vaswani et al. in 2017. "
"It relies entirely on self-attention mechanisms, dispensing with "
"recurrence and convolutions entirely. Transformers have since become "
"the dominant architecture for natural language processing, powering "
"models such as BERT, GPT, T5, and the Gemma family. "
"The key innovation is the multi-head attention mechanism, which allows "
"the model to jointly attend to information from different representation "
"subspaces at different positions. This is complemented by position-wise "
"feed-forward networks and residual connections with layer normalisation. "
"Large language models built on this architecture are trained on massive "
"corpora using next-token prediction (autoregressive language modelling) "
"or masked language modelling. They exhibit emergent capabilities such as "
"few-shot and zero-shot generalisation across a wide variety of tasks."
)
model, tokenizer = load_gemma4()
ppl = ppl_on_text(model, tokenizer, _WIKI_TEXT)
print(f"\nPerplexity on sample text: {ppl:.2f} (target: ~17β18 for bfloat16)")
|