""" Llama-Style Transformer Model ============================= Modern transformer architecture with all Tier 1 and Tier 2 optimizations: Architecture (Tier 1): - RMSNorm (faster than LayerNorm, no mean calculation) - RoPE (Rotary Position Embedding, better length generalization) - SwiGLU activation (gated FFN, consistently outperforms GELU) - Pre-norm (apply norm before attention/FFN, more stable training) Optimizations (Tier 2): - GQA (Grouped Query Attention, fewer KV heads = faster + less memory) - Weight tying (share embedding and output projection) - Flash Attention via F.scaled_dot_product_attention - Gradient checkpointing support (trade compute for memory) Compatible with: - liger-kernel (fused RMSNorm, SwiGLU, RoPE, cross-entropy) - bf16/fp16 mixed precision training - torch.compile for additional speedups Model Sizes: - tiny: ~15M params (for testing) - small: ~125M params - medium: ~350M params - large: ~760M params - 1B: ~1.1B params (Llama 3.2 1B style) """ import math from dataclasses import dataclass from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F # ============================================================================ # Model Configuration # ============================================================================ @dataclass class ModelConfig: """Configuration for Llama-style transformer model.""" # Model architecture vocab_size: int = 32000 d_model: int = 2048 # Hidden dimension n_layers: int = 16 # Number of transformer blocks n_heads: int = 32 # Number of attention heads n_kv_heads: int = 8 # Number of KV heads (for GQA) d_ff: int = None # FFN intermediate dim (default: 8/3 * d_model) # Sequence max_seq_len: int = 2048 # Maximum sequence length # RoPE rope_theta: float = 500000.0 # RoPE base frequency # Regularization dropout: float = 0.0 # Dropout (0 for pretraining) # Options tie_weights: bool = True # Tie embedding and output weights use_flash_attn: bool = True # Use Flash Attention (SDPA) def __post_init__(self): # SwiGLU uses 8/3 * d_model for FFN, rounded to multiple of 256 if self.d_ff is None: self.d_ff = int(8 / 3 * self.d_model) self.d_ff = ((self.d_ff + 255) // 256) * 256 # Validate GQA configuration assert self.n_heads % self.n_kv_heads == 0, \ f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})" self.n_kv_groups = self.n_heads // self.n_kv_heads self.head_dim = self.d_model // self.n_heads # Predefined model configurations MODEL_CONFIGS = { "tiny": ModelConfig( d_model=256, n_layers=6, n_heads=8, n_kv_heads=4, max_seq_len=1024, ), "small": ModelConfig( d_model=768, n_layers=12, n_heads=12, n_kv_heads=4, max_seq_len=2048, ), "medium": ModelConfig( d_model=1024, n_layers=16, n_heads=16, n_kv_heads=4, max_seq_len=2048, ), "large": ModelConfig( d_model=1536, n_layers=20, n_heads=24, n_kv_heads=8, max_seq_len=2048, ), "1B": ModelConfig( d_model=2048, n_layers=16, n_heads=32, n_kv_heads=8, d_ff=8192, # Llama 3.2 1B uses 4x hidden, not 8/3x max_seq_len=2048, ), } def get_model_config(size: str, **overrides) -> ModelConfig: """Get a predefined model configuration with optional overrides.""" if size not in MODEL_CONFIGS: raise ValueError(f"Unknown model size: {size}. Choose from: {list(MODEL_CONFIGS.keys())}") config = MODEL_CONFIGS[size] # Apply overrides for key, value in overrides.items(): if hasattr(config, key): setattr(config, key, value) else: raise ValueError(f"Unknown config parameter: {key}") # Recompute derived values config.__post_init__() return config # ============================================================================ # RMSNorm (Tier 1) # ============================================================================ class RMSNorm(nn.Module): """ Root Mean Square Layer Normalization. Simpler and faster than LayerNorm - skips the mean calculation. Used in Llama, Mistral, and other modern LLMs. Can be replaced with liger_kernel.transformers.LigerRMSNorm for additional speedup via kernel fusion. """ def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x: torch.Tensor) -> torch.Tensor: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x: torch.Tensor) -> torch.Tensor: output = self._norm(x.float()).type_as(x) return output * self.weight # ============================================================================ # Rotary Position Embedding (RoPE) (Tier 1) # ============================================================================ def precompute_rope_freqs( dim: int, max_seq_len: int, theta: float = 10000.0, device: torch.device = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Precompute the cos and sin frequencies for RoPE. Args: dim: Head dimension (d_model // n_heads) max_seq_len: Maximum sequence length theta: Base frequency (Llama 3 uses 500000) device: Target device Returns: cos, sin tensors of shape (max_seq_len, dim) """ # Compute inverse frequencies freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim)) # Create position indices t = torch.arange(max_seq_len, device=device) # Outer product: (seq_len,) x (dim/2,) -> (seq_len, dim/2) freqs = torch.outer(t, freqs) # Compute cos and sin, then interleave to get (seq_len, dim) cos = torch.cos(freqs).repeat_interleave(2, dim=-1) sin = torch.sin(freqs).repeat_interleave(2, dim=-1) return cos, sin def apply_rotary_emb( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: """ Apply rotary position embedding to input tensor. Args: x: Input tensor of shape (batch, n_heads, seq_len, head_dim) cos: Cosine frequencies of shape (seq_len, head_dim) sin: Sine frequencies of shape (seq_len, head_dim) Returns: Tensor with rotary embedding applied """ # Get sequence length from input seq_len = x.size(2) cos = cos[:seq_len] sin = sin[:seq_len] # Reshape for broadcasting: (1, 1, seq_len, head_dim) cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) # Rotate pairs: [x0, x1, x2, x3, ...] -> [-x1, x0, -x3, x2, ...] x_rot = torch.stack([-x[..., 1::2], x[..., ::2]], dim=-1) x_rot = x_rot.reshape(x.shape) # Apply rotation return x * cos + x_rot * sin # ============================================================================ # Grouped Query Attention (GQA) with Flash Attention (Tier 1 + Tier 2) # ============================================================================ class Attention(nn.Module): """ Multi-head attention with Grouped Query Attention (GQA) and Flash Attention. GQA uses fewer key-value heads than query heads, reducing memory and compute while maintaining quality. For example, with 32 query heads and 8 KV heads, each KV head is shared by 4 query heads. Flash Attention is used via PyTorch's scaled_dot_product_attention, which provides O(N) memory complexity instead of O(N^2). """ def __init__(self, config: ModelConfig): super().__init__() self.config = config self.n_heads = config.n_heads self.n_kv_heads = config.n_kv_heads self.n_kv_groups = config.n_kv_groups self.head_dim = config.head_dim # Query projection: full heads self.wq = nn.Linear(config.d_model, config.n_heads * config.head_dim, bias=False) # Key and Value projections: fewer heads for GQA self.wk = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=False) self.wv = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=False) # Output projection self.wo = nn.Linear(config.n_heads * config.head_dim, config.d_model, bias=False) self.dropout = nn.Dropout(config.dropout) self.use_flash_attn = config.use_flash_attn def forward( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, seq_len, _ = x.shape # Project to Q, K, V q = self.wq(x) # (B, T, n_heads * head_dim) k = self.wk(x) # (B, T, n_kv_heads * head_dim) v = self.wv(x) # (B, T, n_kv_heads * head_dim) # Reshape to (B, n_heads, T, head_dim) q = q.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) v = v.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) # Apply RoPE to Q and K q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) # Expand KV heads for GQA: (B, n_kv_heads, T, head_dim) -> (B, n_heads, T, head_dim) if self.n_kv_groups > 1: k = k.repeat_interleave(self.n_kv_groups, dim=1) v = v.repeat_interleave(self.n_kv_groups, dim=1) # Attention if self.use_flash_attn: # Use PyTorch's optimized SDPA (Flash Attention when available) attn_out = F.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=self.dropout.p if self.training else 0.0, is_causal=mask is None, # Use causal mask if no explicit mask ) else: # Manual attention (for debugging or when SDPA unavailable) scale = 1.0 / math.sqrt(self.head_dim) attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale if mask is not None: attn_weights = attn_weights + mask else: # Causal mask causal_mask = torch.triu( torch.full((seq_len, seq_len), float('-inf'), device=x.device), diagonal=1 ) attn_weights = attn_weights + causal_mask attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = self.dropout(attn_weights) attn_out = torch.matmul(attn_weights, v) # Reshape back: (B, n_heads, T, head_dim) -> (B, T, d_model) attn_out = attn_out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) return self.wo(attn_out) # ============================================================================ # SwiGLU Feed-Forward Network (Tier 1) # ============================================================================ class FeedForward(nn.Module): """ SwiGLU Feed-Forward Network. Replaces the standard GELU FFN with a gated linear unit using SiLU activation. Uses 3 weight matrices (gate, up, down) instead of 2. SwiGLU(x) = (x * W_gate * SiLU) * (x * W_up) * W_down Consistently outperforms GELU at the same compute budget. Can be replaced with liger_kernel.transformers.LigerSwiGLUMLP for fusion. """ def __init__(self, config: ModelConfig): super().__init__() hidden_dim = config.d_ff # Gate and up projections (can be fused) self.w_gate = nn.Linear(config.d_model, hidden_dim, bias=False) self.w_up = nn.Linear(config.d_model, hidden_dim, bias=False) # Down projection self.w_down = nn.Linear(hidden_dim, config.d_model, bias=False) self.dropout = nn.Dropout(config.dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: # SwiGLU: SiLU(gate) * up, then project down return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))) # ============================================================================ # Transformer Block (Pre-norm) # ============================================================================ class TransformerBlock(nn.Module): """ Single transformer block with pre-norm architecture. Pre-norm applies normalization BEFORE attention/FFN (not after), which provides more stable gradients at scale. Structure: x = x + Attention(RMSNorm(x)) x = x + FFN(RMSNorm(x)) """ def __init__(self, config: ModelConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx # Pre-norm layers self.attn_norm = RMSNorm(config.d_model) self.ffn_norm = RMSNorm(config.d_model) # Attention and FFN self.attn = Attention(config) self.ffn = FeedForward(config) def forward( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Pre-norm attention with residual x = x + self.attn(self.attn_norm(x), cos, sin, mask) # Pre-norm FFN with residual x = x + self.ffn(self.ffn_norm(x)) return x # ============================================================================ # Complete Llama Model # ============================================================================ class LlamaModel(nn.Module): """ Complete Llama-style transformer model for language modeling. Features: - RMSNorm, RoPE, SwiGLU, GQA (Tier 1) - Weight tying, Flash Attention (Tier 2) - Gradient checkpointing support - Compatible with liger-kernel fused ops Usage: config = get_model_config("1B", vocab_size=32000) model = LlamaModel(config) # Enable gradient checkpointing for memory savings model.gradient_checkpointing_enable() # Forward pass logits = model(input_ids) loss = model(input_ids, targets=targets) """ def __init__(self, config: ModelConfig): super().__init__() self.config = config # Token embedding self.tok_emb = nn.Embedding(config.vocab_size, config.d_model) # Transformer blocks self.layers = nn.ModuleList([ TransformerBlock(config, layer_idx=i) for i in range(config.n_layers) ]) # Final normalization self.norm = RMSNorm(config.d_model) # Output projection (language model head) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) # Weight tying: share embedding and output weights if config.tie_weights: self.lm_head.weight = self.tok_emb.weight # Precompute RoPE frequencies self.register_buffer( "rope_cos", torch.zeros(config.max_seq_len, config.head_dim), persistent=False ) self.register_buffer( "rope_sin", torch.zeros(config.max_seq_len, config.head_dim), persistent=False ) # Gradient checkpointing flag self._gradient_checkpointing = False # Initialize weights self.apply(self._init_weights) # Apply special initialization for output projection self._init_output_weights() def _init_weights(self, module: nn.Module): """Initialize weights using Llama-style initialization.""" if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def _init_output_weights(self): """Apply scaled initialization to output projections for stability.""" # Scale down residual projections by 1/sqrt(2*n_layers) scale = (2 * self.config.n_layers) ** -0.5 for layer in self.layers: torch.nn.init.normal_(layer.attn.wo.weight, mean=0.0, std=0.02 * scale) torch.nn.init.normal_(layer.ffn.w_down.weight, mean=0.0, std=0.02 * scale) def _init_rope(self, device: torch.device): """Initialize RoPE frequencies on the correct device.""" cos, sin = precompute_rope_freqs( dim=self.config.head_dim, max_seq_len=self.config.max_seq_len, theta=self.config.rope_theta, device=device, ) self.rope_cos = cos self.rope_sin = sin def gradient_checkpointing_enable(self): """Enable gradient checkpointing for memory-efficient training.""" self._gradient_checkpointing = True def gradient_checkpointing_disable(self): """Disable gradient checkpointing.""" self._gradient_checkpointing = False def forward( self, input_ids: torch.Tensor, targets: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass. Args: input_ids: Token IDs of shape (batch_size, seq_len) targets: Optional target IDs for loss computation mask: Optional attention mask Returns: If targets provided: scalar loss Otherwise: logits of shape (batch_size, seq_len, vocab_size) """ batch_size, seq_len = input_ids.shape device = input_ids.device # Initialize RoPE on first forward pass (ensures correct device) if self.rope_cos.device != device or self.rope_cos.sum() == 0: self._init_rope(device) # Token embeddings x = self.tok_emb(input_ids) # Get RoPE frequencies for this sequence length cos = self.rope_cos[:seq_len] sin = self.rope_sin[:seq_len] # Transformer blocks for layer in self.layers: if self._gradient_checkpointing and self.training: x = torch.utils.checkpoint.checkpoint( layer, x, cos, sin, mask, use_reentrant=False ) else: x = layer(x, cos, sin, mask) # Final norm x = self.norm(x) # Compute logits logits = self.lm_head(x) # Compute loss if targets provided if targets is not None: # NOTE: No shift here — the DataLoader already provides # pre-shifted targets (x = tokens[:-1], y = tokens[1:]), # so logits[k] should predict targets[k] directly. loss = F.cross_entropy( logits.view(-1, self.config.vocab_size), targets.view(-1), ignore_index=-100, # Ignore padding ) return loss return logits @torch.no_grad() def generate( self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, ) -> torch.Tensor: """ Generate tokens autoregressively. Args: input_ids: Starting token IDs (batch_size, seq_len) max_new_tokens: Maximum number of tokens to generate temperature: Sampling temperature (1.0 = neutral) top_k: If set, only sample from top k tokens top_p: If set, use nucleus sampling with this probability mass Returns: Generated token IDs (batch_size, seq_len + max_new_tokens) """ self.eval() for _ in range(max_new_tokens): # Crop to max_seq_len if needed idx_cond = input_ids if input_ids.size(1) <= self.config.max_seq_len else \ input_ids[:, -self.config.max_seq_len:] # Forward pass logits = self(idx_cond) # Get logits for last position logits = logits[:, -1, :] / temperature # Apply top-k filtering if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float('-inf') # Apply top-p (nucleus) filtering if top_p is not None: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above threshold sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove ) logits[indices_to_remove] = float('-inf') # Sample probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Append input_ids = torch.cat([input_ids, next_token], dim=1) return input_ids def count_parameters(self, trainable_only: bool = True) -> int: """Count model parameters.""" if trainable_only: return sum(p.numel() for p in self.parameters() if p.requires_grad) return sum(p.numel() for p in self.parameters()) def estimate_flops(self, seq_len: int, batch_size: int = 1) -> int: """ Estimate FLOPs for a forward pass. Uses the approximation: FLOPs ≈ 2 * params * tokens (multiply-add counts as 2 ops) """ params = self.count_parameters(trainable_only=False) tokens = batch_size * seq_len return 2 * params * tokens # ============================================================================ # Utility Functions # ============================================================================ def create_model( size: str = "1B", vocab_size: int = 32000, max_seq_len: int = 2048, **kwargs ) -> LlamaModel: """ Create a Llama model with the specified configuration. Args: size: Model size ("tiny", "small", "medium", "large", "1B") vocab_size: Vocabulary size max_seq_len: Maximum sequence length **kwargs: Additional config overrides Returns: Initialized LlamaModel """ config = get_model_config( size, vocab_size=vocab_size, max_seq_len=max_seq_len, **kwargs ) return LlamaModel(config) def print_model_summary(model: LlamaModel): """Print a summary of the model architecture.""" config = model.config params = model.count_parameters() print("\n" + "=" * 60) print("LLAMA MODEL SUMMARY") print("=" * 60) print(f"\nArchitecture:") print(f" Hidden dim: {config.d_model}") print(f" Layers: {config.n_layers}") print(f" Attention heads: {config.n_heads}") print(f" KV heads (GQA): {config.n_kv_heads}") print(f" Head dim: {config.head_dim}") print(f" FFN dim: {config.d_ff}") print(f" Vocab size: {config.vocab_size}") print(f" Max seq len: {config.max_seq_len}") print(f"\nOptimizations:") print(f" RMSNorm: Yes") print(f" RoPE: Yes (theta={config.rope_theta})") print(f" SwiGLU: Yes") print(f" GQA: Yes ({config.n_heads}/{config.n_kv_heads} = {config.n_kv_groups}x)") print(f" Weight tying: {config.tie_weights}") print(f" Flash Attention: {config.use_flash_attn}") print(f"\nParameters:") print(f" Total: {params:,}") print(f" Size: ~{params / 1e9:.2f}B" if params > 1e9 else f" Size: ~{params / 1e6:.0f}M") # Estimate memory param_bytes = params * 4 # fp32 print(f" FP32 memory: ~{param_bytes / 1e9:.2f} GB") print(f" BF16 memory: ~{param_bytes / 2 / 1e9:.2f} GB") print("=" * 60 + "\n") # ============================================================================ # Main (for testing) # ============================================================================ if __name__ == "__main__": # Test model creation print("Testing Llama model creation...\n") for size in ["tiny", "small", "medium", "large", "1B"]: model = create_model(size) params = model.count_parameters() print(f"{size:8s}: {params:>12,} parameters ({params/1e6:>7.1f}M)") print("\n" + "-" * 60) # Detailed summary for 1B model = create_model("1B") print_model_summary(model) # Test forward pass print("Testing forward pass...") device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) batch_size = 2 seq_len = 128 input_ids = torch.randint(0, 32000, (batch_size, seq_len), device=device) # Forward without targets (returns logits) logits = model(input_ids) print(f"Logits shape: {logits.shape}") # Forward with targets (returns loss) targets = torch.randint(0, 32000, (batch_size, seq_len), device=device) loss = model(input_ids, targets=targets) print(f"Loss: {loss.item():.4f}") # Test gradient checkpointing print("\nTesting gradient checkpointing...") model.gradient_checkpointing_enable() loss = model(input_ids, targets=targets) loss.backward() print(f"Gradient checkpointing loss: {loss.item():.4f}") print("\nAll tests passed!")