gpt2-from-scratch / model.py
GPUburnout's picture
Add multi-model GPT-2 demo with Tiny, Medium, and GPT-2 Small
f7ba1e0
"""
Transformer Language Model Architecture
Modern architecture (GPT-style) scalable from tiny to large
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import os
import math
class MultiHeadAttention(nn.Module):
"""Multi-head self-attention mechanism with Flash Attention support"""
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.dropout_p = dropout
# Q, K, V projections
self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
# Check if Flash Attention is available (PyTorch 2.0+)
self.use_flash = hasattr(F, 'scaled_dot_product_attention')
# Fallback dropout for non-flash path
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
batch_size, seq_len, embed_dim = x.shape
# Compute Q, K, V
qkv = self.qkv(x) # (batch, seq, 3*embed_dim)
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch, heads, seq, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
if self.use_flash:
# Use PyTorch's scaled_dot_product_attention (Flash Attention when available)
# This is 1.5-2x faster and more memory efficient
dropout_p = self.dropout_p if self.training else 0.0
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None, # We use is_causal instead
dropout_p=dropout_p,
is_causal=True # Causal mask for autoregressive generation
)
else:
# Fallback to manual attention for older PyTorch versions
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
# Apply causal mask (for autoregressive generation)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Attention weights
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
# Apply attention to values
out = torch.matmul(attn, v)
# Reshape: (batch, heads, seq, head_dim) -> (batch, seq, embed_dim)
out = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, embed_dim)
# Output projection
out = self.out_proj(out)
return out
class FeedForward(nn.Module):
"""Position-wise feed-forward network"""
def __init__(self, embed_dim, ff_dim, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(embed_dim, ff_dim)
self.fc2 = nn.Linear(ff_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = F.gelu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
class TransformerBlock(nn.Module):
"""Single Transformer block (attention + feed-forward)"""
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(embed_dim, num_heads, dropout)
self.feed_forward = FeedForward(embed_dim, ff_dim, dropout)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention with residual connection
attn_out = self.attention(self.norm1(x), mask)
x = x + self.dropout(attn_out)
# Feed-forward with residual connection
ff_out = self.feed_forward(self.norm2(x))
x = x + self.dropout(ff_out)
return x
class TransformerLanguageModel(nn.Module):
"""
GPT-style Transformer Language Model
Scalable from tiny (CPU) to large (GPU cluster)
"""
def __init__(self, vocab_size, embed_dim=256, num_heads=4, num_layers=4,
ff_dim=None, max_seq_len=256, dropout=0.1):
"""
Initialize Transformer model
Args:
vocab_size: Number of tokens in vocabulary
embed_dim: Embedding dimension (must be divisible by num_heads)
num_heads: Number of attention heads
num_layers: Number of Transformer blocks
ff_dim: Feed-forward dimension (default: 4 * embed_dim)
max_seq_len: Maximum sequence length
dropout: Dropout probability
"""
super().__init__()
if ff_dim is None:
ff_dim = 4 * embed_dim
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.ff_dim = ff_dim
self.max_seq_len = max_seq_len
self.dropout = dropout
# Token embeddings
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
# Positional embeddings (learned)
self.positional_embedding = nn.Embedding(max_seq_len, embed_dim)
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, ff_dim, dropout)
for _ in range(num_layers)
])
# Final layer norm
self.ln_f = nn.LayerNorm(embed_dim)
# Output projection
self.head = nn.Linear(embed_dim, vocab_size, bias=False)
# Dropout
self.dropout_layer = nn.Dropout(dropout)
# Initialize weights
self._init_weights()
# Create causal mask
self.register_buffer("causal_mask", self._create_causal_mask(max_seq_len))
def _init_weights(self):
"""Initialize weights"""
for module in self.modules():
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def _create_causal_mask(self, seq_len):
"""Create causal mask for autoregressive generation"""
mask = torch.tril(torch.ones(seq_len, seq_len))
mask = mask.view(1, 1, seq_len, seq_len)
return mask
def forward(self, x):
"""
Forward pass
Args:
x: Input tensor of shape (batch_size, seq_len)
Returns:
logits: Output logits of shape (batch_size, seq_len, vocab_size)
"""
batch_size, seq_len = x.shape
device = x.device
# Token embeddings
token_emb = self.token_embedding(x) # (batch, seq_len, embed_dim)
# Positional embeddings
positions = torch.arange(seq_len, device=device).unsqueeze(0)
pos_emb = self.positional_embedding(positions) # (1, seq_len, embed_dim)
# Combine embeddings
x = self.dropout_layer(token_emb + pos_emb)
# Get causal mask for this sequence length
mask = self.causal_mask[:, :, :seq_len, :seq_len]
# Apply Transformer blocks
for block in self.blocks:
x = block(x, mask)
# Final layer norm
x = self.ln_f(x)
# Output logits
logits = self.head(x) # (batch, seq_len, vocab_size)
return logits
def count_parameters(self):
"""Count trainable parameters"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def get_config(self):
"""Get model configuration"""
return {
'model_type': 'Transformer',
'architecture': 'GPT-style (decoder-only)',
'vocab_size': self.vocab_size,
'embed_dim': self.embed_dim,
'num_heads': self.num_heads,
'num_layers': self.num_layers,
'ff_dim': self.ff_dim,
'max_seq_len': self.max_seq_len,
'dropout': self.dropout,
'total_parameters': self.count_parameters()
}
def save_config(self, filepath='models/model_config.json'):
"""Save model configuration"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
config = self.get_config()
with open(filepath, 'w') as f:
json.dump(config, f, indent=2)
print(f"Model config saved to: {filepath}")
return filepath
def create_tiny_transformer(vocab_size):
"""Create a tiny Transformer (fastest on CPU)"""
return TransformerLanguageModel(
vocab_size=vocab_size,
embed_dim=128,
num_heads=4,
num_layers=2,
max_seq_len=128,
dropout=0.1
)
def create_small_transformer(vocab_size):
"""Create a small Transformer (recommended for first run)"""
return TransformerLanguageModel(
vocab_size=vocab_size,
embed_dim=256,
num_heads=4,
num_layers=4,
max_seq_len=256,
dropout=0.1
)
def create_medium_transformer(vocab_size):
"""Create a medium Transformer (GPU recommended)"""
return TransformerLanguageModel(
vocab_size=vocab_size,
embed_dim=512,
num_heads=8,
num_layers=6,
max_seq_len=512,
dropout=0.1
)
def create_large_transformer(vocab_size):
"""Create a large Transformer (GPU cluster)"""
return TransformerLanguageModel(
vocab_size=vocab_size,
embed_dim=1024,
num_heads=16,
num_layers=12,
max_seq_len=1024,
dropout=0.1
)
def main():
"""Test model creation"""
print("\n" + "="*80)
print("TRANSFORMER MODEL ARCHITECTURE")
print("="*80)
# Load tokenizer to get vocab size
tokenizer_path = 'models/tokenizer.json'
if not os.path.exists(tokenizer_path):
print(f"\nError: Tokenizer not found at {tokenizer_path}")
print("Please run tokenizer.py first.")
return
with open(tokenizer_path, 'r') as f:
tokenizer_data = json.load(f)
vocab_size = tokenizer_data['vocab_size']
print(f"\nVocabulary size: {vocab_size}")
print("Architecture: GPT-style Transformer (decoder-only)")
# Create models of different sizes
print("\n" + "-"*80)
print("TINY TRANSFORMER (fastest on CPU)")
print("-"*80)
tiny_model = create_tiny_transformer(vocab_size)
print(f"Parameters: {tiny_model.count_parameters():,}")
print(f"Embed dim: {tiny_model.embed_dim}")
print(f"Attention heads: {tiny_model.num_heads}")
print(f"Layers: {tiny_model.num_layers}")
print(f"Context length: {tiny_model.max_seq_len}")
print("\n" + "-"*80)
print("SMALL TRANSFORMER (recommended for first run)")
print("-"*80)
small_model = create_small_transformer(vocab_size)
print(f"Parameters: {small_model.count_parameters():,}")
print(f"Embed dim: {small_model.embed_dim}")
print(f"Attention heads: {small_model.num_heads}")
print(f"Layers: {small_model.num_layers}")
print(f"Context length: {small_model.max_seq_len}")
print("\n" + "-"*80)
print("MEDIUM TRANSFORMER (GPU recommended)")
print("-"*80)
medium_model = create_medium_transformer(vocab_size)
print(f"Parameters: {medium_model.count_parameters():,}")
print(f"Embed dim: {medium_model.embed_dim}")
print(f"Attention heads: {medium_model.num_heads}")
print(f"Layers: {medium_model.num_layers}")
print(f"Context length: {medium_model.max_seq_len}")
# Use small model for our tiny LM
print("\n" + "="*80)
print("SELECTED MODEL: SMALL TRANSFORMER")
print("="*80)
print("Good balance for CPU training with modern architecture")
model = small_model
# Test forward pass
print("\nTesting forward pass...")
batch_size = 4
seq_len = 32
dummy_input = torch.randint(0, vocab_size, (batch_size, seq_len))
with torch.no_grad():
logits = model(dummy_input)
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {logits.shape}")
print(f"Expected: (batch={batch_size}, seq_len={seq_len}, vocab={vocab_size})")
assert logits.shape == (batch_size, seq_len, vocab_size), "Shape mismatch!"
print("Forward pass test passed!")
# Save configuration
model.save_config()
print("\n" + "="*80)
print("MODEL CREATION COMPLETE")
print("="*80)
print(f"\nModel ready for training!")
print(f"Architecture: {model.get_config()['model_type']}")
print(f"Total parameters: {model.count_parameters():,}")
print(f"Configuration saved to: models/model_config.json")
print(f"\nNext step: Implement the training loop")
print("="*80 + "\n")
if __name__ == "__main__":
main()