#!/usr/bin/env python3 """A GPT decoder for JavaScript autocomplete (plan.md). Architecture (modern defaults, all individually toggleable): Input ids -> token embedding (RoPE means no learned position table) -> N x { RMSNorm, RoPE causal self-attention, RMSNorm, SwiGLU } blocks -> final RMSNorm -> tied linear head -> logits over the vocab The ``rope`` / ``rmsnorm`` / ``swiglu`` flags on :class:`GPTConfig` select between the modern pieces (RoPE, RMSNorm, SwiGLU) and the original GPT-2 pieces (learned ``wpe`` position table, ``LayerNorm``, GELU MLP). Defaulting them all to ``True`` gives the configuration used for the ~300M target run; flipping them to ``False`` reproduces the original ~25M GPT-2-style model so old checkpoints still load. The implementation favours clarity over micro-optimisation but uses ``F.scaled_dot_product_attention`` (flash/efficient kernels when available) and weight-tied embeddings, so it trains comfortably on an M-series Mac (MPS) or CPU. """ from __future__ import annotations import math from dataclasses import dataclass from typing import Optional, Tuple import torch import torch.nn as nn import torch.utils.checkpoint from torch.nn import functional as F @dataclass class GPTConfig: """Hyper-parameters for the model. The architecture flags default to the modern stack (RoPE + RMSNorm + SwiGLU, no bias) which is what the ~300M target run uses. Set them all to ``False`` (and ``bias=True``) to recover the original GPT-2-style model. """ vocab_size: int = 8192 block_size: int = 512 n_layer: int = 6 n_head: int = 8 n_embd: int = 512 dropout: float = 0.1 bias: bool = False # bias in Linear / norm layers # Architecture upgrades (plan.md #3). Each is an independent toggle. rope: bool = True # rotary positional embeddings instead of learned wpe rope_theta: float = 10000.0 # RoPE base frequency rmsnorm: bool = True # RMSNorm instead of LayerNorm swiglu: bool = True # gated SwiGLU MLP instead of GELU MLP grad_checkpoint: bool = False # torch.utils.checkpoint on each block def head_dim(self) -> int: if self.n_embd % self.n_head != 0: raise ValueError( f"n_embd ({self.n_embd}) must be divisible by n_head ({self.n_head})" ) return self.n_embd // self.n_head class RMSNorm(nn.Module): """Root-mean-square layer normalisation (Zhang & Sennrich, 2019). Cheaper and more stable at depth than LayerNorm: no mean subtraction and no bias term, just a learned per-channel scale. """ def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: dtype = x.dtype x = x.float() x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) return (x.to(dtype)) * self.weight def build_rope_cache( seq_len: int, head_dim: int, theta: float, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor]: """Precompute the cos/sin tables for rotary embeddings. Returns two ``[seq_len, head_dim]`` tensors where each frequency is duplicated across the two halves so it can be applied with ``rotate_half``. """ half = head_dim // 2 inv_freq = 1.0 / (theta ** (torch.arange(0, half, device=device).float() / half)) t = torch.arange(seq_len, device=device).float() freqs = torch.outer(t, inv_freq) # [seq_len, half] emb = torch.cat((freqs, freqs), dim=-1) # [seq_len, head_dim] return emb.cos().to(dtype), emb.sin().to(dtype) def rotate_half(x: torch.Tensor) -> torch.Tensor: """Rotate the two halves of the last dim: [-x2, x1].""" x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rope( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply rotary embeddings to q and k of shape [B, n_head, T, head_dim].""" # cos/sin are [T, head_dim]; broadcast over batch and heads. cos = cos[None, None, :, :] sin = sin[None, None, :, :] q_rot = (q * cos) + (rotate_half(q) * sin) k_rot = (k * cos) + (rotate_half(k) * sin) return q_rot.type_as(q), k_rot.type_as(k) class CausalSelfAttention(nn.Module): """Multi-head masked self-attention with a fused QKV projection. Supports rotary positional embeddings (RoPE): when enabled the per-position cos/sin tables are passed into :meth:`forward` and applied to q/k. """ def __init__(self, config: GPTConfig): super().__init__() self.n_head = config.n_head self.n_embd = config.n_embd self.dropout = config.dropout self.rope = config.rope self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout) def forward( self, x: torch.Tensor, rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: B, T, C = x.shape q, k, v = self.c_attn(x).split(self.n_embd, dim=2) head_dim = C // self.n_head # [B, n_head, T, head_dim] q = q.view(B, T, self.n_head, head_dim).transpose(1, 2) k = k.view(B, T, self.n_head, head_dim).transpose(1, 2) v = v.view(B, T, self.n_head, head_dim).transpose(1, 2) if self.rope and rope is not None: cos, sin = rope q, k = apply_rope(q, k, cos[:T], sin[:T]) y = F.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True, ) y = y.transpose(1, 2).contiguous().view(B, T, C) return self.resid_dropout(self.c_proj(y)) class MLP(nn.Module): """Position-wise feed-forward network (4x expansion, GELU).""" def __init__(self, config: GPTConfig): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.c_fc(x) x = F.gelu(x) x = self.c_proj(x) return self.dropout(x) class SwiGLU(nn.Module): """Gated SwiGLU feed-forward network (Shazeer, 2020). Computes ``W_down(silu(W_gate x) * W_up x)``. The hidden dimension is 8/3 * n_embd (so the gate + up projections have ~the same parameter budget as a 4x GELU MLP) rounded to a multiple of 256 for hardware friendliness. """ def __init__(self, config: GPTConfig): super().__init__() hidden = int(8 * config.n_embd / 3) hidden = 256 * ((hidden + 255) // 256) # round up to multiple of 256 self.w_gate = nn.Linear(config.n_embd, hidden, bias=config.bias) self.w_up = nn.Linear(config.n_embd, hidden, bias=config.bias) self.w_down = nn.Linear(hidden, config.n_embd, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.silu(self.w_gate(x)) * self.w_up(x) return self.dropout(self.w_down(x)) class Block(nn.Module): """A pre-norm transformer decoder block.""" def __init__(self, config: GPTConfig): super().__init__() norm = RMSNorm if config.rmsnorm else _layernorm(config) self.ln_1 = norm(config.n_embd) self.attn = CausalSelfAttention(config) self.ln_2 = norm(config.n_embd) self.mlp = SwiGLU(config) if config.swiglu else MLP(config) def forward( self, x: torch.Tensor, rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: x = x + self.attn(self.ln_1(x), rope=rope) x = x + self.mlp(self.ln_2(x)) return x def _layernorm(config: GPTConfig): """Return a LayerNorm factory honouring the config's bias setting.""" def make(dim: int) -> nn.Module: return nn.LayerNorm(dim, bias=config.bias) return make class GPT(nn.Module): """Decoder-only transformer language model.""" def __init__(self, config: GPTConfig): super().__init__() head_dim = config.head_dim() # validate divisibility early self.config = config modules = dict( wte=nn.Embedding(config.vocab_size, config.n_embd), drop=nn.Dropout(config.dropout), h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), ln_f=( RMSNorm(config.n_embd) if config.rmsnorm else nn.LayerNorm(config.n_embd, bias=config.bias) ), ) # RoPE removes the learned position table entirely. if not config.rope: modules["wpe"] = nn.Embedding(config.block_size, config.n_embd) self.transformer = nn.ModuleDict(modules) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # Weight tying: share the input embedding with the output projection. self.transformer.wte.weight = self.lm_head.weight if config.rope: cos, sin = build_rope_cache( config.block_size, head_dim, config.rope_theta, device="cpu", dtype=torch.float32, ) self.register_buffer("rope_cos", cos, persistent=False) self.register_buffer("rope_sin", sin, persistent=False) self.apply(self._init_weights) # Scaled init for residual projections (GPT-2 style). SwiGLU's output # projection is named ``w_down`` rather than ``c_proj``. for name, param in self.named_parameters(): if name.endswith("c_proj.weight") or name.endswith("w_down.weight"): nn.init.normal_( param, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) ) def _init_weights(self, module: nn.Module) -> None: if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def num_params(self, non_embedding: bool = True) -> int: """Total parameter count. Excludes the positional table by default. The token embedding is tied to the head, so it is counted once. With RoPE there is no learned position table to subtract. """ n = sum(p.numel() for p in self.parameters()) if non_embedding and "wpe" in self.transformer: n -= self.transformer.wpe.weight.numel() return n def forward( self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None, ignore_index: int = -100, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: B, T = idx.shape if T > self.config.block_size: raise ValueError( f"sequence length {T} exceeds block_size {self.config.block_size}" ) tok_emb = self.transformer.wte(idx) # [B, T, n_embd] if self.config.rope: x = self.transformer.drop(tok_emb) rope = ( self.rope_cos.to(device=x.device, dtype=x.dtype), self.rope_sin.to(device=x.device, dtype=x.dtype), ) else: pos = torch.arange(T, dtype=torch.long, device=idx.device) pos_emb = self.transformer.wpe(pos) # [T, n_embd] x = self.transformer.drop(tok_emb + pos_emb) rope = None for block in self.transformer.h: if self.config.grad_checkpoint and self.training: x = torch.utils.checkpoint.checkpoint( block, x, rope, use_reentrant=False ) else: x = block(x, rope=rope) x = self.transformer.ln_f(x) if targets is not None: logits = self.lm_head(x) loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=ignore_index, ) return logits, loss # Inference shortcut: only compute logits for the final position. logits = self.lm_head(x[:, [-1], :]) return logits, None def configure_optimizers( self, weight_decay: float, learning_rate: float, betas: Tuple[float, float], device_type: str, ) -> torch.optim.Optimizer: """AdamW with decay on 2D+ weights only (no decay on biases/norms).""" decay_params = [p for p in self.parameters() if p.requires_grad and p.dim() >= 2] nodecay_params = [ p for p in self.parameters() if p.requires_grad and p.dim() < 2 ] optim_groups = [ {"params": decay_params, "weight_decay": weight_decay}, {"params": nodecay_params, "weight_decay": 0.0}, ] fused = device_type == "cuda" return torch.optim.AdamW( optim_groups, lr=learning_rate, betas=betas, fused=fused ) @torch.no_grad() def generate( self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k: Optional[int] = None, eot_id: Optional[int] = None, ) -> torch.Tensor: """Autoregressively sample ``max_new_tokens`` continuations. Stops early if every sequence in the batch has emitted ``eot_id``. """ self.eval() finished = torch.zeros(idx.size(0), dtype=torch.bool, device=idx.device) for _ in range(max_new_tokens): idx_cond = ( idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size :] ) logits, _ = self(idx_cond) logits = logits[:, -1, :] / max(temperature, 1e-6) if top_k is not None: k = min(top_k, logits.size(-1)) vals, _ = torch.topk(logits, k) logits[logits < vals[:, [-1]]] = float("-inf") probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, next_token], dim=1) if eot_id is not None: finished = finished | (next_token.squeeze(1) == eot_id) if bool(finished.all()): break return idx