| """Core transformer model components for the mini LLM.""" |
|
|
| import math |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| def precompute_rope_freqs( |
| head_dim: int, |
| max_seq_len: int, |
| base: float = 10_000.0, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Precompute cosine and sine tables for rotary embeddings.""" |
| if head_dim % 2 != 0: |
| raise ValueError("RoPE requires an even head dimension.") |
|
|
| freqs = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) |
| positions = torch.arange(max_seq_len).float() |
| angles = torch.outer(positions, freqs) |
| return torch.cos(angles), torch.sin(angles) |
|
|
|
|
| def apply_rope( |
| x: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| ) -> torch.Tensor: |
| """Apply RoPE to a tensor shaped [batch, heads, seq_len, head_dim].""" |
| seq_len = x.shape[2] |
| cos = cos[:seq_len].unsqueeze(0).unsqueeze(0) |
| sin = sin[:seq_len].unsqueeze(0).unsqueeze(0) |
|
|
| x_even = x[..., ::2] |
| x_odd = x[..., 1::2] |
|
|
| rotated_even = x_even * cos - x_odd * sin |
| rotated_odd = x_even * sin + x_odd * cos |
| return torch.stack([rotated_even, rotated_odd], dim=-1).flatten(-2) |
|
|
|
|
| def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """Repeat KV heads so they can be shared across more query heads.""" |
| if n_rep == 1: |
| return x |
|
|
| batch, n_kv_heads, seq_len, head_dim = x.shape |
| return ( |
| x[:, :, None, :, :] |
| .expand(batch, n_kv_heads, n_rep, seq_len, head_dim) |
| .reshape(batch, n_kv_heads * n_rep, seq_len, head_dim) |
| ) |
|
|
|
|
| class RMSNorm(nn.Module): |
| """Root Mean Square normalization used by modern LLMs.""" |
|
|
| def __init__(self, dim: int, eps: float = 1e-6) -> None: |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) |
| return (x / rms) * self.weight |
|
|
|
|
| class SwiGLU(nn.Module): |
| """Feed-forward block with SiLU gating.""" |
|
|
| def __init__(self, d_model: int, hidden_dim: int, dropout: float = 0.2) -> None: |
| super().__init__() |
| self.dropout = dropout |
| self.w_gate = nn.Linear(d_model, hidden_dim, bias=False) |
| self.w_up = nn.Linear(d_model, hidden_dim, bias=False) |
| self.w_down = nn.Linear(hidden_dim, d_model, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| |
| gate = F.silu(self.w_gate(x)) |
| up = self.w_up(x) |
| out = self.w_down(gate * up) |
| return F.dropout(out, p=self.dropout, training=self.training) |
|
|
|
|
| class GroupedQueryAttention(nn.Module): |
| """Grouped-query self-attention with rotary embeddings and a causal mask.""" |
|
|
| def __init__( |
| self, |
| d_model: int, |
| n_heads: int, |
| n_kv_heads: int, |
| dropout: float = 0.2, |
| ) -> None: |
| super().__init__() |
| if d_model % n_heads != 0: |
| raise ValueError("d_model must be divisible by n_heads.") |
| if n_heads % n_kv_heads != 0: |
| raise ValueError("n_heads must be divisible by n_kv_heads.") |
|
|
| self.n_heads = n_heads |
| self.n_kv_heads = n_kv_heads |
| self.n_rep = n_heads // n_kv_heads |
| self.head_dim = d_model // n_heads |
| self.dropout = dropout |
|
|
| self.q_proj = nn.Linear(d_model, n_heads * self.head_dim, bias=False) |
| self.k_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(n_heads * self.head_dim, d_model, bias=False) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| rope_cos: torch.Tensor, |
| rope_sin: torch.Tensor, |
| ) -> torch.Tensor: |
| batch, seq_len, _ = x.shape |
|
|
| |
| |
| q = self.q_proj(x) |
| k = self.k_proj(x) |
| v = self.v_proj(x) |
|
|
| q = q.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2) |
| k = k.view(batch, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) |
| v = v.view(batch, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) |
|
|
| |
| |
| q = apply_rope(q, rope_cos, rope_sin) |
| k = apply_rope(k, rope_cos, rope_sin) |
|
|
| |
| k = repeat_kv(k, self.n_rep) |
| v = repeat_kv(v, self.n_rep) |
|
|
| scale = 1.0 / math.sqrt(self.head_dim) |
| scores = (q @ k.transpose(-2, -1)) * scale |
|
|
| mask = torch.triu( |
| torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), |
| diagonal=1, |
| ) |
| scores = scores.masked_fill(mask, float("-inf")) |
|
|
| |
| weights = F.softmax(scores, dim=-1) |
| weights = F.dropout(weights, p=self.dropout, training=self.training) |
|
|
| out = weights @ v |
| out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1) |
| return self.o_proj(out) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """Pre-norm transformer block with grouped-query attention and SwiGLU.""" |
|
|
| def __init__( |
| self, |
| d_model: int, |
| n_heads: int, |
| n_kv_heads: int, |
| ffn_hidden_dim: int, |
| dropout: float = 0.2, |
| ) -> None: |
| super().__init__() |
| |
| |
| |
| self.attn_norm = RMSNorm(d_model) |
| self.attention = GroupedQueryAttention( |
| d_model=d_model, |
| n_heads=n_heads, |
| n_kv_heads=n_kv_heads, |
| dropout=dropout, |
| ) |
| self.ffn_norm = RMSNorm(d_model) |
| self.ffn = SwiGLU(d_model=d_model, hidden_dim=ffn_hidden_dim, dropout=dropout) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| rope_cos: torch.Tensor, |
| rope_sin: torch.Tensor, |
| ) -> torch.Tensor: |
| |
| |
| x = x + self.attention(self.attn_norm(x), rope_cos, rope_sin) |
| x = x + self.ffn(self.ffn_norm(x)) |
| return x |
|
|
|
|
| class MiniLLM(nn.Module): |
| """Small modern language model built from the extracted notebook blocks.""" |
|
|
| def __init__( |
| self, |
| vocab_size: int, |
| d_model: int, |
| n_layers: int, |
| n_heads: int, |
| n_kv_heads: int, |
| ffn_hidden_dim: int, |
| max_seq_len: int, |
| dropout: float = 0.2, |
| ) -> None: |
| super().__init__() |
| if d_model % n_heads != 0: |
| raise ValueError("d_model must be divisible by n_heads.") |
|
|
| self.d_model = d_model |
| self.max_seq_len = max_seq_len |
|
|
| |
| |
| |
| |
| self.token_emb = nn.Embedding(vocab_size, d_model) |
| self.layers = nn.ModuleList( |
| [ |
| TransformerBlock( |
| d_model=d_model, |
| n_heads=n_heads, |
| n_kv_heads=n_kv_heads, |
| ffn_hidden_dim=ffn_hidden_dim, |
| dropout=dropout, |
| ) |
| for _ in range(n_layers) |
| ] |
| ) |
| self.final_norm = RMSNorm(d_model) |
| self.lm_head = nn.Linear(d_model, vocab_size, bias=False) |
|
|
| |
| self.lm_head.weight = self.token_emb.weight |
|
|
| head_dim = d_model // n_heads |
| rope_cos, rope_sin = precompute_rope_freqs(head_dim, max_seq_len) |
| self.register_buffer("rope_cos", rope_cos) |
| self.register_buffer("rope_sin", rope_sin) |
|
|
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module: nn.Module) -> None: |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def forward( |
| self, |
| idx: torch.Tensor, |
| targets: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| _, seq_len = idx.shape |
| if seq_len > self.max_seq_len: |
| raise ValueError( |
| f"Sequence length {seq_len} exceeds max_seq_len={self.max_seq_len}." |
| ) |
|
|
| |
| x = self.token_emb(idx) |
|
|
| for layer in self.layers: |
| x = layer(x, self.rope_cos, self.rope_sin) |
|
|
| |
| logits = self.lm_head(self.final_norm(x)) |
|
|
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| targets.view(-1), |
| ) |
|
|
| return logits, loss |
|
|