abi96062's picture
Create components.py
0e3e3d6 verified
raw
history blame
14.1 kB
"""
components.py
=============
Architectural components for SmolLM2-135M implementation
Components:
- RMSNorm: Root Mean Square Layer Normalization
- RotaryEmbedding: Rotary Position Embeddings (RoPE)
- GroupedQueryAttention: Grouped Query Attention (9 Q heads, 3 KV heads)
- SwiGLU_FFN: SwiGLU Feed-Forward Network
- TransformerBlock: Complete transformer block with pre-norm architecture
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization
Simpler and faster than LayerNorm:
- No mean centering
- No bias term
- 10-15% faster than LayerNorm
Formula: output = input * rsqrt(mean(input²) + eps) * weight
"""
def __init__(self, hidden_size, eps=1e-5):
"""
Args:
hidden_size (int): Dimension of the input
eps (float): Small constant for numerical stability
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
def forward(self, x):
"""
Args:
x (torch.Tensor): Input tensor of shape [batch, seq_len, hidden_size]
Returns:
torch.Tensor: Normalized tensor of same shape as input
"""
# Calculate variance (mean of squares)
variance = x.pow(2).mean(-1, keepdim=True)
# Normalize: x / sqrt(variance + eps)
x = x * torch.rsqrt(variance + self.eps)
# Scale by learned weight
return self.weight * x
class RotaryEmbedding(nn.Module):
"""
Rotary Position Embedding (RoPE)
Encodes position by rotating Q and K vectors in 2D subspaces.
Enables relative position encoding and extrapolation to longer sequences.
Key properties:
- Applied only to Q and K, not V
- Different rotation frequencies for different dimension pairs
- Enables length extrapolation beyond training sequences
"""
def __init__(self, dim, max_position_embeddings=2048, base=10000.0):
"""
Args:
dim (int): Dimension of each attention head (typically hidden_size / num_heads)
max_position_embeddings (int): Maximum sequence length
base (float): Base for inverse frequency calculation (theta)
"""
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Calculate inverse frequencies for rotation
# inv_freq[i] = 1 / (base^(2i/dim)) for i in [0, dim/2)
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, x, position_ids):
"""
Args:
x (torch.Tensor): Input tensor (used for device/dtype)
position_ids (torch.Tensor): Position indices [batch, seq_len] or [seq_len]
Returns:
tuple: (cos, sin) embeddings of shape [batch, seq_len, dim]
"""
# Ensure position_ids has batch dimension
if position_ids.dim() == 1:
position_ids = position_ids.unsqueeze(0)
# Calculate rotation angles: position_ids × inv_freq
# Shape: [batch, seq_len, dim/2]
freqs = torch.einsum('bi,j->bij', position_ids.float(), self.inv_freq)
# Duplicate frequencies for both sin and cos
# Shape: [batch, seq_len, dim]
emb = torch.cat((freqs, freqs), dim=-1)
# Return cos and sin, preserving input dtype
return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
def rotate_half(x):
"""
Rotate half the hidden dimensions
For RoPE, we rotate pairs of dimensions. This function rearranges
the tensor to prepare for rotation.
Args:
x (torch.Tensor): Input of shape [..., dim]
Returns:
torch.Tensor: Rotated tensor where second half is negated and moved to first
"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
"""
Apply rotary position embeddings to queries and keys
Rotation formula:
q_rotated = q * cos + rotate_half(q) * sin
k_rotated = k * cos + rotate_half(k) * sin
Args:
q (torch.Tensor): Query tensor [batch, num_heads, seq_len, head_dim]
k (torch.Tensor): Key tensor [batch, num_heads, seq_len, head_dim]
cos (torch.Tensor): Cosine embeddings [batch, seq_len, head_dim]
sin (torch.Tensor): Sine embeddings [batch, seq_len, head_dim]
Returns:
tuple: (q_rotated, k_rotated) with rotary embeddings applied
"""
# Add dimensions for broadcasting
# cos/sin: [batch, seq_len, dim] -> [batch, 1, seq_len, dim]
if cos.dim() == 2:
cos = cos.unsqueeze(0)
sin = sin.unsqueeze(0)
if cos.dim() == 3:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
# Apply rotation
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class GroupedQueryAttention(nn.Module):
"""
Grouped Query Attention (GQA)
Memory-efficient attention where multiple query heads share KV heads.
SmolLM2-135M uses 9 query heads and 3 KV heads (3:1 ratio).
Benefits:
- Reduces KV cache memory by 66% vs full MHA
- Maintains most of multi-head attention's expressiveness
- Used in Llama 2, Mistral, and other modern LLMs
Architecture:
- 9 query heads (each head_dim=64)
- 3 KV heads (each head_dim=64)
- Each KV head is repeated 3 times to serve 3 query heads
"""
def __init__(self, config):
"""
Args:
config: Model configuration with attributes:
- hidden_size: Model dimension (576)
- num_attention_heads: Number of query heads (9)
- num_key_value_heads: Number of KV heads (3)
- max_position_embeddings: Max sequence length
- rope_theta: RoPE base frequency
"""
super().__init__()
self.hidden_size = config.hidden_size # 576
self.num_heads = config.num_attention_heads # 9
self.num_kv_heads = config.num_key_value_heads # 3
self.num_kv_groups = self.num_heads // self.num_kv_heads # 3
self.head_dim = self.hidden_size // self.num_heads # 64
assert self.hidden_size % self.num_heads == 0, "hidden_size must be divisible by num_heads"
assert self.num_heads % self.num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"
# Projections (no bias in any linear layers)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
# Rotary embeddings
self.rotary_emb = RotaryEmbedding(
self.head_dim,
max_position_embeddings=config.max_position_embeddings,
base=config.rope_theta
)
def forward(self, hidden_states, attention_mask=None, position_ids=None):
"""
Forward pass of grouped query attention
Args:
hidden_states (torch.Tensor): Input [batch, seq_len, hidden_size]
attention_mask (torch.Tensor, optional): Attention mask
position_ids (torch.Tensor, optional): Position indices
Returns:
torch.Tensor: Output [batch, seq_len, hidden_size]
"""
batch_size, seq_len, _ = hidden_states.size()
# Create position IDs if not provided
if position_ids is None:
position_ids = torch.arange(seq_len, device=hidden_states.device)
# Q, K, V projections
query_states = self.q_proj(hidden_states) # [batch, seq_len, 576]
key_states = self.k_proj(hidden_states) # [batch, seq_len, 192]
value_states = self.v_proj(hidden_states) # [batch, seq_len, 192]
# Reshape to separate heads
# Q: [batch, seq_len, 9, 64] -> [batch, 9, seq_len, 64]
query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# K, V: [batch, seq_len, 3, 64] -> [batch, 3, seq_len, 64]
key_states = key_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
# Apply RoPE to Q and K
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Repeat K and V for GQA (3 KV heads -> 9 to match Q heads)
# Each KV head is repeated 3 times: [batch, 3, seq, 64] -> [batch, 9, seq, 64]
key_states = key_states.repeat_interleave(self.num_kv_groups, dim=1)
value_states = value_states.repeat_interleave(self.num_kv_groups, dim=1)
# Scaled dot-product attention (PyTorch 2.0+ optimized)
# Equivalent to ~80% of Flash Attention performance
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=True # Causal masking for autoregressive generation
)
# Reshape back: [batch, 9, seq_len, 64] -> [batch, seq_len, 576]
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.hidden_size)
# Output projection
attn_output = self.o_proj(attn_output)
return attn_output
class SwiGLU_FFN(nn.Module):
"""
SwiGLU Feed-Forward Network
Uses Swish-Gated Linear Units instead of standard FFN.
Formula: FFN(x) = down_proj(SiLU(gate_proj(x)) ⊙ up_proj(x))
Key differences from standard FFN:
- 3 linear projections instead of 2 (gate, up, down)
- Element-wise gating mechanism (⊙)
- 50% more parameters but better performance
- Used in Llama, PaLM, and most modern LLMs
"""
def __init__(self, config):
"""
Args:
config: Model configuration with attributes:
- hidden_size: Model dimension (576)
- intermediate_size: FFN intermediate dimension (1536)
"""
super().__init__()
self.hidden_size = config.hidden_size # 576
self.intermediate_size = config.intermediate_size # 1536
# Three projections (no bias)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
# Swish/SiLU activation
self.act_fn = nn.SiLU()
def forward(self, x):
"""
Forward pass: down(SiLU(gate) * up)
Args:
x (torch.Tensor): Input [batch, seq_len, hidden_size]
Returns:
torch.Tensor: Output [batch, seq_len, hidden_size]
"""
# Gate path: apply SiLU activation
gate = self.act_fn(self.gate_proj(x))
# Up path: linear transformation
up = self.up_proj(x)
# Element-wise multiplication (gating)
gated = gate * up
# Down projection
return self.down_proj(gated)
class TransformerBlock(nn.Module):
"""
Complete Transformer Block with Pre-Norm Architecture
Architecture:
1. x -> RMSNorm -> Attention -> Add residual
2. x -> RMSNorm -> FFN -> Add residual
Pre-norm (norm before sublayer) is standard in modern transformers
as it provides better gradient flow in deep networks.
"""
def __init__(self, config):
"""
Args:
config: Model configuration
"""
super().__init__()
# Layer normalization (pre-norm)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Self-attention
self.self_attn = GroupedQueryAttention(config)
# Post-attention layer norm
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Feed-forward network
self.mlp = SwiGLU_FFN(config)
def forward(self, hidden_states, attention_mask=None, position_ids=None):
"""
Forward pass through transformer block
Args:
hidden_states (torch.Tensor): Input [batch, seq_len, hidden_size]
attention_mask (torch.Tensor, optional): Attention mask
position_ids (torch.Tensor, optional): Position indices
Returns:
torch.Tensor: Output [batch, seq_len, hidden_size]
"""
# Self-attention with residual connection
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(hidden_states, attention_mask, position_ids)
hidden_states = residual + hidden_states
# FFN with residual connection
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states