Spaces:
Sleeping
Sleeping
| """ | |
| components.py | |
| ============= | |
| Architectural components for SmolLM2-135M implementation | |
| Components: | |
| - RMSNorm: Root Mean Square Layer Normalization | |
| - RotaryEmbedding: Rotary Position Embeddings (RoPE) | |
| - GroupedQueryAttention: Grouped Query Attention (9 Q heads, 3 KV heads) | |
| - SwiGLU_FFN: SwiGLU Feed-Forward Network | |
| - TransformerBlock: Complete transformer block with pre-norm architecture | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| class RMSNorm(nn.Module): | |
| """ | |
| Root Mean Square Layer Normalization | |
| Simpler and faster than LayerNorm: | |
| - No mean centering | |
| - No bias term | |
| - 10-15% faster than LayerNorm | |
| Formula: output = input * rsqrt(mean(input²) + eps) * weight | |
| """ | |
| def __init__(self, hidden_size, eps=1e-5): | |
| """ | |
| Args: | |
| hidden_size (int): Dimension of the input | |
| eps (float): Small constant for numerical stability | |
| """ | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [batch, seq_len, hidden_size] | |
| Returns: | |
| torch.Tensor: Normalized tensor of same shape as input | |
| """ | |
| # Calculate variance (mean of squares) | |
| variance = x.pow(2).mean(-1, keepdim=True) | |
| # Normalize: x / sqrt(variance + eps) | |
| x = x * torch.rsqrt(variance + self.eps) | |
| # Scale by learned weight | |
| return self.weight * x | |
| class RotaryEmbedding(nn.Module): | |
| """ | |
| Rotary Position Embedding (RoPE) | |
| Encodes position by rotating Q and K vectors in 2D subspaces. | |
| Enables relative position encoding and extrapolation to longer sequences. | |
| Key properties: | |
| - Applied only to Q and K, not V | |
| - Different rotation frequencies for different dimension pairs | |
| - Enables length extrapolation beyond training sequences | |
| """ | |
| def __init__(self, dim, max_position_embeddings=2048, base=10000.0): | |
| """ | |
| Args: | |
| dim (int): Dimension of each attention head (typically hidden_size / num_heads) | |
| max_position_embeddings (int): Maximum sequence length | |
| base (float): Base for inverse frequency calculation (theta) | |
| """ | |
| super().__init__() | |
| self.dim = dim | |
| self.max_position_embeddings = max_position_embeddings | |
| self.base = base | |
| # Calculate inverse frequencies for rotation | |
| # inv_freq[i] = 1 / (base^(2i/dim)) for i in [0, dim/2) | |
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| def forward(self, x, position_ids): | |
| """ | |
| Args: | |
| x (torch.Tensor): Input tensor (used for device/dtype) | |
| position_ids (torch.Tensor): Position indices [batch, seq_len] or [seq_len] | |
| Returns: | |
| tuple: (cos, sin) embeddings of shape [batch, seq_len, dim] | |
| """ | |
| # Ensure position_ids has batch dimension | |
| if position_ids.dim() == 1: | |
| position_ids = position_ids.unsqueeze(0) | |
| # Calculate rotation angles: position_ids × inv_freq | |
| # Shape: [batch, seq_len, dim/2] | |
| freqs = torch.einsum('bi,j->bij', position_ids.float(), self.inv_freq) | |
| # Duplicate frequencies for both sin and cos | |
| # Shape: [batch, seq_len, dim] | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| # Return cos and sin, preserving input dtype | |
| return emb.cos().to(x.dtype), emb.sin().to(x.dtype) | |
| def rotate_half(x): | |
| """ | |
| Rotate half the hidden dimensions | |
| For RoPE, we rotate pairs of dimensions. This function rearranges | |
| the tensor to prepare for rotation. | |
| Args: | |
| x (torch.Tensor): Input of shape [..., dim] | |
| Returns: | |
| torch.Tensor: Rotated tensor where second half is negated and moved to first | |
| """ | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb(q, k, cos, sin): | |
| """ | |
| Apply rotary position embeddings to queries and keys | |
| Rotation formula: | |
| q_rotated = q * cos + rotate_half(q) * sin | |
| k_rotated = k * cos + rotate_half(k) * sin | |
| Args: | |
| q (torch.Tensor): Query tensor [batch, num_heads, seq_len, head_dim] | |
| k (torch.Tensor): Key tensor [batch, num_heads, seq_len, head_dim] | |
| cos (torch.Tensor): Cosine embeddings [batch, seq_len, head_dim] | |
| sin (torch.Tensor): Sine embeddings [batch, seq_len, head_dim] | |
| Returns: | |
| tuple: (q_rotated, k_rotated) with rotary embeddings applied | |
| """ | |
| # Add dimensions for broadcasting | |
| # cos/sin: [batch, seq_len, dim] -> [batch, 1, seq_len, dim] | |
| if cos.dim() == 2: | |
| cos = cos.unsqueeze(0) | |
| sin = sin.unsqueeze(0) | |
| if cos.dim() == 3: | |
| cos = cos.unsqueeze(1) | |
| sin = sin.unsqueeze(1) | |
| # Apply rotation | |
| q_embed = (q * cos) + (rotate_half(q) * sin) | |
| k_embed = (k * cos) + (rotate_half(k) * sin) | |
| return q_embed, k_embed | |
| class GroupedQueryAttention(nn.Module): | |
| """ | |
| Grouped Query Attention (GQA) | |
| Memory-efficient attention where multiple query heads share KV heads. | |
| SmolLM2-135M uses 9 query heads and 3 KV heads (3:1 ratio). | |
| Benefits: | |
| - Reduces KV cache memory by 66% vs full MHA | |
| - Maintains most of multi-head attention's expressiveness | |
| - Used in Llama 2, Mistral, and other modern LLMs | |
| Architecture: | |
| - 9 query heads (each head_dim=64) | |
| - 3 KV heads (each head_dim=64) | |
| - Each KV head is repeated 3 times to serve 3 query heads | |
| """ | |
| def __init__(self, config): | |
| """ | |
| Args: | |
| config: Model configuration with attributes: | |
| - hidden_size: Model dimension (576) | |
| - num_attention_heads: Number of query heads (9) | |
| - num_key_value_heads: Number of KV heads (3) | |
| - max_position_embeddings: Max sequence length | |
| - rope_theta: RoPE base frequency | |
| """ | |
| super().__init__() | |
| self.hidden_size = config.hidden_size # 576 | |
| self.num_heads = config.num_attention_heads # 9 | |
| self.num_kv_heads = config.num_key_value_heads # 3 | |
| self.num_kv_groups = self.num_heads // self.num_kv_heads # 3 | |
| self.head_dim = self.hidden_size // self.num_heads # 64 | |
| assert self.hidden_size % self.num_heads == 0, "hidden_size must be divisible by num_heads" | |
| assert self.num_heads % self.num_kv_heads == 0, "num_heads must be divisible by num_kv_heads" | |
| # Projections (no bias in any linear layers) | |
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) | |
| self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) | |
| self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) | |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) | |
| # Rotary embeddings | |
| self.rotary_emb = RotaryEmbedding( | |
| self.head_dim, | |
| max_position_embeddings=config.max_position_embeddings, | |
| base=config.rope_theta | |
| ) | |
| def forward(self, hidden_states, attention_mask=None, position_ids=None): | |
| """ | |
| Forward pass of grouped query attention | |
| Args: | |
| hidden_states (torch.Tensor): Input [batch, seq_len, hidden_size] | |
| attention_mask (torch.Tensor, optional): Attention mask | |
| position_ids (torch.Tensor, optional): Position indices | |
| Returns: | |
| torch.Tensor: Output [batch, seq_len, hidden_size] | |
| """ | |
| batch_size, seq_len, _ = hidden_states.size() | |
| # Create position IDs if not provided | |
| if position_ids is None: | |
| position_ids = torch.arange(seq_len, device=hidden_states.device) | |
| # Q, K, V projections | |
| query_states = self.q_proj(hidden_states) # [batch, seq_len, 576] | |
| key_states = self.k_proj(hidden_states) # [batch, seq_len, 192] | |
| value_states = self.v_proj(hidden_states) # [batch, seq_len, 192] | |
| # Reshape to separate heads | |
| # Q: [batch, seq_len, 9, 64] -> [batch, 9, seq_len, 64] | |
| query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) | |
| # K, V: [batch, seq_len, 3, 64] -> [batch, 3, seq_len, 64] | |
| key_states = key_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) | |
| value_states = value_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) | |
| # Apply RoPE to Q and K | |
| cos, sin = self.rotary_emb(value_states, position_ids) | |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | |
| # Repeat K and V for GQA (3 KV heads -> 9 to match Q heads) | |
| # Each KV head is repeated 3 times: [batch, 3, seq, 64] -> [batch, 9, seq, 64] | |
| key_states = key_states.repeat_interleave(self.num_kv_groups, dim=1) | |
| value_states = value_states.repeat_interleave(self.num_kv_groups, dim=1) | |
| # Scaled dot-product attention (PyTorch 2.0+ optimized) | |
| # Equivalent to ~80% of Flash Attention performance | |
| attn_output = F.scaled_dot_product_attention( | |
| query_states, | |
| key_states, | |
| value_states, | |
| attn_mask=attention_mask, | |
| dropout_p=0.0, | |
| is_causal=True # Causal masking for autoregressive generation | |
| ) | |
| # Reshape back: [batch, 9, seq_len, 64] -> [batch, seq_len, 576] | |
| attn_output = attn_output.transpose(1, 2).contiguous() | |
| attn_output = attn_output.view(batch_size, seq_len, self.hidden_size) | |
| # Output projection | |
| attn_output = self.o_proj(attn_output) | |
| return attn_output | |
| class SwiGLU_FFN(nn.Module): | |
| """ | |
| SwiGLU Feed-Forward Network | |
| Uses Swish-Gated Linear Units instead of standard FFN. | |
| Formula: FFN(x) = down_proj(SiLU(gate_proj(x)) ⊙ up_proj(x)) | |
| Key differences from standard FFN: | |
| - 3 linear projections instead of 2 (gate, up, down) | |
| - Element-wise gating mechanism (⊙) | |
| - 50% more parameters but better performance | |
| - Used in Llama, PaLM, and most modern LLMs | |
| """ | |
| def __init__(self, config): | |
| """ | |
| Args: | |
| config: Model configuration with attributes: | |
| - hidden_size: Model dimension (576) | |
| - intermediate_size: FFN intermediate dimension (1536) | |
| """ | |
| super().__init__() | |
| self.hidden_size = config.hidden_size # 576 | |
| self.intermediate_size = config.intermediate_size # 1536 | |
| # Three projections (no bias) | |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) | |
| # Swish/SiLU activation | |
| self.act_fn = nn.SiLU() | |
| def forward(self, x): | |
| """ | |
| Forward pass: down(SiLU(gate) * up) | |
| Args: | |
| x (torch.Tensor): Input [batch, seq_len, hidden_size] | |
| Returns: | |
| torch.Tensor: Output [batch, seq_len, hidden_size] | |
| """ | |
| # Gate path: apply SiLU activation | |
| gate = self.act_fn(self.gate_proj(x)) | |
| # Up path: linear transformation | |
| up = self.up_proj(x) | |
| # Element-wise multiplication (gating) | |
| gated = gate * up | |
| # Down projection | |
| return self.down_proj(gated) | |
| class TransformerBlock(nn.Module): | |
| """ | |
| Complete Transformer Block with Pre-Norm Architecture | |
| Architecture: | |
| 1. x -> RMSNorm -> Attention -> Add residual | |
| 2. x -> RMSNorm -> FFN -> Add residual | |
| Pre-norm (norm before sublayer) is standard in modern transformers | |
| as it provides better gradient flow in deep networks. | |
| """ | |
| def __init__(self, config): | |
| """ | |
| Args: | |
| config: Model configuration | |
| """ | |
| super().__init__() | |
| # Layer normalization (pre-norm) | |
| self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| # Self-attention | |
| self.self_attn = GroupedQueryAttention(config) | |
| # Post-attention layer norm | |
| self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| # Feed-forward network | |
| self.mlp = SwiGLU_FFN(config) | |
| def forward(self, hidden_states, attention_mask=None, position_ids=None): | |
| """ | |
| Forward pass through transformer block | |
| Args: | |
| hidden_states (torch.Tensor): Input [batch, seq_len, hidden_size] | |
| attention_mask (torch.Tensor, optional): Attention mask | |
| position_ids (torch.Tensor, optional): Position indices | |
| Returns: | |
| torch.Tensor: Output [batch, seq_len, hidden_size] | |
| """ | |
| # Self-attention with residual connection | |
| residual = hidden_states | |
| hidden_states = self.input_layernorm(hidden_states) | |
| hidden_states = self.self_attn(hidden_states, attention_mask, position_ids) | |
| hidden_states = residual + hidden_states | |
| # FFN with residual connection | |
| residual = hidden_states | |
| hidden_states = self.post_attention_layernorm(hidden_states) | |
| hidden_states = self.mlp(hidden_states) | |
| hidden_states = residual + hidden_states | |
| return hidden_states |