# model/transformer_explained.py """ Tiny Transformer language model (educational). Components: - PositionalEncoding: sinusoidal positional encodings (buffered) - MultiHeadSelfAttention: returns attn weights optionally - FeedForward: MLP with GELU - TransformerBlock: attention + add&norm + FFN + add&norm - TinyTransformerLM: token embeddings, pos enc, stacked blocks, LM head """ import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F class PositionalEncoding(nn.Module): """Sinusoidal positional encoding as in "Attention is All You Need". Stored as a buffer (not learned). Adds positional encodings to token embeddings. """ def __init__(self, d_model: int, max_len: int = 2048): super().__init__() pe = torch.zeros(max_len, d_model) # (max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (max_len, 1) div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) ) # (d_model/2,) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # (1, max_len, d_model) self.register_buffer("pe", pe) # not a parameter def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: (batch, seq_len, d_model) returns: x + pe[:, :seq_len, :] """ seq_len = x.size(1) return x + self.pe[:, :seq_len, :].to(x.device) class MultiHeadSelfAttention(nn.Module): """ Multi-head self-attention. Optionally returns attention weights for visualization. Input shapes: x: (batch, seq_len, d_model) Output: out: (batch, seq_len, d_model) Optional: attn: (batch, num_heads, seq_len, seq_len) """ def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0): super().__init__() assert d_model % num_heads == 0, "d_model must be divisible by num_heads" self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads # single linear for qkv then split self.qkv_proj = nn.Linear(d_model, d_model * 3, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) self.attn_dropout = nn.Dropout(dropout) self.softmax = nn.Softmax(dim=-1) def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_attn: bool = False ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ x: (batch, seq_len, d_model) mask: (batch, 1, seq_len, seq_len) or (batch, seq_len) causal mask etc. return_attn: if True, also return attention weights """ B, S, D = x.shape # project and split into q,k,v qkv = self.qkv_proj(x) # (B, S, 3*D) qkv = qkv.view(B, S, 3, self.num_heads, self.d_k) q, k, v = qkv.unbind(dim=2) # each: (B, S, num_heads, d_k) # transpose to (B, num_heads, S, d_k) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # scaled dot-product attention # attn_scores: (B, num_heads, S, S) attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) if mask is not None: # mask should be broadcastable to (B, num_heads, S, S) attn_scores = attn_scores.masked_fill(mask == 0, float("-inf")) attn = self.softmax(attn_scores) # (B, num_heads, S, S) attn = self.attn_dropout(attn) # attn @ v -> (B, num_heads, S, d_k) out = torch.matmul(attn, v) # transpose & combine heads -> (B, S, D) out = out.transpose(1, 2).contiguous().view(B, S, D) out = self.out_proj(out) # final linear if return_attn: return out, attn return out, None class FeedForward(nn.Module): """Position-wise feed-forward network.""" def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.net = nn.Sequential( nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model), nn.Dropout(dropout), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) class TransformerBlock(nn.Module): """A single Transformer block: MHSA -> Add&Norm -> FFN -> Add&Norm""" def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.attn = MultiHeadSelfAttention(d_model, num_heads, dropout) self.ln2 = nn.LayerNorm(d_model) self.ff = FeedForward(d_model, d_ff, dropout) def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_attn: bool = False ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Pre-norm style: ln -> attn -> add z = self.ln1(x) attn_out, attn_weights = self.attn(z, mask=mask, return_attn=return_attn) x = x + attn_out # FFN z2 = self.ln2(x) ff_out = self.ff(z2) x = x + ff_out if return_attn: return x, attn_weights return x, None class TinyTransformerLM(nn.Module): """ Tiny Transformer language model for educational training/experiments. Not tokenizer-dependent; expects token ids. """ def __init__( self, vocab_size: int, d_model: int = 256, n_layers: int = 4, num_heads: int = 4, d_ff: int = 1024, max_len: int = 512, dropout: float = 0.1, ): super().__init__() self.vocab_size = vocab_size self.tok_emb = nn.Embedding(vocab_size, d_model) self.pos_enc = PositionalEncoding(d_model, max_len=max_len) self.layers = nn.ModuleList( [TransformerBlock(d_model, num_heads, d_ff, dropout) for _ in range(n_layers)] ) self.ln_f = nn.LayerNorm(d_model) self.head = nn.Linear(d_model, vocab_size, bias=False) # logits head def forward( self, input_ids: torch.LongTensor, mask: Optional[torch.Tensor] = None, return_attn_layer: Optional[int] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ input_ids: (B, S) returns: logits (B, S, vocab_size) if return_attn_layer is an int, it will return attention weights from that layer (heads) """ B, S = input_ids.shape x = self.tok_emb(input_ids) # (B, S, d_model) x = self.pos_enc(x) attn_weights = None for idx, layer in enumerate(self.layers): if return_attn_layer is not None and idx == return_attn_layer: x, attn_weights = layer(x, mask=mask, return_attn=True) else: x, _ = layer(x, mask=mask, return_attn=False) x = self.ln_f(x) logits = self.head(x) # (B, S, vocab_size) return logits, attn_weights