LikeGPT2small / Architecture.py
Zemulax's picture
files for inference
ea2eee0 verified
import torch
import torch.nn as nn
import math
class MaskedMultiHeadedSelfAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
# ── Fused QKV projection (one matmul instead of three) ──────────────
self.W_qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
batch_size, num_tokens, d_in = x.shape
# ── Single fused projection then split ──────────────────────────────
qkv = self.W_qkv(x) # (B, T, 3 * d_out)
q, k, v = qkv.split(self.d_out, dim=-1) # each: (B, T, d_out)
# Reshape to (B, num_heads, T, head_dim)
q = q.view(batch_size, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
# ── Flash Attention β€” no manual mask, is_causal handles it ──────────
context_vec = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=True
)
# Merge heads and project
context_vec = context_vec.transpose(1, 2).reshape(batch_size, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec)
return context_vec
class FeedForward(nn.Module):
def __init__(self, configuration):
super().__init__()
dim = configuration["embedding_dim"]
self.layers = nn.Sequential(
nn.Linear(dim, 4 * dim),
nn.GELU(approximate='tanh'), # fused kernel, faster than manual tanh approx
nn.Linear(4 * dim, dim),
)
def forward(self, x):
return self.layers(x)
class TransformerBlock(nn.Module):
def __init__(self, configuration):
super().__init__()
self.attention = MaskedMultiHeadedSelfAttention(
d_in=configuration["embedding_dim"],
d_out=configuration["embedding_dim"],
context_length=configuration["context_length"],
num_heads=configuration["n_heads"],
dropout=configuration["dropout_rate"],
qkv_bias=configuration["qkv_bias"]
)
self.feed_forward = FeedForward(configuration)
# ── nn.LayerNorm uses a fused CUDA kernel β€” faster than custom impl ──
self.norm1 = nn.LayerNorm(configuration["embedding_dim"])
self.norm2 = nn.LayerNorm(configuration["embedding_dim"])
self.drop_shortcut = nn.Dropout(configuration["dropout_rate"])
def forward(self, x):
# ── Attention block with residual ────────────────────────────────────
shortcut = x
x = self.norm1(x)
x = self.attention(x)
x = self.drop_shortcut(x)
x = x + shortcut
# ── Feed-forward block with residual ─────────────────────────────────
shortcut = x
x = self.norm2(x)
x = self.feed_forward(x)
x = self.drop_shortcut(x)
x = x + shortcut
return x
class LanguageModel(nn.Module):
def __init__(self, configuration):
super().__init__()
self.config = configuration
self.token_embedding = nn.Embedding(configuration["vocab_size"], configuration["embedding_dim"])
self.pos_embedding = nn.Embedding(configuration["context_length"], configuration["embedding_dim"])
self.drop_embedding = nn.Dropout(configuration["dropout_rate"])
self.transformer_blocks = nn.Sequential(
*[TransformerBlock(configuration) for _ in range(configuration["n_layers"])]
)
# ── Final norm also switched to nn.LayerNorm ─────────────────────────
self.final_norm = nn.LayerNorm(configuration["embedding_dim"])
self.out_head = nn.Linear(configuration["embedding_dim"], configuration["vocab_size"], bias=False)
# Weight tying β€” output projection shares weights with token embedding
self.out_head.weight = self.token_embedding.weight
# ── GPT-2 style weight initialization ────────────────────────────────
self.apply(self._init_weights)
# Scale residual-path projections by 1/√(2*n_layers) to prevent the
# residual stream from blowing up in deep networks. Targets:
# - attention out_proj (ends with 'out_proj.weight')
# - FFN second linear (ends with 'layers.2.weight')
residual_std = 0.02 / math.sqrt(2 * configuration["n_layers"])
for pn, p in self.named_parameters():
if pn.endswith('out_proj.weight') or pn.endswith('layers.2.weight'):
torch.nn.init.normal_(p, mean=0.0, std=residual_std)
def _init_weights(self, module):
"""Initialize weights following GPT-2: N(0, 0.02) for all projections."""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, in_idx):
batch_size, seq_len = in_idx.shape
tok_embeds = self.token_embedding(in_idx)
pos_embeds = self.pos_embedding(torch.arange(seq_len, device=in_idx.device))
x = tok_embeds + pos_embeds
x = self.drop_embedding(x)
x = self.transformer_blocks(x)
x = self.final_norm(x)
logits = self.out_head(x)
return logits