OpenTransformer commited on
Commit
2db758d
Β·
verified Β·
1 Parent(s): 2b0bfd4

Add experiments/n_heavy2.py

Browse files
Files changed (1) hide show
  1. experiments/n_heavy2.py +605 -0
experiments/n_heavy2.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ n_heavy2.py β€” Extended Heavy Attention Experiments
4
+ Testing mechanisms that use MORE compute than standard attention
5
+
6
+ Approaches:
7
+ 1. Multi-Hop: Explicit k-step reasoning chains
8
+ 2. Slot Attention: Competitive binding (from object-centric learning)
9
+ 3. Edge-Compute: Full pairwise MLP, not just weighted sum
10
+ 4. Memory-Aug: External memory bank with read/write
11
+ 5. Recurrent Depth: Same block applied k times (Universal Transformer)
12
+ """
13
+
14
+ from __future__ import annotations
15
+ import argparse, math, time
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ torch.backends.cuda.matmul.allow_tf32 = True
22
+ try:
23
+ torch.set_float32_matmul_precision("high")
24
+ except:
25
+ pass
26
+
27
+ VOCAB = 128256
28
+ EOS = 128001
29
+
30
+ # ─────────────────────────── ALiBi ───────────────────────────
31
+ def _alibi_slopes(n_heads: int):
32
+ def pow2slopes(n):
33
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
34
+ return [start * (start ** i) for i in range(n)]
35
+ if math.log2(n_heads).is_integer():
36
+ vals = pow2slopes(n_heads)
37
+ else:
38
+ closest = 2 ** math.floor(math.log2(n_heads))
39
+ vals = pow2slopes(closest)
40
+ extra = pow2slopes(2 * closest)
41
+ vals += extra[0::2][:n_heads - closest]
42
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
43
+
44
+ def alibi_bias(n_heads: int, n_tokens: int):
45
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
46
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
47
+ dist = (j - i).clamp_min(0).float()
48
+ slopes = _alibi_slopes(n_heads)
49
+ return -slopes * dist
50
+
51
+ def causal_mask(n):
52
+ return torch.triu(torch.full((1, 1, n, n), float("-inf"), device=DEV), 1)
53
+
54
+
55
+ # ═══════════════════════════════════════════════════════════════
56
+ # BASELINE: Standard Attention
57
+ # ═══════════════════════════════════════════════════════════════
58
+ class StandardAttention(nn.Module):
59
+ def __init__(self, d: int, h: int):
60
+ super().__init__()
61
+ assert d % h == 0
62
+ self.h, self.dk = h, d // h
63
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
64
+ self.proj = nn.Linear(d, d, bias=False)
65
+
66
+ def forward(self, x, mask=None):
67
+ B, N, _ = x.shape
68
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
69
+ q, k, v = qkv[0], qkv[1], qkv[2]
70
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
71
+ att = att + alibi_bias(self.h, N)
72
+ if mask is not None:
73
+ att = att + mask
74
+ z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
75
+ return self.proj(z)
76
+
77
+
78
+ # ═══════════════════════════════════════════════════════════════
79
+ # HEAVY 1: Multi-Hop Attention
80
+ # Each "hop" attends to previous hop's output
81
+ # Simulates multi-step reasoning chains
82
+ # ═══════════════════════════════════════════════════════════════
83
+ class MultiHopAttention(nn.Module):
84
+ """
85
+ K explicit reasoning hops. Each hop:
86
+ 1. Attend to current state
87
+ 2. Update state with attended info
88
+ 3. Next hop attends to updated state
89
+
90
+ O(k * nΒ²) - linear in hops, quadratic in sequence
91
+ """
92
+ def __init__(self, d: int, h: int, num_hops: int = 3):
93
+ super().__init__()
94
+ self.h, self.dk = h, d // h
95
+ self.num_hops = num_hops
96
+
97
+ # Separate Q projection per hop (K,V shared)
98
+ self.q_projs = nn.ModuleList([nn.Linear(d, d, bias=False) for _ in range(num_hops)])
99
+ self.kv = nn.Linear(d, 2 * d, bias=False)
100
+ self.proj = nn.Linear(d, d, bias=False)
101
+
102
+ # Hop mixing: combine info from all hops
103
+ self.hop_gate = nn.Linear(d * num_hops, d)
104
+
105
+ def forward(self, x, mask=None):
106
+ B, N, D = x.shape
107
+
108
+ # Compute K, V once (shared across hops)
109
+ kv = self.kv(x).reshape(B, N, 2, self.h, self.dk).permute(2, 0, 3, 1, 4)
110
+ k, v = kv[0], kv[1]
111
+
112
+ bias = alibi_bias(self.h, N)
113
+ hop_outputs = []
114
+ state = x
115
+
116
+ for hop in range(self.num_hops):
117
+ # Query from current state
118
+ q = self.q_projs[hop](state).reshape(B, N, self.h, self.dk).transpose(1, 2)
119
+
120
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
121
+ att = att + bias
122
+ if mask is not None:
123
+ att = att + mask
124
+
125
+ hop_out = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
126
+ hop_outputs.append(hop_out)
127
+
128
+ # Update state for next hop
129
+ state = state + hop_out
130
+
131
+ # Combine all hops
132
+ combined = torch.cat(hop_outputs, dim=-1)
133
+ return self.proj(self.hop_gate(combined))
134
+
135
+
136
+ # ═══════════════════════════════════════════════════════════════
137
+ # HEAVY 2: Slot Attention
138
+ # From "Object-Centric Learning with Slot Attention"
139
+ # Slots compete to bind to input positions
140
+ # ═══════════════════════════════════════════════════════════════
141
+ class SlotAttention(nn.Module):
142
+ """
143
+ Competitive binding: K slots compete for N positions.
144
+ Unlike standard attention (N queries), we have K << N slots.
145
+
146
+ Each slot iteratively refines what it attends to.
147
+ Then we project slots back to sequence.
148
+
149
+ O(iterations * K * N) where K = num_slots
150
+ """
151
+ def __init__(self, d: int, num_slots: int = 8, num_iters: int = 3):
152
+ super().__init__()
153
+ self.num_slots = num_slots
154
+ self.num_iters = num_iters
155
+ self.d = d
156
+
157
+ # Learnable slot initializations
158
+ self.slots_mu = nn.Parameter(torch.randn(1, num_slots, d) * 0.02)
159
+ self.slots_sigma = nn.Parameter(torch.ones(1, num_slots, d) * 0.02)
160
+
161
+ # Attention
162
+ self.to_q = nn.Linear(d, d, bias=False)
163
+ self.to_k = nn.Linear(d, d, bias=False)
164
+ self.to_v = nn.Linear(d, d, bias=False)
165
+
166
+ # Slot update GRU
167
+ self.gru = nn.GRUCell(d, d)
168
+ self.mlp = nn.Sequential(
169
+ nn.Linear(d, d * 2),
170
+ nn.ReLU(),
171
+ nn.Linear(d * 2, d)
172
+ )
173
+ self.ln1 = nn.LayerNorm(d)
174
+ self.ln2 = nn.LayerNorm(d)
175
+
176
+ # Project slots back to sequence
177
+ self.slot_to_seq = nn.Linear(d, d)
178
+
179
+ def forward(self, x, mask=None):
180
+ B, N, D = x.shape
181
+
182
+ # Initialize slots with noise
183
+ slots = self.slots_mu + self.slots_sigma * torch.randn(B, self.num_slots, D, device=x.device)
184
+
185
+ # Pre-compute keys and values
186
+ k = self.to_k(x) # (B, N, D)
187
+ v = self.to_v(x) # (B, N, D)
188
+
189
+ for _ in range(self.num_iters):
190
+ slots_prev = slots
191
+ slots = self.ln1(slots)
192
+
193
+ # Slot attention: slots query, inputs are keys/values
194
+ q = self.to_q(slots) # (B, K, D)
195
+
196
+ # Attention: (B, K, D) @ (B, D, N) -> (B, K, N)
197
+ attn = torch.einsum('bkd,bnd->bkn', q, k) / math.sqrt(D)
198
+
199
+ # Softmax over SLOTS (competition) not positions
200
+ attn = F.softmax(attn, dim=1) # Slots compete for each position
201
+
202
+ # Weighted sum of values
203
+ updates = torch.einsum('bkn,bnd->bkd', attn, v) # (B, K, D)
204
+
205
+ # GRU update
206
+ slots = self.gru(
207
+ updates.reshape(B * self.num_slots, D),
208
+ slots_prev.reshape(B * self.num_slots, D)
209
+ ).reshape(B, self.num_slots, D)
210
+
211
+ # MLP residual
212
+ slots = slots + self.mlp(self.ln2(slots))
213
+
214
+ # Project slots back to sequence length
215
+ # Use attention from slots to positions
216
+ q_out = self.to_q(x) # (B, N, D)
217
+ k_slots = self.to_k(slots) # (B, K, D)
218
+
219
+ attn_out = torch.einsum('bnd,bkd->bnk', q_out, k_slots) / math.sqrt(D)
220
+ attn_out = F.softmax(attn_out, dim=-1) # (B, N, K)
221
+
222
+ output = torch.einsum('bnk,bkd->bnd', attn_out, slots)
223
+ return self.slot_to_seq(output)
224
+
225
+
226
+ # ═══════════════════════════════════════════════════════════════
227
+ # HEAVY 3: Edge-Compute Attention
228
+ # Instead of weighted sum, compute MLP on each (query, key) pair
229
+ # ═══════════════════════════════════════════════════════════════
230
+ class EdgeComputeAttention(nn.Module):
231
+ """
232
+ Standard attention: output = softmax(QK^T) @ V
233
+ This is just a weighted sum - no computation on relationships.
234
+
235
+ Edge-Compute: For each (i,j) pair, run MLP([q_i; k_j; v_j])
236
+ Then aggregate. Much heavier but captures richer interactions.
237
+
238
+ O(nΒ² * mlp_cost) - quadratic with multiplicative MLP factor
239
+
240
+ Note: Only practical for short sequences!
241
+ """
242
+ def __init__(self, d: int, h: int, max_seq: int = 128):
243
+ super().__init__()
244
+ self.h, self.dk = h, d // h
245
+ self.max_seq = max_seq
246
+
247
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
248
+
249
+ # Edge MLP: processes each (q_i, k_j, v_j) triple
250
+ self.edge_mlp = nn.Sequential(
251
+ nn.Linear(3 * self.dk, 2 * self.dk),
252
+ nn.ReLU(),
253
+ nn.Linear(2 * self.dk, self.dk)
254
+ )
255
+
256
+ # Attention for aggregation
257
+ self.score_mlp = nn.Sequential(
258
+ nn.Linear(2 * self.dk, self.dk),
259
+ nn.ReLU(),
260
+ nn.Linear(self.dk, 1)
261
+ )
262
+
263
+ self.proj = nn.Linear(d, d, bias=False)
264
+
265
+ def forward(self, x, mask=None):
266
+ B, N, D = x.shape
267
+
268
+ # For long sequences, fall back to standard
269
+ if N > self.max_seq:
270
+ return self._standard_forward(x, mask)
271
+
272
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk)
273
+ q, k, v = qkv[:,:,0], qkv[:,:,1], qkv[:,:,2] # Each: (B, N, H, dk)
274
+
275
+ outputs = []
276
+ for head in range(self.h):
277
+ q_h = q[:, :, head, :] # (B, N, dk)
278
+ k_h = k[:, :, head, :]
279
+ v_h = v[:, :, head, :]
280
+
281
+ # Expand for pairwise: (B, N, 1, dk) and (B, 1, N, dk)
282
+ q_exp = q_h.unsqueeze(2).expand(-1, -1, N, -1) # (B, N, N, dk)
283
+ k_exp = k_h.unsqueeze(1).expand(-1, N, -1, -1) # (B, N, N, dk)
284
+ v_exp = v_h.unsqueeze(1).expand(-1, N, -1, -1) # (B, N, N, dk)
285
+
286
+ # Concatenate for edge MLP
287
+ edge_input = torch.cat([q_exp, k_exp, v_exp], dim=-1) # (B, N, N, 3*dk)
288
+
289
+ # Compute edge features
290
+ edge_features = self.edge_mlp(edge_input) # (B, N, N, dk)
291
+
292
+ # Compute attention scores
293
+ score_input = torch.cat([q_exp, k_exp], dim=-1) # (B, N, N, 2*dk)
294
+ scores = self.score_mlp(score_input).squeeze(-1) # (B, N, N)
295
+
296
+ # Apply causal mask
297
+ if mask is not None:
298
+ scores = scores + mask.squeeze(1)
299
+
300
+ # Aggregate
301
+ weights = F.softmax(scores, dim=-1) # (B, N, N)
302
+ head_out = (weights.unsqueeze(-1) * edge_features).sum(dim=2) # (B, N, dk)
303
+ outputs.append(head_out)
304
+
305
+ out = torch.cat(outputs, dim=-1) # (B, N, D)
306
+ return self.proj(out)
307
+
308
+ def _standard_forward(self, x, mask=None):
309
+ B, N, _ = x.shape
310
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
311
+ q, k, v = qkv[0], qkv[1], qkv[2]
312
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
313
+ att = att + alibi_bias(self.h, N)
314
+ if mask is not None:
315
+ att = att + mask
316
+ z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
317
+ return self.proj(z)
318
+
319
+
320
+ # ═══════════════════════════════════════════════════════════════
321
+ # HEAVY 4: Memory-Augmented Attention
322
+ # External memory bank with read/write operations
323
+ # ═══════════════════════════════════════════════════════════════
324
+ class MemoryAugmentedAttention(nn.Module):
325
+ """
326
+ Maintain external memory bank M of size (mem_size, d).
327
+ Each forward:
328
+ 1. Read from memory using attention
329
+ 2. Standard self-attention augmented with memory content
330
+ 3. Write updated info back to memory
331
+
332
+ O(nΒ² + n*mem_size) - adds memory interaction cost
333
+ """
334
+ def __init__(self, d: int, h: int, mem_size: int = 64):
335
+ super().__init__()
336
+ self.h, self.dk = h, d // h
337
+ self.mem_size = mem_size
338
+
339
+ # Persistent memory (learned)
340
+ self.memory = nn.Parameter(torch.randn(1, mem_size, d) * 0.02)
341
+
342
+ # Standard attention
343
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
344
+ self.proj = nn.Linear(d, d, bias=False)
345
+
346
+ # Memory read/write
347
+ self.mem_q = nn.Linear(d, d, bias=False)
348
+ self.mem_k = nn.Linear(d, d, bias=False)
349
+ self.mem_v = nn.Linear(d, d, bias=False)
350
+
351
+ # Write gate
352
+ self.write_gate = nn.Sequential(
353
+ nn.Linear(d * 2, d),
354
+ nn.Sigmoid()
355
+ )
356
+
357
+ # Combine self-attention and memory
358
+ self.combine = nn.Linear(d * 2, d)
359
+
360
+ def forward(self, x, mask=None):
361
+ B, N, D = x.shape
362
+
363
+ # Expand memory for batch
364
+ mem = self.memory.expand(B, -1, -1) # (B, mem_size, D)
365
+
366
+ # 1. Read from memory
367
+ q_mem = self.mem_q(x) # (B, N, D)
368
+ k_mem = self.mem_k(mem) # (B, mem_size, D)
369
+ v_mem = self.mem_v(mem) # (B, mem_size, D)
370
+
371
+ mem_attn = torch.einsum('bnd,bmd->bnm', q_mem, k_mem) / math.sqrt(D)
372
+ mem_attn = F.softmax(mem_attn, dim=-1)
373
+ mem_read = torch.einsum('bnm,bmd->bnd', mem_attn, v_mem) # (B, N, D)
374
+
375
+ # 2. Standard self-attention
376
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
377
+ q, k, v = qkv[0], qkv[1], qkv[2]
378
+
379
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
380
+ att = att + alibi_bias(self.h, N)
381
+ if mask is not None:
382
+ att = att + mask
383
+ self_out = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
384
+
385
+ # 3. Combine self-attention and memory read
386
+ combined = self.combine(torch.cat([self_out, mem_read], dim=-1))
387
+
388
+ return self.proj(combined)
389
+
390
+
391
+ # ═══════════════════════════════════════════════════════════════
392
+ # HEAVY 5: Recurrent Depth (Universal Transformer)
393
+ # Same block applied k times with position-in-depth encoding
394
+ # ═══════════════════════════════════════════════════════════════
395
+ class RecurrentDepthAttention(nn.Module):
396
+ """
397
+ Instead of L different layers, use 1 layer L times.
398
+ Add depth embedding so model knows which iteration it's on.
399
+
400
+ O(k * nΒ²) where k = num_recurrences
401
+
402
+ Key insight: Weight sharing + depth embedding = potentially more
403
+ efficient use of parameters for complex reasoning.
404
+ """
405
+ def __init__(self, d: int, h: int, num_recur: int = 4):
406
+ super().__init__()
407
+ self.h, self.dk = h, d // h
408
+ self.num_recur = num_recur
409
+
410
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
411
+ self.proj = nn.Linear(d, d, bias=False)
412
+
413
+ # Depth embedding
414
+ self.depth_emb = nn.Embedding(num_recur, d)
415
+
416
+ # Transition function between recurrences
417
+ self.transition = nn.Sequential(
418
+ nn.LayerNorm(d),
419
+ nn.Linear(d, d * 2),
420
+ nn.GELU(),
421
+ nn.Linear(d * 2, d)
422
+ )
423
+
424
+ def forward(self, x, mask=None):
425
+ B, N, D = x.shape
426
+ bias = alibi_bias(self.h, N)
427
+
428
+ for r in range(self.num_recur):
429
+ # Add depth embedding
430
+ x_r = x + self.depth_emb.weight[r].unsqueeze(0).unsqueeze(0)
431
+
432
+ # Self-attention
433
+ qkv = self.qkv(x_r).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
434
+ q, k, v = qkv[0], qkv[1], qkv[2]
435
+
436
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
437
+ att = att + bias
438
+ if mask is not None:
439
+ att = att + mask
440
+
441
+ attn_out = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
442
+ attn_out = self.proj(attn_out)
443
+
444
+ # Residual + transition
445
+ x = x + attn_out
446
+ x = x + self.transition(x)
447
+
448
+ return x - x.detach() + x.detach() # Gradient trick for stability
449
+
450
+
451
+ # ═══════════════════════════════════════════════════════════════
452
+ # Block and Model wrappers
453
+ # ═══════════════════════════════════════════════════════════════
454
+ class Block(nn.Module):
455
+ def __init__(self, d: int, h: int, attn_type: str = "standard", **kwargs):
456
+ super().__init__()
457
+ self.ln1 = nn.LayerNorm(d)
458
+ self.ln2 = nn.LayerNorm(d)
459
+
460
+ if attn_type == "standard":
461
+ self.attn = StandardAttention(d, h)
462
+ elif attn_type == "multihop":
463
+ self.attn = MultiHopAttention(d, h, num_hops=kwargs.get('num_hops', 3))
464
+ elif attn_type == "slot":
465
+ self.attn = SlotAttention(d, num_slots=kwargs.get('num_slots', 8))
466
+ elif attn_type == "edge":
467
+ self.attn = EdgeComputeAttention(d, h)
468
+ elif attn_type == "memory":
469
+ self.attn = MemoryAugmentedAttention(d, h, mem_size=kwargs.get('mem_size', 64))
470
+ elif attn_type == "recurrent":
471
+ self.attn = RecurrentDepthAttention(d, h, num_recur=kwargs.get('num_recur', 4))
472
+ else:
473
+ raise ValueError(f"Unknown attn_type: {attn_type}")
474
+
475
+ self.ff = nn.Sequential(
476
+ nn.Linear(d, 4 * d),
477
+ nn.GELU(),
478
+ nn.Linear(4 * d, d)
479
+ )
480
+
481
+ def forward(self, x, mask=None):
482
+ x = x + self.attn(self.ln1(x), mask)
483
+ x = x + self.ff(self.ln2(x))
484
+ return x
485
+
486
+
487
+ class HeavyModel(nn.Module):
488
+ def __init__(self, d: int, layers: int, h: int, attn_type: str = "standard", **kwargs):
489
+ super().__init__()
490
+ self.emb = nn.Embedding(VOCAB, d)
491
+ self.blocks = nn.ModuleList([Block(d, h, attn_type, **kwargs) for _ in range(layers)])
492
+ self.ln = nn.LayerNorm(d)
493
+ self.head = nn.Linear(d, VOCAB, bias=False)
494
+ self.head.weight = self.emb.weight # Tie weights
495
+
496
+ def forward(self, x, mask=None):
497
+ x = self.emb(x)
498
+ for blk in self.blocks:
499
+ x = blk(x, mask)
500
+ return self.head(self.ln(x))
501
+
502
+ def count_params(self):
503
+ return sum(p.numel() for p in self.parameters())
504
+
505
+
506
+ # ═══════════════════════════════════════════════════════════════
507
+ # Experiment Runner
508
+ # ═══════════════════════════════════════════════════════════════
509
+ def run_experiment(attn_type: str, d: int, layers: int, heads: int,
510
+ batch: int, seq: int, steps: int, **kwargs):
511
+ print(f"\n{'='*60}")
512
+ print(f"ATTENTION TYPE: {attn_type.upper()}")
513
+ print(f"Config: d={d}, layers={layers}, heads={heads}")
514
+ print(f"{'='*60}")
515
+
516
+ model = HeavyModel(d, layers, heads, attn_type, **kwargs).to(DEV)
517
+ print(f"Parameters: {model.count_params():,}")
518
+
519
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
520
+ mask = causal_mask(seq - 1)
521
+
522
+ losses, times = [], []
523
+
524
+ for step in range(steps):
525
+ ids = torch.randint(0, VOCAB, (batch, seq), device=DEV)
526
+ target = ids[:, 1:]
527
+ input_ids = ids[:, :-1]
528
+
529
+ start = time.time()
530
+ optimizer.zero_grad()
531
+ logits = model(input_ids, mask)
532
+ loss = F.cross_entropy(logits.view(-1, VOCAB), target.reshape(-1))
533
+ loss.backward()
534
+ optimizer.step()
535
+ elapsed = time.time() - start
536
+
537
+ losses.append(loss.item())
538
+ times.append(elapsed)
539
+ tok_s = (batch * seq) / elapsed
540
+
541
+ if step % 10 == 0 or step == steps - 1:
542
+ print(f"Step {step:3d} | Loss: {loss.item():.4f} | {tok_s:.0f} tok/s | {elapsed*1000:.0f}ms")
543
+
544
+ avg_loss = sum(losses[-20:]) / min(20, len(losses))
545
+ avg_time = sum(times[-20:]) / min(20, len(times))
546
+ avg_toks = (batch * seq) / avg_time
547
+
548
+ return {
549
+ "type": attn_type,
550
+ "loss": avg_loss,
551
+ "tok_s": avg_toks,
552
+ "params": model.count_params()
553
+ }
554
+
555
+
556
+ def main():
557
+ parser = argparse.ArgumentParser()
558
+ parser.add_argument("--d", type=int, default=256)
559
+ parser.add_argument("--layers", type=int, default=4)
560
+ parser.add_argument("--heads", type=int, default=8)
561
+ parser.add_argument("--batch", type=int, default=16)
562
+ parser.add_argument("--seq", type=int, default=128)
563
+ parser.add_argument("--steps", type=int, default=100)
564
+ parser.add_argument("--types", type=str, default="all",
565
+ help="Comma-separated: standard,multihop,slot,edge,memory,recurrent")
566
+ args = parser.parse_args()
567
+
568
+ print(f"Device: {DEV}")
569
+ if torch.cuda.is_available():
570
+ print(f"GPU: {torch.cuda.get_device_name()}")
571
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
572
+
573
+ if args.types == "all":
574
+ types = ["standard", "multihop", "slot", "edge", "memory", "recurrent"]
575
+ else:
576
+ types = [t.strip() for t in args.types.split(",")]
577
+
578
+ results = []
579
+ for t in types:
580
+ try:
581
+ r = run_experiment(t, args.d, args.layers, args.heads,
582
+ args.batch, args.seq, args.steps)
583
+ results.append(r)
584
+ except Exception as e:
585
+ print(f"ERROR in {t}: {e}")
586
+ import traceback
587
+ traceback.print_exc()
588
+
589
+ # Summary
590
+ print(f"\n{'='*60}")
591
+ print("SUMMARY")
592
+ print(f"{'='*60}")
593
+ baseline = next((r for r in results if r['type'] == 'standard'), None)
594
+
595
+ for r in results:
596
+ rel = ""
597
+ if baseline and r['type'] != 'standard':
598
+ loss_diff = (baseline['loss'] - r['loss']) / baseline['loss'] * 100
599
+ speed_ratio = r['tok_s'] / baseline['tok_s']
600
+ rel = f" | vs baseline: {loss_diff:+.1f}% loss, {speed_ratio:.2f}x speed"
601
+ print(f"{r['type']:12s} | Loss: {r['loss']:.4f} | {r['tok_s']:6.0f} tok/s | {r['params']:,} params{rel}")
602
+
603
+
604
+ if __name__ == "__main__":
605
+ main()