"""GPT-2 style transformer (BPE). Pre-norm, GELU, weight-tied head, causal attention. Trained fully from scratch.""" import math, torch, torch.nn as nn from torch.nn import functional as F class CausalSelfAttention(nn.Module): def __init__(self, cfg): super().__init__() assert cfg['n_embd'] % cfg['n_head'] == 0 self.c_attn = nn.Linear(cfg['n_embd'], 3 * cfg['n_embd']) self.c_proj = nn.Linear(cfg['n_embd'], cfg['n_embd']) self.n_head = cfg['n_head']; self.n_embd = cfg['n_embd'] self.drop = nn.Dropout(cfg['dropout']) self.resid_drop = nn.Dropout(cfg['dropout']) self.register_buffer('mask', torch.tril(torch.ones(cfg['block_size'], cfg['block_size'])) .view(1, 1, cfg['block_size'], cfg['block_size'])) def forward(self, x): B, T, C = x.size() q, k, v = self.c_attn(x).split(self.n_embd, dim=2) hs = C // self.n_head q = q.view(B, T, self.n_head, hs).transpose(1, 2) k = k.view(B, T, self.n_head, hs).transpose(1, 2) v = v.view(B, T, self.n_head, hs).transpose(1, 2) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(hs)) att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) att = self.drop(F.softmax(att, dim=-1)) y = (att @ v).transpose(1, 2).contiguous().view(B, T, C) return self.resid_drop(self.c_proj(y)) class MLP(nn.Module): def __init__(self, cfg): super().__init__() self.c_fc = nn.Linear(cfg['n_embd'], 4 * cfg['n_embd']) self.c_proj = nn.Linear(4 * cfg['n_embd'], cfg['n_embd']) self.drop = nn.Dropout(cfg['dropout']) def forward(self, x): return self.drop(self.c_proj(F.gelu(self.c_fc(x)))) class Block(nn.Module): def __init__(self, cfg): super().__init__() self.ln1 = nn.LayerNorm(cfg['n_embd']); self.attn = CausalSelfAttention(cfg) self.ln2 = nn.LayerNorm(cfg['n_embd']); self.mlp = MLP(cfg) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class GPT2(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg; self.block_size = cfg['block_size'] self.wte = nn.Embedding(cfg['vocab_size'], cfg['n_embd']) self.wpe = nn.Embedding(cfg['block_size'], cfg['n_embd']) self.drop = nn.Dropout(cfg['dropout']) self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg['n_layer'])]) self.ln_f = nn.LayerNorm(cfg['n_embd']) self.head = nn.Linear(cfg['n_embd'], cfg['vocab_size'], bias=False) self.wte.weight = self.head.weight # weight tying self.apply(self._init) for pn, p in self.named_parameters(): if pn.endswith('c_proj.weight'): nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * cfg['n_layer'])) def _init(self, m): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0.0, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): B, T = idx.size() pos = torch.arange(0, T, dtype=torch.long, device=idx.device) x = self.drop(self.wte(idx) + self.wpe(pos)) for b in self.blocks: x = b(x) x = self.ln_f(x) logits = self.head(x) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=0.8, top_k=50, eot_id=None): for _ in range(max_new_tokens): idx_cond = idx[:, -self.block_size:] logits, _ = self(idx_cond) logits = logits[:, -1, :] / temperature if top_k: v, _ = torch.topk(logits, top_k) logits[logits < v[:, [-1]]] = float('-inf') probs = F.softmax(logits, dim=-1) nxt = torch.multinomial(probs, 1) if eot_id is not None and nxt.item() == eot_id: break idx = torch.cat((idx, nxt), dim=1) return idx