| """ |
| baseline_gpt2.py - Vanilla GPT-2 implementation for fair comparison. |
| |
| This is a minimal GPT-2 implementation with: |
| - Absolute positional embeddings |
| - Standard ReLU MLP (not gated) |
| - Standard multi-head attention |
| |
| Used as a baseline to compare against RippleGPT. |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
| from typing import Optional, Tuple |
|
|
|
|
| @dataclass |
| class GPT2Config: |
| """Configuration for vanilla GPT-2 baseline.""" |
| vocab_size: int = 50257 |
| n_layer: int = 6 |
| n_head: int = 6 |
| n_embd: int = 384 |
| block_size: int = 256 |
| dropout: float = 0.1 |
| bias: bool = True |
|
|
|
|
| class MultiHeadSelfAttention(nn.Module): |
| """Standard multi-head self-attention with absolute positional encoding.""" |
| |
| def __init__(self, config: GPT2Config): |
| super().__init__() |
| assert config.n_embd % config.n_head == 0 |
| |
| self.n_head = config.n_head |
| self.n_embd = config.n_embd |
| self.head_size = config.n_embd // config.n_head |
| self.dropout = config.dropout |
| |
| |
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
| |
| self.attn_dropout = nn.Dropout(config.dropout) |
| self.resid_dropout = nn.Dropout(config.dropout) |
| |
| |
| self.register_buffer( |
| "mask", |
| torch.tril(torch.ones(config.block_size, config.block_size)) |
| .view(1, 1, config.block_size, config.block_size) |
| ) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, T, C = x.shape |
| |
| |
| qkv = self.c_attn(x) |
| q, k, v = qkv.split(self.n_embd, dim=-1) |
| |
| |
| q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2) |
| k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2) |
| v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2) |
| |
| |
| scale = 1.0 / math.sqrt(self.head_size) |
| attn = (q @ k.transpose(-2, -1)) * scale |
| |
| |
| attn = attn.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf')) |
| attn = F.softmax(attn, dim=-1) |
| attn = self.attn_dropout(attn) |
| |
| |
| y = attn @ v |
| |
| |
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| y = self.resid_dropout(self.c_proj(y)) |
| |
| return y |
|
|
|
|
| class MLP(nn.Module): |
| """Standard ReLU-based MLP (not gated like SwiGLU).""" |
| |
| def __init__(self, config: GPT2Config): |
| super().__init__() |
| |
| hidden_dim = 4 * config.n_embd |
| |
| self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=config.bias) |
| self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias) |
| self.act = nn.GELU() |
| self.dropout = nn.Dropout(config.dropout) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.c_fc(x) |
| x = self.act(x) |
| x = self.c_proj(x) |
| x = self.dropout(x) |
| return x |
|
|
|
|
| class Block(nn.Module): |
| """Transformer block with pre-norm.""" |
| |
| def __init__(self, config: GPT2Config): |
| super().__init__() |
| self.ln_1 = nn.LayerNorm(config.n_embd) |
| self.attn = MultiHeadSelfAttention(config) |
| self.ln_2 = nn.LayerNorm(config.n_embd) |
| self.mlp = MLP(config) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x + self.attn(self.ln_1(x)) |
| x = x + self.mlp(self.ln_2(x)) |
| return x |
|
|
|
|
| class VanillaGPT2(nn.Module): |
| """ |
| Vanilla GPT-2 baseline for comparison. |
| |
| Key differences from RippleGPT: |
| 1. Uses absolute positional embeddings (cannot extrapolate) |
| 2. Uses standard MLP (not gated SwiGLU) |
| 3. Uses standard attention (no decay bias) |
| |
| This should have MORE parameters than RippleGPT for the same |
| layer/head/embedding config, due to the 4x MLP expansion vs SwiGLU's 8/3x. |
| """ |
| |
| def __init__(self, config: GPT2Config): |
| super().__init__() |
| self.config = config |
| |
| |
| self.wte = nn.Embedding(config.vocab_size, config.n_embd) |
| self.wpe = nn.Embedding(config.block_size, config.n_embd) |
| |
| self.drop = nn.Dropout(config.dropout) |
| self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) |
| self.ln_f = nn.LayerNorm(config.n_embd) |
| |
| |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| self.lm_head.weight = self.wte.weight |
| |
| |
| self.apply(self._init_weights) |
| |
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| |
| def get_num_params(self) -> int: |
| """Returns number of parameters.""" |
| return sum(p.numel() for p in self.parameters()) |
| |
| def forward( |
| self, |
| idx: torch.Tensor, |
| targets: Optional[torch.Tensor] = None |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| B, T = idx.shape |
| device = idx.device |
| |
| |
| if T > self.config.block_size: |
| raise ValueError( |
| f"Sequence length {T} exceeds block_size {self.config.block_size}. " |
| "VanillaGPT2 cannot extrapolate beyond training length!" |
| ) |
| |
| |
| pos = torch.arange(0, T, dtype=torch.long, device=device) |
| tok_emb = self.wte(idx) |
| pos_emb = self.wpe(pos) |
| x = self.drop(tok_emb + pos_emb) |
| |
| |
| x = self.blocks(x) |
| x = self.ln_f(x) |
| |
| |
| logits = self.lm_head(x) |
| |
| |
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| targets.view(-1) |
| ) |
| |
| return logits, loss |
| |
| @torch.no_grad() |
| def generate( |
| self, |
| idx: torch.Tensor, |
| max_new_tokens: int, |
| temperature: float = 1.0, |
| top_k: Optional[int] = None |
| ) -> torch.Tensor: |
| """Generate tokens autoregressively.""" |
| for _ in range(max_new_tokens): |
| |
| idx_cond = idx[:, -self.config.block_size:] |
| |
| |
| logits, _ = self(idx_cond) |
| logits = logits[:, -1, :] / temperature |
| |
| |
| if top_k is not None: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = float('-inf') |
| |
| probs = F.softmax(logits, dim=-1) |
| idx_next = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat([idx, idx_next], dim=1) |
| |
| return idx |
|
|
|
|
| def create_baseline_config(ripple_config) -> GPT2Config: |
| """Create a VanillaGPT2 config matching a RippleConfig for fair comparison.""" |
| return GPT2Config( |
| vocab_size=ripple_config.vocab_size, |
| n_layer=ripple_config.n_layer, |
| n_head=ripple_config.n_head, |
| n_embd=ripple_config.n_embd, |
| block_size=ripple_config.block_size, |
| dropout=ripple_config.dropout, |
| bias=ripple_config.bias |
| ) |
|
|
|
|
| if __name__ == '__main__': |
| |
| print("🔧 Testing VanillaGPT2 Baseline...") |
| |
| config = GPT2Config( |
| vocab_size=50257, |
| n_layer=6, |
| n_head=6, |
| n_embd=384, |
| block_size=256 |
| ) |
| |
| model = VanillaGPT2(config) |
| print(f"✅ Model created with {model.get_num_params():,} parameters") |
| |
| |
| x = torch.randint(0, 50257, (2, 64)) |
| y = torch.randint(0, 50257, (2, 64)) |
| |
| logits, loss = model(x, y) |
| print(f"✅ Forward pass: logits shape {logits.shape}, loss {loss.item():.4f}") |
| |
| |
| prompt = torch.randint(0, 50257, (1, 10)) |
| output = model.generate(prompt, max_new_tokens=20) |
| print(f"✅ Generation: {prompt.shape} → {output.shape}") |
|
|