""" nano GPT: A tiny GPT model built from scratch in pure PyTorch. This is a step-by-step tutorial implementation following Andrej Karpathy's build-nanogpt approach. Every piece is explicit and commented. """ import torch import torch.nn as nn from torch.nn import functional as F from dataclasses import dataclass # --------------------------------------------------------------------------- # Step 1: Configuration # --------------------------------------------------------------------------- # We define all hyperparameters in a single dataclass so they are easy to # tweak without hunting through the code. @dataclass class GPTConfig: block_size: int = 256 # maximum sequence length (context length) vocab_size: int = 65 # number of unique characters in our dataset n_layer: int = 4 # number of transformer blocks n_head: int = 4 # number of attention heads per block n_embd: int = 256 # embedding dimension (hidden size) dropout: float = 0.0 # dropout probability (0 for small overfit-prone runs) # --------------------------------------------------------------------------- # Step 2: Causal Self-Attention # --------------------------------------------------------------------------- # This is the heart of the transformer. For each token we compute three # vectors: Query, Key, and Value. # # Query: "What am I looking for?" # Key: "What do I contain?" # Value: "What information do I have?" # # We then compute attention scores = Q @ K.T, mask future tokens so the # model cannot "cheat" by looking ahead, and take a weighted sum of Values. class CausalSelfAttention(nn.Module): def __init__(self, config: GPTConfig): super().__init__() assert config.n_embd % config.n_head == 0, "n_embd must be divisible by n_head" # One linear layer projects input into Q, K, V concatenated together. # Output shape: (B, T, 3 * n_embd) self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) # Output projection back to n_embd self.c_proj = nn.Linear(config.n_embd, config.n_embd) self.n_head = config.n_head self.n_embd = config.n_embd self.dropout = config.dropout # Register a causal mask (lower-triangular) so we never attend to future tokens. # We do this once at init instead of recomputing every forward pass. self.register_buffer( "bias", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size) ) def forward(self, x: torch.Tensor) -> torch.Tensor: B, T, C = x.size() # batch, sequence length, embedding dim # 1. Compute Q, K, V qkv = self.c_attn(x) # (B, T, 3*C) q, k, v = qkv.split(self.n_embd, dim=2) # each (B, T, C) # 2. Reshape into (B, n_head, T, head_size) for multi-head attention head_size = C // self.n_head q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs) # 3. Compute attention scores: (B, nh, T, hs) @ (B, nh, hs, T) -> (B, nh, T, T) # We scale by 1/sqrt(head_size) to keep gradients stable. att = (q @ k.transpose(-2, -1)) * (1.0 / (head_size ** 0.5)) # 4. Apply causal mask: set future positions to -inf so softmax gives 0 att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) # 5. Softmax to get probability distribution over past tokens att = F.softmax(att, dim=-1) # 6. Weighted sum of values: (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs) y = att @ v # 7. Concatenate heads back together: (B, nh, T, hs) -> (B, T, nh*hs) = (B, T, C) y = y.transpose(1, 2).contiguous().view(B, T, C) # 8. Final output projection y = self.c_proj(y) return y # --------------------------------------------------------------------------- # Step 3: Feed-Forward Network (MLP) # --------------------------------------------------------------------------- # After attention, each token gets its own private "thinking" step through # a simple two-layer MLP with a GELU non-linearity. class MLP(nn.Module): def __init__(self, config: GPTConfig): super().__init__() # Expand by 4x (common in transformers) then project back down self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) self.gelu = nn.GELU() self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) self.dropout = nn.Dropout(config.dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.c_fc(x) x = self.gelu(x) x = self.c_proj(x) x = self.dropout(x) return x # --------------------------------------------------------------------------- # Step 4: Transformer Block # --------------------------------------------------------------------------- # A block = Attention -> Add & Norm -> MLP -> Add & Norm # We use **pre-norm**: normalize BEFORE applying attention/MLP. # This is what modern models (GPT-2, GPT-3, Llama, etc.) use. class Block(nn.Module): def __init__(self, config: GPTConfig): super().__init__() self.ln_1 = nn.LayerNorm(config.n_embd) self.attn = CausalSelfAttention(config) self.ln_2 = nn.LayerNorm(config.n_embd) self.mlp = MLP(config) def forward(self, x: torch.Tensor) -> torch.Tensor: # Pre-norm residual connections x = x + self.attn(self.ln_1(x)) # attention branch x = x + self.mlp(self.ln_2(x)) # MLP branch return x # --------------------------------------------------------------------------- # Step 5: Full GPT Model # --------------------------------------------------------------------------- # Putting it all together: # 1. Token embedding table (wte): maps character index -> vector # 2. Position embedding table (wpe): maps position index -> vector # 3. Stack of N transformer blocks # 4. Final layer norm # 5. Language model head: projects back to vocab_size logits class GPT(nn.Module): def __init__(self, config: GPTConfig): super().__init__() self.config = config self.transformer = nn.ModuleDict({ "wte": nn.Embedding(config.vocab_size, config.n_embd), # token embeddings "wpe": nn.Embedding(config.block_size, config.n_embd), # position embeddings "h": nn.ModuleList([Block(config) for _ in range(config.n_layer)]), "ln_f": nn.LayerNorm(config.n_embd), }) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # Weight tying: share the token embedding weights with the output projection. # This saves parameters and often improves training. self.transformer.wte.weight = self.lm_head.weight # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward( self, idx: torch.Tensor, targets: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ idx: (B, T) integer token indices targets:(B, T) integer targets for next-token prediction (optional) returns: logits (B, T, vocab_size), loss (scalar or None) """ B, T = idx.size() assert T <= self.config.block_size, f"Sequence length {T} exceeds block_size {self.config.block_size}" # 1. Token + position embeddings pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # (T,) tok_emb = self.transformer.wte(idx) # (B, T, C) pos_emb = self.transformer.wpe(pos) # (T, C) x = tok_emb + pos_emb # (B, T, C) # 2. Pass through transformer blocks for block in self.transformer.h: x = block(x) # 3. Final layer norm x = self.transformer.ln_f(x) # 4. Project to vocabulary logits logits = self.lm_head(x) # (B, T, vocab_size) # 5. Compute cross-entropy loss if targets are provided loss = None if targets is not None: loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, ) return logits, loss def generate( self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k: int | None = None, ) -> torch.Tensor: """ Generate new tokens autoregressively. idx: (B, T) starting token indices """ for _ in range(max_new_tokens): # Crop to block_size so we never exceed context length idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] # Forward pass logits, _ = self(idx_cond) logits = logits[:, -1, :] # take logits for the last token only: (B, vocab_size) # Optional top-k sampling if top_k is not None: v, _ = torch.topk(logits, top_k, dim=-1) logits[logits < v[:, [-1]]] = float("-inf") # Apply temperature and softmax probs = F.softmax(logits / temperature, dim=-1) # Sample from the distribution idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) idx = torch.cat((idx, idx_next), dim=1) # (B, T+1) return idx