RippleGPT-Nano / validation /benchmarks /baseline_gpt2.py
Tavernari's picture
Upload folder using huggingface_hub
148b631 verified
"""
baseline_gpt2.py - Vanilla GPT-2 implementation for fair comparison.
This is a minimal GPT-2 implementation with:
- Absolute positional embeddings
- Standard ReLU MLP (not gated)
- Standard multi-head attention
Used as a baseline to compare against RippleGPT.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Tuple
@dataclass
class GPT2Config:
"""Configuration for vanilla GPT-2 baseline."""
vocab_size: int = 50257
n_layer: int = 6
n_head: int = 6
n_embd: int = 384
block_size: int = 256
dropout: float = 0.1
bias: bool = True
class MultiHeadSelfAttention(nn.Module):
"""Standard multi-head self-attention with absolute positional encoding."""
def __init__(self, config: GPT2Config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_size = config.n_embd // config.n_head
self.dropout = config.dropout
# Combined QKV projection for efficiency
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
# Causal mask
self.register_buffer(
"mask",
torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
# Project to Q, K, V
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embd, dim=-1)
# Reshape for multi-head attention
q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)
# Compute attention scores
scale = 1.0 / math.sqrt(self.head_size)
attn = (q @ k.transpose(-2, -1)) * scale
# Apply causal mask
attn = attn.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)
attn = self.attn_dropout(attn)
# Apply attention to values
y = attn @ v
# Reshape back
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_dropout(self.c_proj(y))
return y
class MLP(nn.Module):
"""Standard ReLU-based MLP (not gated like SwiGLU)."""
def __init__(self, config: GPT2Config):
super().__init__()
# Standard 4x expansion factor
hidden_dim = 4 * config.n_embd
self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
self.act = nn.GELU() # GPT-2 uses GELU, not ReLU
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.c_fc(x)
x = self.act(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
class Block(nn.Module):
"""Transformer block with pre-norm."""
def __init__(self, config: GPT2Config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = MultiHeadSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class VanillaGPT2(nn.Module):
"""
Vanilla GPT-2 baseline for comparison.
Key differences from RippleGPT:
1. Uses absolute positional embeddings (cannot extrapolate)
2. Uses standard MLP (not gated SwiGLU)
3. Uses standard attention (no decay bias)
This should have MORE parameters than RippleGPT for the same
layer/head/embedding config, due to the 4x MLP expansion vs SwiGLU's 8/3x.
"""
def __init__(self, config: GPT2Config):
super().__init__()
self.config = config
# Token and position embeddings
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.block_size, config.n_embd)
self.drop = nn.Dropout(config.dropout)
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd)
# Language modeling head (weight tied with wte)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.lm_head.weight = self.wte.weight # Weight tying
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def get_num_params(self) -> int:
"""Returns number of parameters."""
return sum(p.numel() for p in self.parameters())
def forward(
self,
idx: torch.Tensor,
targets: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
B, T = idx.shape
device = idx.device
# Check sequence length
if T > self.config.block_size:
raise ValueError(
f"Sequence length {T} exceeds block_size {self.config.block_size}. "
"VanillaGPT2 cannot extrapolate beyond training length!"
)
# Token + positional embeddings
pos = torch.arange(0, T, dtype=torch.long, device=device)
tok_emb = self.wte(idx)
pos_emb = self.wpe(pos)
x = self.drop(tok_emb + pos_emb)
# Transformer blocks
x = self.blocks(x)
x = self.ln_f(x)
# Language modeling head
logits = self.lm_head(x)
# Compute loss if targets provided
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1)
)
return logits, loss
@torch.no_grad()
def generate(
self,
idx: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0,
top_k: Optional[int] = None
) -> torch.Tensor:
"""Generate tokens autoregressively."""
for _ in range(max_new_tokens):
# Crop to block_size (MUST do for vanilla GPT-2)
idx_cond = idx[:, -self.config.block_size:]
# Forward pass
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
# Optional top-k filtering
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, idx_next], dim=1)
return idx
def create_baseline_config(ripple_config) -> GPT2Config:
"""Create a VanillaGPT2 config matching a RippleConfig for fair comparison."""
return GPT2Config(
vocab_size=ripple_config.vocab_size,
n_layer=ripple_config.n_layer,
n_head=ripple_config.n_head,
n_embd=ripple_config.n_embd,
block_size=ripple_config.block_size,
dropout=ripple_config.dropout,
bias=ripple_config.bias
)
if __name__ == '__main__':
# Test baseline model
print("🔧 Testing VanillaGPT2 Baseline...")
config = GPT2Config(
vocab_size=50257,
n_layer=6,
n_head=6,
n_embd=384,
block_size=256
)
model = VanillaGPT2(config)
print(f"✅ Model created with {model.get_num_params():,} parameters")
# Test forward pass
x = torch.randint(0, 50257, (2, 64))
y = torch.randint(0, 50257, (2, 64))
logits, loss = model(x, y)
print(f"✅ Forward pass: logits shape {logits.shape}, loss {loss.item():.4f}")
# Test generation
prompt = torch.randint(0, 50257, (1, 10))
output = model.generate(prompt, max_new_tokens=20)
print(f"✅ Generation: {prompt.shape}{output.shape}")