""" Gemma 4 E2B — clean PyTorch forward pass (text model only). Architecture: - 35 decoder layers, hidden_size=1536, vocab=262144 - 8 Q heads, 1 KV head (MQA) - Sliding attention layers (0-3, 5-8, 10-13, 15-18, 20-23, 25-28, 30-33): head_dim=256, sliding_window=512, rope_theta=10000 - Full attention layers (every 5th: 4,9,14,19,24,29,34): head_dim=512, partial_rotary_factor=0.25 (only first 128 of 512 dims rotated), rope_theta=1000000 - MLP (all layers): GeGLU, intermediate_size=6144 - Per-layer auxiliary stream (full details below) - layer_scalar: per-layer learned scalar multiplied onto residual contributions - QK RMSNorm before RoPE, attn_scale=1.0 - Final: RMSNorm + tied lm_head + logit softcapping at 30.0 Per-layer auxiliary stream: Model-level (computed once, before all layers): 1. embed_tokens_per_layer(input_ids) → [B, T, 35*256] (vocab lookup) 2. per_layer_model_projection(x_embed) → [B, T, 35*256] (project hidden→aux) scaled by hidden_size**-0.5 3. per_layer_projection_norm (RMSNorm(256)) on the projection slice per layer 4. Combine: per_layer_inputs = (embed_aux + proj_aux) * (1/sqrt(2)) reshaped to [B, T, 35, 256] Per-layer (at layer i): per_layer_input_i = per_layer_inputs[:, :, i, :] # [B, T, 256] x_normed = input_layernorm(x) gate = sigmoid(per_layer_input_gate(x_normed)) # [B, T, 256] gated = gate * per_layer_input_i # [B, T, 256] out = per_layer_projection(gated) # [B, T, 1536] (256→1536) x = x + post_per_layer_input_norm(out) Weight shapes in checkpoint: per_layer_model_projection.weight : [8960, 1536] (Linear 1536→8960) per_layer_projection_norm.weight : [256] (RMSNorm on 256-dim slices) layers.i.per_layer_input_gate.weight : [256, 1536] (Linear 1536→256) layers.i.per_layer_projection.weight : [1536, 256] (Linear 256→1536) layers.i.post_per_layer_input_norm.weight : [1536] (RMSNorm on 1536-dim output) """ import math import os from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F from safetensors import safe_open from transformers import AutoTokenizer # ── device / dtype ──────────────────────────────────────────────────────────── DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 # ── model path ──────────────────────────────────────────────────────────────── # Try known HF repo caches in order; first one that exists wins. Override with # $GEMMA4_HF_REPO to point at an arbitrary repo cache (e.g., "google/gemma-4-e2b-it"). _HUB_ROOT = Path(os.path.expanduser("~/.cache/huggingface/hub")) _REPO_CANDIDATES = ( os.environ.get("GEMMA4_HF_REPO", ""), "gg-hf-gg/gemma-4-E2B", "google/gemma-4-e2b-it", ) def _resolve_model_paths(): """Return (snapshot_dir, safetensors_path). Picks first available repo+snapshot that actually contains a .safetensors file. Iterates ALL snapshots per repo before moving to the next repo — iterdir() order is not deterministic and HF may keep multiple snapshots where only one has weights blob-resolved. """ for repo in _REPO_CANDIDATES: if not repo: continue repo_cache = _HUB_ROOT / ("models--" + repo.replace("/", "--")) snap_root = repo_cache / "snapshots" if not snap_root.is_dir(): continue for snap in sorted(p for p in snap_root.iterdir() if p.is_dir()): # Prefer model.safetensors (single-file) else any .safetensors sft = snap / "model.safetensors" if not sft.exists(): candidates = sorted(snap.glob("*.safetensors")) if not candidates: continue sft = candidates[0] return snap, sft raise FileNotFoundError( "No Gemma-4 E2B HF cache found. Tried: " + ", ".join(r for r in _REPO_CANDIDATES if r) + ". Run `hf download google/gemma-4-e2b-it` or set GEMMA4_HF_REPO." ) MODEL_DIR, SAFETENSORS_BLOB = _resolve_model_paths() # ── architecture constants ──────────────────────────────────────────────────── N_LAYERS = 35 HIDDEN_SIZE = 1536 VOCAB_SIZE = 262144 N_Q_HEADS = 8 N_KV_HEADS = 1 HEAD_DIM_SLIDE = 256 # sliding attention head dim HEAD_DIM_FULL = 512 # full attention head dim PER_LAYER_DIM = 256 # per-layer auxiliary stream width per layer INTERMEDIATE = 6144 # MLP intermediate size (layers 0-14) INTERMEDIATE_WIDE = 12288 # double-wide MLP intermediate size (layers 15-34) # Layers 15-34 use double-wide MLP (use_double_wide_mlp=True in config) DOUBLE_WIDE_START = 15 SLIDING_WINDOW = 512 ROPE_THETA_SLIDE = 10_000.0 ROPE_THETA_FULL = 1_000_000.0 PARTIAL_ROT_FULL = 0.25 # only first floor(512*0.25)=128 dims get RoPE RMS_EPS = 1e-6 LOGIT_CAP = 30.0 ATTN_SCALE = 1.0 # QK are RMSNorm'd, so no sqrt(d) scaling needed # Per-layer projection scale: hidden_size**-0.5 (applied to per_layer_model_projection output) PER_LAYER_PROJ_SCALE = HIDDEN_SIZE ** -0.5 # Input combination scale: 1/sqrt(2) (mix embed aux + model projection) PER_LAYER_INPUT_SCALE = math.sqrt(0.5) # = 1/sqrt(2) # Full-attention layers: every 5th layer (0-indexed: 4,9,14,19,24,29,34) FULL_ATTN_LAYERS = frozenset(range(4, N_LAYERS, 5)) def is_full_attention(layer_idx: int) -> bool: """Return True if layer_idx uses full (global) attention.""" return layer_idx in FULL_ATTN_LAYERS # ── RMSNorm ─────────────────────────────────────────────────────────────────── class RMSNorm(nn.Module): """RMSNorm with weight * normed, weight initialized to ones.""" def __init__(self, dim: int): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: x_f32 = x.float() normed = x_f32 * torch.rsqrt(x_f32.pow(2).mean(-1, keepdim=True) + RMS_EPS) return (normed * self.weight.float()).to(x.dtype) # ── RoPE ───────────────────────────────────────────────────────────────────── def build_rope_freqs( head_dim: int, max_seq: int, theta: float, device: torch.device, n_rot_pairs: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Build cos/sin tables of shape [max_seq, head_dim]. For full-attention layers with partial rotation, only the first n_rot_pairs*2 positions carry actual frequencies; the rest are zeros (NoPE — no positional encoding for those dims). Args: head_dim: total head dimension max_seq: maximum sequence length to precompute theta: RoPE base frequency device: target device n_rot_pairs: if set, only compute real freqs for this many pairs; remaining dims get freq=0 (cos=1, sin=0 → identity). """ half = head_dim // 2 if n_rot_pairs is None: n_rot_pairs = half # Build frequencies only for the pairs that actually rotate inv_freq = 1.0 / (theta ** ( torch.arange(0, n_rot_pairs, device=device).float() / half )) # shape [n_rot_pairs] # Pad with zeros for the remaining pairs (NoPE: cos=1, sin=0) if n_rot_pairs < half: inv_freq = torch.cat([ inv_freq, torch.zeros(half - n_rot_pairs, device=device), ]) # [half] t = torch.arange(max_seq, device=device).float() freqs = torch.outer(t, inv_freq) # [T, half] freqs = torch.cat([freqs, freqs], dim=-1) # [T, head_dim] return freqs.cos(), freqs.sin() def apply_rope( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: """ Apply rotary embeddings. Args: x: [B, H, T, head_dim] cos: [T, head_dim] (broadcastable) sin: [T, head_dim] """ half = x.shape[-1] // 2 x1, x2 = x[..., :half], x[..., half:] rotated = torch.cat([-x2, x1], dim=-1) T = x.shape[2] cos_ = cos[:T].unsqueeze(0).unsqueeze(0).to(x.dtype) # [1,1,T,D] sin_ = sin[:T].unsqueeze(0).unsqueeze(0).to(x.dtype) return x * cos_ + rotated * sin_ # ── Attention ───────────────────────────────────────────────────────────────── class Attention(nn.Module): """ Multi-query attention (8 Q heads, 1 KV head). Sliding layers: head_dim=256, local window=512. Full layers: head_dim=512, causal (no window restriction). """ def __init__(self, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.full_attn = is_full_attention(layer_idx) self.head_dim = HEAD_DIM_FULL if self.full_attn else HEAD_DIM_SLIDE hd = self.head_dim self.q_proj = nn.Linear(HIDDEN_SIZE, N_Q_HEADS * hd, bias=False) self.k_proj = nn.Linear(HIDDEN_SIZE, N_KV_HEADS * hd, bias=False) self.v_proj = nn.Linear(HIDDEN_SIZE, N_KV_HEADS * hd, bias=False) self.o_proj = nn.Linear(N_Q_HEADS * hd, HIDDEN_SIZE, bias=False) self.q_norm = RMSNorm(hd) self.k_norm = RMSNorm(hd) def forward( self, x: torch.Tensor, # [B, T, D] cos: torch.Tensor, # [T, head_dim] sin: torch.Tensor, ) -> torch.Tensor: B, T, _ = x.shape hd = self.head_dim q = self.q_proj(x).view(B, T, N_Q_HEADS, hd).transpose(1, 2) # [B,Hq,T,hd] k = self.k_proj(x).view(B, T, N_KV_HEADS, hd).transpose(1, 2) # [B,1,T,hd] v = self.v_proj(x).view(B, T, N_KV_HEADS, hd).transpose(1, 2) # Per-head QK normalisation (before RoPE) q = self.q_norm(q) k = self.k_norm(k) # Rotary position embeddings q = apply_rope(q, cos, sin) k = apply_rope(k, cos, sin) # Expand KV to match Q heads (MQA) k = k.expand(B, N_Q_HEADS, T, hd) v = v.expand(B, N_Q_HEADS, T, hd) if self.full_attn: # Standard causal attention, no window restriction out = F.scaled_dot_product_attention( q, k, v, is_causal=True, scale=ATTN_SCALE, ) else: # Sliding window causal attention. # attn_mask[i, j] = True means query-position i CAN attend to key-position j. # Causal: j <= i (can only attend to past/current positions) # Window: i - j < SLIDING_WINDOW idx = torch.arange(T, device=x.device) # idx.unsqueeze(0) = [1, T] broadcast as j (key) axis # idx.unsqueeze(1) = [T, 1] broadcast as i (query) axis # mask[i, j] = True iff j <= i AND i - j < SLIDING_WINDOW attn_mask = ( (idx.unsqueeze(0) <= idx.unsqueeze(1)) & # j <= i (causal) (idx.unsqueeze(1) - idx.unsqueeze(0) < SLIDING_WINDOW) # i - j < W ) # [T_q, T_k] out = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, scale=ATTN_SCALE, ) out = out.transpose(1, 2).contiguous().view(B, T, N_Q_HEADS * hd) return self.o_proj(out) # ── MLP (GeGLU) ─────────────────────────────────────────────────────────────── class MLP(nn.Module): """ GeGLU feed-forward network. Layers 0-14: intermediate_size=6144 Layers 15-34: intermediate_size=12288 (double-wide) """ def __init__(self, layer_idx: int): super().__init__() inter = INTERMEDIATE_WIDE if layer_idx >= DOUBLE_WIDE_START else INTERMEDIATE self.gate_proj = nn.Linear(HIDDEN_SIZE, inter, bias=False) self.up_proj = nn.Linear(HIDDEN_SIZE, inter, bias=False) self.down_proj = nn.Linear(inter, HIDDEN_SIZE, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: gate = F.gelu(self.gate_proj(x), approximate="tanh") return self.down_proj(gate * self.up_proj(x)) # ── Decoder layer ───────────────────────────────────────────────────────────── class Gemma4TextLayer(nn.Module): """ Single Gemma 4 decoder layer. Execution order (per forward call): 1. Per-layer auxiliary stream injection 2. Self-attention block (pre/post norm, residual scaled by layer_scalar) 3. MLP block (pre/post norm, residual scaled by layer_scalar) Per-layer auxiliary stream injection: Receives per_layer_input [B,T,256] = combined embed+projection for this layer. x_normed = input_layernorm(x) gate = sigmoid(per_layer_input_gate(x_normed)) # [B,T,256] gated = gate * per_layer_input # [B,T,256] out_1536 = per_layer_projection(gated) # [B,T,1536] x = x + post_per_layer_input_norm(out_1536) """ def __init__(self, layer_idx: int): super().__init__() self.layer_idx = layer_idx # Attention self.self_attn = Attention(layer_idx) # MLP (double-wide for layers 15+) self.mlp = MLP(layer_idx) # Layer norms self.input_layernorm = RMSNorm(HIDDEN_SIZE) self.post_attention_layernorm = RMSNorm(HIDDEN_SIZE) self.pre_feedforward_layernorm = RMSNorm(HIDDEN_SIZE) self.post_feedforward_layernorm = RMSNorm(HIDDEN_SIZE) self.post_per_layer_input_norm = RMSNorm(HIDDEN_SIZE) # Per-layer auxiliary stream weights: # per_layer_input_gate: Linear(1536→256), weight=[256, 1536] # per_layer_projection: Linear(256→1536), weight=[1536, 256] self.per_layer_input_gate = nn.Linear(HIDDEN_SIZE, PER_LAYER_DIM, bias=False) self.per_layer_projection = nn.Linear(PER_LAYER_DIM, HIDDEN_SIZE, bias=False) # Scalar multiplier for attention and MLP residual contributions self.layer_scalar = nn.Parameter(torch.ones(1)) def forward( self, x: torch.Tensor, # [B, T, D] cos: torch.Tensor, # RoPE tables for this layer type sin: torch.Tensor, per_layer_input: torch.Tensor, # [B, T, 256] combined embed+projection for this layer ) -> torch.Tensor: scalar = self.layer_scalar.to(x.dtype) # ── 1. Per-layer auxiliary stream injection ────────────────────────── # Gate uses the model's hidden activation (gelu_pytorch_tanh), matching # the Gemma3n reference implementation. # The layer_scalar multiplies all residual contributions (per-layer, attn, MLP). x_normed = self.input_layernorm(x) gate = F.gelu(self.per_layer_input_gate(x_normed), approximate="tanh") # [B,T,256] gated = gate * per_layer_input # [B,T,256] out = self.per_layer_projection(gated) # [B,T,1536] x = x + scalar * self.post_per_layer_input_norm(out) # ── 2. Self-attention ──────────────────────────────────────────────── # Apply input_layernorm again after the per-layer injection h = self.self_attn(self.input_layernorm(x), cos, sin) x = x + scalar * self.post_attention_layernorm(h) # ── 3. MLP ─────────────────────────────────────────────────────────── h = self.mlp(self.pre_feedforward_layernorm(x)) x = x + scalar * self.post_feedforward_layernorm(h) return x # ── Full model ───────────────────────────────────────────────────────────────── class Gemma4ForCausalLM(nn.Module): """ Gemma 4 E2B text model (causal LM head, no vision/audio). Tied embeddings: embed_tokens.weight is shared with lm_head. Output logits are softcapped: 30 * tanh(logits / 30). Per-layer auxiliary stream is computed model-level before layer iteration: - embed_tokens_per_layer lookup: [B,T,35*256] - per_layer_model_projection: Linear(1536→35*256) - per_layer_projection_norm: RMSNorm(256) per layer-slice - combine: per_layer_inputs = (embed_aux + proj_scaled) * (1/sqrt(2)) """ def __init__(self): super().__init__() # Token embeddings self.embed_tokens = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) self.embed_tokens_per_layer = nn.Embedding(VOCAB_SIZE, N_LAYERS * PER_LAYER_DIM) # Final norm self.norm = RMSNorm(HIDDEN_SIZE) # Transformer layers self.layers = nn.ModuleList([Gemma4TextLayer(i) for i in range(N_LAYERS)]) # Model-level per-layer projection (hidden → all layer aux dims at once) # weight shape: [35*256, 1536] = [8960, 1536] self.per_layer_model_projection = nn.Linear( HIDDEN_SIZE, N_LAYERS * PER_LAYER_DIM, bias=False ) # Norm applied to per-layer projection slices [256] self.per_layer_projection_norm = RMSNorm(PER_LAYER_DIM) # RoPE tables (computed lazily) self._rope_slide_cos: torch.Tensor | None = None self._rope_slide_sin: torch.Tensor | None = None self._rope_full_cos: torch.Tensor | None = None self._rope_full_sin: torch.Tensor | None = None self._rope_seq: int = 0 @staticmethod def is_full_attention(layer_idx: int) -> bool: return is_full_attention(layer_idx) def _ensure_rope(self, seq_len: int, device: torch.device) -> None: """Precompute (or extend) RoPE tables on demand.""" if self._rope_slide_cos is not None and self._rope_seq >= seq_len: return max_seq = max(seq_len, 2048) # Sliding layers: head_dim=256, full rotation cs, sn = build_rope_freqs(HEAD_DIM_SLIDE, max_seq, ROPE_THETA_SLIDE, device) self._rope_slide_cos = cs self._rope_slide_sin = sn # Full-attention layers: head_dim=512, partial_rotary_factor=0.25. # 512 * 0.25 = 128 dims rotated = 64 rotation pairs (half=256, 64 of 256 pairs). n_rot = int(HEAD_DIM_FULL * PARTIAL_ROT_FULL) // 2 # = 64 cf, sf = build_rope_freqs( HEAD_DIM_FULL, max_seq, ROPE_THETA_FULL, device, n_rot_pairs=n_rot ) self._rope_full_cos = cf self._rope_full_sin = sf self._rope_seq = max_seq def _compute_per_layer_inputs( self, input_ids: torch.Tensor, x_embed: torch.Tensor ) -> torch.Tensor: """ Precompute per-layer auxiliary inputs for all 35 layers. Returns: per_layer_inputs: [B, T, N_LAYERS, PER_LAYER_DIM] """ B, T = input_ids.shape # 1. Token-based per-layer embeddings (vocabulary lookup) # Scaled by sqrt(PER_LAYER_DIM)=16, matching Gemma3n's ScaledWordEmbedding convention embed_aux = self.embed_tokens_per_layer(input_ids).to(x_embed.dtype) embed_aux = embed_aux * math.sqrt(PER_LAYER_DIM) # scale by sqrt(256)=16 # embed_aux: [B, T, 35*256] reshape → [B, T, 35, 256] embed_aux = embed_aux.view(B, T, N_LAYERS, PER_LAYER_DIM) # 2. Hidden-state projection: project x_embed to [B, T, 35*256] proj_all = self.per_layer_model_projection(x_embed) # [B, T, 35*256] proj_all = proj_all * PER_LAYER_PROJ_SCALE # scale by 1/sqrt(hidden) proj_all = proj_all.view(B, T, N_LAYERS, PER_LAYER_DIM) # Apply RMSNorm(256) to each layer slice proj_all = self.per_layer_projection_norm(proj_all) # broadcast over [B,T,N] # 3. Combine: (embed_aux + proj_normed) * (1/sqrt(2)) per_layer_inputs = (embed_aux + proj_all) * PER_LAYER_INPUT_SCALE return per_layer_inputs # [B, T, 35, 256] def forward(self, input_ids: torch.Tensor) -> torch.Tensor: """ Args: input_ids: [B, T] long tensor Returns: logits: [B, T, vocab_size] with softcapping applied """ B, T = input_ids.shape self._ensure_rope(T, input_ids.device) # Token embeddings scaled by sqrt(hidden_size) x = self.embed_tokens(input_ids) * math.sqrt(HIDDEN_SIZE) # [B,T,D] # Compute per-layer auxiliary inputs (uses unmodified x_embed) per_layer_inputs = self._compute_per_layer_inputs(input_ids, x) for i, layer in enumerate(self.layers): per_layer_i = per_layer_inputs[:, :, i, :] # [B, T, 256] if is_full_attention(i): cos, sin = self._rope_full_cos, self._rope_full_sin else: cos, sin = self._rope_slide_cos, self._rope_slide_sin x = layer(x, cos, sin, per_layer_i) x = self.norm(x) # Tied lm_head: F.linear(x, embed_tokens.weight) logits = F.linear(x, self.embed_tokens.weight.to(x.dtype)) # [B,T,V] # Logit softcapping logits = LOGIT_CAP * torch.tanh(logits / LOGIT_CAP) return logits @classmethod def load_weights( cls, safetensors_path: str | Path, device: str = "cpu", ) -> "Gemma4ForCausalLM": """ Load from the safetensors checkpoint. Weight names in the file follow the pattern: model.language_model.X → self.X """ model = cls() path = str(safetensors_path) prefix = "model.language_model." state = {} with safe_open(path, framework="pt", device=device) as f: for key in f.keys(): if not key.startswith(prefix): continue local_key = key[len(prefix):] # strip "model.language_model." state[local_key] = f.get_tensor(key) missing, unexpected = model.load_state_dict(state, strict=False) if missing: print(f"[load_weights] {len(missing)} missing keys (first 5): {missing[:5]}") if unexpected: print(f"[load_weights] {len(unexpected)} unexpected keys (first 5): {unexpected[:5]}") model = model.to(dtype=DTYPE) return model # ── Convenience loader ───────────────────────────────────────────────────────── def load_gemma4( device: str | None = None, ) -> tuple[Gemma4ForCausalLM, AutoTokenizer]: """ Load the Gemma 4 E2B model and tokenizer. Returns: (model, tokenizer) — model is in eval mode on `device`. """ if device is None: device = DEVICE print(f"Loading Gemma 4 E2B from {SAFETENSORS_BLOB} ...") model = Gemma4ForCausalLM.load_weights(SAFETENSORS_BLOB, device=device) model = model.to(device).eval() print(f"Loading tokenizer from {MODEL_DIR} ...") tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR), local_files_only=True) return model, tokenizer # ── PPL evaluation ───────────────────────────────────────────────────────────── def ppl_on_text( model: Gemma4ForCausalLM, tokenizer: AutoTokenizer, text: str, device: str | None = None, max_length: int = 1024, ) -> float: """ Compute token-level perplexity on `text`. Args: model: Gemma4ForCausalLM in eval mode tokenizer: matching AutoTokenizer text: input string device: device for inference (defaults to DEVICE) max_length: truncate to this many tokens Returns: perplexity (float) """ if device is None: device = DEVICE enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length) input_ids = enc["input_ids"].to(device) with torch.no_grad(): logits = model(input_ids) # [1, T, V] # Shift: predict token t+1 from position t shift_logits = logits[0, :-1, :] # [T-1, V] shift_labels = input_ids[0, 1:] # [T-1] log_probs = F.log_softmax(shift_logits.float(), dim=-1) nll = -log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1).mean() return nll.exp().item() # ── main ────────────────────────────────────────────────────────────────────── if __name__ == "__main__": _WIKI_TEXT = ( "The transformer architecture was introduced in the paper " "'Attention Is All You Need' by Vaswani et al. in 2017. " "It relies entirely on self-attention mechanisms, dispensing with " "recurrence and convolutions entirely. Transformers have since become " "the dominant architecture for natural language processing, powering " "models such as BERT, GPT, T5, and the Gemma family. " "The key innovation is the multi-head attention mechanism, which allows " "the model to jointly attend to information from different representation " "subspaces at different positions. This is complemented by position-wise " "feed-forward networks and residual connections with layer normalisation. " "Large language models built on this architecture are trained on massive " "corpora using next-token prediction (autoregressive language modelling) " "or masked language modelling. They exhibit emergent capabilities such as " "few-shot and zero-shot generalisation across a wide variety of tasks." ) model, tokenizer = load_gemma4() ppl = ppl_on_text(model, tokenizer, _WIKI_TEXT) print(f"\nPerplexity on sample text: {ppl:.2f} (target: ~17–18 for bfloat16)")