| """ |
| Small transformer model for modular arithmetic experiments. |
| ============================================================ |
| A minimal GPT-style decoder-only transformer designed to: |
| 1. Train from scratch in minutes on a single GPU |
| 2. Expose all internal activations (hidden states, attention patterns) |
| 3. Support checkpoint saving/loading for representation tracking |
| |
| Architecture matches Nanda et al. 2023 (grokking) configuration |
| with adjustments for our two-task experiment. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from typing import Dict, Optional, Tuple, List |
| from dataclasses import dataclass |
|
|
|
|
| @dataclass |
| class TransformerConfig: |
| """Configuration for the small transformer.""" |
| vocab_size: int = 101 |
| n_layers: int = 2 |
| d_model: int = 128 |
| n_heads: int = 4 |
| d_mlp: int = 512 |
| max_seq_len: int = 5 |
| dropout: float = 0.0 |
| layer_norm: bool = True |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| def __init__(self, config: TransformerConfig): |
| super().__init__() |
| self.n_heads = config.n_heads |
| self.d_head = config.d_model // config.n_heads |
| self.d_model = config.d_model |
|
|
| self.W_Q = nn.Linear(config.d_model, config.d_model, bias=False) |
| self.W_K = nn.Linear(config.d_model, config.d_model, bias=False) |
| self.W_V = nn.Linear(config.d_model, config.d_model, bias=False) |
| self.W_O = nn.Linear(config.d_model, config.d_model, bias=False) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x: torch.Tensor, |
| return_attn: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| B, T, D = x.shape |
|
|
| Q = self.W_Q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2) |
| K = self.W_K(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2) |
| V = self.W_V(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2) |
|
|
| |
| scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_head) |
| causal_mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() |
| scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf')) |
| attn_weights = F.softmax(scores, dim=-1) |
| attn_weights = self.dropout(attn_weights) |
|
|
| out = (attn_weights @ V).transpose(1, 2).reshape(B, T, D) |
| out = self.W_O(out) |
|
|
| if return_attn: |
| return out, attn_weights |
| return out, None |
|
|
|
|
| class MLP(nn.Module): |
| def __init__(self, config: TransformerConfig): |
| super().__init__() |
| self.W_in = nn.Linear(config.d_model, config.d_mlp) |
| self.W_out = nn.Linear(config.d_mlp, config.d_model) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| hidden = F.gelu(self.W_in(x)) |
| out = self.dropout(self.W_out(hidden)) |
| return out, hidden |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, config: TransformerConfig): |
| super().__init__() |
| self.attn = MultiHeadAttention(config) |
| self.mlp = MLP(config) |
| self.ln1 = nn.LayerNorm(config.d_model) if config.layer_norm else nn.Identity() |
| self.ln2 = nn.LayerNorm(config.d_model) if config.layer_norm else nn.Identity() |
|
|
| def forward(self, x: torch.Tensor, |
| return_internals: bool = False) -> Dict[str, torch.Tensor]: |
| |
| attn_out, attn_weights = self.attn(self.ln1(x), return_attn=return_internals) |
| x_post_attn = x + attn_out |
|
|
| mlp_out, mlp_hidden = self.mlp(self.ln2(x_post_attn)) |
| x_post_mlp = x_post_attn + mlp_out |
|
|
| result = {'hidden_state': x_post_mlp} |
| if return_internals: |
| result['attn_weights'] = attn_weights |
| result['mlp_hidden'] = mlp_hidden |
| result['residual_post_attn'] = x_post_attn |
| return result |
|
|
|
|
| class SmallTransformer(nn.Module): |
| """ |
| Minimal GPT for modular arithmetic with full activation access. |
| """ |
|
|
| def __init__(self, config: TransformerConfig): |
| super().__init__() |
| self.config = config |
| self.tok_embed = nn.Embedding(config.vocab_size, config.d_model) |
| self.pos_embed = nn.Embedding(config.max_seq_len, config.d_model) |
| self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) |
| self.ln_final = nn.LayerNorm(config.d_model) if config.layer_norm else nn.Identity() |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
| |
| self.lm_head.weight = self.tok_embed.weight |
|
|
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, std=0.02) |
|
|
| def forward(self, input_ids: torch.Tensor, |
| labels: Optional[torch.Tensor] = None, |
| return_internals: bool = False) -> Dict[str, torch.Tensor]: |
| B, T = input_ids.shape |
| device = input_ids.device |
|
|
| tok_emb = self.tok_embed(input_ids) |
| pos_emb = self.pos_embed(torch.arange(T, device=device)) |
| x = tok_emb + pos_emb |
|
|
| |
| all_hidden_states = [x.detach()] |
| all_attn_weights = [] |
| all_mlp_hidden = [] |
|
|
| for block in self.blocks: |
| block_out = block(x, return_internals=return_internals) |
| x = block_out['hidden_state'] |
| all_hidden_states.append(x.detach()) |
| if return_internals: |
| all_attn_weights.append(block_out['attn_weights'].detach()) |
| all_mlp_hidden.append(block_out['mlp_hidden'].detach()) |
|
|
| x = self.ln_final(x) |
| logits = self.lm_head(x) |
|
|
| result = {'logits': logits} |
|
|
| if labels is not None: |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), |
| labels.view(-1), ignore_index=-100) |
| result['loss'] = loss |
|
|
| if return_internals: |
| result['hidden_states'] = all_hidden_states |
| result['attn_weights'] = all_attn_weights |
| result['mlp_hidden'] = all_mlp_hidden |
|
|
| return result |
|
|
| def get_representations(self, input_ids: torch.Tensor, |
| token_position: int = -1) -> List[torch.Tensor]: |
| """ |
| Get hidden state at each layer for a specific token position. |
| Returns list of [batch_size, d_model] tensors. |
| """ |
| with torch.no_grad(): |
| out = self.forward(input_ids, return_internals=True) |
| return [hs[:, token_position, :] for hs in out['hidden_states']] |
|
|
| def count_parameters(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |
|
|