""" Component 4: Transformer model architecture for code generation. This module defines a decoder-only transformer built from scratch in PyTorch. It is modular through configuration so model size can be scaled up/down. """ from __future__ import annotations import math from dataclasses import asdict, dataclass from typing import Dict, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F @dataclass class ModelConfig: # Vocabulary size from tokenizer. vocab_size: int = 50_000 # Maximum context length in tokens. max_seq_len: int = 2048 # Core hidden size of transformer. d_model: int = 1152 # Number of transformer blocks. n_layers: int = 23 # Number of attention heads. n_heads: int = 16 # Feed-forward hidden size. d_ff: int = 4608 # Dropout for regularization. dropout: float = 0.1 # Whether to tie token embedding and LM head weights. tie_embeddings: bool = True # Enable gradient checkpointing to reduce VRAM usage during training. gradient_checkpointing: bool = False # Initialization standard deviation. init_std: float = 0.02 # Epsilon for layer normalization stability. rms_norm_eps: float = 1e-5 @property def head_dim(self) -> int: if self.d_model % self.n_heads != 0: raise ValueError("d_model must be divisible by n_heads.") return self.d_model // self.n_heads def get_model_presets() -> Dict[str, ModelConfig]: """ Returns standard size presets. """ return { "small_180m": ModelConfig(d_model=896, n_layers=18, n_heads=14, d_ff=3584), "medium_420m": ModelConfig(d_model=1152, n_layers=23, n_heads=16, d_ff=4608), "large_800m": ModelConfig(d_model=1536, n_layers=24, n_heads=16, d_ff=6144), } class RMSNorm(nn.Module): """ RMSNorm is a lightweight normalization layer used in many modern LLMs. """ def __init__(self, dim: int, eps: float = 1e-5) -> None: super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: norm = x.pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(norm + self.eps) return self.weight * x class RotaryEmbedding(nn.Module): """ Rotary positional embedding. This injects token order information directly into query/key vectors. """ def __init__(self, head_dim: int, max_seq_len: int) -> None: super().__init__() if head_dim % 2 != 0: raise ValueError("head_dim must be even for rotary embeddings.") inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim)) t = torch.arange(max_seq_len, dtype=torch.float32) freqs = torch.outer(t, inv_freq) self.register_buffer("cos_cached", torch.cos(freqs), persistent=False) self.register_buffer("sin_cached", torch.sin(freqs), persistent=False) def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0) # [1,1,S,H/2] sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0) # [1,1,S,H/2] q = self._apply_rotary(q, cos, sin) k = self._apply_rotary(k, cos, sin) return q, k @staticmethod def _apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: x1 = x[..., ::2] x2 = x[..., 1::2] x_rot_even = x1 * cos - x2 * sin x_rot_odd = x1 * sin + x2 * cos out = torch.stack((x_rot_even, x_rot_odd), dim=-1).flatten(-2) return out class CausalSelfAttention(nn.Module): """ Multi-head causal self-attention for autoregressive code generation. """ def __init__(self, config: ModelConfig) -> None: super().__init__() self.n_heads = config.n_heads self.head_dim = config.head_dim self.scale = self.head_dim ** -0.5 self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.dropout = nn.Dropout(config.dropout) self.rotary = RotaryEmbedding(head_dim=self.head_dim, max_seq_len=config.max_seq_len) def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: bsz, seq_len, _ = x.shape q = self.q_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) q, k = self.rotary(q, k, seq_len=seq_len) # Use PyTorch scaled dot-product attention with causal masking. out = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.dropout.p if self.training else 0.0, is_causal=True, scale=self.scale, ) out = out.transpose(1, 2).contiguous().view(bsz, seq_len, -1) return self.o_proj(out) class FeedForward(nn.Module): """ Two-layer feed-forward network with GELU activation. """ def __init__(self, config: ModelConfig) -> None: super().__init__() self.fc1 = nn.Linear(config.d_model, config.d_ff, bias=False) self.fc2 = nn.Linear(config.d_ff, config.d_model, bias=False) self.dropout = nn.Dropout(config.dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = F.gelu(x, approximate="tanh") x = self.fc2(x) x = self.dropout(x) return x class TransformerBlock(nn.Module): """ One transformer block: norm -> attention -> residual norm -> feed-forward -> residual """ def __init__(self, config: ModelConfig) -> None: super().__init__() self.norm1 = RMSNorm(config.d_model, eps=config.rms_norm_eps) self.attn = CausalSelfAttention(config) self.norm2 = RMSNorm(config.d_model, eps=config.rms_norm_eps) self.ffn = FeedForward(config) def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: x = x + self.attn(self.norm1(x), attn_mask=attn_mask) x = x + self.ffn(self.norm2(x)) return x class CodeTransformerLM(nn.Module): """ Full decoder-only language model for code generation. """ def __init__(self, config: ModelConfig) -> None: super().__init__() self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) self.dropout = nn.Dropout(config.dropout) self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) self.norm_final = RMSNorm(config.d_model, eps=config.rms_norm_eps) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) if config.tie_embeddings: self.lm_head.weight = self.embed_tokens.weight self.apply(self._init_weights) def _init_weights(self, module: nn.Module) -> None: # Keep initialization stable for deep networks. if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) def enable_gradient_checkpointing(self, enabled: bool = True) -> None: # Toggle gradient checkpointing mode. self.config.gradient_checkpointing = enabled def forward( self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: if input_ids.dim() != 2: raise ValueError("input_ids must be shape [batch, seq_len].") x = self.embed_tokens(input_ids) x = self.dropout(x) for block in self.blocks: if self.config.gradient_checkpointing and self.training: x = torch.utils.checkpoint.checkpoint(block, x, attn_mask, use_reentrant=False) else: x = block(x, attn_mask=attn_mask) x = self.norm_final(x) logits = self.lm_head(x) out: Dict[str, torch.Tensor] = {"logits": logits} if labels is not None: # Standard next-token cross entropy loss. shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) out["loss"] = loss return out def estimate_num_parameters(self) -> int: # Returns total trainable parameter count. return sum(p.numel() for p in self.parameters() if p.requires_grad) def summary(self) -> Dict[str, object]: # Returns a simple structured summary for logs/CLI. return { "config": asdict(self.config), "num_parameters": self.estimate_num_parameters(), }