tekkmaven's picture
Upload model.py with huggingface_hub
3ec445e verified
"""
Small transformer model for modular arithmetic experiments.
============================================================
A minimal GPT-style decoder-only transformer designed to:
1. Train from scratch in minutes on a single GPU
2. Expose all internal activations (hidden states, attention patterns)
3. Support checkpoint saving/loading for representation tracking
Architecture matches Nanda et al. 2023 (grokking) configuration
with adjustments for our two-task experiment.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Dict, Optional, Tuple, List
from dataclasses import dataclass
@dataclass
class TransformerConfig:
"""Configuration for the small transformer."""
vocab_size: int = 101 # p + NUM_SPECIAL (97 + 4)
n_layers: int = 2
d_model: int = 128
n_heads: int = 4
d_mlp: int = 512
max_seq_len: int = 5 # [a, op, b, =, c]
dropout: float = 0.0
layer_norm: bool = True
class MultiHeadAttention(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.n_heads = config.n_heads
self.d_head = config.d_model // config.n_heads
self.d_model = config.d_model
self.W_Q = nn.Linear(config.d_model, config.d_model, bias=False)
self.W_K = nn.Linear(config.d_model, config.d_model, bias=False)
self.W_V = nn.Linear(config.d_model, config.d_model, bias=False)
self.W_O = nn.Linear(config.d_model, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor,
return_attn: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
B, T, D = x.shape
Q = self.W_Q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
K = self.W_K(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
V = self.W_V(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
# Scaled dot-product attention with causal mask
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_head)
causal_mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
out = (attn_weights @ V).transpose(1, 2).reshape(B, T, D)
out = self.W_O(out)
if return_attn:
return out, attn_weights # [B, H, T, T]
return out, None
class MLP(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.W_in = nn.Linear(config.d_model, config.d_mlp)
self.W_out = nn.Linear(config.d_mlp, config.d_model)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
hidden = F.gelu(self.W_in(x))
out = self.dropout(self.W_out(hidden))
return out, hidden # return pre-projection activations for probing
class TransformerBlock(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.attn = MultiHeadAttention(config)
self.mlp = MLP(config)
self.ln1 = nn.LayerNorm(config.d_model) if config.layer_norm else nn.Identity()
self.ln2 = nn.LayerNorm(config.d_model) if config.layer_norm else nn.Identity()
def forward(self, x: torch.Tensor,
return_internals: bool = False) -> Dict[str, torch.Tensor]:
# Pre-norm residual architecture
attn_out, attn_weights = self.attn(self.ln1(x), return_attn=return_internals)
x_post_attn = x + attn_out
mlp_out, mlp_hidden = self.mlp(self.ln2(x_post_attn))
x_post_mlp = x_post_attn + mlp_out
result = {'hidden_state': x_post_mlp}
if return_internals:
result['attn_weights'] = attn_weights
result['mlp_hidden'] = mlp_hidden
result['residual_post_attn'] = x_post_attn
return result
class SmallTransformer(nn.Module):
"""
Minimal GPT for modular arithmetic with full activation access.
"""
def __init__(self, config: TransformerConfig):
super().__init__()
self.config = config
self.tok_embed = nn.Embedding(config.vocab_size, config.d_model)
self.pos_embed = nn.Embedding(config.max_seq_len, config.d_model)
self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
self.ln_final = nn.LayerNorm(config.d_model) if config.layer_norm else nn.Identity()
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Weight tying (embedding ↔ output)
self.lm_head.weight = self.tok_embed.weight
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=0.02)
def forward(self, input_ids: torch.Tensor,
labels: Optional[torch.Tensor] = None,
return_internals: bool = False) -> Dict[str, torch.Tensor]:
B, T = input_ids.shape
device = input_ids.device
tok_emb = self.tok_embed(input_ids)
pos_emb = self.pos_embed(torch.arange(T, device=device))
x = tok_emb + pos_emb
# Collect internals
all_hidden_states = [x.detach()]
all_attn_weights = []
all_mlp_hidden = []
for block in self.blocks:
block_out = block(x, return_internals=return_internals)
x = block_out['hidden_state']
all_hidden_states.append(x.detach())
if return_internals:
all_attn_weights.append(block_out['attn_weights'].detach())
all_mlp_hidden.append(block_out['mlp_hidden'].detach())
x = self.ln_final(x)
logits = self.lm_head(x)
result = {'logits': logits}
if labels is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
labels.view(-1), ignore_index=-100)
result['loss'] = loss
if return_internals:
result['hidden_states'] = all_hidden_states # List of [B, T, D]
result['attn_weights'] = all_attn_weights # List of [B, H, T, T]
result['mlp_hidden'] = all_mlp_hidden # List of [B, T, D_mlp]
return result
def get_representations(self, input_ids: torch.Tensor,
token_position: int = -1) -> List[torch.Tensor]:
"""
Get hidden state at each layer for a specific token position.
Returns list of [batch_size, d_model] tensors.
"""
with torch.no_grad():
out = self.forward(input_ids, return_internals=True)
return [hs[:, token_position, :] for hs in out['hidden_states']]
def count_parameters(self) -> int:
return sum(p.numel() for p in self.parameters())