jscoder-300m / model /gpt.py
Shadid's picture
Upload model/gpt.py with huggingface_hub
988c899 verified
Raw
History Blame Contribute Delete
14.9 kB
#!/usr/bin/env python3
"""A GPT decoder for JavaScript autocomplete (plan.md).
Architecture (modern defaults, all individually toggleable):
Input ids -> token embedding (RoPE means no learned position table)
-> N x { RMSNorm, RoPE causal self-attention, RMSNorm, SwiGLU } blocks
-> final RMSNorm -> tied linear head -> logits over the vocab
The ``rope`` / ``rmsnorm`` / ``swiglu`` flags on :class:`GPTConfig` select between
the modern pieces (RoPE, RMSNorm, SwiGLU) and the original GPT-2 pieces (learned
``wpe`` position table, ``LayerNorm``, GELU MLP). Defaulting them all to ``True``
gives the configuration used for the ~300M target run; flipping them to ``False``
reproduces the original ~25M GPT-2-style model so old checkpoints still load.
The implementation favours clarity over micro-optimisation but uses
``F.scaled_dot_product_attention`` (flash/efficient kernels when available) and
weight-tied embeddings, so it trains comfortably on an M-series Mac (MPS) or CPU.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.nn import functional as F
@dataclass
class GPTConfig:
"""Hyper-parameters for the model.
The architecture flags default to the modern stack (RoPE + RMSNorm +
SwiGLU, no bias) which is what the ~300M target run uses. Set them all to
``False`` (and ``bias=True``) to recover the original GPT-2-style model.
"""
vocab_size: int = 8192
block_size: int = 512
n_layer: int = 6
n_head: int = 8
n_embd: int = 512
dropout: float = 0.1
bias: bool = False # bias in Linear / norm layers
# Architecture upgrades (plan.md #3). Each is an independent toggle.
rope: bool = True # rotary positional embeddings instead of learned wpe
rope_theta: float = 10000.0 # RoPE base frequency
rmsnorm: bool = True # RMSNorm instead of LayerNorm
swiglu: bool = True # gated SwiGLU MLP instead of GELU MLP
grad_checkpoint: bool = False # torch.utils.checkpoint on each block
def head_dim(self) -> int:
if self.n_embd % self.n_head != 0:
raise ValueError(
f"n_embd ({self.n_embd}) must be divisible by n_head ({self.n_head})"
)
return self.n_embd // self.n_head
class RMSNorm(nn.Module):
"""Root-mean-square layer normalisation (Zhang & Sennrich, 2019).
Cheaper and more stable at depth than LayerNorm: no mean subtraction and no
bias term, just a learned per-channel scale.
"""
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
x = x.float()
x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return (x.to(dtype)) * self.weight
def build_rope_cache(
seq_len: int, head_dim: int, theta: float, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Precompute the cos/sin tables for rotary embeddings.
Returns two ``[seq_len, head_dim]`` tensors where each frequency is
duplicated across the two halves so it can be applied with ``rotate_half``.
"""
half = head_dim // 2
inv_freq = 1.0 / (theta ** (torch.arange(0, half, device=device).float() / half))
t = torch.arange(seq_len, device=device).float()
freqs = torch.outer(t, inv_freq) # [seq_len, half]
emb = torch.cat((freqs, freqs), dim=-1) # [seq_len, head_dim]
return emb.cos().to(dtype), emb.sin().to(dtype)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate the two halves of the last dim: [-x2, x1]."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rope(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary embeddings to q and k of shape [B, n_head, T, head_dim]."""
# cos/sin are [T, head_dim]; broadcast over batch and heads.
cos = cos[None, None, :, :]
sin = sin[None, None, :, :]
q_rot = (q * cos) + (rotate_half(q) * sin)
k_rot = (k * cos) + (rotate_half(k) * sin)
return q_rot.type_as(q), k_rot.type_as(k)
class CausalSelfAttention(nn.Module):
"""Multi-head masked self-attention with a fused QKV projection.
Supports rotary positional embeddings (RoPE): when enabled the per-position
cos/sin tables are passed into :meth:`forward` and applied to q/k.
"""
def __init__(self, config: GPTConfig):
super().__init__()
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
self.rope = config.rope
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
def forward(
self,
x: torch.Tensor,
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
B, T, C = x.shape
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
head_dim = C // self.n_head
# [B, n_head, T, head_dim]
q = q.view(B, T, self.n_head, head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, head_dim).transpose(1, 2)
if self.rope and rope is not None:
cos, sin = rope
q, k = apply_rope(q, k, cos[:T], sin[:T])
y = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True,
)
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.resid_dropout(self.c_proj(y))
class MLP(nn.Module):
"""Position-wise feed-forward network (4x expansion, GELU)."""
def __init__(self, config: GPTConfig):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.c_fc(x)
x = F.gelu(x)
x = self.c_proj(x)
return self.dropout(x)
class SwiGLU(nn.Module):
"""Gated SwiGLU feed-forward network (Shazeer, 2020).
Computes ``W_down(silu(W_gate x) * W_up x)``. The hidden dimension is
8/3 * n_embd (so the gate + up projections have ~the same parameter budget
as a 4x GELU MLP) rounded to a multiple of 256 for hardware friendliness.
"""
def __init__(self, config: GPTConfig):
super().__init__()
hidden = int(8 * config.n_embd / 3)
hidden = 256 * ((hidden + 255) // 256) # round up to multiple of 256
self.w_gate = nn.Linear(config.n_embd, hidden, bias=config.bias)
self.w_up = nn.Linear(config.n_embd, hidden, bias=config.bias)
self.w_down = nn.Linear(hidden, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.w_gate(x)) * self.w_up(x)
return self.dropout(self.w_down(x))
class Block(nn.Module):
"""A pre-norm transformer decoder block."""
def __init__(self, config: GPTConfig):
super().__init__()
norm = RMSNorm if config.rmsnorm else _layernorm(config)
self.ln_1 = norm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = norm(config.n_embd)
self.mlp = SwiGLU(config) if config.swiglu else MLP(config)
def forward(
self,
x: torch.Tensor,
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
x = x + self.attn(self.ln_1(x), rope=rope)
x = x + self.mlp(self.ln_2(x))
return x
def _layernorm(config: GPTConfig):
"""Return a LayerNorm factory honouring the config's bias setting."""
def make(dim: int) -> nn.Module:
return nn.LayerNorm(dim, bias=config.bias)
return make
class GPT(nn.Module):
"""Decoder-only transformer language model."""
def __init__(self, config: GPTConfig):
super().__init__()
head_dim = config.head_dim() # validate divisibility early
self.config = config
modules = dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
drop=nn.Dropout(config.dropout),
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f=(
RMSNorm(config.n_embd)
if config.rmsnorm
else nn.LayerNorm(config.n_embd, bias=config.bias)
),
)
# RoPE removes the learned position table entirely.
if not config.rope:
modules["wpe"] = nn.Embedding(config.block_size, config.n_embd)
self.transformer = nn.ModuleDict(modules)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Weight tying: share the input embedding with the output projection.
self.transformer.wte.weight = self.lm_head.weight
if config.rope:
cos, sin = build_rope_cache(
config.block_size, head_dim, config.rope_theta,
device="cpu", dtype=torch.float32,
)
self.register_buffer("rope_cos", cos, persistent=False)
self.register_buffer("rope_sin", sin, persistent=False)
self.apply(self._init_weights)
# Scaled init for residual projections (GPT-2 style). SwiGLU's output
# projection is named ``w_down`` rather than ``c_proj``.
for name, param in self.named_parameters():
if name.endswith("c_proj.weight") or name.endswith("w_down.weight"):
nn.init.normal_(
param, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
)
def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def num_params(self, non_embedding: bool = True) -> int:
"""Total parameter count. Excludes the positional table by default.
The token embedding is tied to the head, so it is counted once. With
RoPE there is no learned position table to subtract.
"""
n = sum(p.numel() for p in self.parameters())
if non_embedding and "wpe" in self.transformer:
n -= self.transformer.wpe.weight.numel()
return n
def forward(
self,
idx: torch.Tensor,
targets: Optional[torch.Tensor] = None,
ignore_index: int = -100,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
B, T = idx.shape
if T > self.config.block_size:
raise ValueError(
f"sequence length {T} exceeds block_size {self.config.block_size}"
)
tok_emb = self.transformer.wte(idx) # [B, T, n_embd]
if self.config.rope:
x = self.transformer.drop(tok_emb)
rope = (
self.rope_cos.to(device=x.device, dtype=x.dtype),
self.rope_sin.to(device=x.device, dtype=x.dtype),
)
else:
pos = torch.arange(T, dtype=torch.long, device=idx.device)
pos_emb = self.transformer.wpe(pos) # [T, n_embd]
x = self.transformer.drop(tok_emb + pos_emb)
rope = None
for block in self.transformer.h:
if self.config.grad_checkpoint and self.training:
x = torch.utils.checkpoint.checkpoint(
block, x, rope, use_reentrant=False
)
else:
x = block(x, rope=rope)
x = self.transformer.ln_f(x)
if targets is not None:
logits = self.lm_head(x)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=ignore_index,
)
return logits, loss
# Inference shortcut: only compute logits for the final position.
logits = self.lm_head(x[:, [-1], :])
return logits, None
def configure_optimizers(
self,
weight_decay: float,
learning_rate: float,
betas: Tuple[float, float],
device_type: str,
) -> torch.optim.Optimizer:
"""AdamW with decay on 2D+ weights only (no decay on biases/norms)."""
decay_params = [p for p in self.parameters() if p.requires_grad and p.dim() >= 2]
nodecay_params = [
p for p in self.parameters() if p.requires_grad and p.dim() < 2
]
optim_groups = [
{"params": decay_params, "weight_decay": weight_decay},
{"params": nodecay_params, "weight_decay": 0.0},
]
fused = device_type == "cuda"
return torch.optim.AdamW(
optim_groups, lr=learning_rate, betas=betas, fused=fused
)
@torch.no_grad()
def generate(
self,
idx: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0,
top_k: Optional[int] = None,
eot_id: Optional[int] = None,
) -> torch.Tensor:
"""Autoregressively sample ``max_new_tokens`` continuations.
Stops early if every sequence in the batch has emitted ``eot_id``.
"""
self.eval()
finished = torch.zeros(idx.size(0), dtype=torch.bool, device=idx.device)
for _ in range(max_new_tokens):
idx_cond = (
idx
if idx.size(1) <= self.config.block_size
else idx[:, -self.config.block_size :]
)
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / max(temperature, 1e-6)
if top_k is not None:
k = min(top_k, logits.size(-1))
vals, _ = torch.topk(logits, k)
logits[logits < vals[:, [-1]]] = float("-inf")
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, next_token], dim=1)
if eot_id is not None:
finished = finished | (next_token.squeeze(1) == eot_id)
if bool(finished.all()):
break
return idx