""" GPT-300M Model Architecture ============================ A decoder-only transformer built entirely from scratch in PyTorch. Architecture features: - Pre-LayerNorm transformer blocks - Rotary Position Embeddings (RoPE) - Multi-Head Self-Attention with causal masking - GELU activation in feed-forward layers - Optional weight tying (token embeddings ↔ LM head) - KV-Cache for efficient autoregressive generation - Flash Attention support (PyTorch 2.0+) """ import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from config import GPT300MConfig # ═══════════════════════════════════════════════════════════════════════ # ROTARY POSITION EMBEDDINGS (RoPE) # ═══════════════════════════════════════════════════════════════════════ class RotaryEmbedding(nn.Module): """Rotary Position Embedding (Su et al., 2021).""" def __init__(self, dim: int, max_seq_len: int = 2048, theta: float = 10000.0): super().__init__() inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Precompute cos/sin tables t = torch.arange(max_seq_len, dtype=torch.float32) freqs = torch.outer(t, inv_freq) emb = torch.cat([freqs, freqs], dim=-1) self.register_buffer("cos_cached", emb.cos(), persistent=False) self.register_buffer("sin_cached", emb.sin(), persistent=False) def forward(self, seq_len: int, offset: int = 0): return ( self.cos_cached[offset : offset + seq_len], self.sin_cached[offset : offset + seq_len], ) def rotate_half(x: torch.Tensor) -> torch.Tensor: """Rotate the second half of the last dimension.""" x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def apply_rotary_emb( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply rotary embeddings to query and key tensors.""" # cos/sin shape: [seq_len, head_dim] → [1, 1, seq_len, head_dim] cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) q_rot = q * cos + rotate_half(q) * sin k_rot = k * cos + rotate_half(k) * sin return q_rot, k_rot # ═══════════════════════════════════════════════════════════════════════ # RMSNORM (faster alternative to LayerNorm) # ═══════════════════════════════════════════════════════════════════════ class RMSNorm(nn.Module): """Root Mean Square Layer Normalization.""" def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() return (x.float() * norm).type_as(x) * self.weight # ═══════════════════════════════════════════════════════════════════════ # MULTI-HEAD SELF-ATTENTION # ═══════════════════════════════════════════════════════════════════════ class MultiHeadAttention(nn.Module): """Multi-Head Self-Attention with causal masking and optional KV-cache.""" def __init__(self, config: GPT300MConfig): super().__init__() self.n_heads = config.n_heads self.head_dim = config.head_dim self.d_model = config.d_model self.dropout = config.dropout # Q, K, V projections (fused for efficiency) self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias) # Output projection self.out_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias) self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout) # Check for Flash Attention support self.flash_attn = hasattr(F, "scaled_dot_product_attention") def forward( self, x: torch.Tensor, cos: Optional[torch.Tensor] = None, sin: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: B, T, C = x.shape # Project to Q, K, V qkv = self.qkv_proj(x) q, k, v = qkv.split(self.d_model, dim=-1) # Reshape: [B, T, n_heads, head_dim] → [B, n_heads, T, head_dim] q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # Apply RoPE if cos is not None and sin is not None: q, k = apply_rotary_emb(q, k, cos, sin) # KV-Cache for generation if kv_cache is not None: k_prev, v_prev = kv_cache k = torch.cat([k_prev, k], dim=2) v = torch.cat([v_prev, v], dim=2) new_cache = (k, v) if use_cache else None # Attention if self.flash_attn and not use_cache: # Use PyTorch's efficient SDPA attn_out = F.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=self.dropout if self.training else 0.0, is_causal=True if mask is None else False, ) else: # Manual attention for compatibility / KV-cache scale = 1.0 / math.sqrt(self.head_dim) scores = torch.matmul(q, k.transpose(-2, -1)) * scale if mask is not None: scores = scores.masked_fill(mask == 0, float("-inf")) else: # Causal mask T_q, T_k = q.size(2), k.size(2) causal = torch.tril(torch.ones(T_q, T_k, device=x.device, dtype=torch.bool)) # For KV-cache, the causal mask must align with key length causal = causal[-T:, :] # last T rows scores = scores.masked_fill(~causal.unsqueeze(0).unsqueeze(0), float("-inf")) attn_weights = F.softmax(scores, dim=-1) attn_weights = self.attn_dropout(attn_weights) attn_out = torch.matmul(attn_weights, v) # Reshape back and project attn_out = attn_out.transpose(1, 2).contiguous().view(B, -1, self.d_model) out = self.resid_dropout(self.out_proj(attn_out)) return out, new_cache # ═══════════════════════════════════════════════════════════════════════ # FEED-FORWARD NETWORK # ═══════════════════════════════════════════════════════════════════════ class FeedForward(nn.Module): """Position-wise Feed-Forward Network with GELU activation.""" def __init__(self, config: GPT300MConfig): super().__init__() self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=config.bias) self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=config.bias) self.dropout = nn.Dropout(config.dropout) if config.activation == "gelu": self.act = nn.GELU() elif config.activation == "swiglu": self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=config.bias) self.act = nn.SiLU() else: raise ValueError(f"Unknown activation: {config.activation}") self.use_swiglu = config.activation == "swiglu" def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_swiglu: return self.dropout(self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))) else: return self.dropout(self.down_proj(self.act(self.up_proj(x)))) # ═══════════════════════════════════════════════════════════════════════ # TRANSFORMER BLOCK # ═══════════════════════════════════════════════════════════════════════ class TransformerBlock(nn.Module): """Pre-norm Transformer block: LayerNorm → Attention → Residual → LayerNorm → FFN → Residual.""" def __init__(self, config: GPT300MConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.ln1 = RMSNorm(config.d_model, eps=config.norm_eps) self.attn = MultiHeadAttention(config) self.ln2 = RMSNorm(config.d_model, eps=config.norm_eps) self.ffn = FeedForward(config) def forward( self, x: torch.Tensor, cos: Optional[torch.Tensor] = None, sin: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Pre-norm attention with residual residual = x x = self.ln1(x) attn_out, new_cache = self.attn(x, cos, sin, mask, kv_cache, use_cache) x = residual + attn_out # Pre-norm FFN with residual x = x + self.ffn(self.ln2(x)) return x, new_cache # ═══════════════════════════════════════════════════════════════════════ # GPT-300M: THE FULL MODEL # ═══════════════════════════════════════════════════════════════════════ class GPT300M(nn.Module): """ GPT-300M: A 300-million parameter autoregressive language model. Architecture: Token Embedding → [Transformer Block × 24] → RMSNorm → LM Head Each Transformer Block: RMSNorm → Multi-Head Attention (+ RoPE) → Residual → RMSNorm → Feed-Forward (GELU) → Residual """ def __init__(self, config: GPT300MConfig): super().__init__() self.config = config # ── Embeddings ─────────────────────────────────────────────── self.token_emb = nn.Embedding(config.vocab_size, config.d_model) self.drop = nn.Dropout(config.dropout) # Rotary embeddings if config.rope: self.rotary = RotaryEmbedding( config.head_dim, config.max_seq_len, config.rope_theta ) else: self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model) # ── Transformer Blocks ─────────────────────────────────────── self.layers = nn.ModuleList([ TransformerBlock(config, layer_idx=i) for i in range(config.n_layers) ]) # ── Output ─────────────────────────────────────────────────── self.ln_f = RMSNorm(config.d_model, eps=config.norm_eps) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) # Weight tying if config.tie_weights: self.lm_head.weight = self.token_emb.weight # Initialize weights self.apply(self._init_weights) # Scale residual projections for pn, p in self.named_parameters(): if pn.endswith("out_proj.weight") or pn.endswith("down_proj.weight"): nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layers)) def _init_weights(self, module: nn.Module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward( self, input_ids: torch.Tensor, targets: Optional[torch.Tensor] = None, kv_caches: Optional[list] = None, use_cache: bool = False, position_offset: int = 0, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[list]]: """ Forward pass. Args: input_ids: [B, T] token indices targets: [B, T] target token indices for loss computation kv_caches: List of KV-cache tuples, one per layer use_cache: Whether to return updated KV-caches position_offset: Offset for position embeddings (for KV-cache generation) Returns: logits: [B, T, vocab_size] loss: scalar loss if targets provided, else None new_caches: Updated KV-caches if use_cache=True """ B, T = input_ids.shape assert T <= self.config.max_seq_len, ( f"Sequence length {T} exceeds max {self.config.max_seq_len}" ) # Token embeddings x = self.token_emb(input_ids) # [B, T, d_model] # Position information if self.config.rope: cos, sin = self.rotary(T, offset=position_offset) else: positions = torch.arange(position_offset, position_offset + T, device=input_ids.device) x = x + self.pos_emb(positions) cos, sin = None, None x = self.drop(x) # Transformer blocks new_caches = [] if use_cache else None for i, layer in enumerate(self.layers): cache_i = kv_caches[i] if kv_caches is not None else None x, new_cache = layer(x, cos, sin, kv_cache=cache_i, use_cache=use_cache) if use_cache: new_caches.append(new_cache) # Final norm and LM head x = self.ln_f(x) logits = self.lm_head(x) # [B, T, vocab_size] # Loss loss = None if targets is not None: loss = F.cross_entropy( logits.view(-1, self.config.vocab_size), targets.view(-1), ignore_index=self.config.pad_token_id, ) return logits, loss, new_caches @torch.no_grad() def generate( self, input_ids: torch.Tensor, max_new_tokens: int = 256, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.9, repetition_penalty: float = 1.1, eos_token_id: Optional[int] = None, ) -> torch.Tensor: """ Autoregressive generation with KV-cache. Args: input_ids: [B, T] prompt token IDs max_new_tokens: Maximum number of tokens to generate temperature: Sampling temperature top_k: Top-k sampling top_p: Nucleus sampling threshold repetition_penalty: Penalty for repeated tokens eos_token_id: Stop generation when this token is produced Returns: [B, T + max_new_tokens] generated token IDs """ self.eval() B, T = input_ids.shape device = input_ids.device # Initial forward pass to populate KV-cache logits, _, kv_caches = self.forward(input_ids, use_cache=True) generated = input_ids all_token_ids = input_ids.tolist()[0] if B == 1 else [] for step in range(max_new_tokens): # Get logits for the last token next_logits = logits[:, -1, :] # [B, vocab_size] # Repetition penalty if repetition_penalty != 1.0 and B == 1: for token_id in set(all_token_ids): if next_logits[0, token_id] > 0: next_logits[0, token_id] /= repetition_penalty else: next_logits[0, token_id] *= repetition_penalty # Temperature if temperature > 0: next_logits = next_logits / temperature # Top-k filtering if top_k > 0: topk_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1))) next_logits[next_logits < topk_vals[:, -1:]] = float("-inf") # Top-p (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(next_logits, descending=True) cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_mask = cumprobs - F.softmax(sorted_logits, dim=-1) >= top_p sorted_logits[sorted_mask] = float("-inf") next_logits = sorted_logits.scatter(1, sorted_idx, sorted_logits) probs = F.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: # Greedy next_token = next_logits.argmax(dim=-1, keepdim=True) generated = torch.cat([generated, next_token], dim=1) if B == 1: all_token_ids.append(next_token.item()) # Stop on EOS if eos_token_id is not None and next_token.item() == eos_token_id: break # Forward pass with KV-cache (only the new token) position_offset = generated.size(1) - 1 logits, _, kv_caches = self.forward( next_token, kv_caches=kv_caches, use_cache=True, position_offset=position_offset, ) return generated def count_parameters(self, trainable_only: bool = True) -> int: """Count model parameters.""" if trainable_only: return sum(p.numel() for p in self.parameters() if p.requires_grad) return sum(p.numel() for p in self.parameters()) def model_summary(self) -> str: """Print a human-readable model summary.""" total = self.count_parameters(trainable_only=False) trainable = self.count_parameters(trainable_only=True) lines = [ "=" * 60, " GPT-300M Model Summary", "=" * 60, f" Total parameters: {total:>15,}", f" Trainable parameters: {trainable:>15,}", f" d_model: {self.config.d_model:>15}", f" n_heads: {self.config.n_heads:>15}", f" n_layers: {self.config.n_layers:>15}", f" d_ff: {self.config.d_ff:>15}", f" vocab_size: {self.config.vocab_size:>15}", f" max_seq_len: {self.config.max_seq_len:>15}", f" RoPE: {'Yes':>15}", f" Weight tying: {'Yes' if self.config.tie_weights else 'No':>15}", f" Flash Attention: {'Yes' if self.layers[0].attn.flash_attn else 'No':>15}", "=" * 60, ] return "\n".join(lines) # ═══════════════════════════════════════════════════════════════════════ # QUICK TEST # ═══════════════════════════════════════════════════════════════════════ if __name__ == "__main__": from config import gpt_tiny # Use tiny config for testing cfg = gpt_tiny() model = GPT300M(cfg) print(model.model_summary()) # Test forward pass x = torch.randint(0, cfg.vocab_size, (2, 32)) targets = torch.randint(0, cfg.vocab_size, (2, 32)) logits, loss, _ = model(x, targets=targets) print(f"\nForward pass OK: logits={logits.shape}, loss={loss.item():.4f}") # Test generation prompt = torch.randint(0, cfg.vocab_size, (1, 8)) gen = model.generate(prompt, max_new_tokens=16, temperature=0.8) print(f"Generation OK: {gen.shape}")