import torch import torch.nn as nn import math class MaskedMultiHeadedSelfAttention(nn.Module): def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False): super().__init__() assert (d_out % num_heads == 0), "d_out must be divisible by num_heads" self.d_out = d_out self.num_heads = num_heads self.head_dim = d_out // num_heads # ── Fused QKV projection (one matmul instead of three) ────────────── self.W_qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias) self.out_proj = nn.Linear(d_out, d_out, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x): batch_size, num_tokens, d_in = x.shape # ── Single fused projection then split ────────────────────────────── qkv = self.W_qkv(x) # (B, T, 3 * d_out) q, k, v = qkv.split(self.d_out, dim=-1) # each: (B, T, d_out) # Reshape to (B, num_heads, T, head_dim) q = q.view(batch_size, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(batch_size, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(batch_size, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) # ── Flash Attention — no manual mask, is_causal handles it ────────── context_vec = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=self.dropout.p if self.training else 0.0, is_causal=True ) # Merge heads and project context_vec = context_vec.transpose(1, 2).reshape(batch_size, num_tokens, self.d_out) context_vec = self.out_proj(context_vec) return context_vec class FeedForward(nn.Module): def __init__(self, configuration): super().__init__() dim = configuration["embedding_dim"] self.layers = nn.Sequential( nn.Linear(dim, 4 * dim), nn.GELU(approximate='tanh'), # fused kernel, faster than manual tanh approx nn.Linear(4 * dim, dim), ) def forward(self, x): return self.layers(x) class TransformerBlock(nn.Module): def __init__(self, configuration): super().__init__() self.attention = MaskedMultiHeadedSelfAttention( d_in=configuration["embedding_dim"], d_out=configuration["embedding_dim"], context_length=configuration["context_length"], num_heads=configuration["n_heads"], dropout=configuration["dropout_rate"], qkv_bias=configuration["qkv_bias"] ) self.feed_forward = FeedForward(configuration) # ── nn.LayerNorm uses a fused CUDA kernel — faster than custom impl ── self.norm1 = nn.LayerNorm(configuration["embedding_dim"]) self.norm2 = nn.LayerNorm(configuration["embedding_dim"]) self.drop_shortcut = nn.Dropout(configuration["dropout_rate"]) def forward(self, x): # ── Attention block with residual ──────────────────────────────────── shortcut = x x = self.norm1(x) x = self.attention(x) x = self.drop_shortcut(x) x = x + shortcut # ── Feed-forward block with residual ───────────────────────────────── shortcut = x x = self.norm2(x) x = self.feed_forward(x) x = self.drop_shortcut(x) x = x + shortcut return x class LanguageModel(nn.Module): def __init__(self, configuration): super().__init__() self.config = configuration self.token_embedding = nn.Embedding(configuration["vocab_size"], configuration["embedding_dim"]) self.pos_embedding = nn.Embedding(configuration["context_length"], configuration["embedding_dim"]) self.drop_embedding = nn.Dropout(configuration["dropout_rate"]) self.transformer_blocks = nn.Sequential( *[TransformerBlock(configuration) for _ in range(configuration["n_layers"])] ) # ── Final norm also switched to nn.LayerNorm ───────────────────────── self.final_norm = nn.LayerNorm(configuration["embedding_dim"]) self.out_head = nn.Linear(configuration["embedding_dim"], configuration["vocab_size"], bias=False) # Weight tying — output projection shares weights with token embedding self.out_head.weight = self.token_embedding.weight # ── GPT-2 style weight initialization ──────────────────────────────── self.apply(self._init_weights) # Scale residual-path projections by 1/√(2*n_layers) to prevent the # residual stream from blowing up in deep networks. Targets: # - attention out_proj (ends with 'out_proj.weight') # - FFN second linear (ends with 'layers.2.weight') residual_std = 0.02 / math.sqrt(2 * configuration["n_layers"]) for pn, p in self.named_parameters(): if pn.endswith('out_proj.weight') or pn.endswith('layers.2.weight'): torch.nn.init.normal_(p, mean=0.0, std=residual_std) def _init_weights(self, module): """Initialize weights following GPT-2: N(0, 0.02) for all projections.""" if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, in_idx): batch_size, seq_len = in_idx.shape tok_embeds = self.token_embedding(in_idx) pos_embeds = self.pos_embedding(torch.arange(seq_len, device=in_idx.device)) x = tok_embeds + pos_embeds x = self.drop_embedding(x) x = self.transformer_blocks(x) x = self.final_norm(x) logits = self.out_head(x) return logits