|
|
""" |
|
|
components.py |
|
|
============= |
|
|
Architectural components for SmolLM2-135M implementation |
|
|
|
|
|
Components: |
|
|
- RMSNorm: Root Mean Square Layer Normalization |
|
|
- RotaryEmbedding: Rotary Position Embeddings (RoPE) |
|
|
- GroupedQueryAttention: Grouped Query Attention (9 Q heads, 3 KV heads) |
|
|
- SwiGLU_FFN: SwiGLU Feed-Forward Network |
|
|
- TransformerBlock: Complete transformer block with pre-norm architecture |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
""" |
|
|
Root Mean Square Layer Normalization |
|
|
|
|
|
Simpler and faster than LayerNorm: |
|
|
- No mean centering |
|
|
- No bias term |
|
|
- 10-15% faster than LayerNorm |
|
|
|
|
|
Formula: output = input * rsqrt(mean(input²) + eps) * weight |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_size, eps=1e-5): |
|
|
""" |
|
|
Args: |
|
|
hidden_size (int): Dimension of the input |
|
|
eps (float): Small constant for numerical stability |
|
|
""" |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Args: |
|
|
x (torch.Tensor): Input tensor of shape [batch, seq_len, hidden_size] |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Normalized tensor of same shape as input |
|
|
""" |
|
|
|
|
|
variance = x.pow(2).mean(-1, keepdim=True) |
|
|
|
|
|
|
|
|
x = x * torch.rsqrt(variance + self.eps) |
|
|
|
|
|
|
|
|
return self.weight * x |
|
|
|
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
|
""" |
|
|
Rotary Position Embedding (RoPE) |
|
|
|
|
|
Encodes position by rotating Q and K vectors in 2D subspaces. |
|
|
Enables relative position encoding and extrapolation to longer sequences. |
|
|
|
|
|
Key properties: |
|
|
- Applied only to Q and K, not V |
|
|
- Different rotation frequencies for different dimension pairs |
|
|
- Enables length extrapolation beyond training sequences |
|
|
""" |
|
|
|
|
|
def __init__(self, dim, max_position_embeddings=2048, base=10000.0): |
|
|
""" |
|
|
Args: |
|
|
dim (int): Dimension of each attention head (typically hidden_size / num_heads) |
|
|
max_position_embeddings (int): Maximum sequence length |
|
|
base (float): Base for inverse frequency calculation (theta) |
|
|
""" |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.base = base |
|
|
|
|
|
|
|
|
|
|
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
|
|
def forward(self, x, position_ids): |
|
|
""" |
|
|
Args: |
|
|
x (torch.Tensor): Input tensor (used for device/dtype) |
|
|
position_ids (torch.Tensor): Position indices [batch, seq_len] or [seq_len] |
|
|
|
|
|
Returns: |
|
|
tuple: (cos, sin) embeddings of shape [batch, seq_len, dim] |
|
|
""" |
|
|
|
|
|
if position_ids.dim() == 1: |
|
|
position_ids = position_ids.unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
freqs = torch.einsum('bi,j->bij', position_ids.float(), self.inv_freq) |
|
|
|
|
|
|
|
|
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
|
|
|
|
|
|
return emb.cos().to(x.dtype), emb.sin().to(x.dtype) |
|
|
|
|
|
|
|
|
def rotate_half(x): |
|
|
""" |
|
|
Rotate half the hidden dimensions |
|
|
|
|
|
For RoPE, we rotate pairs of dimensions. This function rearranges |
|
|
the tensor to prepare for rotation. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Input of shape [..., dim] |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Rotated tensor where second half is negated and moved to first |
|
|
""" |
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2 :] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin): |
|
|
""" |
|
|
Apply rotary position embeddings to queries and keys |
|
|
|
|
|
Rotation formula: |
|
|
q_rotated = q * cos + rotate_half(q) * sin |
|
|
k_rotated = k * cos + rotate_half(k) * sin |
|
|
|
|
|
Args: |
|
|
q (torch.Tensor): Query tensor [batch, num_heads, seq_len, head_dim] |
|
|
k (torch.Tensor): Key tensor [batch, num_heads, seq_len, head_dim] |
|
|
cos (torch.Tensor): Cosine embeddings [batch, seq_len, head_dim] |
|
|
sin (torch.Tensor): Sine embeddings [batch, seq_len, head_dim] |
|
|
|
|
|
Returns: |
|
|
tuple: (q_rotated, k_rotated) with rotary embeddings applied |
|
|
""" |
|
|
|
|
|
|
|
|
if cos.dim() == 2: |
|
|
cos = cos.unsqueeze(0) |
|
|
sin = sin.unsqueeze(0) |
|
|
if cos.dim() == 3: |
|
|
cos = cos.unsqueeze(1) |
|
|
sin = sin.unsqueeze(1) |
|
|
|
|
|
|
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
|
|
|
return q_embed, k_embed |
|
|
|
|
|
|
|
|
class GroupedQueryAttention(nn.Module): |
|
|
""" |
|
|
Grouped Query Attention (GQA) |
|
|
|
|
|
Memory-efficient attention where multiple query heads share KV heads. |
|
|
SmolLM2-135M uses 9 query heads and 3 KV heads (3:1 ratio). |
|
|
|
|
|
Benefits: |
|
|
- Reduces KV cache memory by 66% vs full MHA |
|
|
- Maintains most of multi-head attention's expressiveness |
|
|
- Used in Llama 2, Mistral, and other modern LLMs |
|
|
|
|
|
Architecture: |
|
|
- 9 query heads (each head_dim=64) |
|
|
- 3 KV heads (each head_dim=64) |
|
|
- Each KV head is repeated 3 times to serve 3 query heads |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
""" |
|
|
Args: |
|
|
config: Model configuration with attributes: |
|
|
- hidden_size: Model dimension (576) |
|
|
- num_attention_heads: Number of query heads (9) |
|
|
- num_key_value_heads: Number of KV heads (3) |
|
|
- max_position_embeddings: Max sequence length |
|
|
- rope_theta: RoPE base frequency |
|
|
""" |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.num_kv_heads = config.num_key_value_heads |
|
|
self.num_kv_groups = self.num_heads // self.num_kv_heads |
|
|
self.head_dim = self.hidden_size // self.num_heads |
|
|
|
|
|
assert self.hidden_size % self.num_heads == 0, "hidden_size must be divisible by num_heads" |
|
|
assert self.num_heads % self.num_kv_heads == 0, "num_heads must be divisible by num_kv_heads" |
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
|
|
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) |
|
|
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) |
|
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
|
|
|
|
|
|
self.rotary_emb = RotaryEmbedding( |
|
|
self.head_dim, |
|
|
max_position_embeddings=config.max_position_embeddings, |
|
|
base=config.rope_theta |
|
|
) |
|
|
|
|
|
def forward(self, hidden_states, attention_mask=None, position_ids=None): |
|
|
""" |
|
|
Forward pass of grouped query attention |
|
|
|
|
|
Args: |
|
|
hidden_states (torch.Tensor): Input [batch, seq_len, hidden_size] |
|
|
attention_mask (torch.Tensor, optional): Attention mask |
|
|
position_ids (torch.Tensor, optional): Position indices |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Output [batch, seq_len, hidden_size] |
|
|
""" |
|
|
batch_size, seq_len, _ = hidden_states.size() |
|
|
|
|
|
|
|
|
if position_ids is None: |
|
|
position_ids = torch.arange(seq_len, device=hidden_states.device) |
|
|
|
|
|
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
key_states = self.k_proj(hidden_states) |
|
|
value_states = self.v_proj(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
key_states = key_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
value_states = value_states.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
cos, sin = self.rotary_emb(value_states, position_ids) |
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
|
|
|
|
|
|
|
|
key_states = key_states.repeat_interleave(self.num_kv_groups, dim=1) |
|
|
value_states = value_states.repeat_interleave(self.num_kv_groups, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
attn_output = F.scaled_dot_product_attention( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attn_mask=attention_mask, |
|
|
dropout_p=0.0, |
|
|
is_causal=True |
|
|
) |
|
|
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
attn_output = attn_output.view(batch_size, seq_len, self.hidden_size) |
|
|
|
|
|
|
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
|
|
return attn_output |
|
|
|
|
|
|
|
|
class SwiGLU_FFN(nn.Module): |
|
|
""" |
|
|
SwiGLU Feed-Forward Network |
|
|
|
|
|
Uses Swish-Gated Linear Units instead of standard FFN. |
|
|
Formula: FFN(x) = down_proj(SiLU(gate_proj(x)) ⊙ up_proj(x)) |
|
|
|
|
|
Key differences from standard FFN: |
|
|
- 3 linear projections instead of 2 (gate, up, down) |
|
|
- Element-wise gating mechanism (⊙) |
|
|
- 50% more parameters but better performance |
|
|
- Used in Llama, PaLM, and most modern LLMs |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
""" |
|
|
Args: |
|
|
config: Model configuration with attributes: |
|
|
- hidden_size: Model dimension (576) |
|
|
- intermediate_size: FFN intermediate dimension (1536) |
|
|
""" |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.intermediate_size = config.intermediate_size |
|
|
|
|
|
|
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
|
|
|
|
|
|
|
self.act_fn = nn.SiLU() |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Forward pass: down(SiLU(gate) * up) |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Input [batch, seq_len, hidden_size] |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Output [batch, seq_len, hidden_size] |
|
|
""" |
|
|
|
|
|
gate = self.act_fn(self.gate_proj(x)) |
|
|
|
|
|
|
|
|
up = self.up_proj(x) |
|
|
|
|
|
|
|
|
gated = gate * up |
|
|
|
|
|
|
|
|
return self.down_proj(gated) |
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
""" |
|
|
Complete Transformer Block with Pre-Norm Architecture |
|
|
|
|
|
Architecture: |
|
|
1. x -> RMSNorm -> Attention -> Add residual |
|
|
2. x -> RMSNorm -> FFN -> Add residual |
|
|
|
|
|
Pre-norm (norm before sublayer) is standard in modern transformers |
|
|
as it provides better gradient flow in deep networks. |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
""" |
|
|
Args: |
|
|
config: Model configuration |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
|
|
|
self.self_attn = GroupedQueryAttention(config) |
|
|
|
|
|
|
|
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
|
|
|
self.mlp = SwiGLU_FFN(config) |
|
|
|
|
|
def forward(self, hidden_states, attention_mask=None, position_ids=None): |
|
|
""" |
|
|
Forward pass through transformer block |
|
|
|
|
|
Args: |
|
|
hidden_states (torch.Tensor): Input [batch, seq_len, hidden_size] |
|
|
attention_mask (torch.Tensor, optional): Attention mask |
|
|
position_ids (torch.Tensor, optional): Position indices |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Output [batch, seq_len, hidden_size] |
|
|
""" |
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
hidden_states = self.self_attn(hidden_states, attention_mask, position_ids) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
return hidden_states |