import math import torch import torch.nn as nn from torch.nn import functional as F from dataclasses import dataclass @dataclass class GPTConfig: vocab_size: int = 50304 context_length: int = 1024 n_layer: int = 12 n_head: int = 12 n_embed: int = 768 dropout: float = 0.0 class MultiHeadAttention(nn.Module): def __init__(self, config): super().__init__() self.n_head = config.n_head self.head_size = config.n_embed // config.n_head self.qkv = nn.Linear(config.n_embed, 3 * config.n_embed, bias=False) self.proj = nn.Linear(config.n_embed, config.n_embed, bias=False) self.attn_dropout = config.dropout self.resid_dropout = nn.Dropout(config.dropout) def forward(self, x): B, T, C = x.shape q, k, v = self.qkv(x).split(C, dim=2) q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2) k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2) v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2) x = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=self.attn_dropout if self.training else 0.0) x = x.transpose(1, 2).contiguous().view(B, T, C) x = self.resid_dropout(self.proj(x)) return x class FeedForward(nn.Module): def __init__(self, config): super().__init__() self.fc = nn.Linear(config.n_embed, 4 * config.n_embed, bias=False) self.gelu = nn.GELU() self.proj = nn.Linear(4 * config.n_embed, config.n_embed, bias=False) self.dropout = nn.Dropout(config.dropout) def forward(self, x): x = self.fc(x) x = self.gelu(x) x = self.proj(x) x = self.dropout(x) return x class Block(nn.Module): def __init__(self, config): super().__init__() self.sa_head = MultiHeadAttention(config) self.ffwd = FeedForward(config) self.ln1 = nn.LayerNorm(config.n_embed, bias=False) self.ln2 = nn.LayerNorm(config.n_embed, bias=False) def forward(self, x): x = x + self.sa_head(self.ln1(x)) x = x + self.ffwd(self.ln2(x)) return x class GPT(nn.Module): def __init__(self, config): super().__init__() self.config = config self.token_embedding = nn.Embedding(config.vocab_size, config.n_embed) self.pos_embedding = nn.Embedding(config.context_length, config.n_embed) self.drop = nn.Dropout(config.dropout) self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.n_embed, bias=False) self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False) # weight tying self.token_embedding.weight = self.lm_head.weight self.apply(self._init_weights) # scaled init for residual projections (GPT-2 paper) for pn, p in self.named_parameters(): if pn.endswith('proj.weight'): nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None, loss_mask=None): B, T = idx.shape tok_emb = self.token_embedding(idx) pos_emb = self.pos_embedding(torch.arange(T, device=idx.device)) x = self.drop(tok_emb + pos_emb) x = self.blocks(x) x = self.ln_f(x) if targets is not None: logits = self.lm_head(x) logits_flat = logits.view(-1, logits.size(-1)) targets_flat = targets.view(-1) if loss_mask is not None: loss_mask_flat = loss_mask.view(-1).float() loss = F.cross_entropy(logits_flat, targets_flat, reduction='none') loss = (loss * loss_mask_flat).sum() / loss_mask_flat.sum() else: loss = F.cross_entropy(logits_flat, targets_flat) else: # inference: only compute logits for last position logits = self.lm_head(x[:, [-1], :]) loss = None return logits, loss def configure_optimizers(self, lr, betas, weight_decay, device_type): """Separate weight decay for 2D params (matmuls/embeds) vs 1D params (biases/norms).""" decay = [p for n, p in self.named_parameters() if p.requires_grad and p.dim() >= 2] no_decay = [p for n, p in self.named_parameters() if p.requires_grad and p.dim() < 2] groups = [ {'params': decay, 'weight_decay': weight_decay}, {'params': no_decay, 'weight_decay': 0.0}, ] fused = device_type == 'cuda' return torch.optim.AdamW(groups, lr=lr, betas=betas, fused=fused) @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=0.8, top_k=40): for _ in range(max_new_tokens): idx_ctx = idx[:, -self.config.context_length:] logits, _ = self(idx_ctx) logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float('-inf') probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, idx_next], dim=1) return idx