| """ |
| Component 4: Transformer model architecture for code generation. |
| |
| This module defines a decoder-only transformer built from scratch in PyTorch. |
| It is modular through configuration so model size can be scaled up/down. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import asdict, dataclass |
| from typing import Dict, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| |
| vocab_size: int = 50_000 |
| |
| max_seq_len: int = 2048 |
| |
| d_model: int = 1152 |
| |
| n_layers: int = 23 |
| |
| n_heads: int = 16 |
| |
| d_ff: int = 4608 |
| |
| dropout: float = 0.1 |
| |
| tie_embeddings: bool = True |
| |
| gradient_checkpointing: bool = False |
| |
| init_std: float = 0.02 |
| |
| rms_norm_eps: float = 1e-5 |
|
|
| @property |
| def head_dim(self) -> int: |
| if self.d_model % self.n_heads != 0: |
| raise ValueError("d_model must be divisible by n_heads.") |
| return self.d_model // self.n_heads |
|
|
|
|
| def get_model_presets() -> Dict[str, ModelConfig]: |
| """ |
| Returns standard size presets. |
| """ |
| return { |
| "small_180m": ModelConfig(d_model=896, n_layers=18, n_heads=14, d_ff=3584), |
| "medium_420m": ModelConfig(d_model=1152, n_layers=23, n_heads=16, d_ff=4608), |
| "large_800m": ModelConfig(d_model=1536, n_layers=24, n_heads=16, d_ff=6144), |
| } |
|
|
|
|
| class RMSNorm(nn.Module): |
| """ |
| RMSNorm is a lightweight normalization layer used in many modern LLMs. |
| """ |
|
|
| def __init__(self, dim: int, eps: float = 1e-5) -> None: |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| norm = x.pow(2).mean(dim=-1, keepdim=True) |
| x = x * torch.rsqrt(norm + self.eps) |
| return self.weight * x |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| """ |
| Rotary positional embedding. |
| This injects token order information directly into query/key vectors. |
| """ |
|
|
| def __init__(self, head_dim: int, max_seq_len: int) -> None: |
| super().__init__() |
| if head_dim % 2 != 0: |
| raise ValueError("head_dim must be even for rotary embeddings.") |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim)) |
| t = torch.arange(max_seq_len, dtype=torch.float32) |
| freqs = torch.outer(t, inv_freq) |
| self.register_buffer("cos_cached", torch.cos(freqs), persistent=False) |
| self.register_buffer("sin_cached", torch.sin(freqs), persistent=False) |
|
|
| def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0) |
| sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0) |
| q = self._apply_rotary(q, cos, sin) |
| k = self._apply_rotary(k, cos, sin) |
| return q, k |
|
|
| @staticmethod |
| def _apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| x1 = x[..., ::2] |
| x2 = x[..., 1::2] |
| x_rot_even = x1 * cos - x2 * sin |
| x_rot_odd = x1 * sin + x2 * cos |
| out = torch.stack((x_rot_even, x_rot_odd), dim=-1).flatten(-2) |
| return out |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| """ |
| Multi-head causal self-attention for autoregressive code generation. |
| """ |
|
|
| def __init__(self, config: ModelConfig) -> None: |
| super().__init__() |
| self.n_heads = config.n_heads |
| self.head_dim = config.head_dim |
| self.scale = self.head_dim ** -0.5 |
|
|
| self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False) |
| self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False) |
| self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False) |
| self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False) |
| self.dropout = nn.Dropout(config.dropout) |
| self.rotary = RotaryEmbedding(head_dim=self.head_dim, max_seq_len=config.max_seq_len) |
|
|
| def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| bsz, seq_len, _ = x.shape |
| q = self.q_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
| q, k = self.rotary(q, k, seq_len=seq_len) |
|
|
| |
| out = F.scaled_dot_product_attention( |
| q, |
| k, |
| v, |
| attn_mask=attn_mask, |
| dropout_p=self.dropout.p if self.training else 0.0, |
| is_causal=True, |
| scale=self.scale, |
| ) |
| out = out.transpose(1, 2).contiguous().view(bsz, seq_len, -1) |
| return self.o_proj(out) |
|
|
|
|
| class FeedForward(nn.Module): |
| """ |
| Two-layer feed-forward network with GELU activation. |
| """ |
|
|
| def __init__(self, config: ModelConfig) -> None: |
| super().__init__() |
| self.fc1 = nn.Linear(config.d_model, config.d_ff, bias=False) |
| self.fc2 = nn.Linear(config.d_ff, config.d_model, bias=False) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.fc1(x) |
| x = F.gelu(x, approximate="tanh") |
| x = self.fc2(x) |
| x = self.dropout(x) |
| return x |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """ |
| One transformer block: |
| norm -> attention -> residual |
| norm -> feed-forward -> residual |
| """ |
|
|
| def __init__(self, config: ModelConfig) -> None: |
| super().__init__() |
| self.norm1 = RMSNorm(config.d_model, eps=config.rms_norm_eps) |
| self.attn = CausalSelfAttention(config) |
| self.norm2 = RMSNorm(config.d_model, eps=config.rms_norm_eps) |
| self.ffn = FeedForward(config) |
|
|
| def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| x = x + self.attn(self.norm1(x), attn_mask=attn_mask) |
| x = x + self.ffn(self.norm2(x)) |
| return x |
|
|
|
|
| class CodeTransformerLM(nn.Module): |
| """ |
| Full decoder-only language model for code generation. |
| """ |
|
|
| def __init__(self, config: ModelConfig) -> None: |
| super().__init__() |
| self.config = config |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) |
| self.dropout = nn.Dropout(config.dropout) |
| self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) |
| self.norm_final = RMSNorm(config.d_model, eps=config.rms_norm_eps) |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
| if config.tie_embeddings: |
| self.lm_head.weight = self.embed_tokens.weight |
|
|
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module: nn.Module) -> None: |
| |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) |
|
|
| def enable_gradient_checkpointing(self, enabled: bool = True) -> None: |
| |
| self.config.gradient_checkpointing = enabled |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| labels: Optional[torch.Tensor] = None, |
| attn_mask: Optional[torch.Tensor] = None, |
| ) -> Dict[str, torch.Tensor]: |
| if input_ids.dim() != 2: |
| raise ValueError("input_ids must be shape [batch, seq_len].") |
|
|
| x = self.embed_tokens(input_ids) |
| x = self.dropout(x) |
|
|
| for block in self.blocks: |
| if self.config.gradient_checkpointing and self.training: |
| x = torch.utils.checkpoint.checkpoint(block, x, attn_mask, use_reentrant=False) |
| else: |
| x = block(x, attn_mask=attn_mask) |
|
|
| x = self.norm_final(x) |
| logits = self.lm_head(x) |
|
|
| out: Dict[str, torch.Tensor] = {"logits": logits} |
| if labels is not None: |
| |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = labels[:, 1:].contiguous() |
| loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=-100, |
| ) |
| out["loss"] = loss |
| return out |
|
|
| def estimate_num_parameters(self) -> int: |
| |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
| def summary(self) -> Dict[str, object]: |
| |
| return { |
| "config": asdict(self.config), |
| "num_parameters": self.estimate_num_parameters(), |
| } |
|
|
|
|