OpenTransformer commited on
Commit
2b0bfd4
Β·
verified Β·
1 Parent(s): 7c91b87

Add experiments/n_heavy.py

Browse files
Files changed (1) hide show
  1. experiments/n_heavy.py +466 -0
experiments/n_heavy.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ n_heavy.py β€” Iterative Refinement Transformer Experiment
4
+ Heavier-than-standard-attention: tokens get reprocessed based on uncertainty
5
+
6
+ Key idea: Instead of single-pass attention, run multiple iterations
7
+ where "hard" tokens (high uncertainty) get recomputed while "easy" tokens halt.
8
+
9
+ This is O(nΒ² Γ— k) where k = average iterations, vs standard O(nΒ²).
10
+ """
11
+
12
+ from __future__ import annotations
13
+ import argparse, json, math, pathlib, random, time, os, sys
14
+ from contextlib import nullcontext
15
+ from typing import Dict, Any, List, Optional, Tuple
16
+ from datetime import datetime, timezone
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ # ─────────────────────────── Globals ───────────────────────────
22
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ torch.backends.cuda.matmul.allow_tf32 = True
24
+ try:
25
+ torch.set_float32_matmul_precision("high")
26
+ except:
27
+ pass
28
+
29
+ VOCAB = 128256 # DeepSeek V3 vocab
30
+ EOS = 128001
31
+
32
+ # ─────────────────────────── ALiBi ───────────────────────────
33
+ def _alibi_slopes(n_heads: int):
34
+ def pow2slopes(n):
35
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
36
+ ratio = start
37
+ return [start * (ratio ** i) for i in range(n)]
38
+ if math.log2(n_heads).is_integer(): vals = pow2slopes(n_heads)
39
+ else:
40
+ closest = 2 ** math.floor(math.log2(n_heads))
41
+ vals = pow2slopes(closest)
42
+ extra = pow2slopes(2 * closest)
43
+ vals += extra[0::2][: n_heads - closest]
44
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
45
+
46
+ def alibi_bias(n_heads: int, n_tokens: int):
47
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
48
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
49
+ dist = (j - i).clamp_min(0)
50
+ return -_alibi_slopes(n_heads) * dist
51
+
52
+ # ─────────────────────────── Standard Attention ───────────────────────────
53
+ class StandardAttention(nn.Module):
54
+ """Baseline: single-pass multi-head attention"""
55
+ def __init__(self, d: int, h: int):
56
+ super().__init__()
57
+ assert d % h == 0
58
+ self.h, self.dk = h, d // h
59
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
60
+ self.proj = nn.Linear(d, d, bias=False)
61
+ self.drop = nn.Dropout(0.1)
62
+
63
+ def forward(self, x, mask=None):
64
+ B, N, _ = x.shape
65
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
66
+ q, k, v = qkv[0], qkv[1], qkv[2]
67
+
68
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
69
+ att = att + alibi_bias(self.h, N)
70
+ if mask is not None:
71
+ att = att + mask
72
+
73
+ z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
74
+ return self.drop(self.proj(z))
75
+
76
+
77
+ # ─────────────────────────── HEAVY: Iterative Refinement Attention ───────────────────────────
78
+ class IterativeAttention(nn.Module):
79
+ """
80
+ Heavier-than-standard: iteratively refine representations.
81
+
82
+ Each token has a "halting probability" - once it exceeds threshold,
83
+ that token stops updating. Hard tokens keep getting reprocessed.
84
+
85
+ Inspired by Universal Transformers + PonderNet.
86
+ """
87
+ def __init__(self, d: int, h: int, max_iters: int = 5, halt_threshold: float = 0.9):
88
+ super().__init__()
89
+ assert d % h == 0
90
+ self.h, self.dk = h, d // h
91
+ self.max_iters = max_iters
92
+ self.halt_threshold = halt_threshold
93
+
94
+ # Shared attention weights across iterations (Universal Transformer style)
95
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
96
+ self.proj = nn.Linear(d, d, bias=False)
97
+ self.drop = nn.Dropout(0.1)
98
+
99
+ # Halting predictor: per-token probability of "done processing"
100
+ self.halt_pred = nn.Sequential(
101
+ nn.Linear(d, d // 4),
102
+ nn.ReLU(),
103
+ nn.Linear(d // 4, 1),
104
+ nn.Sigmoid()
105
+ )
106
+
107
+ # Iteration embedding: tell model which iteration we're on
108
+ self.iter_emb = nn.Embedding(max_iters, d)
109
+
110
+ def forward(self, x, mask=None):
111
+ B, N, D = x.shape
112
+
113
+ # Track halting state
114
+ halted = torch.zeros(B, N, 1, device=x.device, dtype=torch.bool)
115
+ cumulative_halt = torch.zeros(B, N, 1, device=x.device)
116
+
117
+ # Accumulate outputs weighted by when each token halted
118
+ output = torch.zeros_like(x)
119
+ remainder = torch.ones(B, N, 1, device=x.device)
120
+
121
+ total_compute = 0
122
+
123
+ for i in range(self.max_iters):
124
+ # Add iteration embedding
125
+ x_iter = x + self.iter_emb.weight[i].unsqueeze(0).unsqueeze(0)
126
+
127
+ # Standard attention on current state
128
+ qkv = self.qkv(x_iter).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
129
+ q, k, v = qkv[0], qkv[1], qkv[2]
130
+
131
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
132
+ att = att + alibi_bias(self.h, N)
133
+ if mask is not None:
134
+ att = att + mask
135
+
136
+ z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
137
+ delta = self.drop(self.proj(z))
138
+
139
+ # Compute halting probability for each token
140
+ halt_prob = self.halt_pred(x + delta) # p(halt | current state)
141
+
142
+ # Update cumulative halt probability
143
+ new_cumulative = cumulative_halt + halt_prob * (~halted).float()
144
+
145
+ # Tokens that should halt this iteration
146
+ should_halt = (new_cumulative >= self.halt_threshold) & (~halted)
147
+
148
+ # For halting tokens, use remainder; for already halted, 0; for continuing, halt_prob
149
+ contrib_weight = torch.where(
150
+ should_halt,
151
+ remainder,
152
+ torch.where(halted, torch.zeros_like(halt_prob), halt_prob)
153
+ )
154
+
155
+ # Accumulate output
156
+ output = output + contrib_weight * (x + delta)
157
+
158
+ # Update remainder
159
+ remainder = remainder - contrib_weight
160
+
161
+ # Update halted status
162
+ halted = halted | should_halt
163
+ cumulative_halt = new_cumulative
164
+
165
+ # Update x for next iteration (only for non-halted)
166
+ x = torch.where(halted.expand_as(x), x, x + delta)
167
+
168
+ # Track compute
169
+ total_compute += (~halted).float().sum().item()
170
+
171
+ # Early exit if all halted
172
+ if halted.all():
173
+ break
174
+
175
+ # Final remainder goes to last state
176
+ output = output + remainder * x
177
+
178
+ # Store stats for analysis
179
+ self._last_iters = i + 1
180
+ self._last_compute_ratio = total_compute / (B * N * self.max_iters)
181
+
182
+ return output
183
+
184
+
185
+ # ─────────────────────────── HEAVY: Triplet Attention ───────────────────────────
186
+ class TripletAttention(nn.Module):
187
+ """
188
+ O(nΒ³) attention: model 3-way interactions.
189
+ "How does token A relate to B in context of C?"
190
+
191
+ This is VERY heavy - use small sequences only.
192
+ """
193
+ def __init__(self, d: int, h: int, max_triplet_n: int = 64):
194
+ super().__init__()
195
+ self.h, self.dk = h, d // h
196
+ self.max_triplet_n = max_triplet_n
197
+
198
+ # Standard pairwise attention
199
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
200
+
201
+ # Triplet scoring: takes concatenated (q_i, k_j, k_c) and outputs score modifier
202
+ self.triplet_score = nn.Sequential(
203
+ nn.Linear(3 * d // h, d // h),
204
+ nn.ReLU(),
205
+ nn.Linear(d // h, 1)
206
+ )
207
+
208
+ self.proj = nn.Linear(d, d, bias=False)
209
+ self.drop = nn.Dropout(0.1)
210
+
211
+ def forward(self, x, mask=None):
212
+ B, N, D = x.shape
213
+
214
+ # For large N, fall back to standard attention
215
+ if N > self.max_triplet_n:
216
+ return self._standard_forward(x, mask)
217
+
218
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
219
+ q, k, v = qkv[0], qkv[1], qkv[2] # Each: (B, H, N, dk)
220
+
221
+ # Pairwise scores
222
+ pairwise = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) # (B, H, N, N)
223
+
224
+ # Triplet modulation: for each (i,j) pair, average influence from all contexts c
225
+ # This is O(nΒ³) - compute triplet score for each (i, j, c) triple
226
+ triplet_mod = torch.zeros_like(pairwise)
227
+
228
+ for c in range(N): # Context position
229
+ # For each (i,j), compute how context c modifies the attention
230
+ # q_i: (B, H, N, dk), k_j: (B, H, N, dk), k_c: (B, H, 1, dk)
231
+ k_c = k[:, :, c:c+1, :].expand(-1, -1, N, -1) # (B, H, N, dk)
232
+
233
+ # Broadcast: q (B,H,N,1,dk), k (B,H,1,N,dk), k_c (B,H,N,1,dk)
234
+ q_exp = q.unsqueeze(3) # (B, H, N, 1, dk)
235
+ k_exp = k.unsqueeze(2) # (B, H, 1, N, dk)
236
+ k_c_exp = k_c.unsqueeze(3) # (B, H, N, 1, dk)
237
+
238
+ # Concatenate for triplet: (q_i, k_j, k_c)
239
+ triplet_input = torch.cat([
240
+ q_exp.expand(-1, -1, -1, N, -1),
241
+ k_exp.expand(-1, -1, N, -1, -1),
242
+ k_c_exp.expand(-1, -1, -1, N, -1)
243
+ ], dim=-1) # (B, H, N, N, 3*dk)
244
+
245
+ # Score modification from this context
246
+ mod = self.triplet_score(triplet_input).squeeze(-1) # (B, H, N, N)
247
+ triplet_mod = triplet_mod + mod
248
+
249
+ # Average over contexts and combine
250
+ triplet_mod = triplet_mod / N
251
+ att = pairwise + 0.1 * triplet_mod # Residual triplet contribution
252
+
253
+ att = att + alibi_bias(self.h, N)
254
+ if mask is not None:
255
+ att = att + mask
256
+
257
+ z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
258
+ return self.drop(self.proj(z))
259
+
260
+ def _standard_forward(self, x, mask=None):
261
+ B, N, _ = x.shape
262
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
263
+ q, k, v = qkv[0], qkv[1], qkv[2]
264
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
265
+ att = att + alibi_bias(self.h, N)
266
+ if mask is not None:
267
+ att = att + mask
268
+ z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
269
+ return self.drop(self.proj(z))
270
+
271
+
272
+ # ─────────────────────────── Block Variants ───────────────────────────
273
+ class StandardBlock(nn.Module):
274
+ def __init__(self, d: int, h: int):
275
+ super().__init__()
276
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
277
+ self.attn = StandardAttention(d, h)
278
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
279
+
280
+ def forward(self, x, mask=None):
281
+ x = x + self.attn(self.ln1(x), mask)
282
+ return x + self.ff(self.ln2(x))
283
+
284
+
285
+ class IterativeBlock(nn.Module):
286
+ def __init__(self, d: int, h: int, max_iters: int = 5):
287
+ super().__init__()
288
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
289
+ self.attn = IterativeAttention(d, h, max_iters=max_iters)
290
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
291
+
292
+ def forward(self, x, mask=None):
293
+ x = x + self.attn(self.ln1(x), mask)
294
+ return x + self.ff(self.ln2(x))
295
+
296
+
297
+ class TripletBlock(nn.Module):
298
+ def __init__(self, d: int, h: int):
299
+ super().__init__()
300
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
301
+ self.attn = TripletAttention(d, h)
302
+ self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
303
+
304
+ def forward(self, x, mask=None):
305
+ x = x + self.attn(self.ln1(x), mask)
306
+ return x + self.ff(self.ln2(x))
307
+
308
+
309
+ # ─────────────────────────── Models ───────────────────────────
310
+ class HeavyTransformer(nn.Module):
311
+ def __init__(self, d: int, layers: int, heads: int, mode: str = "standard"):
312
+ super().__init__()
313
+ self.emb = nn.Embedding(VOCAB, d)
314
+
315
+ if mode == "standard":
316
+ self.blocks = nn.ModuleList([StandardBlock(d, heads) for _ in range(layers)])
317
+ elif mode == "iterative":
318
+ self.blocks = nn.ModuleList([IterativeBlock(d, heads) for _ in range(layers)])
319
+ elif mode == "triplet":
320
+ self.blocks = nn.ModuleList([TripletBlock(d, heads) for _ in range(layers)])
321
+ else:
322
+ raise ValueError(f"Unknown mode: {mode}")
323
+
324
+ self.ln = nn.LayerNorm(d)
325
+ self.head = nn.Linear(d, VOCAB)
326
+ self.mode = mode
327
+
328
+ # Tie weights
329
+ self.head.weight = self.emb.weight
330
+
331
+ def forward(self, ids, mask=None):
332
+ x = self.emb(ids)
333
+ for blk in self.blocks:
334
+ x = blk(x, mask)
335
+ return self.head(self.ln(x))
336
+
337
+ def count_params(self):
338
+ return sum(p.numel() for p in self.parameters())
339
+
340
+
341
+ # ─────────────────────────── Experiment Runner ───────────────────────────
342
+ def causal_mask(n):
343
+ return torch.triu(torch.full((1, 1, n, n), float("-inf"), device=DEV), 1)
344
+
345
+
346
+ def run_experiment(mode: str, d: int, layers: int, heads: int,
347
+ batch_size: int, seq_len: int, num_steps: int):
348
+ """Run training steps and measure loss + throughput"""
349
+ print(f"\n{'='*60}")
350
+ print(f"MODE: {mode.upper()}")
351
+ print(f"Config: d={d}, layers={layers}, heads={heads}")
352
+ print(f"{'='*60}")
353
+
354
+ model = HeavyTransformer(d, layers, heads, mode=mode).to(DEV)
355
+ print(f"Parameters: {model.count_params():,}")
356
+
357
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
358
+
359
+ losses = []
360
+ times = []
361
+
362
+ for step in range(num_steps):
363
+ # Random batch
364
+ ids = torch.randint(0, VOCAB, (batch_size, seq_len), device=DEV)
365
+ target = ids[:, 1:]
366
+ input_ids = ids[:, :-1]
367
+ mask = causal_mask(seq_len - 1)
368
+
369
+ start = time.time()
370
+
371
+ optimizer.zero_grad()
372
+ logits = model(input_ids, mask)
373
+ loss = F.cross_entropy(logits.view(-1, VOCAB), target.reshape(-1))
374
+ loss.backward()
375
+ optimizer.step()
376
+
377
+ elapsed = time.time() - start
378
+ times.append(elapsed)
379
+ losses.append(loss.item())
380
+
381
+ tok_per_sec = (batch_size * seq_len) / elapsed
382
+
383
+ if step % 10 == 0 or step == num_steps - 1:
384
+ print(f"Step {step:3d} | Loss: {loss.item():.4f} | {tok_per_sec:.0f} tok/s | {elapsed*1000:.0f}ms")
385
+
386
+ # For iterative attention, show extra stats
387
+ if mode == "iterative" and hasattr(model.blocks[0].attn, '_last_iters'):
388
+ if step % 20 == 0:
389
+ avg_iters = model.blocks[0].attn._last_iters
390
+ compute_ratio = model.blocks[0].attn._last_compute_ratio
391
+ print(f" └─ Avg iters: {avg_iters}, Compute ratio: {compute_ratio:.2%}")
392
+
393
+ avg_loss = sum(losses[-20:]) / min(20, len(losses))
394
+ avg_time = sum(times[-20:]) / min(20, len(times))
395
+ avg_toks = (batch_size * seq_len) / avg_time
396
+
397
+ return {
398
+ "mode": mode,
399
+ "final_loss": losses[-1],
400
+ "avg_loss": avg_loss,
401
+ "avg_tok_per_sec": avg_toks,
402
+ "params": model.count_params()
403
+ }
404
+
405
+
406
+ def main():
407
+ parser = argparse.ArgumentParser(description="Heavy Attention Experiment")
408
+ parser.add_argument("--d", type=int, default=256, help="Model dimension")
409
+ parser.add_argument("--layers", type=int, default=4, help="Number of layers")
410
+ parser.add_argument("--heads", type=int, default=8, help="Number of heads")
411
+ parser.add_argument("--batch", type=int, default=8, help="Batch size")
412
+ parser.add_argument("--seq", type=int, default=128, help="Sequence length")
413
+ parser.add_argument("--steps", type=int, default=100, help="Training steps")
414
+ parser.add_argument("--mode", type=str, default="all",
415
+ choices=["standard", "iterative", "triplet", "all"])
416
+ args = parser.parse_args()
417
+
418
+ print(f"Device: {DEV}")
419
+ print(f"CUDA available: {torch.cuda.is_available()}")
420
+ if torch.cuda.is_available():
421
+ print(f"GPU: {torch.cuda.get_device_name()}")
422
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
423
+
424
+ results = []
425
+
426
+ modes = ["standard", "iterative", "triplet"] if args.mode == "all" else [args.mode]
427
+
428
+ for mode in modes:
429
+ try:
430
+ result = run_experiment(
431
+ mode=mode,
432
+ d=args.d,
433
+ layers=args.layers,
434
+ heads=args.heads,
435
+ batch_size=args.batch,
436
+ seq_len=args.seq,
437
+ num_steps=args.steps
438
+ )
439
+ results.append(result)
440
+ except Exception as e:
441
+ print(f"ERROR in {mode}: {e}")
442
+ import traceback
443
+ traceback.print_exc()
444
+
445
+ # Summary
446
+ print(f"\n{'='*60}")
447
+ print("SUMMARY")
448
+ print(f"{'='*60}")
449
+ for r in results:
450
+ print(f"{r['mode']:12s} | Loss: {r['avg_loss']:.4f} | {r['avg_tok_per_sec']:6.0f} tok/s | {r['params']:,} params")
451
+
452
+ # Scientific comparison
453
+ if len(results) >= 2:
454
+ baseline = next((r for r in results if r['mode'] == 'standard'), results[0])
455
+ print(f"\n{'='*60}")
456
+ print("RELATIVE TO STANDARD:")
457
+ print(f"{'='*60}")
458
+ for r in results:
459
+ if r['mode'] != 'standard':
460
+ loss_diff = (baseline['avg_loss'] - r['avg_loss']) / baseline['avg_loss'] * 100
461
+ speed_ratio = r['avg_tok_per_sec'] / baseline['avg_tok_per_sec']
462
+ print(f"{r['mode']:12s} | Loss: {loss_diff:+.1f}% | Speed: {speed_ratio:.2f}x")
463
+
464
+
465
+ if __name__ == "__main__":
466
+ main()