"""GPT model definition and checkpoint loading for collab-run-1.""" from __future__ import annotations import math from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F from tokenizers import Tokenizer import config as cfg class CausalSelfAttention(nn.Module): def __init__(self, n_head: int, n_embd: int, block_size: int, dropout: float, bias: bool): super().__init__() assert n_embd % n_head == 0 self.n_head = n_head self.n_embd = n_embd self.head_dim = n_embd // n_head self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=bias) self.c_proj = nn.Linear(n_embd, n_embd, bias=bias) self.attn_dropout = nn.Dropout(dropout) self.resid_dropout = nn.Dropout(dropout) self.register_buffer( "bias", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size), ) def forward(self, x: torch.Tensor) -> torch.Tensor: b, t, c = x.size() q, k, v = self.c_attn(x).split(self.n_embd, dim=2) k = k.view(b, t, self.n_head, self.head_dim).transpose(1, 2) q = q.view(b, t, self.n_head, self.head_dim).transpose(1, 2) v = v.view(b, t, self.n_head, self.head_dim).transpose(1, 2) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) att = att.masked_fill(self.bias[:, :, :t, :t] == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v y = y.transpose(1, 2).contiguous().view(b, t, c) return self.resid_dropout(self.c_proj(y)) class MLP(nn.Module): def __init__(self, n_embd: int, dropout: float, bias: bool): super().__init__() self.c_fc = nn.Linear(n_embd, 4 * n_embd, bias=bias) self.gelu = nn.GELU() self.c_proj = nn.Linear(4 * n_embd, n_embd, bias=bias) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.c_proj(self.gelu(self.c_fc(x)))) class Block(nn.Module): def __init__(self, n_head: int, n_embd: int, block_size: int, dropout: float, bias: bool): super().__init__() self.ln1 = nn.LayerNorm(n_embd) self.attn = CausalSelfAttention(n_head, n_embd, block_size, dropout, bias) self.ln2 = nn.LayerNorm(n_embd) self.mlp = MLP(n_embd, dropout, bias) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class GPT(nn.Module): def __init__( self, vocab_size: int, n_layer: int, n_head: int, n_embd: int, block_size: int, dropout: float, bias: bool, ): super().__init__() self.block_size = block_size self.transformer = nn.ModuleDict( { "wte": nn.Embedding(vocab_size, n_embd), "wpe": nn.Embedding(block_size, n_embd), "drop": nn.Dropout(dropout), "h": nn.ModuleList( [Block(n_head, n_embd, block_size, dropout, bias) for _ in range(n_layer)] ), "ln_f": nn.LayerNorm(n_embd), } ) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) self.transformer.wte.weight = self.lm_head.weight self.apply(self._init_weights) def _init_weights(self, module: nn.Module) -> None: if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx: torch.Tensor, targets=None): b, t = idx.size() assert t <= self.block_size pos = torch.arange(0, t, dtype=torch.long, device=idx.device) x = self.transformer.drop( self.transformer.wte(idx) + self.transformer.wpe(pos) ) for block in self.transformer.h: x = block(x) x = self.transformer.ln_f(x) logits = self.lm_head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return logits, loss @torch.no_grad() def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k=None): for _ in range(max_new_tokens): idx_cond = idx[:, -self.block_size :] logits, _ = self(idx_cond) logits = logits[:, -1, :] / max(temperature, 1e-8) if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("Inf") probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, idx_next), dim=1) return idx def resolve_checkpoint_paths( checkpoint_path: Path | None = None, tokenizer_path: Path | None = None, ) -> tuple[Path, Path]: ckpt = checkpoint_path or cfg.OUTPUT_DIR / "checkpoint.pt" tok = tokenizer_path or cfg.OUTPUT_DIR / "tokenizer.json" if not ckpt.is_file(): raise FileNotFoundError( f"Checkpoint not found at {ckpt}. Train first with train.ipynb." ) if not tok.is_file(): raise FileNotFoundError( f"Tokenizer not found at {tok}. Train first with train.ipynb." ) return ckpt, tok def load_model( checkpoint_path: Path | None = None, tokenizer_path: Path | None = None, device: str | None = None, ) -> tuple[GPT, Tokenizer, str]: ckpt_path, tok_path = resolve_checkpoint_paths(checkpoint_path, tokenizer_path) dev = device or ("cuda" if torch.cuda.is_available() else "cpu") tokenizer = Tokenizer.from_file(str(tok_path)) checkpoint = torch.load(ckpt_path, map_location=dev, weights_only=False) model_config = checkpoint["model_config"] model = GPT( vocab_size=model_config["vocab_size"], n_layer=model_config["n_layer"], n_head=model_config["n_head"], n_embd=model_config["n_embd"], block_size=model_config["block_size"], dropout=model_config["dropout"], bias=model_config["bias"], ) model.load_state_dict(checkpoint["model_state_dict"]) model.to(dev) model.eval() return model, tokenizer, dev