""" components.py ============= Architectural components for SmolLM2-135M implementation Components: - RMSNorm: Root Mean Square Layer Normalization - RotaryEmbedding: Rotary Position Embeddings (RoPE) - GroupedQueryAttention: Grouped Query Attention (9 Q heads, 3 KV heads) - SwiGLU_FFN: SwiGLU Feed-Forward Network - TransformerBlock: Complete transformer block with pre-norm architecture """ import torch import torch.nn as nn import torch.nn.functional as F import math class RMSNorm(nn.Module): """ Root Mean Square Layer Normalization Simpler and faster than LayerNorm: - No mean centering - No bias term - 10-15% faster than LayerNorm Formula: output = input * rsqrt(mean(input²) + eps) * weight """ def __init__(self, hidden_size, eps=1e-5): """ Args: hidden_size (int): Dimension of the input eps (float): Small constant for numerical stability """ super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(hidden_size)) def forward(self, x): """ Args: x (torch.Tensor): Input tensor of shape [batch, seq_len, hidden_size] Returns: torch.Tensor: Normalized tensor of same shape as input """ # Calculate variance (mean of squares) variance = x.pow(2).mean(-1, keepdim=True) # Normalize: x / sqrt(variance + eps) x = x * torch.rsqrt(variance + self.eps) # Scale by learned weight return self.weight * x class RotaryEmbedding(nn.Module): """ Rotary Position Embedding (RoPE) Encodes position by rotating Q and K vectors in 2D subspaces. Enables relative position encoding and extrapolation to longer sequences. Key properties: - Applied only to Q and K, not V - Different rotation frequencies for different dimension pairs - Enables length extrapolation beyond training sequences """ def __init__(self, dim, max_position_embeddings=2048, base=10000.0): """ Args: dim (int): Dimension of each attention head (typically hidden_size / num_heads) max_position_embeddings (int): Maximum sequence length base (float): Base for inverse frequency calculation (theta) """ super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base # Calculate inverse frequencies for rotation # inv_freq[i] = 1 / (base^(2i/dim)) for i in [0, dim/2) inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, x, position_ids): """ Args: x (torch.Tensor): Input tensor (used for device/dtype) position_ids (torch.Tensor): Position indices [batch, seq_len] or [seq_len] Returns: tuple: (cos, sin) embeddings of shape [batch, seq_len, dim] """ # Ensure position_ids has batch dimension if position_ids.dim() == 1: position_ids = position_ids.unsqueeze(0) # Calculate rotation angles: position_ids × inv_freq # Shape: [batch, seq_len, dim/2] freqs = torch.einsum('bi,j->bij', position_ids.float(), self.inv_freq) # Duplicate frequencies for both sin and cos # Shape: [batch, seq_len, dim] emb = torch.cat((freqs, freqs), dim=-1) # Return cos and sin, preserving input dtype return emb.cos().to(x.dtype), emb.sin().to(x.dtype) def rotate_half(x): """ Rotate half the hidden dimensions For RoPE, we rotate pairs of dimensions. This function rearranges the tensor to prepare for rotation. Args: x (torch.Tensor): Input of shape [..., dim] Returns: torch.Tensor: Rotated tensor where second half is negated and moved to first """ x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): """ Apply rotary position embeddings to queries and keys Rotation formula: q_rotated = q * cos + rotate_half(q) * sin k_rotated = k * cos + rotate_half(k) * sin Args: q (torch.Tensor): Query tensor [batch, num_heads, seq_len, head_dim] k (torch.Tensor): Key tensor [batch, num_heads, seq_len, head_dim] cos (torch.Tensor): Cosine embeddings [batch, seq_len, head_dim] sin (torch.Tensor): Sine embeddings [batch, seq_len, head_dim] Returns: tuple: (q_rotated, k_rotated) with rotary embeddings applied """ # Add dimensions for broadcasting # cos/sin: [batch, seq_len, dim] -> [batch, 1, seq_len, dim] if cos.dim() == 2: cos = cos.unsqueeze(0) sin = sin.unsqueeze(0) if cos.dim() == 3: cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) # Apply rotation q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class GroupedQueryAttention(nn.Module): """ Grouped Query Attention (GQA) Memory-efficient attention where multiple query heads share KV heads. SmolLM2-135M uses 9 query heads and 3 KV heads (3:1 ratio). Benefits: - Reduces KV cache memory by 66% vs full MHA - Maintains most of multi-head attention's expressiveness - Used in Llama 2, Mistral, and other modern LLMs Architecture: - 9 query heads (each head_dim=64) - 3 KV heads (each head_dim=64) - Each KV head is repeated 3 times to serve 3 query heads """ def __init__(self, config): """ Args: config: Model configuration with attributes: - hidden_size: Model dimension (576) - num_attention_heads: Number of query heads (9) - num_key_value_heads: Number of KV heads (3) - max_position_embeddings: Max sequence length - rope_theta: RoPE base frequency """ super().__init__() self.hidden_size = config.hidden_size # 576 self.num_heads = config.num_attention_heads # 9 self.num_kv_heads = config.num_key_value_heads # 3 self.num_kv_groups = self.num_heads // self.num_kv_heads # 3 self.head_dim = self.hidden_size // self.num_heads # 64 assert self.hidden_size % self.num_heads == 0, "hidden_size must be divisible by num_heads" assert self.num_heads % self.num_kv_heads == 0, "num_heads must be divisible by num_kv_heads" # Projections (no bias in any linear layers) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) # Rotary embeddings self.rotary_emb = RotaryEmbedding( self.head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta ) def forward(self, hidden_states, attention_mask=None, position_ids=None): """ Forward pass of grouped query attention Args: hidden_states (torch.Tensor): Input [batch, seq_len, hidden_size] attention_mask (torch.Tensor, optional): Attention mask position_ids (torch.Tensor, optional): Position indices Returns: torch.Tensor: Output [batch, seq_len, hidden_size] """ batch_size, seq_len, _ = hidden_states.size() # Create position IDs if not provided if position_ids is None: position_ids = torch.arange(seq_len, device=hidden_states.device) # Q, K, V projections query_states = self.q_proj(hidden_states) # [batch, seq_len, 576] key_states = self.k_proj(hidden_states) # [batch, seq_len, 192] value_states = self.v_proj(hidden_states) # [batch, seq_len, 192] # Reshape to separate heads # Q: [batch, seq_len, 9, 64] -> [batch, 9, seq_len, 64] query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # K, V: [batch, seq_len, 3, 64] -> [batch, 3, seq_len, 64] key_states = key_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # Apply RoPE to Q and K cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Repeat K and V for GQA (3 KV heads -> 9 to match Q heads) # Each KV head is repeated 3 times: [batch, 3, seq, 64] -> [batch, 9, seq, 64] key_states = key_states.repeat_interleave(self.num_kv_groups, dim=1) value_states = value_states.repeat_interleave(self.num_kv_groups, dim=1) # Scaled dot-product attention (PyTorch 2.0+ optimized) # Equivalent to ~80% of Flash Attention performance attn_output = F.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=True # Causal masking for autoregressive generation ) # Reshape back: [batch, 9, seq_len, 64] -> [batch, seq_len, 576] attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(batch_size, seq_len, self.hidden_size) # Output projection attn_output = self.o_proj(attn_output) return attn_output class SwiGLU_FFN(nn.Module): """ SwiGLU Feed-Forward Network Uses Swish-Gated Linear Units instead of standard FFN. Formula: FFN(x) = down_proj(SiLU(gate_proj(x)) ⊙ up_proj(x)) Key differences from standard FFN: - 3 linear projections instead of 2 (gate, up, down) - Element-wise gating mechanism (⊙) - 50% more parameters but better performance - Used in Llama, PaLM, and most modern LLMs """ def __init__(self, config): """ Args: config: Model configuration with attributes: - hidden_size: Model dimension (576) - intermediate_size: FFN intermediate dimension (1536) """ super().__init__() self.hidden_size = config.hidden_size # 576 self.intermediate_size = config.intermediate_size # 1536 # Three projections (no bias) self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) # Swish/SiLU activation self.act_fn = nn.SiLU() def forward(self, x): """ Forward pass: down(SiLU(gate) * up) Args: x (torch.Tensor): Input [batch, seq_len, hidden_size] Returns: torch.Tensor: Output [batch, seq_len, hidden_size] """ # Gate path: apply SiLU activation gate = self.act_fn(self.gate_proj(x)) # Up path: linear transformation up = self.up_proj(x) # Element-wise multiplication (gating) gated = gate * up # Down projection return self.down_proj(gated) class TransformerBlock(nn.Module): """ Complete Transformer Block with Pre-Norm Architecture Architecture: 1. x -> RMSNorm -> Attention -> Add residual 2. x -> RMSNorm -> FFN -> Add residual Pre-norm (norm before sublayer) is standard in modern transformers as it provides better gradient flow in deep networks. """ def __init__(self, config): """ Args: config: Model configuration """ super().__init__() # Layer normalization (pre-norm) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Self-attention self.self_attn = GroupedQueryAttention(config) # Post-attention layer norm self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Feed-forward network self.mlp = SwiGLU_FFN(config) def forward(self, hidden_states, attention_mask=None, position_ids=None): """ Forward pass through transformer block Args: hidden_states (torch.Tensor): Input [batch, seq_len, hidden_size] attention_mask (torch.Tensor, optional): Attention mask position_ids (torch.Tensor, optional): Position indices Returns: torch.Tensor: Output [batch, seq_len, hidden_size] """ # Self-attention with residual connection residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn(hidden_states, attention_mask, position_ids) hidden_states = residual + hidden_states # FFN with residual connection residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states