Scholar-Sage / model /transformer_explained.py
TheCodeKat's picture
Update model/transformer_explained.py
5129187 verified
# model/transformer_explained.py
"""
Tiny Transformer language model (educational).
Components:
- PositionalEncoding: sinusoidal positional encodings (buffered)
- MultiHeadSelfAttention: returns attn weights optionally
- FeedForward: MLP with GELU
- TransformerBlock: attention + add&norm + FFN + add&norm
- TinyTransformerLM: token embeddings, pos enc, stacked blocks, LM head
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class PositionalEncoding(nn.Module):
"""Sinusoidal positional encoding as in "Attention is All You Need".
Stored as a buffer (not learned). Adds positional encodings to token embeddings.
"""
def __init__(self, d_model: int, max_len: int = 2048):
super().__init__()
pe = torch.zeros(max_len, d_model) # (max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (max_len, 1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
) # (d_model/2,)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer("pe", pe) # not a parameter
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (batch, seq_len, d_model)
returns: x + pe[:, :seq_len, :]
"""
seq_len = x.size(1)
return x + self.pe[:, :seq_len, :].to(x.device)
class MultiHeadSelfAttention(nn.Module):
"""
Multi-head self-attention.
Optionally returns attention weights for visualization.
Input shapes:
x: (batch, seq_len, d_model)
Output:
out: (batch, seq_len, d_model)
Optional:
attn: (batch, num_heads, seq_len, seq_len)
"""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# single linear for qkv then split
self.qkv_proj = nn.Linear(d_model, d_model * 3, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.attn_dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)
def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_attn: bool = False
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
x: (batch, seq_len, d_model)
mask: (batch, 1, seq_len, seq_len) or (batch, seq_len) causal mask etc.
return_attn: if True, also return attention weights
"""
B, S, D = x.shape
# project and split into q,k,v
qkv = self.qkv_proj(x) # (B, S, 3*D)
qkv = qkv.view(B, S, 3, self.num_heads, self.d_k)
q, k, v = qkv.unbind(dim=2) # each: (B, S, num_heads, d_k)
# transpose to (B, num_heads, S, d_k)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# scaled dot-product attention
# attn_scores: (B, num_heads, S, S)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
# mask should be broadcastable to (B, num_heads, S, S)
attn_scores = attn_scores.masked_fill(mask == 0, float("-inf"))
attn = self.softmax(attn_scores) # (B, num_heads, S, S)
attn = self.attn_dropout(attn)
# attn @ v -> (B, num_heads, S, d_k)
out = torch.matmul(attn, v)
# transpose & combine heads -> (B, S, D)
out = out.transpose(1, 2).contiguous().view(B, S, D)
out = self.out_proj(out) # final linear
if return_attn:
return out, attn
return out, None
class FeedForward(nn.Module):
"""Position-wise feed-forward network."""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class TransformerBlock(nn.Module):
"""A single Transformer block: MHSA -> Add&Norm -> FFN -> Add&Norm"""
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = MultiHeadSelfAttention(d_model, num_heads, dropout)
self.ln2 = nn.LayerNorm(d_model)
self.ff = FeedForward(d_model, d_ff, dropout)
def forward(
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_attn: bool = False
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Pre-norm style: ln -> attn -> add
z = self.ln1(x)
attn_out, attn_weights = self.attn(z, mask=mask, return_attn=return_attn)
x = x + attn_out
# FFN
z2 = self.ln2(x)
ff_out = self.ff(z2)
x = x + ff_out
if return_attn:
return x, attn_weights
return x, None
class TinyTransformerLM(nn.Module):
"""
Tiny Transformer language model for educational training/experiments.
Not tokenizer-dependent; expects token ids.
"""
def __init__(
self,
vocab_size: int,
d_model: int = 256,
n_layers: int = 4,
num_heads: int = 4,
d_ff: int = 1024,
max_len: int = 512,
dropout: float = 0.1,
):
super().__init__()
self.vocab_size = vocab_size
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.pos_enc = PositionalEncoding(d_model, max_len=max_len)
self.layers = nn.ModuleList(
[TransformerBlock(d_model, num_heads, d_ff, dropout) for _ in range(n_layers)]
)
self.ln_f = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size, bias=False) # logits head
def forward(
self, input_ids: torch.LongTensor, mask: Optional[torch.Tensor] = None, return_attn_layer: Optional[int] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
input_ids: (B, S)
returns: logits (B, S, vocab_size)
if return_attn_layer is an int, it will return attention weights from that layer (heads)
"""
B, S = input_ids.shape
x = self.tok_emb(input_ids) # (B, S, d_model)
x = self.pos_enc(x)
attn_weights = None
for idx, layer in enumerate(self.layers):
if return_attn_layer is not None and idx == return_attn_layer:
x, attn_weights = layer(x, mask=mask, return_attn=True)
else:
x, _ = layer(x, mask=mask, return_attn=False)
x = self.ln_f(x)
logits = self.head(x) # (B, S, vocab_size)
return logits, attn_weights