tekkmaven commited on
Commit
3ec445e
·
verified ·
1 Parent(s): ac2814e

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +186 -0
model.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Small transformer model for modular arithmetic experiments.
3
+ ============================================================
4
+ A minimal GPT-style decoder-only transformer designed to:
5
+ 1. Train from scratch in minutes on a single GPU
6
+ 2. Expose all internal activations (hidden states, attention patterns)
7
+ 3. Support checkpoint saving/loading for representation tracking
8
+
9
+ Architecture matches Nanda et al. 2023 (grokking) configuration
10
+ with adjustments for our two-task experiment.
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import math
17
+ from typing import Dict, Optional, Tuple, List
18
+ from dataclasses import dataclass
19
+
20
+
21
+ @dataclass
22
+ class TransformerConfig:
23
+ """Configuration for the small transformer."""
24
+ vocab_size: int = 101 # p + NUM_SPECIAL (97 + 4)
25
+ n_layers: int = 2
26
+ d_model: int = 128
27
+ n_heads: int = 4
28
+ d_mlp: int = 512
29
+ max_seq_len: int = 5 # [a, op, b, =, c]
30
+ dropout: float = 0.0
31
+ layer_norm: bool = True
32
+
33
+
34
+ class MultiHeadAttention(nn.Module):
35
+ def __init__(self, config: TransformerConfig):
36
+ super().__init__()
37
+ self.n_heads = config.n_heads
38
+ self.d_head = config.d_model // config.n_heads
39
+ self.d_model = config.d_model
40
+
41
+ self.W_Q = nn.Linear(config.d_model, config.d_model, bias=False)
42
+ self.W_K = nn.Linear(config.d_model, config.d_model, bias=False)
43
+ self.W_V = nn.Linear(config.d_model, config.d_model, bias=False)
44
+ self.W_O = nn.Linear(config.d_model, config.d_model, bias=False)
45
+ self.dropout = nn.Dropout(config.dropout)
46
+
47
+ def forward(self, x: torch.Tensor,
48
+ return_attn: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
49
+ B, T, D = x.shape
50
+
51
+ Q = self.W_Q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
52
+ K = self.W_K(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
53
+ V = self.W_V(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
54
+
55
+ # Scaled dot-product attention with causal mask
56
+ scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_head)
57
+ causal_mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
58
+ scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
59
+ attn_weights = F.softmax(scores, dim=-1)
60
+ attn_weights = self.dropout(attn_weights)
61
+
62
+ out = (attn_weights @ V).transpose(1, 2).reshape(B, T, D)
63
+ out = self.W_O(out)
64
+
65
+ if return_attn:
66
+ return out, attn_weights # [B, H, T, T]
67
+ return out, None
68
+
69
+
70
+ class MLP(nn.Module):
71
+ def __init__(self, config: TransformerConfig):
72
+ super().__init__()
73
+ self.W_in = nn.Linear(config.d_model, config.d_mlp)
74
+ self.W_out = nn.Linear(config.d_mlp, config.d_model)
75
+ self.dropout = nn.Dropout(config.dropout)
76
+
77
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
78
+ hidden = F.gelu(self.W_in(x))
79
+ out = self.dropout(self.W_out(hidden))
80
+ return out, hidden # return pre-projection activations for probing
81
+
82
+
83
+ class TransformerBlock(nn.Module):
84
+ def __init__(self, config: TransformerConfig):
85
+ super().__init__()
86
+ self.attn = MultiHeadAttention(config)
87
+ self.mlp = MLP(config)
88
+ self.ln1 = nn.LayerNorm(config.d_model) if config.layer_norm else nn.Identity()
89
+ self.ln2 = nn.LayerNorm(config.d_model) if config.layer_norm else nn.Identity()
90
+
91
+ def forward(self, x: torch.Tensor,
92
+ return_internals: bool = False) -> Dict[str, torch.Tensor]:
93
+ # Pre-norm residual architecture
94
+ attn_out, attn_weights = self.attn(self.ln1(x), return_attn=return_internals)
95
+ x_post_attn = x + attn_out
96
+
97
+ mlp_out, mlp_hidden = self.mlp(self.ln2(x_post_attn))
98
+ x_post_mlp = x_post_attn + mlp_out
99
+
100
+ result = {'hidden_state': x_post_mlp}
101
+ if return_internals:
102
+ result['attn_weights'] = attn_weights
103
+ result['mlp_hidden'] = mlp_hidden
104
+ result['residual_post_attn'] = x_post_attn
105
+ return result
106
+
107
+
108
+ class SmallTransformer(nn.Module):
109
+ """
110
+ Minimal GPT for modular arithmetic with full activation access.
111
+ """
112
+
113
+ def __init__(self, config: TransformerConfig):
114
+ super().__init__()
115
+ self.config = config
116
+ self.tok_embed = nn.Embedding(config.vocab_size, config.d_model)
117
+ self.pos_embed = nn.Embedding(config.max_seq_len, config.d_model)
118
+ self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
119
+ self.ln_final = nn.LayerNorm(config.d_model) if config.layer_norm else nn.Identity()
120
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
121
+
122
+ # Weight tying (embedding ↔ output)
123
+ self.lm_head.weight = self.tok_embed.weight
124
+
125
+ self.apply(self._init_weights)
126
+
127
+ def _init_weights(self, module):
128
+ if isinstance(module, nn.Linear):
129
+ nn.init.normal_(module.weight, std=0.02)
130
+ if module.bias is not None:
131
+ nn.init.zeros_(module.bias)
132
+ elif isinstance(module, nn.Embedding):
133
+ nn.init.normal_(module.weight, std=0.02)
134
+
135
+ def forward(self, input_ids: torch.Tensor,
136
+ labels: Optional[torch.Tensor] = None,
137
+ return_internals: bool = False) -> Dict[str, torch.Tensor]:
138
+ B, T = input_ids.shape
139
+ device = input_ids.device
140
+
141
+ tok_emb = self.tok_embed(input_ids)
142
+ pos_emb = self.pos_embed(torch.arange(T, device=device))
143
+ x = tok_emb + pos_emb
144
+
145
+ # Collect internals
146
+ all_hidden_states = [x.detach()]
147
+ all_attn_weights = []
148
+ all_mlp_hidden = []
149
+
150
+ for block in self.blocks:
151
+ block_out = block(x, return_internals=return_internals)
152
+ x = block_out['hidden_state']
153
+ all_hidden_states.append(x.detach())
154
+ if return_internals:
155
+ all_attn_weights.append(block_out['attn_weights'].detach())
156
+ all_mlp_hidden.append(block_out['mlp_hidden'].detach())
157
+
158
+ x = self.ln_final(x)
159
+ logits = self.lm_head(x)
160
+
161
+ result = {'logits': logits}
162
+
163
+ if labels is not None:
164
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
165
+ labels.view(-1), ignore_index=-100)
166
+ result['loss'] = loss
167
+
168
+ if return_internals:
169
+ result['hidden_states'] = all_hidden_states # List of [B, T, D]
170
+ result['attn_weights'] = all_attn_weights # List of [B, H, T, T]
171
+ result['mlp_hidden'] = all_mlp_hidden # List of [B, T, D_mlp]
172
+
173
+ return result
174
+
175
+ def get_representations(self, input_ids: torch.Tensor,
176
+ token_position: int = -1) -> List[torch.Tensor]:
177
+ """
178
+ Get hidden state at each layer for a specific token position.
179
+ Returns list of [batch_size, d_model] tensors.
180
+ """
181
+ with torch.no_grad():
182
+ out = self.forward(input_ids, return_internals=True)
183
+ return [hs[:, token_position, :] for hs in out['hidden_states']]
184
+
185
+ def count_parameters(self) -> int:
186
+ return sum(p.numel() for p in self.parameters())