ncylich's picture
Upload gemma4.py with huggingface_hub
ea203cb verified
"""
Gemma 4 E2B β€” clean PyTorch forward pass (text model only).
Architecture:
- 35 decoder layers, hidden_size=1536, vocab=262144
- 8 Q heads, 1 KV head (MQA)
- Sliding attention layers (0-3, 5-8, 10-13, 15-18, 20-23, 25-28, 30-33):
head_dim=256, sliding_window=512, rope_theta=10000
- Full attention layers (every 5th: 4,9,14,19,24,29,34):
head_dim=512, partial_rotary_factor=0.25 (only first 128 of 512 dims rotated),
rope_theta=1000000
- MLP (all layers): GeGLU, intermediate_size=6144
- Per-layer auxiliary stream (full details below)
- layer_scalar: per-layer learned scalar multiplied onto residual contributions
- QK RMSNorm before RoPE, attn_scale=1.0
- Final: RMSNorm + tied lm_head + logit softcapping at 30.0
Per-layer auxiliary stream:
Model-level (computed once, before all layers):
1. embed_tokens_per_layer(input_ids) β†’ [B, T, 35*256] (vocab lookup)
2. per_layer_model_projection(x_embed) → [B, T, 35*256] (project hidden→aux)
scaled by hidden_size**-0.5
3. per_layer_projection_norm (RMSNorm(256)) on the projection slice per layer
4. Combine: per_layer_inputs = (embed_aux + proj_aux) * (1/sqrt(2))
reshaped to [B, T, 35, 256]
Per-layer (at layer i):
per_layer_input_i = per_layer_inputs[:, :, i, :] # [B, T, 256]
x_normed = input_layernorm(x)
gate = sigmoid(per_layer_input_gate(x_normed)) # [B, T, 256]
gated = gate * per_layer_input_i # [B, T, 256]
out = per_layer_projection(gated) # [B, T, 1536] (256β†’1536)
x = x + post_per_layer_input_norm(out)
Weight shapes in checkpoint:
per_layer_model_projection.weight : [8960, 1536] (Linear 1536β†’8960)
per_layer_projection_norm.weight : [256] (RMSNorm on 256-dim slices)
layers.i.per_layer_input_gate.weight : [256, 1536] (Linear 1536β†’256)
layers.i.per_layer_projection.weight : [1536, 256] (Linear 256β†’1536)
layers.i.post_per_layer_input_norm.weight : [1536] (RMSNorm on 1536-dim output)
"""
import math
import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors import safe_open
from transformers import AutoTokenizer
# ── device / dtype ────────────────────────────────────────────────────────────
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16
# ── model path ────────────────────────────────────────────────────────────────
# Try known HF repo caches in order; first one that exists wins. Override with
# $GEMMA4_HF_REPO to point at an arbitrary repo cache (e.g., "google/gemma-4-e2b-it").
_HUB_ROOT = Path(os.path.expanduser("~/.cache/huggingface/hub"))
_REPO_CANDIDATES = (
os.environ.get("GEMMA4_HF_REPO", ""),
"gg-hf-gg/gemma-4-E2B",
"google/gemma-4-e2b-it",
)
def _resolve_model_paths():
"""Return (snapshot_dir, safetensors_path). Picks first available repo+snapshot
that actually contains a .safetensors file. Iterates ALL snapshots per repo
before moving to the next repo β€” iterdir() order is not deterministic and HF
may keep multiple snapshots where only one has weights blob-resolved.
"""
for repo in _REPO_CANDIDATES:
if not repo:
continue
repo_cache = _HUB_ROOT / ("models--" + repo.replace("/", "--"))
snap_root = repo_cache / "snapshots"
if not snap_root.is_dir():
continue
for snap in sorted(p for p in snap_root.iterdir() if p.is_dir()):
# Prefer model.safetensors (single-file) else any .safetensors
sft = snap / "model.safetensors"
if not sft.exists():
candidates = sorted(snap.glob("*.safetensors"))
if not candidates:
continue
sft = candidates[0]
return snap, sft
raise FileNotFoundError(
"No Gemma-4 E2B HF cache found. Tried: " + ", ".join(r for r in _REPO_CANDIDATES if r)
+ ". Run `hf download google/gemma-4-e2b-it` or set GEMMA4_HF_REPO."
)
MODEL_DIR, SAFETENSORS_BLOB = _resolve_model_paths()
# ── architecture constants ────────────────────────────────────────────────────
N_LAYERS = 35
HIDDEN_SIZE = 1536
VOCAB_SIZE = 262144
N_Q_HEADS = 8
N_KV_HEADS = 1
HEAD_DIM_SLIDE = 256 # sliding attention head dim
HEAD_DIM_FULL = 512 # full attention head dim
PER_LAYER_DIM = 256 # per-layer auxiliary stream width per layer
INTERMEDIATE = 6144 # MLP intermediate size (layers 0-14)
INTERMEDIATE_WIDE = 12288 # double-wide MLP intermediate size (layers 15-34)
# Layers 15-34 use double-wide MLP (use_double_wide_mlp=True in config)
DOUBLE_WIDE_START = 15
SLIDING_WINDOW = 512
ROPE_THETA_SLIDE = 10_000.0
ROPE_THETA_FULL = 1_000_000.0
PARTIAL_ROT_FULL = 0.25 # only first floor(512*0.25)=128 dims get RoPE
RMS_EPS = 1e-6
LOGIT_CAP = 30.0
ATTN_SCALE = 1.0 # QK are RMSNorm'd, so no sqrt(d) scaling needed
# Per-layer projection scale: hidden_size**-0.5 (applied to per_layer_model_projection output)
PER_LAYER_PROJ_SCALE = HIDDEN_SIZE ** -0.5
# Input combination scale: 1/sqrt(2) (mix embed aux + model projection)
PER_LAYER_INPUT_SCALE = math.sqrt(0.5) # = 1/sqrt(2)
# Full-attention layers: every 5th layer (0-indexed: 4,9,14,19,24,29,34)
FULL_ATTN_LAYERS = frozenset(range(4, N_LAYERS, 5))
def is_full_attention(layer_idx: int) -> bool:
"""Return True if layer_idx uses full (global) attention."""
return layer_idx in FULL_ATTN_LAYERS
# ── RMSNorm ───────────────────────────────────────────────────────────────────
class RMSNorm(nn.Module):
"""RMSNorm with weight * normed, weight initialized to ones."""
def __init__(self, dim: int):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_f32 = x.float()
normed = x_f32 * torch.rsqrt(x_f32.pow(2).mean(-1, keepdim=True) + RMS_EPS)
return (normed * self.weight.float()).to(x.dtype)
# ── RoPE ─────────────────────────────────────────────────────────────────────
def build_rope_freqs(
head_dim: int,
max_seq: int,
theta: float,
device: torch.device,
n_rot_pairs: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Build cos/sin tables of shape [max_seq, head_dim].
For full-attention layers with partial rotation, only the first
n_rot_pairs*2 positions carry actual frequencies; the rest are zeros
(NoPE β€” no positional encoding for those dims).
Args:
head_dim: total head dimension
max_seq: maximum sequence length to precompute
theta: RoPE base frequency
device: target device
n_rot_pairs: if set, only compute real freqs for this many pairs;
remaining dims get freq=0 (cos=1, sin=0 β†’ identity).
"""
half = head_dim // 2
if n_rot_pairs is None:
n_rot_pairs = half
# Build frequencies only for the pairs that actually rotate
inv_freq = 1.0 / (theta ** (
torch.arange(0, n_rot_pairs, device=device).float() / half
)) # shape [n_rot_pairs]
# Pad with zeros for the remaining pairs (NoPE: cos=1, sin=0)
if n_rot_pairs < half:
inv_freq = torch.cat([
inv_freq,
torch.zeros(half - n_rot_pairs, device=device),
]) # [half]
t = torch.arange(max_seq, device=device).float()
freqs = torch.outer(t, inv_freq) # [T, half]
freqs = torch.cat([freqs, freqs], dim=-1) # [T, head_dim]
return freqs.cos(), freqs.sin()
def apply_rope(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings.
Args:
x: [B, H, T, head_dim]
cos: [T, head_dim] (broadcastable)
sin: [T, head_dim]
"""
half = x.shape[-1] // 2
x1, x2 = x[..., :half], x[..., half:]
rotated = torch.cat([-x2, x1], dim=-1)
T = x.shape[2]
cos_ = cos[:T].unsqueeze(0).unsqueeze(0).to(x.dtype) # [1,1,T,D]
sin_ = sin[:T].unsqueeze(0).unsqueeze(0).to(x.dtype)
return x * cos_ + rotated * sin_
# ── Attention ─────────────────────────────────────────────────────────────────
class Attention(nn.Module):
"""
Multi-query attention (8 Q heads, 1 KV head).
Sliding layers: head_dim=256, local window=512.
Full layers: head_dim=512, causal (no window restriction).
"""
def __init__(self, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.full_attn = is_full_attention(layer_idx)
self.head_dim = HEAD_DIM_FULL if self.full_attn else HEAD_DIM_SLIDE
hd = self.head_dim
self.q_proj = nn.Linear(HIDDEN_SIZE, N_Q_HEADS * hd, bias=False)
self.k_proj = nn.Linear(HIDDEN_SIZE, N_KV_HEADS * hd, bias=False)
self.v_proj = nn.Linear(HIDDEN_SIZE, N_KV_HEADS * hd, bias=False)
self.o_proj = nn.Linear(N_Q_HEADS * hd, HIDDEN_SIZE, bias=False)
self.q_norm = RMSNorm(hd)
self.k_norm = RMSNorm(hd)
def forward(
self,
x: torch.Tensor, # [B, T, D]
cos: torch.Tensor, # [T, head_dim]
sin: torch.Tensor,
) -> torch.Tensor:
B, T, _ = x.shape
hd = self.head_dim
q = self.q_proj(x).view(B, T, N_Q_HEADS, hd).transpose(1, 2) # [B,Hq,T,hd]
k = self.k_proj(x).view(B, T, N_KV_HEADS, hd).transpose(1, 2) # [B,1,T,hd]
v = self.v_proj(x).view(B, T, N_KV_HEADS, hd).transpose(1, 2)
# Per-head QK normalisation (before RoPE)
q = self.q_norm(q)
k = self.k_norm(k)
# Rotary position embeddings
q = apply_rope(q, cos, sin)
k = apply_rope(k, cos, sin)
# Expand KV to match Q heads (MQA)
k = k.expand(B, N_Q_HEADS, T, hd)
v = v.expand(B, N_Q_HEADS, T, hd)
if self.full_attn:
# Standard causal attention, no window restriction
out = F.scaled_dot_product_attention(
q, k, v,
is_causal=True,
scale=ATTN_SCALE,
)
else:
# Sliding window causal attention.
# attn_mask[i, j] = True means query-position i CAN attend to key-position j.
# Causal: j <= i (can only attend to past/current positions)
# Window: i - j < SLIDING_WINDOW
idx = torch.arange(T, device=x.device)
# idx.unsqueeze(0) = [1, T] broadcast as j (key) axis
# idx.unsqueeze(1) = [T, 1] broadcast as i (query) axis
# mask[i, j] = True iff j <= i AND i - j < SLIDING_WINDOW
attn_mask = (
(idx.unsqueeze(0) <= idx.unsqueeze(1)) & # j <= i (causal)
(idx.unsqueeze(1) - idx.unsqueeze(0) < SLIDING_WINDOW) # i - j < W
) # [T_q, T_k]
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
scale=ATTN_SCALE,
)
out = out.transpose(1, 2).contiguous().view(B, T, N_Q_HEADS * hd)
return self.o_proj(out)
# ── MLP (GeGLU) ───────────────────────────────────────────────────────────────
class MLP(nn.Module):
"""
GeGLU feed-forward network.
Layers 0-14: intermediate_size=6144
Layers 15-34: intermediate_size=12288 (double-wide)
"""
def __init__(self, layer_idx: int):
super().__init__()
inter = INTERMEDIATE_WIDE if layer_idx >= DOUBLE_WIDE_START else INTERMEDIATE
self.gate_proj = nn.Linear(HIDDEN_SIZE, inter, bias=False)
self.up_proj = nn.Linear(HIDDEN_SIZE, inter, bias=False)
self.down_proj = nn.Linear(inter, HIDDEN_SIZE, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = F.gelu(self.gate_proj(x), approximate="tanh")
return self.down_proj(gate * self.up_proj(x))
# ── Decoder layer ─────────────────────────────────────────────────────────────
class Gemma4TextLayer(nn.Module):
"""
Single Gemma 4 decoder layer.
Execution order (per forward call):
1. Per-layer auxiliary stream injection
2. Self-attention block (pre/post norm, residual scaled by layer_scalar)
3. MLP block (pre/post norm, residual scaled by layer_scalar)
Per-layer auxiliary stream injection:
Receives per_layer_input [B,T,256] = combined embed+projection for this layer.
x_normed = input_layernorm(x)
gate = sigmoid(per_layer_input_gate(x_normed)) # [B,T,256]
gated = gate * per_layer_input # [B,T,256]
out_1536 = per_layer_projection(gated) # [B,T,1536]
x = x + post_per_layer_input_norm(out_1536)
"""
def __init__(self, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
# Attention
self.self_attn = Attention(layer_idx)
# MLP (double-wide for layers 15+)
self.mlp = MLP(layer_idx)
# Layer norms
self.input_layernorm = RMSNorm(HIDDEN_SIZE)
self.post_attention_layernorm = RMSNorm(HIDDEN_SIZE)
self.pre_feedforward_layernorm = RMSNorm(HIDDEN_SIZE)
self.post_feedforward_layernorm = RMSNorm(HIDDEN_SIZE)
self.post_per_layer_input_norm = RMSNorm(HIDDEN_SIZE)
# Per-layer auxiliary stream weights:
# per_layer_input_gate: Linear(1536β†’256), weight=[256, 1536]
# per_layer_projection: Linear(256β†’1536), weight=[1536, 256]
self.per_layer_input_gate = nn.Linear(HIDDEN_SIZE, PER_LAYER_DIM, bias=False)
self.per_layer_projection = nn.Linear(PER_LAYER_DIM, HIDDEN_SIZE, bias=False)
# Scalar multiplier for attention and MLP residual contributions
self.layer_scalar = nn.Parameter(torch.ones(1))
def forward(
self,
x: torch.Tensor, # [B, T, D]
cos: torch.Tensor, # RoPE tables for this layer type
sin: torch.Tensor,
per_layer_input: torch.Tensor, # [B, T, 256] combined embed+projection for this layer
) -> torch.Tensor:
scalar = self.layer_scalar.to(x.dtype)
# ── 1. Per-layer auxiliary stream injection ──────────────────────────
# Gate uses the model's hidden activation (gelu_pytorch_tanh), matching
# the Gemma3n reference implementation.
# The layer_scalar multiplies all residual contributions (per-layer, attn, MLP).
x_normed = self.input_layernorm(x)
gate = F.gelu(self.per_layer_input_gate(x_normed), approximate="tanh") # [B,T,256]
gated = gate * per_layer_input # [B,T,256]
out = self.per_layer_projection(gated) # [B,T,1536]
x = x + scalar * self.post_per_layer_input_norm(out)
# ── 2. Self-attention ────────────────────────────────────────────────
# Apply input_layernorm again after the per-layer injection
h = self.self_attn(self.input_layernorm(x), cos, sin)
x = x + scalar * self.post_attention_layernorm(h)
# ── 3. MLP ───────────────────────────────────────────────────────────
h = self.mlp(self.pre_feedforward_layernorm(x))
x = x + scalar * self.post_feedforward_layernorm(h)
return x
# ── Full model ─────────────────────────────────────────────────────────────────
class Gemma4ForCausalLM(nn.Module):
"""
Gemma 4 E2B text model (causal LM head, no vision/audio).
Tied embeddings: embed_tokens.weight is shared with lm_head.
Output logits are softcapped: 30 * tanh(logits / 30).
Per-layer auxiliary stream is computed model-level before layer iteration:
- embed_tokens_per_layer lookup: [B,T,35*256]
- per_layer_model_projection: Linear(1536β†’35*256)
- per_layer_projection_norm: RMSNorm(256) per layer-slice
- combine: per_layer_inputs = (embed_aux + proj_scaled) * (1/sqrt(2))
"""
def __init__(self):
super().__init__()
# Token embeddings
self.embed_tokens = nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
self.embed_tokens_per_layer = nn.Embedding(VOCAB_SIZE, N_LAYERS * PER_LAYER_DIM)
# Final norm
self.norm = RMSNorm(HIDDEN_SIZE)
# Transformer layers
self.layers = nn.ModuleList([Gemma4TextLayer(i) for i in range(N_LAYERS)])
# Model-level per-layer projection (hidden β†’ all layer aux dims at once)
# weight shape: [35*256, 1536] = [8960, 1536]
self.per_layer_model_projection = nn.Linear(
HIDDEN_SIZE, N_LAYERS * PER_LAYER_DIM, bias=False
)
# Norm applied to per-layer projection slices [256]
self.per_layer_projection_norm = RMSNorm(PER_LAYER_DIM)
# RoPE tables (computed lazily)
self._rope_slide_cos: torch.Tensor | None = None
self._rope_slide_sin: torch.Tensor | None = None
self._rope_full_cos: torch.Tensor | None = None
self._rope_full_sin: torch.Tensor | None = None
self._rope_seq: int = 0
@staticmethod
def is_full_attention(layer_idx: int) -> bool:
return is_full_attention(layer_idx)
def _ensure_rope(self, seq_len: int, device: torch.device) -> None:
"""Precompute (or extend) RoPE tables on demand."""
if self._rope_slide_cos is not None and self._rope_seq >= seq_len:
return
max_seq = max(seq_len, 2048)
# Sliding layers: head_dim=256, full rotation
cs, sn = build_rope_freqs(HEAD_DIM_SLIDE, max_seq, ROPE_THETA_SLIDE, device)
self._rope_slide_cos = cs
self._rope_slide_sin = sn
# Full-attention layers: head_dim=512, partial_rotary_factor=0.25.
# 512 * 0.25 = 128 dims rotated = 64 rotation pairs (half=256, 64 of 256 pairs).
n_rot = int(HEAD_DIM_FULL * PARTIAL_ROT_FULL) // 2 # = 64
cf, sf = build_rope_freqs(
HEAD_DIM_FULL, max_seq, ROPE_THETA_FULL, device, n_rot_pairs=n_rot
)
self._rope_full_cos = cf
self._rope_full_sin = sf
self._rope_seq = max_seq
def _compute_per_layer_inputs(
self, input_ids: torch.Tensor, x_embed: torch.Tensor
) -> torch.Tensor:
"""
Precompute per-layer auxiliary inputs for all 35 layers.
Returns:
per_layer_inputs: [B, T, N_LAYERS, PER_LAYER_DIM]
"""
B, T = input_ids.shape
# 1. Token-based per-layer embeddings (vocabulary lookup)
# Scaled by sqrt(PER_LAYER_DIM)=16, matching Gemma3n's ScaledWordEmbedding convention
embed_aux = self.embed_tokens_per_layer(input_ids).to(x_embed.dtype)
embed_aux = embed_aux * math.sqrt(PER_LAYER_DIM) # scale by sqrt(256)=16
# embed_aux: [B, T, 35*256] reshape β†’ [B, T, 35, 256]
embed_aux = embed_aux.view(B, T, N_LAYERS, PER_LAYER_DIM)
# 2. Hidden-state projection: project x_embed to [B, T, 35*256]
proj_all = self.per_layer_model_projection(x_embed) # [B, T, 35*256]
proj_all = proj_all * PER_LAYER_PROJ_SCALE # scale by 1/sqrt(hidden)
proj_all = proj_all.view(B, T, N_LAYERS, PER_LAYER_DIM)
# Apply RMSNorm(256) to each layer slice
proj_all = self.per_layer_projection_norm(proj_all) # broadcast over [B,T,N]
# 3. Combine: (embed_aux + proj_normed) * (1/sqrt(2))
per_layer_inputs = (embed_aux + proj_all) * PER_LAYER_INPUT_SCALE
return per_layer_inputs # [B, T, 35, 256]
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
"""
Args:
input_ids: [B, T] long tensor
Returns:
logits: [B, T, vocab_size] with softcapping applied
"""
B, T = input_ids.shape
self._ensure_rope(T, input_ids.device)
# Token embeddings scaled by sqrt(hidden_size)
x = self.embed_tokens(input_ids) * math.sqrt(HIDDEN_SIZE) # [B,T,D]
# Compute per-layer auxiliary inputs (uses unmodified x_embed)
per_layer_inputs = self._compute_per_layer_inputs(input_ids, x)
for i, layer in enumerate(self.layers):
per_layer_i = per_layer_inputs[:, :, i, :] # [B, T, 256]
if is_full_attention(i):
cos, sin = self._rope_full_cos, self._rope_full_sin
else:
cos, sin = self._rope_slide_cos, self._rope_slide_sin
x = layer(x, cos, sin, per_layer_i)
x = self.norm(x)
# Tied lm_head: F.linear(x, embed_tokens.weight)
logits = F.linear(x, self.embed_tokens.weight.to(x.dtype)) # [B,T,V]
# Logit softcapping
logits = LOGIT_CAP * torch.tanh(logits / LOGIT_CAP)
return logits
@classmethod
def load_weights(
cls,
safetensors_path: str | Path,
device: str = "cpu",
) -> "Gemma4ForCausalLM":
"""
Load from the safetensors checkpoint.
Weight names in the file follow the pattern:
model.language_model.X β†’ self.X
"""
model = cls()
path = str(safetensors_path)
prefix = "model.language_model."
state = {}
with safe_open(path, framework="pt", device=device) as f:
for key in f.keys():
if not key.startswith(prefix):
continue
local_key = key[len(prefix):] # strip "model.language_model."
state[local_key] = f.get_tensor(key)
missing, unexpected = model.load_state_dict(state, strict=False)
if missing:
print(f"[load_weights] {len(missing)} missing keys (first 5): {missing[:5]}")
if unexpected:
print(f"[load_weights] {len(unexpected)} unexpected keys (first 5): {unexpected[:5]}")
model = model.to(dtype=DTYPE)
return model
# ── Convenience loader ─────────────────────────────────────────────────────────
def load_gemma4(
device: str | None = None,
) -> tuple[Gemma4ForCausalLM, AutoTokenizer]:
"""
Load the Gemma 4 E2B model and tokenizer.
Returns:
(model, tokenizer) β€” model is in eval mode on `device`.
"""
if device is None:
device = DEVICE
print(f"Loading Gemma 4 E2B from {SAFETENSORS_BLOB} ...")
model = Gemma4ForCausalLM.load_weights(SAFETENSORS_BLOB, device=device)
model = model.to(device).eval()
print(f"Loading tokenizer from {MODEL_DIR} ...")
tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR), local_files_only=True)
return model, tokenizer
# ── PPL evaluation ─────────────────────────────────────────────────────────────
def ppl_on_text(
model: Gemma4ForCausalLM,
tokenizer: AutoTokenizer,
text: str,
device: str | None = None,
max_length: int = 1024,
) -> float:
"""
Compute token-level perplexity on `text`.
Args:
model: Gemma4ForCausalLM in eval mode
tokenizer: matching AutoTokenizer
text: input string
device: device for inference (defaults to DEVICE)
max_length: truncate to this many tokens
Returns:
perplexity (float)
"""
if device is None:
device = DEVICE
enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
input_ids = enc["input_ids"].to(device)
with torch.no_grad():
logits = model(input_ids) # [1, T, V]
# Shift: predict token t+1 from position t
shift_logits = logits[0, :-1, :] # [T-1, V]
shift_labels = input_ids[0, 1:] # [T-1]
log_probs = F.log_softmax(shift_logits.float(), dim=-1)
nll = -log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1).mean()
return nll.exp().item()
# ── main ──────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
_WIKI_TEXT = (
"The transformer architecture was introduced in the paper "
"'Attention Is All You Need' by Vaswani et al. in 2017. "
"It relies entirely on self-attention mechanisms, dispensing with "
"recurrence and convolutions entirely. Transformers have since become "
"the dominant architecture for natural language processing, powering "
"models such as BERT, GPT, T5, and the Gemma family. "
"The key innovation is the multi-head attention mechanism, which allows "
"the model to jointly attend to information from different representation "
"subspaces at different positions. This is complemented by position-wise "
"feed-forward networks and residual connections with layer normalisation. "
"Large language models built on this architecture are trained on massive "
"corpora using next-token prediction (autoregressive language modelling) "
"or masked language modelling. They exhibit emergent capabilities such as "
"few-shot and zero-shot generalisation across a wide variety of tasks."
)
model, tokenizer = load_gemma4()
ppl = ppl_on_text(model, tokenizer, _WIKI_TEXT)
print(f"\nPerplexity on sample text: {ppl:.2f} (target: ~17–18 for bfloat16)")