""" Transformer Language Model Architecture Modern architecture (GPT-style) scalable from tiny to large """ import torch import torch.nn as nn import torch.nn.functional as F import json import os import math class MultiHeadAttention(nn.Module): """Multi-head self-attention mechanism with Flash Attention support""" def __init__(self, embed_dim, num_heads, dropout=0.1): super().__init__() assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.dropout_p = dropout # Q, K, V projections self.qkv = nn.Linear(embed_dim, 3 * embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) # Check if Flash Attention is available (PyTorch 2.0+) self.use_flash = hasattr(F, 'scaled_dot_product_attention') # Fallback dropout for non-flash path self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): batch_size, seq_len, embed_dim = x.shape # Compute Q, K, V qkv = self.qkv(x) # (batch, seq, 3*embed_dim) qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch, heads, seq, head_dim) q, k, v = qkv[0], qkv[1], qkv[2] if self.use_flash: # Use PyTorch's scaled_dot_product_attention (Flash Attention when available) # This is 1.5-2x faster and more memory efficient dropout_p = self.dropout_p if self.training else 0.0 out = F.scaled_dot_product_attention( q, k, v, attn_mask=None, # We use is_causal instead dropout_p=dropout_p, is_causal=True # Causal mask for autoregressive generation ) else: # Fallback to manual attention for older PyTorch versions scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # Apply causal mask (for autoregressive generation) if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) # Attention weights attn = F.softmax(scores, dim=-1) attn = self.dropout(attn) # Apply attention to values out = torch.matmul(attn, v) # Reshape: (batch, heads, seq, head_dim) -> (batch, seq, embed_dim) out = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, embed_dim) # Output projection out = self.out_proj(out) return out class FeedForward(nn.Module): """Position-wise feed-forward network""" def __init__(self, embed_dim, ff_dim, dropout=0.1): super().__init__() self.fc1 = nn.Linear(embed_dim, ff_dim) self.fc2 = nn.Linear(ff_dim, embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): x = F.gelu(self.fc1(x)) x = self.dropout(x) x = self.fc2(x) return x class TransformerBlock(nn.Module): """Single Transformer block (attention + feed-forward)""" def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1): super().__init__() self.attention = MultiHeadAttention(embed_dim, num_heads, dropout) self.feed_forward = FeedForward(embed_dim, ff_dim, dropout) self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): # Self-attention with residual connection attn_out = self.attention(self.norm1(x), mask) x = x + self.dropout(attn_out) # Feed-forward with residual connection ff_out = self.feed_forward(self.norm2(x)) x = x + self.dropout(ff_out) return x class TransformerLanguageModel(nn.Module): """ GPT-style Transformer Language Model Scalable from tiny (CPU) to large (GPU cluster) """ def __init__(self, vocab_size, embed_dim=256, num_heads=4, num_layers=4, ff_dim=None, max_seq_len=256, dropout=0.1): """ Initialize Transformer model Args: vocab_size: Number of tokens in vocabulary embed_dim: Embedding dimension (must be divisible by num_heads) num_heads: Number of attention heads num_layers: Number of Transformer blocks ff_dim: Feed-forward dimension (default: 4 * embed_dim) max_seq_len: Maximum sequence length dropout: Dropout probability """ super().__init__() if ff_dim is None: ff_dim = 4 * embed_dim assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" self.vocab_size = vocab_size self.embed_dim = embed_dim self.num_heads = num_heads self.num_layers = num_layers self.ff_dim = ff_dim self.max_seq_len = max_seq_len self.dropout = dropout # Token embeddings self.token_embedding = nn.Embedding(vocab_size, embed_dim) # Positional embeddings (learned) self.positional_embedding = nn.Embedding(max_seq_len, embed_dim) # Transformer blocks self.blocks = nn.ModuleList([ TransformerBlock(embed_dim, num_heads, ff_dim, dropout) for _ in range(num_layers) ]) # Final layer norm self.ln_f = nn.LayerNorm(embed_dim) # Output projection self.head = nn.Linear(embed_dim, vocab_size, bias=False) # Dropout self.dropout_layer = nn.Dropout(dropout) # Initialize weights self._init_weights() # Create causal mask self.register_buffer("causal_mask", self._create_causal_mask(max_seq_len)) def _init_weights(self): """Initialize weights""" for module in self.modules(): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def _create_causal_mask(self, seq_len): """Create causal mask for autoregressive generation""" mask = torch.tril(torch.ones(seq_len, seq_len)) mask = mask.view(1, 1, seq_len, seq_len) return mask def forward(self, x): """ Forward pass Args: x: Input tensor of shape (batch_size, seq_len) Returns: logits: Output logits of shape (batch_size, seq_len, vocab_size) """ batch_size, seq_len = x.shape device = x.device # Token embeddings token_emb = self.token_embedding(x) # (batch, seq_len, embed_dim) # Positional embeddings positions = torch.arange(seq_len, device=device).unsqueeze(0) pos_emb = self.positional_embedding(positions) # (1, seq_len, embed_dim) # Combine embeddings x = self.dropout_layer(token_emb + pos_emb) # Get causal mask for this sequence length mask = self.causal_mask[:, :, :seq_len, :seq_len] # Apply Transformer blocks for block in self.blocks: x = block(x, mask) # Final layer norm x = self.ln_f(x) # Output logits logits = self.head(x) # (batch, seq_len, vocab_size) return logits def count_parameters(self): """Count trainable parameters""" return sum(p.numel() for p in self.parameters() if p.requires_grad) def get_config(self): """Get model configuration""" return { 'model_type': 'Transformer', 'architecture': 'GPT-style (decoder-only)', 'vocab_size': self.vocab_size, 'embed_dim': self.embed_dim, 'num_heads': self.num_heads, 'num_layers': self.num_layers, 'ff_dim': self.ff_dim, 'max_seq_len': self.max_seq_len, 'dropout': self.dropout, 'total_parameters': self.count_parameters() } def save_config(self, filepath='models/model_config.json'): """Save model configuration""" os.makedirs(os.path.dirname(filepath), exist_ok=True) config = self.get_config() with open(filepath, 'w') as f: json.dump(config, f, indent=2) print(f"Model config saved to: {filepath}") return filepath def create_tiny_transformer(vocab_size): """Create a tiny Transformer (fastest on CPU)""" return TransformerLanguageModel( vocab_size=vocab_size, embed_dim=128, num_heads=4, num_layers=2, max_seq_len=128, dropout=0.1 ) def create_small_transformer(vocab_size): """Create a small Transformer (recommended for first run)""" return TransformerLanguageModel( vocab_size=vocab_size, embed_dim=256, num_heads=4, num_layers=4, max_seq_len=256, dropout=0.1 ) def create_medium_transformer(vocab_size): """Create a medium Transformer (GPU recommended)""" return TransformerLanguageModel( vocab_size=vocab_size, embed_dim=512, num_heads=8, num_layers=6, max_seq_len=512, dropout=0.1 ) def create_large_transformer(vocab_size): """Create a large Transformer (GPU cluster)""" return TransformerLanguageModel( vocab_size=vocab_size, embed_dim=1024, num_heads=16, num_layers=12, max_seq_len=1024, dropout=0.1 ) def main(): """Test model creation""" print("\n" + "="*80) print("TRANSFORMER MODEL ARCHITECTURE") print("="*80) # Load tokenizer to get vocab size tokenizer_path = 'models/tokenizer.json' if not os.path.exists(tokenizer_path): print(f"\nError: Tokenizer not found at {tokenizer_path}") print("Please run tokenizer.py first.") return with open(tokenizer_path, 'r') as f: tokenizer_data = json.load(f) vocab_size = tokenizer_data['vocab_size'] print(f"\nVocabulary size: {vocab_size}") print("Architecture: GPT-style Transformer (decoder-only)") # Create models of different sizes print("\n" + "-"*80) print("TINY TRANSFORMER (fastest on CPU)") print("-"*80) tiny_model = create_tiny_transformer(vocab_size) print(f"Parameters: {tiny_model.count_parameters():,}") print(f"Embed dim: {tiny_model.embed_dim}") print(f"Attention heads: {tiny_model.num_heads}") print(f"Layers: {tiny_model.num_layers}") print(f"Context length: {tiny_model.max_seq_len}") print("\n" + "-"*80) print("SMALL TRANSFORMER (recommended for first run)") print("-"*80) small_model = create_small_transformer(vocab_size) print(f"Parameters: {small_model.count_parameters():,}") print(f"Embed dim: {small_model.embed_dim}") print(f"Attention heads: {small_model.num_heads}") print(f"Layers: {small_model.num_layers}") print(f"Context length: {small_model.max_seq_len}") print("\n" + "-"*80) print("MEDIUM TRANSFORMER (GPU recommended)") print("-"*80) medium_model = create_medium_transformer(vocab_size) print(f"Parameters: {medium_model.count_parameters():,}") print(f"Embed dim: {medium_model.embed_dim}") print(f"Attention heads: {medium_model.num_heads}") print(f"Layers: {medium_model.num_layers}") print(f"Context length: {medium_model.max_seq_len}") # Use small model for our tiny LM print("\n" + "="*80) print("SELECTED MODEL: SMALL TRANSFORMER") print("="*80) print("Good balance for CPU training with modern architecture") model = small_model # Test forward pass print("\nTesting forward pass...") batch_size = 4 seq_len = 32 dummy_input = torch.randint(0, vocab_size, (batch_size, seq_len)) with torch.no_grad(): logits = model(dummy_input) print(f"Input shape: {dummy_input.shape}") print(f"Output shape: {logits.shape}") print(f"Expected: (batch={batch_size}, seq_len={seq_len}, vocab={vocab_size})") assert logits.shape == (batch_size, seq_len, vocab_size), "Shape mismatch!" print("Forward pass test passed!") # Save configuration model.save_config() print("\n" + "="*80) print("MODEL CREATION COMPLETE") print("="*80) print(f"\nModel ready for training!") print(f"Architecture: {model.get_config()['model_type']}") print(f"Total parameters: {model.count_parameters():,}") print(f"Configuration saved to: models/model_config.json") print(f"\nNext step: Implement the training loop") print("="*80 + "\n") if __name__ == "__main__": main()