| |
| """A GPT decoder for JavaScript autocomplete (plan.md). |
| |
| Architecture (modern defaults, all individually toggleable): |
| |
| Input ids -> token embedding (RoPE means no learned position table) |
| -> N x { RMSNorm, RoPE causal self-attention, RMSNorm, SwiGLU } blocks |
| -> final RMSNorm -> tied linear head -> logits over the vocab |
| |
| The ``rope`` / ``rmsnorm`` / ``swiglu`` flags on :class:`GPTConfig` select between |
| the modern pieces (RoPE, RMSNorm, SwiGLU) and the original GPT-2 pieces (learned |
| ``wpe`` position table, ``LayerNorm``, GELU MLP). Defaulting them all to ``True`` |
| gives the configuration used for the ~300M target run; flipping them to ``False`` |
| reproduces the original ~25M GPT-2-style model so old checkpoints still load. |
| |
| The implementation favours clarity over micro-optimisation but uses |
| ``F.scaled_dot_product_attention`` (flash/efficient kernels when available) and |
| weight-tied embeddings, so it trains comfortably on an M-series Mac (MPS) or CPU. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.utils.checkpoint |
| from torch.nn import functional as F |
|
|
|
|
| @dataclass |
| class GPTConfig: |
| """Hyper-parameters for the model. |
| |
| The architecture flags default to the modern stack (RoPE + RMSNorm + |
| SwiGLU, no bias) which is what the ~300M target run uses. Set them all to |
| ``False`` (and ``bias=True``) to recover the original GPT-2-style model. |
| """ |
|
|
| vocab_size: int = 8192 |
| block_size: int = 512 |
| n_layer: int = 6 |
| n_head: int = 8 |
| n_embd: int = 512 |
| dropout: float = 0.1 |
| bias: bool = False |
|
|
| |
| rope: bool = True |
| rope_theta: float = 10000.0 |
| rmsnorm: bool = True |
| swiglu: bool = True |
| grad_checkpoint: bool = False |
|
|
| def head_dim(self) -> int: |
| if self.n_embd % self.n_head != 0: |
| raise ValueError( |
| f"n_embd ({self.n_embd}) must be divisible by n_head ({self.n_head})" |
| ) |
| return self.n_embd // self.n_head |
|
|
|
|
| class RMSNorm(nn.Module): |
| """Root-mean-square layer normalisation (Zhang & Sennrich, 2019). |
| |
| Cheaper and more stable at depth than LayerNorm: no mean subtraction and no |
| bias term, just a learned per-channel scale. |
| """ |
|
|
| def __init__(self, dim: int, eps: float = 1e-5): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| dtype = x.dtype |
| x = x.float() |
| x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) |
| return (x.to(dtype)) * self.weight |
|
|
|
|
| def build_rope_cache( |
| seq_len: int, head_dim: int, theta: float, device, dtype |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Precompute the cos/sin tables for rotary embeddings. |
| |
| Returns two ``[seq_len, head_dim]`` tensors where each frequency is |
| duplicated across the two halves so it can be applied with ``rotate_half``. |
| """ |
| half = head_dim // 2 |
| inv_freq = 1.0 / (theta ** (torch.arange(0, half, device=device).float() / half)) |
| t = torch.arange(seq_len, device=device).float() |
| freqs = torch.outer(t, inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| return emb.cos().to(dtype), emb.sin().to(dtype) |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| """Rotate the two halves of the last dim: [-x2, x1].""" |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rope( |
| q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Apply rotary embeddings to q and k of shape [B, n_head, T, head_dim].""" |
| |
| cos = cos[None, None, :, :] |
| sin = sin[None, None, :, :] |
| q_rot = (q * cos) + (rotate_half(q) * sin) |
| k_rot = (k * cos) + (rotate_half(k) * sin) |
| return q_rot.type_as(q), k_rot.type_as(k) |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| """Multi-head masked self-attention with a fused QKV projection. |
| |
| Supports rotary positional embeddings (RoPE): when enabled the per-position |
| cos/sin tables are passed into :meth:`forward` and applied to q/k. |
| """ |
|
|
| def __init__(self, config: GPTConfig): |
| super().__init__() |
| self.n_head = config.n_head |
| self.n_embd = config.n_embd |
| self.dropout = config.dropout |
| self.rope = config.rope |
|
|
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
| self.attn_dropout = nn.Dropout(config.dropout) |
| self.resid_dropout = nn.Dropout(config.dropout) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ) -> torch.Tensor: |
| B, T, C = x.shape |
| q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
| head_dim = C // self.n_head |
| |
| q = q.view(B, T, self.n_head, head_dim).transpose(1, 2) |
| k = k.view(B, T, self.n_head, head_dim).transpose(1, 2) |
| v = v.view(B, T, self.n_head, head_dim).transpose(1, 2) |
|
|
| if self.rope and rope is not None: |
| cos, sin = rope |
| q, k = apply_rope(q, k, cos[:T], sin[:T]) |
|
|
| y = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| attn_mask=None, |
| dropout_p=self.dropout if self.training else 0.0, |
| is_causal=True, |
| ) |
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| return self.resid_dropout(self.c_proj(y)) |
|
|
|
|
| class MLP(nn.Module): |
| """Position-wise feed-forward network (4x expansion, GELU).""" |
|
|
| def __init__(self, config: GPTConfig): |
| super().__init__() |
| self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.c_fc(x) |
| x = F.gelu(x) |
| x = self.c_proj(x) |
| return self.dropout(x) |
|
|
|
|
| class SwiGLU(nn.Module): |
| """Gated SwiGLU feed-forward network (Shazeer, 2020). |
| |
| Computes ``W_down(silu(W_gate x) * W_up x)``. The hidden dimension is |
| 8/3 * n_embd (so the gate + up projections have ~the same parameter budget |
| as a 4x GELU MLP) rounded to a multiple of 256 for hardware friendliness. |
| """ |
|
|
| def __init__(self, config: GPTConfig): |
| super().__init__() |
| hidden = int(8 * config.n_embd / 3) |
| hidden = 256 * ((hidden + 255) // 256) |
| self.w_gate = nn.Linear(config.n_embd, hidden, bias=config.bias) |
| self.w_up = nn.Linear(config.n_embd, hidden, bias=config.bias) |
| self.w_down = nn.Linear(hidden, config.n_embd, bias=config.bias) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = F.silu(self.w_gate(x)) * self.w_up(x) |
| return self.dropout(self.w_down(x)) |
|
|
|
|
| class Block(nn.Module): |
| """A pre-norm transformer decoder block.""" |
|
|
| def __init__(self, config: GPTConfig): |
| super().__init__() |
| norm = RMSNorm if config.rmsnorm else _layernorm(config) |
| self.ln_1 = norm(config.n_embd) |
| self.attn = CausalSelfAttention(config) |
| self.ln_2 = norm(config.n_embd) |
| self.mlp = SwiGLU(config) if config.swiglu else MLP(config) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ) -> torch.Tensor: |
| x = x + self.attn(self.ln_1(x), rope=rope) |
| x = x + self.mlp(self.ln_2(x)) |
| return x |
|
|
|
|
| def _layernorm(config: GPTConfig): |
| """Return a LayerNorm factory honouring the config's bias setting.""" |
|
|
| def make(dim: int) -> nn.Module: |
| return nn.LayerNorm(dim, bias=config.bias) |
|
|
| return make |
|
|
|
|
| class GPT(nn.Module): |
| """Decoder-only transformer language model.""" |
|
|
| def __init__(self, config: GPTConfig): |
| super().__init__() |
| head_dim = config.head_dim() |
| self.config = config |
|
|
| modules = dict( |
| wte=nn.Embedding(config.vocab_size, config.n_embd), |
| drop=nn.Dropout(config.dropout), |
| h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), |
| ln_f=( |
| RMSNorm(config.n_embd) |
| if config.rmsnorm |
| else nn.LayerNorm(config.n_embd, bias=config.bias) |
| ), |
| ) |
| |
| if not config.rope: |
| modules["wpe"] = nn.Embedding(config.block_size, config.n_embd) |
| self.transformer = nn.ModuleDict(modules) |
|
|
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| |
| self.transformer.wte.weight = self.lm_head.weight |
|
|
| if config.rope: |
| cos, sin = build_rope_cache( |
| config.block_size, head_dim, config.rope_theta, |
| device="cpu", dtype=torch.float32, |
| ) |
| self.register_buffer("rope_cos", cos, persistent=False) |
| self.register_buffer("rope_sin", sin, persistent=False) |
|
|
| self.apply(self._init_weights) |
| |
| |
| for name, param in self.named_parameters(): |
| if name.endswith("c_proj.weight") or name.endswith("w_down.weight"): |
| nn.init.normal_( |
| param, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) |
| ) |
|
|
| def _init_weights(self, module: nn.Module) -> None: |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def num_params(self, non_embedding: bool = True) -> int: |
| """Total parameter count. Excludes the positional table by default. |
| |
| The token embedding is tied to the head, so it is counted once. With |
| RoPE there is no learned position table to subtract. |
| """ |
| n = sum(p.numel() for p in self.parameters()) |
| if non_embedding and "wpe" in self.transformer: |
| n -= self.transformer.wpe.weight.numel() |
| return n |
|
|
| def forward( |
| self, |
| idx: torch.Tensor, |
| targets: Optional[torch.Tensor] = None, |
| ignore_index: int = -100, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| B, T = idx.shape |
| if T > self.config.block_size: |
| raise ValueError( |
| f"sequence length {T} exceeds block_size {self.config.block_size}" |
| ) |
|
|
| tok_emb = self.transformer.wte(idx) |
| if self.config.rope: |
| x = self.transformer.drop(tok_emb) |
| rope = ( |
| self.rope_cos.to(device=x.device, dtype=x.dtype), |
| self.rope_sin.to(device=x.device, dtype=x.dtype), |
| ) |
| else: |
| pos = torch.arange(T, dtype=torch.long, device=idx.device) |
| pos_emb = self.transformer.wpe(pos) |
| x = self.transformer.drop(tok_emb + pos_emb) |
| rope = None |
|
|
| for block in self.transformer.h: |
| if self.config.grad_checkpoint and self.training: |
| x = torch.utils.checkpoint.checkpoint( |
| block, x, rope, use_reentrant=False |
| ) |
| else: |
| x = block(x, rope=rope) |
| x = self.transformer.ln_f(x) |
|
|
| if targets is not None: |
| logits = self.lm_head(x) |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| targets.view(-1), |
| ignore_index=ignore_index, |
| ) |
| return logits, loss |
|
|
| |
| logits = self.lm_head(x[:, [-1], :]) |
| return logits, None |
|
|
| def configure_optimizers( |
| self, |
| weight_decay: float, |
| learning_rate: float, |
| betas: Tuple[float, float], |
| device_type: str, |
| ) -> torch.optim.Optimizer: |
| """AdamW with decay on 2D+ weights only (no decay on biases/norms).""" |
| decay_params = [p for p in self.parameters() if p.requires_grad and p.dim() >= 2] |
| nodecay_params = [ |
| p for p in self.parameters() if p.requires_grad and p.dim() < 2 |
| ] |
| optim_groups = [ |
| {"params": decay_params, "weight_decay": weight_decay}, |
| {"params": nodecay_params, "weight_decay": 0.0}, |
| ] |
| fused = device_type == "cuda" |
| return torch.optim.AdamW( |
| optim_groups, lr=learning_rate, betas=betas, fused=fused |
| ) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| idx: torch.Tensor, |
| max_new_tokens: int, |
| temperature: float = 1.0, |
| top_k: Optional[int] = None, |
| eot_id: Optional[int] = None, |
| ) -> torch.Tensor: |
| """Autoregressively sample ``max_new_tokens`` continuations. |
| |
| Stops early if every sequence in the batch has emitted ``eot_id``. |
| """ |
| self.eval() |
| finished = torch.zeros(idx.size(0), dtype=torch.bool, device=idx.device) |
| for _ in range(max_new_tokens): |
| idx_cond = ( |
| idx |
| if idx.size(1) <= self.config.block_size |
| else idx[:, -self.config.block_size :] |
| ) |
| logits, _ = self(idx_cond) |
| logits = logits[:, -1, :] / max(temperature, 1e-6) |
| if top_k is not None: |
| k = min(top_k, logits.size(-1)) |
| vals, _ = torch.topk(logits, k) |
| logits[logits < vals[:, [-1]]] = float("-inf") |
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat([idx, next_token], dim=1) |
| if eot_id is not None: |
| finished = finished | (next_token.squeeze(1) == eot_id) |
| if bool(finished.all()): |
| break |
| return idx |
|
|