Make repo self-contained: rewrite docs, single-model benchmark, remove external references
dbb5d78 verified | """GPT model definition and checkpoint loading for exported smartwatch LM.""" | |
| from __future__ import annotations | |
| import math | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from tokenizers import Tokenizer | |
| import config as cfg | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, n_head: int, n_embd: int, block_size: int, dropout: float, bias: bool): | |
| super().__init__() | |
| assert n_embd % n_head == 0 | |
| self.n_head = n_head | |
| self.n_embd = n_embd | |
| self.head_dim = n_embd // n_head | |
| self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=bias) | |
| self.c_proj = nn.Linear(n_embd, n_embd, bias=bias) | |
| self.attn_dropout = nn.Dropout(dropout) | |
| self.resid_dropout = nn.Dropout(dropout) | |
| self.register_buffer( | |
| "bias", | |
| torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| b, t, c = x.size() | |
| q, k, v = self.c_attn(x).split(self.n_embd, dim=2) | |
| k = k.view(b, t, self.n_head, self.head_dim).transpose(1, 2) | |
| q = q.view(b, t, self.n_head, self.head_dim).transpose(1, 2) | |
| v = v.view(b, t, self.n_head, self.head_dim).transpose(1, 2) | |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) | |
| att = att.masked_fill(self.bias[:, :, :t, :t] == 0, float("-inf")) | |
| att = F.softmax(att, dim=-1) | |
| att = self.attn_dropout(att) | |
| y = att @ v | |
| y = y.transpose(1, 2).contiguous().view(b, t, c) | |
| return self.resid_dropout(self.c_proj(y)) | |
| class MLP(nn.Module): | |
| def __init__(self, n_embd: int, dropout: float, bias: bool): | |
| super().__init__() | |
| self.c_fc = nn.Linear(n_embd, 4 * n_embd, bias=bias) | |
| self.gelu = nn.GELU() | |
| self.c_proj = nn.Linear(4 * n_embd, n_embd, bias=bias) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.dropout(self.c_proj(self.gelu(self.c_fc(x)))) | |
| class Block(nn.Module): | |
| def __init__(self, n_head: int, n_embd: int, block_size: int, dropout: float, bias: bool): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(n_embd) | |
| self.attn = CausalSelfAttention(n_head, n_embd, block_size, dropout, bias) | |
| self.ln2 = nn.LayerNorm(n_embd) | |
| self.mlp = MLP(n_embd, dropout, bias) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.attn(self.ln1(x)) | |
| x = x + self.mlp(self.ln2(x)) | |
| return x | |
| class GPT(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| n_layer: int, | |
| n_head: int, | |
| n_embd: int, | |
| block_size: int, | |
| dropout: float, | |
| bias: bool, | |
| ): | |
| super().__init__() | |
| self.block_size = block_size | |
| self.transformer = nn.ModuleDict( | |
| { | |
| "wte": nn.Embedding(vocab_size, n_embd), | |
| "wpe": nn.Embedding(block_size, n_embd), | |
| "drop": nn.Dropout(dropout), | |
| "h": nn.ModuleList( | |
| [Block(n_head, n_embd, block_size, dropout, bias) for _ in range(n_layer)] | |
| ), | |
| "ln_f": nn.LayerNorm(n_embd), | |
| } | |
| ) | |
| self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) | |
| self.transformer.wte.weight = self.lm_head.weight | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module: nn.Module) -> None: | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def forward(self, idx: torch.Tensor, targets=None): | |
| b, t = idx.size() | |
| assert t <= self.block_size | |
| pos = torch.arange(0, t, dtype=torch.long, device=idx.device) | |
| x = self.transformer.drop( | |
| self.transformer.wte(idx) + self.transformer.wpe(pos) | |
| ) | |
| for block in self.transformer.h: | |
| x = block(x) | |
| x = self.transformer.ln_f(x) | |
| logits = self.lm_head(x) | |
| loss = None | |
| if targets is not None: | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
| return logits, loss | |
| def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k=None): | |
| for _ in range(max_new_tokens): | |
| idx_cond = idx[:, -self.block_size :] | |
| logits, _ = self(idx_cond) | |
| logits = logits[:, -1, :] / max(temperature, 1e-8) | |
| if top_k is not None: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = -float("Inf") | |
| probs = F.softmax(logits, dim=-1) | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat((idx, idx_next), dim=1) | |
| return idx | |
| def resolve_checkpoint_paths( | |
| checkpoint_path: Path | None = None, | |
| tokenizer_path: Path | None = None, | |
| ) -> tuple[Path, Path]: | |
| ckpt = checkpoint_path or cfg.OUTPUT_DIR / "checkpoint.pt" | |
| tok = tokenizer_path or cfg.OUTPUT_DIR / "tokenizer.json" | |
| if not ckpt.is_file(): | |
| raise FileNotFoundError( | |
| f"Checkpoint not found at {ckpt}. Ensure checkpoint.pt is in this model folder." | |
| ) | |
| if not tok.is_file(): | |
| raise FileNotFoundError( | |
| f"Tokenizer not found at {tok}. Ensure tokenizer.json is in this model folder." | |
| ) | |
| return ckpt, tok | |
| def load_model( | |
| checkpoint_path: Path | None = None, | |
| tokenizer_path: Path | None = None, | |
| device: str | None = None, | |
| ) -> tuple[GPT, Tokenizer, str]: | |
| ckpt_path, tok_path = resolve_checkpoint_paths(checkpoint_path, tokenizer_path) | |
| dev = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| tokenizer = Tokenizer.from_file(str(tok_path)) | |
| checkpoint = torch.load(ckpt_path, map_location=dev, weights_only=False) | |
| model_config = checkpoint["model_config"] | |
| model = GPT( | |
| vocab_size=model_config["vocab_size"], | |
| n_layer=model_config["n_layer"], | |
| n_head=model_config["n_head"], | |
| n_embd=model_config["n_embd"], | |
| block_size=model_config["block_size"], | |
| dropout=model_config["dropout"], | |
| bias=model_config["bias"], | |
| ) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.to(dev) | |
| model.eval() | |
| return model, tokenizer, dev | |