OpenTransformer commited on
Commit
a1e7fdb
Β·
verified Β·
1 Parent(s): 2db758d

Add experiments/n_ultra.py

Browse files
Files changed (1) hide show
  1. experiments/n_ultra.py +715 -0
experiments/n_ultra.py ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ n_ultra.py β€” ULTRA Heavy Attention Experiments
4
+ Mechanisms that are borderline impractical but theoretically interesting
5
+
6
+ 1. Neural Turing Machine (NTM) - Full differentiable computer
7
+ 2. Energy-Based Attention - Iterative energy minimization
8
+ 3. Cross-Layer Attention Lattice - Every layer attends to all others
9
+ 4. Continuous Depth (Neural ODE) - Infinite depth limit
10
+ 5. Full N-Body Dynamics - Physics-inspired message passing
11
+ 6. Hypernetwork Attention - Generate attention weights with another network
12
+ """
13
+
14
+ from __future__ import annotations
15
+ import argparse, math, time
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ torch.backends.cuda.matmul.allow_tf32 = True
22
+ try:
23
+ torch.set_float32_matmul_precision("high")
24
+ except:
25
+ pass
26
+
27
+ VOCAB = 128256
28
+
29
+ def _alibi_slopes(n_heads: int):
30
+ def pow2slopes(n):
31
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
32
+ return [start * (start ** i) for i in range(n)]
33
+ if n_heads > 0 and math.log2(n_heads).is_integer():
34
+ vals = pow2slopes(n_heads)
35
+ else:
36
+ closest = 2 ** math.floor(math.log2(max(1, n_heads)))
37
+ vals = pow2slopes(closest)
38
+ extra = pow2slopes(2 * closest)
39
+ vals += extra[0::2][:n_heads - closest]
40
+ return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
41
+
42
+ def alibi_bias(n_heads: int, n_tokens: int):
43
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
44
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
45
+ dist = (j - i).clamp_min(0).float()
46
+ slopes = _alibi_slopes(n_heads)
47
+ return -slopes * dist
48
+
49
+ def causal_mask(n):
50
+ return torch.triu(torch.full((1, 1, n, n), float("-inf"), device=DEV), 1)
51
+
52
+
53
+ # ═══════════════════════════════════════════════════════════════
54
+ # BASELINE
55
+ # ═══════════════════════════════════════════════════════════════
56
+ class StandardAttention(nn.Module):
57
+ def __init__(self, d: int, h: int):
58
+ super().__init__()
59
+ self.h, self.dk = h, d // h
60
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
61
+ self.proj = nn.Linear(d, d, bias=False)
62
+
63
+ def forward(self, x, mask=None, **kwargs):
64
+ B, N, _ = x.shape
65
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
66
+ q, k, v = qkv[0], qkv[1], qkv[2]
67
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
68
+ att = att + alibi_bias(self.h, N)
69
+ if mask is not None:
70
+ att = att + mask
71
+ z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
72
+ return self.proj(z)
73
+
74
+
75
+ # ═══════════════════════════════════════════════════════════════
76
+ # ULTRA 1: Neural Turing Machine (NTM)
77
+ # Full differentiable computer with external memory + read/write heads
78
+ # ═══════════════════════════════════════════════════════════════
79
+ class NTMAttention(nn.Module):
80
+ """
81
+ Neural Turing Machine: external memory matrix with content + location addressing.
82
+
83
+ Each forward pass:
84
+ 1. Read from memory using attention over memory slots
85
+ 2. Process with self-attention augmented by memory
86
+ 3. Write to memory using learned write weights
87
+
88
+ Memory operations are fully differentiable.
89
+ O(nΒ² + n*M*read_heads + M*write_ops)
90
+ """
91
+ def __init__(self, d: int, h: int, mem_slots: int = 128, num_heads: int = 4):
92
+ super().__init__()
93
+ self.d = d
94
+ self.h, self.dk = h, d // h
95
+ self.mem_slots = mem_slots
96
+ self.num_read_heads = num_heads
97
+
98
+ # Memory (persistent across sequence, reset per batch)
99
+ self.mem_init = nn.Parameter(torch.randn(1, mem_slots, d) * 0.01)
100
+
101
+ # Read heads - content-based addressing
102
+ self.read_key = nn.Linear(d, d * num_heads)
103
+ self.read_beta = nn.Linear(d, num_heads) # Sharpening
104
+ self.read_gate = nn.Linear(d, num_heads) # Interpolation gate
105
+ self.read_shift = nn.Linear(d, num_heads * 3) # Location shift (-1, 0, +1)
106
+
107
+ # Write head
108
+ self.write_key = nn.Linear(d, d)
109
+ self.write_beta = nn.Linear(d, 1)
110
+ self.erase_vec = nn.Linear(d, d)
111
+ self.add_vec = nn.Linear(d, d)
112
+
113
+ # Standard attention components
114
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
115
+ self.proj = nn.Linear(d * 2, d, bias=False) # Concat self-attn + read
116
+
117
+ def _content_addressing(self, memory, keys, betas):
118
+ """Compute attention weights based on content similarity"""
119
+ # memory: (B, M, D), keys: (B, N, H, D), betas: (B, N, H)
120
+ B, M, D = memory.shape
121
+ _, N, H, _ = keys.shape
122
+
123
+ # Cosine similarity
124
+ mem_norm = F.normalize(memory, dim=-1) # (B, M, D)
125
+ key_norm = F.normalize(keys, dim=-1) # (B, N, H, D)
126
+
127
+ # (B, N, H, D) @ (B, D, M) -> (B, N, H, M)
128
+ sim = torch.einsum('bnhd,bmd->bnhm', key_norm, mem_norm)
129
+
130
+ # Sharpen with beta
131
+ weights = F.softmax(betas.unsqueeze(-1) * sim, dim=-1) # (B, N, H, M)
132
+ return weights
133
+
134
+ def _location_shift(self, weights, shift_logits):
135
+ """Convolutional shift for location-based addressing"""
136
+ B, N, H, M = weights.shape
137
+ shift = F.softmax(shift_logits.view(B, N, H, 3), dim=-1) # (B, N, H, 3)
138
+
139
+ # Manual circular shift instead of padding
140
+ shifted = torch.zeros_like(weights)
141
+ shifted += shift[:, :, :, 0:1] * torch.roll(weights, 1, dims=-1) # left
142
+ shifted += shift[:, :, :, 1:2] * weights # center
143
+ shifted += shift[:, :, :, 2:3] * torch.roll(weights, -1, dims=-1) # right
144
+ return shifted
145
+
146
+ def forward(self, x, mask=None, **kwargs):
147
+ B, N, D = x.shape
148
+
149
+ # Initialize memory for this batch
150
+ memory = self.mem_init.expand(B, -1, -1).clone() # (B, M, D)
151
+
152
+ # === READ OPERATION ===
153
+ read_keys = self.read_key(x).view(B, N, self.num_read_heads, D)
154
+ read_betas = F.softplus(self.read_beta(x)) # (B, N, H)
155
+ read_gates = torch.sigmoid(self.read_gate(x)) # (B, N, H)
156
+ read_shifts = self.read_shift(x) # (B, N, H*3)
157
+
158
+ # Content-based weights
159
+ content_weights = self._content_addressing(memory, read_keys, read_betas)
160
+
161
+ # Location-based shift
162
+ shifted_weights = self._location_shift(content_weights, read_shifts)
163
+
164
+ # Interpolate (simplified - just use content weights)
165
+ read_weights = content_weights # (B, N, H, M)
166
+
167
+ # Read from memory
168
+ # (B, N, H, M) @ (B, M, D) -> (B, N, H, D)
169
+ read_vectors = torch.einsum('bnhm,bmd->bnhd', read_weights, memory)
170
+ read_out = read_vectors.mean(dim=2) # Average across heads (B, N, D)
171
+
172
+ # === SELF-ATTENTION ===
173
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
174
+ q, k, v = qkv[0], qkv[1], qkv[2]
175
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
176
+ att = att + alibi_bias(self.h, N)
177
+ if mask is not None:
178
+ att = att + mask
179
+ self_out = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
180
+
181
+ # === WRITE OPERATION ===
182
+ write_key = self.write_key(x[:, -1:, :]) # Use last position (B, 1, D)
183
+ write_beta = F.softplus(self.write_beta(x[:, -1:, :]))
184
+ write_weights = self._content_addressing(
185
+ memory,
186
+ write_key.unsqueeze(2), # (B, 1, 1, D)
187
+ write_beta.squeeze(-1).unsqueeze(-1) # (B, 1, 1)
188
+ ).squeeze(2) # (B, 1, M)
189
+
190
+ # Erase and add
191
+ erase = torch.sigmoid(self.erase_vec(x[:, -1:, :])) # (B, 1, D)
192
+ add = self.add_vec(x[:, -1:, :]) # (B, 1, D)
193
+
194
+ # Memory update (for next call - not used in this forward)
195
+ # memory = memory * (1 - write_weights.transpose(-1,-2) @ erase)
196
+ # memory = memory + write_weights.transpose(-1,-2) @ add
197
+
198
+ # Combine self-attention and memory read
199
+ combined = torch.cat([self_out, read_out], dim=-1)
200
+ return self.proj(combined)
201
+
202
+
203
+ # ═══════════════════════════════════════════════════════════════
204
+ # ULTRA 2: Energy-Based Attention
205
+ # Iterative energy minimization instead of single softmax
206
+ # ═══════════════════════════════════════════════════════════════
207
+ class EnergyAttention(nn.Module):
208
+ """
209
+ Energy-based model for attention: find attention weights that minimize energy.
210
+
211
+ E(a, q, k, v) = -sum(a_ij * sim(q_i, k_j)) + entropy(a) + prior
212
+
213
+ Iterate gradient descent on attention weights until convergence.
214
+ Much heavier than softmax but potentially more expressive.
215
+
216
+ O(iters * nΒ²)
217
+ """
218
+ def __init__(self, d: int, h: int, num_iters: int = 10, step_size: float = 0.5):
219
+ super().__init__()
220
+ self.h, self.dk = h, d // h
221
+ self.num_iters = num_iters
222
+ self.step_size = step_size
223
+
224
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
225
+ self.proj = nn.Linear(d, d, bias=False)
226
+
227
+ # Learnable energy function parameters
228
+ self.energy_scale = nn.Parameter(torch.ones(h))
229
+ self.temperature = nn.Parameter(torch.ones(h) * 0.1)
230
+
231
+ def _compute_energy(self, attn_logits, attn_weights, mask):
232
+ """
233
+ Energy = -similarity + temperature * entropy
234
+ Lower energy = better attention pattern
235
+ """
236
+ # Similarity term (want to maximize, so negate)
237
+ sim_energy = -attn_logits * attn_weights
238
+
239
+ # Entropy regularization (encourage sharpness)
240
+ entropy = -attn_weights * torch.log(attn_weights + 1e-10)
241
+
242
+ # Total energy per head
243
+ temp = self.temperature.view(1, -1, 1, 1)
244
+ energy = sim_energy.sum(dim=-1) + temp * entropy.sum(dim=-1)
245
+
246
+ return energy.mean()
247
+
248
+ def forward(self, x, mask=None, **kwargs):
249
+ B, N, _ = x.shape
250
+
251
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
252
+ q, k, v = qkv[0], qkv[1], qkv[2]
253
+
254
+ # Initial attention logits
255
+ scale = self.energy_scale.view(1, -1, 1, 1)
256
+ attn_logits = scale * (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
257
+ attn_logits = attn_logits + alibi_bias(self.h, N)
258
+
259
+ if mask is not None:
260
+ attn_logits = attn_logits + mask
261
+
262
+ # Initialize attention weights with softmax
263
+ attn_weights = F.softmax(attn_logits, dim=-1)
264
+
265
+ # Iterative refinement via energy minimization
266
+ for _ in range(self.num_iters):
267
+ # Compute gradient of energy w.r.t. attention weights
268
+ # Simplified: use attention logits as gradient signal
269
+
270
+ # Energy gradient approximation
271
+ with torch.enable_grad():
272
+ attn_weights_param = attn_weights.detach().requires_grad_(True)
273
+ energy = self._compute_energy(attn_logits, attn_weights_param, mask)
274
+ grad = torch.autograd.grad(energy, attn_weights_param)[0]
275
+
276
+ # Gradient step in logit space
277
+ attn_logits_new = attn_logits - self.step_size * grad
278
+
279
+ # Project back to valid distribution
280
+ if mask is not None:
281
+ attn_logits_new = attn_logits_new + mask
282
+ attn_weights = F.softmax(attn_logits_new, dim=-1)
283
+
284
+ z = (attn_weights @ v).transpose(1, 2).reshape(B, N, -1)
285
+ return self.proj(z)
286
+
287
+
288
+ # ═══════════════════════════════════════════════════════════════
289
+ # ULTRA 3: Cross-Layer Attention Lattice
290
+ # Every layer can attend to outputs of ALL other layers
291
+ # ═══════════════════════════════════════════════════════════════
292
+ class LatticeAttention(nn.Module):
293
+ """
294
+ Instead of sequential layers, create a lattice where each layer
295
+ can attend to all other layers' outputs.
296
+
297
+ Requires storing all layer outputs and recomputing.
298
+ O(LΒ² * nΒ²) where L = number of layers
299
+
300
+ This is implemented at the model level, not attention level.
301
+ """
302
+ def __init__(self, d: int, h: int, cross_layers: int = 4):
303
+ super().__init__()
304
+ self.h, self.dk = h, d // h
305
+ self.cross_layers = cross_layers
306
+
307
+ # Self-attention
308
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
309
+
310
+ # Cross-layer attention (query current, key/value from other layers)
311
+ self.cross_q = nn.Linear(d, d, bias=False)
312
+ self.cross_kv = nn.Linear(d, 2 * d, bias=False)
313
+
314
+ # Combine self and cross
315
+ self.proj = nn.Linear(d * 2, d, bias=False)
316
+
317
+ # Store for lattice
318
+ self.layer_outputs = None
319
+
320
+ def forward(self, x, mask=None, layer_idx=0, all_layers=None, **kwargs):
321
+ B, N, _ = x.shape
322
+
323
+ # Self-attention
324
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
325
+ q, k, v = qkv[0], qkv[1], qkv[2]
326
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
327
+ att = att + alibi_bias(self.h, N)
328
+ if mask is not None:
329
+ att = att + mask
330
+ self_out = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
331
+
332
+ # Cross-layer attention (if we have other layer outputs)
333
+ if all_layers is not None and len(all_layers) > 0:
334
+ # Stack all previous layer outputs
335
+ stacked = torch.stack(all_layers, dim=2) # (B, N, L, D)
336
+ B, N, L, D = stacked.shape
337
+
338
+ # Query from current, key/value from all layers
339
+ cross_q = self.cross_q(x).view(B, N, self.h, self.dk) # (B, N, H, dk)
340
+
341
+ # Reshape for cross attention
342
+ stacked_flat = stacked.view(B, N * L, D)
343
+ cross_kv = self.cross_kv(stacked_flat).view(B, N * L, 2, self.h, self.dk)
344
+ cross_k, cross_v = cross_kv[:, :, 0], cross_kv[:, :, 1]
345
+
346
+ # Cross attention
347
+ cross_q = cross_q.transpose(1, 2) # (B, H, N, dk)
348
+ cross_k = cross_k.view(B, N * L, self.h, self.dk).transpose(1, 2)
349
+ cross_v = cross_v.view(B, N * L, self.h, self.dk).transpose(1, 2)
350
+
351
+ cross_att = (cross_q @ cross_k.transpose(-1, -2)) / math.sqrt(self.dk)
352
+ cross_out = (cross_att.softmax(-1) @ cross_v).transpose(1, 2).reshape(B, N, -1)
353
+ else:
354
+ cross_out = torch.zeros_like(self_out)
355
+
356
+ combined = torch.cat([self_out, cross_out], dim=-1)
357
+ return self.proj(combined)
358
+
359
+
360
+ # ═══════════════════════════════════════════════════════════════
361
+ # ULTRA 4: N-Body Dynamics Attention
362
+ # Treat tokens as particles with forces between them
363
+ # ═══════════════════════════════════════════════════════════════
364
+ class NBodyAttention(nn.Module):
365
+ """
366
+ Physics-inspired: tokens are particles with forces.
367
+ Simplified version that avoids shape complexity.
368
+ """
369
+ def __init__(self, d: int, h: int, num_steps: int = 5, dt: float = 0.1):
370
+ super().__init__()
371
+ self.d = d
372
+ self.num_steps = num_steps
373
+ self.dt = dt
374
+
375
+ self.to_pos = nn.Linear(d, d)
376
+ self.to_vel = nn.Linear(d, d)
377
+
378
+ # Simplified force: pairwise similarity drives attraction
379
+ self.force_scale = nn.Parameter(torch.ones(1) * 0.1)
380
+
381
+ self.out_proj = nn.Linear(d * 2, d)
382
+
383
+ def forward(self, x, mask=None, **kwargs):
384
+ B, N, D = x.shape
385
+
386
+ pos = self.to_pos(x)
387
+ vel = self.to_vel(x)
388
+
389
+ # Causal mask
390
+ causal = torch.triu(torch.ones(N, N, device=x.device), diagonal=1)
391
+ causal_mask = 1.0 - causal # (N, N) lower triangular
392
+
393
+ for _ in range(self.num_steps):
394
+ # Pairwise distances
395
+ pos_diff = pos.unsqueeze(2) - pos.unsqueeze(1) # (B, N, N, D)
396
+ dist_sq = (pos_diff ** 2).sum(-1, keepdim=True) + 1e-6 # (B, N, N, 1)
397
+
398
+ # Force proportional to 1/distance (like gravity)
399
+ force_dir = pos_diff / (dist_sq.sqrt() + 1e-6) # (B, N, N, D)
400
+ force_mag = self.force_scale / dist_sq # (B, N, N, 1)
401
+ forces = force_dir * force_mag # (B, N, N, D)
402
+
403
+ # Apply causal mask
404
+ forces = forces * causal_mask.view(1, N, N, 1)
405
+
406
+ # Sum forces
407
+ total_force = forces.sum(dim=2) # (B, N, D)
408
+
409
+ # Update
410
+ vel = vel + self.dt * total_force
411
+ pos = pos + self.dt * vel
412
+
413
+ out = torch.cat([pos, vel], dim=-1)
414
+ return self.out_proj(out)
415
+
416
+
417
+ # ═══════════════════════════════════════════════════════════════
418
+ # ULTRA 5: Hypernetwork Attention
419
+ # A separate network generates the attention weights
420
+ # ═══════════════════════════════════════════════════════════════
421
+ class HyperAttention(nn.Module):
422
+ """
423
+ Instead of QK^T -> softmax, use a hypernetwork to generate attention.
424
+
425
+ The hypernetwork takes (query_token, key_token) and outputs attention weight.
426
+ Much more expressive but O(nΒ² * hypernetwork_cost).
427
+ """
428
+ def __init__(self, d: int, h: int, hyper_hidden: int = 64):
429
+ super().__init__()
430
+ self.h, self.dk = h, d // h
431
+
432
+ self.to_q = nn.Linear(d, d, bias=False)
433
+ self.to_k = nn.Linear(d, d, bias=False)
434
+ self.to_v = nn.Linear(d, d, bias=False)
435
+
436
+ # Hypernetwork: generates attention weight from (q, k) pair
437
+ self.hypernet = nn.Sequential(
438
+ nn.Linear(self.dk * 2, hyper_hidden),
439
+ nn.SiLU(),
440
+ nn.Linear(hyper_hidden, hyper_hidden),
441
+ nn.SiLU(),
442
+ nn.Linear(hyper_hidden, 1)
443
+ )
444
+
445
+ self.proj = nn.Linear(d, d, bias=False)
446
+
447
+ def forward(self, x, mask=None, **kwargs):
448
+ B, N, _ = x.shape
449
+
450
+ q = self.to_q(x).view(B, N, self.h, self.dk) # (B, N, H, dk)
451
+ k = self.to_k(x).view(B, N, self.h, self.dk)
452
+ v = self.to_v(x).view(B, N, self.h, self.dk)
453
+
454
+ # Compute attention via hypernetwork
455
+ # Need to process all (i, j) pairs
456
+ attn_logits = torch.zeros(B, self.h, N, N, device=x.device)
457
+
458
+ for head in range(self.h):
459
+ q_h = q[:, :, head, :] # (B, N, dk)
460
+ k_h = k[:, :, head, :]
461
+
462
+ # Expand for pairwise
463
+ q_exp = q_h.unsqueeze(2).expand(-1, -1, N, -1) # (B, N, N, dk)
464
+ k_exp = k_h.unsqueeze(1).expand(-1, N, -1, -1) # (B, N, N, dk)
465
+
466
+ # Concatenate and run through hypernetwork
467
+ pair_input = torch.cat([q_exp, k_exp], dim=-1) # (B, N, N, 2*dk)
468
+ attn_logits[:, head] = self.hypernet(pair_input).squeeze(-1) # (B, N, N)
469
+
470
+ # Add ALiBi bias
471
+ attn_logits = attn_logits + alibi_bias(self.h, N)
472
+
473
+ if mask is not None:
474
+ attn_logits = attn_logits + mask
475
+
476
+ attn_weights = F.softmax(attn_logits, dim=-1) # (B, H, N, N)
477
+
478
+ # Apply attention
479
+ v = v.transpose(1, 2) # (B, H, N, dk)
480
+ out = (attn_weights @ v).transpose(1, 2).reshape(B, N, -1)
481
+
482
+ return self.proj(out)
483
+
484
+
485
+ # ═══════════════════════════════════════════════════════════════
486
+ # ULTRA 6: Differentiable Sorting Attention
487
+ # Sort tokens by relevance, attend in sorted order
488
+ # ═══════════════════════════════════════════════════════════════
489
+ class SortingAttention(nn.Module):
490
+ """
491
+ Differentiable sorting: learn to reorder tokens by importance,
492
+ then apply attention in sorted space.
493
+
494
+ Uses Sinkhorn operator for soft permutation matrices.
495
+ O(sinkhorn_iters * nΒ² + nΒ²)
496
+ """
497
+ def __init__(self, d: int, h: int, sinkhorn_iters: int = 10, temp: float = 0.1):
498
+ super().__init__()
499
+ self.h, self.dk = h, d // h
500
+ self.sinkhorn_iters = sinkhorn_iters
501
+ self.temp = temp
502
+
503
+ # Scoring network for sorting
504
+ self.score = nn.Linear(d, 1)
505
+
506
+ # Standard attention
507
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
508
+ self.proj = nn.Linear(d, d, bias=False)
509
+
510
+ def _sinkhorn(self, log_alpha, iters):
511
+ """Sinkhorn normalization for soft permutation"""
512
+ for _ in range(iters):
513
+ log_alpha = log_alpha - torch.logsumexp(log_alpha, dim=-1, keepdim=True)
514
+ log_alpha = log_alpha - torch.logsumexp(log_alpha, dim=-2, keepdim=True)
515
+ return torch.exp(log_alpha)
516
+
517
+ def forward(self, x, mask=None, **kwargs):
518
+ B, N, D = x.shape
519
+
520
+ # Compute sorting scores
521
+ scores = self.score(x).squeeze(-1) # (B, N)
522
+
523
+ # Create soft permutation matrix via Sinkhorn
524
+ # log_alpha[i,j] = score[i] (want row i to go to position based on score)
525
+ log_alpha = scores.unsqueeze(-1) - scores.unsqueeze(-2) # (B, N, N)
526
+ log_alpha = log_alpha / self.temp
527
+
528
+ perm = self._sinkhorn(log_alpha, self.sinkhorn_iters) # (B, N, N)
529
+
530
+ # Apply permutation to get sorted tokens
531
+ x_sorted = torch.einsum('bnm,bmd->bnd', perm, x) # (B, N, D)
532
+
533
+ # Standard attention on sorted tokens
534
+ qkv = self.qkv(x_sorted).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
535
+ q, k, v = qkv[0], qkv[1], qkv[2]
536
+
537
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
538
+ att = att + alibi_bias(self.h, N)
539
+ if mask is not None:
540
+ att = att + mask
541
+
542
+ out_sorted = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
543
+
544
+ # Inverse permutation to restore order
545
+ perm_inv = perm.transpose(-1, -2)
546
+ out = torch.einsum('bnm,bmd->bnd', perm_inv, out_sorted)
547
+
548
+ return self.proj(out)
549
+
550
+
551
+ # ═══════════════════════════════════════════════════════════════
552
+ # Block and Model
553
+ # ═══════════════════════════════════════════════════════════════
554
+ class Block(nn.Module):
555
+ def __init__(self, d: int, h: int, attn_type: str = "standard", **kwargs):
556
+ super().__init__()
557
+ self.ln1 = nn.LayerNorm(d)
558
+ self.ln2 = nn.LayerNorm(d)
559
+
560
+ attn_map = {
561
+ "standard": StandardAttention,
562
+ "ntm": NTMAttention,
563
+ "energy": EnergyAttention,
564
+ "lattice": LatticeAttention,
565
+ "nbody": NBodyAttention,
566
+ "hyper": HyperAttention,
567
+ "sorting": SortingAttention,
568
+ }
569
+
570
+ if attn_type not in attn_map:
571
+ raise ValueError(f"Unknown: {attn_type}")
572
+
573
+ self.attn = attn_map[attn_type](d, h, **kwargs)
574
+ self.attn_type = attn_type
575
+
576
+ self.ff = nn.Sequential(
577
+ nn.Linear(d, 4 * d),
578
+ nn.GELU(),
579
+ nn.Linear(4 * d, d)
580
+ )
581
+
582
+ def forward(self, x, mask=None, **kwargs):
583
+ x = x + self.attn(self.ln1(x), mask, **kwargs)
584
+ x = x + self.ff(self.ln2(x))
585
+ return x
586
+
587
+
588
+ class UltraModel(nn.Module):
589
+ def __init__(self, d: int, layers: int, h: int, attn_type: str = "standard", **kwargs):
590
+ super().__init__()
591
+ self.emb = nn.Embedding(VOCAB, d)
592
+ self.blocks = nn.ModuleList([Block(d, h, attn_type, **kwargs) for _ in range(layers)])
593
+ self.ln = nn.LayerNorm(d)
594
+ self.head = nn.Linear(d, VOCAB, bias=False)
595
+ self.head.weight = self.emb.weight
596
+ self.attn_type = attn_type
597
+
598
+ def forward(self, x, mask=None):
599
+ x = self.emb(x)
600
+
601
+ if self.attn_type == "lattice":
602
+ all_layers = []
603
+ for blk in self.blocks:
604
+ x = blk(x, mask, all_layers=all_layers)
605
+ all_layers.append(x.detach())
606
+ else:
607
+ for blk in self.blocks:
608
+ x = blk(x, mask)
609
+
610
+ return self.head(self.ln(x))
611
+
612
+ def count_params(self):
613
+ return sum(p.numel() for p in self.parameters())
614
+
615
+
616
+ # ═══════════════════════════════════════════════════════════════
617
+ # Experiment Runner
618
+ # ═══════════════════════════════════════════════════════════════
619
+ def run_experiment(attn_type, d, layers, heads, batch, seq, steps, **kwargs):
620
+ print(f"\n{'='*60}")
621
+ print(f"ULTRA ATTENTION: {attn_type.upper()}")
622
+ print(f"{'='*60}")
623
+
624
+ try:
625
+ model = UltraModel(d, layers, heads, attn_type, **kwargs).to(DEV)
626
+ except Exception as e:
627
+ print(f"Failed to create model: {e}")
628
+ return None
629
+
630
+ print(f"Parameters: {model.count_params():,}")
631
+
632
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
633
+ mask = causal_mask(seq - 1)
634
+
635
+ losses, times = [], []
636
+
637
+ for step in range(steps):
638
+ ids = torch.randint(0, VOCAB, (batch, seq), device=DEV)
639
+ target = ids[:, 1:]
640
+ input_ids = ids[:, :-1]
641
+
642
+ start = time.time()
643
+ optimizer.zero_grad()
644
+
645
+ try:
646
+ logits = model(input_ids, mask)
647
+ loss = F.cross_entropy(logits.view(-1, VOCAB), target.reshape(-1))
648
+ loss.backward()
649
+ optimizer.step()
650
+ except RuntimeError as e:
651
+ print(f"Step {step} failed: {e}")
652
+ break
653
+
654
+ elapsed = time.time() - start
655
+ losses.append(loss.item())
656
+ times.append(elapsed)
657
+ tok_s = (batch * seq) / elapsed
658
+
659
+ if step % 10 == 0 or step == steps - 1:
660
+ print(f"Step {step:3d} | Loss: {loss.item():.4f} | {tok_s:.0f} tok/s | {elapsed*1000:.0f}ms")
661
+
662
+ if not losses:
663
+ return None
664
+
665
+ avg_loss = sum(losses[-20:]) / min(20, len(losses))
666
+ avg_time = sum(times[-20:]) / min(20, len(times))
667
+ avg_toks = (batch * seq) / avg_time
668
+
669
+ return {"type": attn_type, "loss": avg_loss, "tok_s": avg_toks, "params": model.count_params()}
670
+
671
+
672
+ def main():
673
+ parser = argparse.ArgumentParser()
674
+ parser.add_argument("--d", type=int, default=256)
675
+ parser.add_argument("--layers", type=int, default=4)
676
+ parser.add_argument("--heads", type=int, default=8)
677
+ parser.add_argument("--batch", type=int, default=8)
678
+ parser.add_argument("--seq", type=int, default=64) # Shorter for ultra-heavy
679
+ parser.add_argument("--steps", type=int, default=50)
680
+ parser.add_argument("--types", type=str, default="all")
681
+ args = parser.parse_args()
682
+
683
+ print(f"Device: {DEV}")
684
+ if torch.cuda.is_available():
685
+ print(f"GPU: {torch.cuda.get_device_name()}")
686
+
687
+ if args.types == "all":
688
+ types = ["standard", "ntm", "energy", "nbody", "hyper", "sorting"]
689
+ else:
690
+ types = [t.strip() for t in args.types.split(",")]
691
+
692
+ results = []
693
+ for t in types:
694
+ r = run_experiment(t, args.d, args.layers, args.heads,
695
+ args.batch, args.seq, args.steps)
696
+ if r:
697
+ results.append(r)
698
+ torch.cuda.empty_cache()
699
+
700
+ print(f"\n{'='*60}")
701
+ print("SUMMARY")
702
+ print(f"{'='*60}")
703
+
704
+ baseline = next((r for r in results if r['type'] == 'standard'), None)
705
+ for r in results:
706
+ rel = ""
707
+ if baseline and r['type'] != 'standard':
708
+ loss_diff = (baseline['loss'] - r['loss']) / baseline['loss'] * 100
709
+ speed_ratio = r['tok_s'] / baseline['tok_s']
710
+ rel = f" | vs std: {loss_diff:+.1f}% loss, {speed_ratio:.2f}x speed"
711
+ print(f"{r['type']:12s} | Loss: {r['loss']:.4f} | {r['tok_s']:6.0f} tok/s{rel}")
712
+
713
+
714
+ if __name__ == "__main__":
715
+ main()