OpenTransformer commited on
Commit
8a88d0a
·
verified ·
1 Parent(s): a1e7fdb

Add experiments/n_flex.py

Browse files
Files changed (1) hide show
  1. experiments/n_flex.py +665 -0
experiments/n_flex.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ n_flex.py — Flexible Attention Mechanisms
4
+ Constraint: Must support AR (causal), SAT (block), and NAR (bidirectional)
5
+
6
+ Testing:
7
+ 1. Linear Attention - O(n) instead of O(n²)
8
+ 2. Cosine Attention - Different similarity metric
9
+ 3. Differential Attention - Noise cancellation (Microsoft 2024)
10
+ 4. Local + Global - Sparse hybrid
11
+ 5. Multi-Query Attention (MQA) - Inference efficient
12
+ 6. Grouped Query Attention (GQA) - Between MHA and MQA
13
+ 7. Retention - RetNet style (recurrent + parallel)
14
+ 8. Gated Linear Attention - Recent efficient attention
15
+ 9. ReLU Attention - Simpler activation
16
+ 10. Sigmoid Attention - Bounded attention
17
+ """
18
+
19
+ from __future__ import annotations
20
+ import argparse, math, time
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from typing import Optional, Literal
25
+
26
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ torch.backends.cuda.matmul.allow_tf32 = True
28
+ VOCAB = 128256
29
+
30
+ # ═══════════════════════════════════════════════════════════════
31
+ # Masking utilities for AR/SAT/NAR
32
+ # ═══════════════════════════════════════════════════════════════
33
+ def get_mask(n: int, mode: str = "ar", block_size: int = 2):
34
+ """
35
+ AR (autoregressive): causal, see only past
36
+ SAT (semi-autoregressive): see within block + all past blocks
37
+ NAR (non-autoregressive): bidirectional, see everything
38
+ """
39
+ if mode == "nar":
40
+ return None # No mask
41
+ elif mode == "ar":
42
+ return torch.triu(torch.full((n, n), float("-inf"), device=DEV), 1)
43
+ elif mode == "sat":
44
+ # Block-wise: can see within same block and all previous blocks
45
+ idx = torch.arange(n, device=DEV)
46
+ block_idx = idx // block_size
47
+ # Allow if same block OR target block is earlier
48
+ mask = torch.where(
49
+ (block_idx.unsqueeze(0) <= block_idx.unsqueeze(1)),
50
+ torch.tensor(0.0, device=DEV),
51
+ torch.tensor(float("-inf"), device=DEV)
52
+ )
53
+ return mask
54
+ else:
55
+ raise ValueError(f"Unknown mode: {mode}")
56
+
57
+
58
+ def alibi_bias(n_heads: int, n_tokens: int):
59
+ def slopes(n):
60
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
61
+ return [start * (start ** i) for i in range(n)]
62
+ if n_heads > 0 and math.log2(n_heads).is_integer():
63
+ s = slopes(n_heads)
64
+ else:
65
+ closest = 2 ** math.floor(math.log2(max(1, n_heads)))
66
+ s = slopes(closest)[:n_heads]
67
+ s = torch.tensor(s, device=DEV).view(1, n_heads, 1, 1)
68
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
69
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
70
+ return -s * (j - i).clamp_min(0).float()
71
+
72
+
73
+ # ═══════════════════════════════════════════════════════════════
74
+ # 1. STANDARD (baseline)
75
+ # ═══════════════════════════════════════════════════════════════
76
+ class StandardAttention(nn.Module):
77
+ """Standard multi-head attention - O(n²)"""
78
+ def __init__(self, d: int, h: int):
79
+ super().__init__()
80
+ self.h, self.dk = h, d // h
81
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
82
+ self.proj = nn.Linear(d, d, bias=False)
83
+
84
+ def forward(self, x, mask=None):
85
+ B, N, _ = x.shape
86
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
87
+ q, k, v = qkv[0], qkv[1], qkv[2]
88
+
89
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
90
+ att = att + alibi_bias(self.h, N)
91
+ if mask is not None:
92
+ att = att + mask.unsqueeze(0).unsqueeze(0)
93
+
94
+ z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
95
+ return self.proj(z)
96
+
97
+
98
+ # ═══════════════════════════════════════════════════════════════
99
+ # 2. LINEAR ATTENTION - O(n) via kernel trick
100
+ # ═══════════════════════════════════════════════════════════════
101
+ class LinearAttention(nn.Module):
102
+ """
103
+ Linear attention: O(n) instead of O(n²)
104
+ Uses feature map φ(x) so that φ(q)φ(k)^T ≈ softmax(qk^T)
105
+
106
+ Key insight: (QK^T)V = Q(K^TV) - compute K^TV first for O(n)
107
+
108
+ Works with AR/SAT/NAR via cumsum tricks for causal
109
+ """
110
+ def __init__(self, d: int, h: int, feature_map: str = "elu"):
111
+ super().__init__()
112
+ self.h, self.dk = h, d // h
113
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
114
+ self.proj = nn.Linear(d, d, bias=False)
115
+ self.feature_map = feature_map
116
+ self.eps = 1e-6
117
+
118
+ def _phi(self, x):
119
+ """Feature map for linear attention"""
120
+ if self.feature_map == "elu":
121
+ return F.elu(x) + 1
122
+ elif self.feature_map == "relu":
123
+ return F.relu(x)
124
+ elif self.feature_map == "softmax":
125
+ return F.softmax(x, dim=-1)
126
+ else: # identity
127
+ return x
128
+
129
+ def forward(self, x, mask=None):
130
+ B, N, _ = x.shape
131
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
132
+ q, k, v = qkv[0], qkv[1], qkv[2] # (B, H, N, dk)
133
+
134
+ # Apply feature map
135
+ q = self._phi(q)
136
+ k = self._phi(k)
137
+
138
+ if mask is None:
139
+ # NAR: Full bidirectional - O(n) via associativity
140
+ # (Q @ K^T) @ V = Q @ (K^T @ V)
141
+ kv = torch.einsum('bhnd,bhnv->bhdv', k, v) # (B, H, dk, dv)
142
+ out = torch.einsum('bhnd,bhdv->bhnv', q, kv) # (B, H, N, dv)
143
+
144
+ # Normalize
145
+ k_sum = k.sum(dim=2, keepdim=True) # (B, H, 1, dk)
146
+ normalizer = torch.einsum('bhnd,bhkd->bhnk', q, k_sum).clamp(min=self.eps)
147
+ out = out / normalizer
148
+ else:
149
+ # AR/SAT: Causal via cumulative sum
150
+ # This is still O(n) but needs sequential computation
151
+ kv_cumsum = torch.cumsum(torch.einsum('bhnd,bhnv->bhndv', k, v), dim=2)
152
+ k_cumsum = torch.cumsum(k, dim=2)
153
+
154
+ out = torch.einsum('bhnd,bhndv->bhnv', q, kv_cumsum)
155
+ normalizer = torch.einsum('bhnd,bhnd->bhn', q, k_cumsum).unsqueeze(-1).clamp(min=self.eps)
156
+ out = out / normalizer
157
+
158
+ return self.proj(out.transpose(1, 2).reshape(B, N, -1))
159
+
160
+
161
+ # ═══════════════════════════════════════════════════════════════
162
+ # 3. COSINE ATTENTION - Different similarity metric
163
+ # ═══════════════════════════════════════════════════════════════
164
+ class CosineAttention(nn.Module):
165
+ """
166
+ Use cosine similarity instead of dot product.
167
+ More stable, bounded [-1, 1] before scaling.
168
+ """
169
+ def __init__(self, d: int, h: int, temp: float = 10.0):
170
+ super().__init__()
171
+ self.h, self.dk = h, d // h
172
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
173
+ self.proj = nn.Linear(d, d, bias=False)
174
+ self.temp = nn.Parameter(torch.tensor(temp))
175
+
176
+ def forward(self, x, mask=None):
177
+ B, N, _ = x.shape
178
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
179
+ q, k, v = qkv[0], qkv[1], qkv[2]
180
+
181
+ # Normalize for cosine similarity
182
+ q = F.normalize(q, dim=-1)
183
+ k = F.normalize(k, dim=-1)
184
+
185
+ att = self.temp * (q @ k.transpose(-1, -2)) # Cosine sim scaled by temp
186
+ if mask is not None:
187
+ att = att + mask.unsqueeze(0).unsqueeze(0)
188
+
189
+ z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
190
+ return self.proj(z)
191
+
192
+
193
+ # ═══════════════════════════════════════════════════════════════
194
+ # 4. DIFFERENTIAL ATTENTION - Noise cancellation
195
+ # ═══════════════════════════════════════════════════════════════
196
+ class DifferentialAttention(nn.Module):
197
+ """
198
+ From Microsoft's "Differential Transformer" (2024)
199
+
200
+ Compute two attention patterns and subtract:
201
+ Attn = softmax(Q1 K1^T) - λ * softmax(Q2 K2^T)
202
+
203
+ Cancels noise, improves signal.
204
+ """
205
+ def __init__(self, d: int, h: int):
206
+ super().__init__()
207
+ self.h, self.dk = h, d // h
208
+
209
+ # Two sets of Q, K projections
210
+ self.q1 = nn.Linear(d, d, bias=False)
211
+ self.k1 = nn.Linear(d, d, bias=False)
212
+ self.q2 = nn.Linear(d, d, bias=False)
213
+ self.k2 = nn.Linear(d, d, bias=False)
214
+ self.v = nn.Linear(d, d, bias=False)
215
+
216
+ # Learnable lambda for subtraction weight
217
+ self.lambda_param = nn.Parameter(torch.tensor(0.5))
218
+
219
+ self.proj = nn.Linear(d, d, bias=False)
220
+
221
+ def forward(self, x, mask=None):
222
+ B, N, _ = x.shape
223
+
224
+ q1 = self.q1(x).view(B, N, self.h, self.dk).transpose(1, 2)
225
+ k1 = self.k1(x).view(B, N, self.h, self.dk).transpose(1, 2)
226
+ q2 = self.q2(x).view(B, N, self.h, self.dk).transpose(1, 2)
227
+ k2 = self.k2(x).view(B, N, self.h, self.dk).transpose(1, 2)
228
+ v = self.v(x).view(B, N, self.h, self.dk).transpose(1, 2)
229
+
230
+ scale = math.sqrt(self.dk)
231
+
232
+ # First attention
233
+ att1 = (q1 @ k1.transpose(-1, -2)) / scale
234
+ if mask is not None:
235
+ att1 = att1 + mask.unsqueeze(0).unsqueeze(0)
236
+ att1 = att1.softmax(-1)
237
+
238
+ # Second attention
239
+ att2 = (q2 @ k2.transpose(-1, -2)) / scale
240
+ if mask is not None:
241
+ att2 = att2 + mask.unsqueeze(0).unsqueeze(0)
242
+ att2 = att2.softmax(-1)
243
+
244
+ # Differential: subtract weighted second from first
245
+ lam = torch.sigmoid(self.lambda_param)
246
+ att = att1 - lam * att2
247
+
248
+ # ReLU to ensure non-negative (optional, can remove)
249
+ att = F.relu(att)
250
+ att = att / (att.sum(dim=-1, keepdim=True) + 1e-6)
251
+
252
+ z = (att @ v).transpose(1, 2).reshape(B, N, -1)
253
+ return self.proj(z)
254
+
255
+
256
+ # ═══════════════════════════════════════════════════════════════
257
+ # 5. MULTI-QUERY ATTENTION (MQA) - Inference efficient
258
+ # ═══════════════════════════════════════════════════════════════
259
+ class MultiQueryAttention(nn.Module):
260
+ """
261
+ MQA: Multiple query heads, single K/V head.
262
+ Massive inference speedup (smaller KV cache).
263
+ Same training cost as standard.
264
+ """
265
+ def __init__(self, d: int, h: int):
266
+ super().__init__()
267
+ self.h, self.dk = h, d // h
268
+
269
+ # H query heads, but only 1 K and 1 V head
270
+ self.q = nn.Linear(d, d, bias=False) # H heads
271
+ self.k = nn.Linear(d, self.dk, bias=False) # 1 head
272
+ self.v = nn.Linear(d, self.dk, bias=False) # 1 head
273
+ self.proj = nn.Linear(d, d, bias=False)
274
+
275
+ def forward(self, x, mask=None):
276
+ B, N, _ = x.shape
277
+
278
+ q = self.q(x).view(B, N, self.h, self.dk).transpose(1, 2) # (B, H, N, dk)
279
+ k = self.k(x).view(B, N, 1, self.dk).transpose(1, 2) # (B, 1, N, dk)
280
+ v = self.v(x).view(B, N, 1, self.dk).transpose(1, 2) # (B, 1, N, dk)
281
+
282
+ # K, V broadcast across heads
283
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
284
+ att = att + alibi_bias(self.h, N)
285
+ if mask is not None:
286
+ att = att + mask.unsqueeze(0).unsqueeze(0)
287
+
288
+ z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
289
+ return self.proj(z)
290
+
291
+
292
+ # ═══════════════════════════════════════════════════════════════
293
+ # 6. GROUPED QUERY ATTENTION (GQA) - Between MHA and MQA
294
+ # ═══════════════════════════════════════════════════════════════
295
+ class GroupedQueryAttention(nn.Module):
296
+ """
297
+ GQA: Groups of query heads share K/V heads.
298
+ Llama 2 uses this. Balance between quality and inference speed.
299
+ """
300
+ def __init__(self, d: int, h: int, num_kv_heads: int = 2):
301
+ super().__init__()
302
+ self.h = h
303
+ self.num_kv_heads = num_kv_heads
304
+ self.dk = d // h
305
+ self.heads_per_group = h // num_kv_heads
306
+
307
+ self.q = nn.Linear(d, d, bias=False)
308
+ self.k = nn.Linear(d, num_kv_heads * self.dk, bias=False)
309
+ self.v = nn.Linear(d, num_kv_heads * self.dk, bias=False)
310
+ self.proj = nn.Linear(d, d, bias=False)
311
+
312
+ def forward(self, x, mask=None):
313
+ B, N, _ = x.shape
314
+
315
+ q = self.q(x).view(B, N, self.h, self.dk).transpose(1, 2)
316
+ k = self.k(x).view(B, N, self.num_kv_heads, self.dk).transpose(1, 2)
317
+ v = self.v(x).view(B, N, self.num_kv_heads, self.dk).transpose(1, 2)
318
+
319
+ # Repeat K, V for each group
320
+ k = k.repeat_interleave(self.heads_per_group, dim=1)
321
+ v = v.repeat_interleave(self.heads_per_group, dim=1)
322
+
323
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
324
+ att = att + alibi_bias(self.h, N)
325
+ if mask is not None:
326
+ att = att + mask.unsqueeze(0).unsqueeze(0)
327
+
328
+ z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
329
+ return self.proj(z)
330
+
331
+
332
+ # ═══════════════════════════════════════════════════════════════
333
+ # 7. RETENTION - RetNet style
334
+ # ═══════════════════════════════════════════════════════════════
335
+ class RetentionAttention(nn.Module):
336
+ """
337
+ From RetNet: Retentive Network
338
+
339
+ Parallel mode (training): Like linear attention
340
+ Recurrent mode (inference): O(1) per step
341
+
342
+ Key: exponential decay instead of softmax
343
+ """
344
+ def __init__(self, d: int, h: int, gamma: float = 0.9):
345
+ super().__init__()
346
+ self.h, self.dk = h, d // h
347
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
348
+ self.proj = nn.Linear(d, d, bias=False)
349
+
350
+ # Per-head decay rates
351
+ self.gamma = nn.Parameter(torch.ones(h) * gamma)
352
+
353
+ def forward(self, x, mask=None):
354
+ B, N, _ = x.shape
355
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
356
+ q, k, v = qkv[0], qkv[1], qkv[2]
357
+
358
+ # Build decay matrix D[i,j] = gamma^(i-j) for i >= j
359
+ gamma = torch.sigmoid(self.gamma).view(1, self.h, 1, 1)
360
+ positions = torch.arange(N, device=x.device).float()
361
+ decay = gamma ** (positions.unsqueeze(0) - positions.unsqueeze(1)).clamp(min=0)
362
+
363
+ # Apply causal mask via decay (future positions get 0)
364
+ causal = torch.tril(torch.ones(N, N, device=x.device))
365
+ decay = decay * causal.unsqueeze(0).unsqueeze(0)
366
+
367
+ # If SAT/NAR mask provided, incorporate it
368
+ if mask is not None:
369
+ mask_binary = (mask == 0).float().unsqueeze(0).unsqueeze(0)
370
+ decay = decay * mask_binary
371
+
372
+ # Retention = (Q @ K^T) * D @ V
373
+ att = (q @ k.transpose(-1, -2)) * decay
374
+
375
+ # Normalize per row
376
+ att = att / (att.sum(dim=-1, keepdim=True) + 1e-6)
377
+
378
+ z = (att @ v).transpose(1, 2).reshape(B, N, -1)
379
+ return self.proj(z)
380
+
381
+
382
+ # ═══════════════════════════════════════════════════════════════
383
+ # 8. GATED LINEAR ATTENTION
384
+ # ═══════════════════════════════════════════════════════════════
385
+ class GatedLinearAttention(nn.Module):
386
+ """
387
+ Linear attention with gating for better gradient flow.
388
+ From "Gated Linear Attention Transformers" (2024)
389
+ """
390
+ def __init__(self, d: int, h: int):
391
+ super().__init__()
392
+ self.h, self.dk = h, d // h
393
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
394
+ self.gate = nn.Linear(d, d)
395
+ self.proj = nn.Linear(d, d, bias=False)
396
+ self.eps = 1e-6
397
+
398
+ def forward(self, x, mask=None):
399
+ B, N, _ = x.shape
400
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
401
+ q, k, v = qkv[0], qkv[1], qkv[2]
402
+
403
+ # Feature map (ELU + 1 for positivity)
404
+ q = F.elu(q) + 1
405
+ k = F.elu(k) + 1
406
+
407
+ if mask is None:
408
+ # Bidirectional
409
+ kv = torch.einsum('bhnd,bhnv->bhdv', k, v)
410
+ out = torch.einsum('bhnd,bhdv->bhnv', q, kv)
411
+ normalizer = torch.einsum('bhnd,bhd->bhn', q, k.sum(dim=2)).unsqueeze(-1).clamp(min=self.eps)
412
+ else:
413
+ # Causal
414
+ kv_cumsum = torch.cumsum(torch.einsum('bhnd,bhnv->bhndv', k, v), dim=2)
415
+ k_cumsum = torch.cumsum(k, dim=2)
416
+ out = torch.einsum('bhnd,bhndv->bhnv', q, kv_cumsum)
417
+ normalizer = torch.einsum('bhnd,bhnd->bhn', q, k_cumsum).unsqueeze(-1).clamp(min=self.eps)
418
+
419
+ out = out / normalizer
420
+ out = out.transpose(1, 2).reshape(B, N, -1)
421
+
422
+ # Gating
423
+ gate = torch.sigmoid(self.gate(x))
424
+ out = out * gate
425
+
426
+ return self.proj(out)
427
+
428
+
429
+ # ═══════════════════════════════════════════════════════════════
430
+ # 9. RELU ATTENTION - Simpler activation
431
+ # ═══════════════════════════════════════════════════════════════
432
+ class ReLUAttention(nn.Module):
433
+ """
434
+ Replace softmax with ReLU + normalization.
435
+ Simpler, faster, sometimes works as well.
436
+ From "ReLU Attention" papers.
437
+ """
438
+ def __init__(self, d: int, h: int):
439
+ super().__init__()
440
+ self.h, self.dk = h, d // h
441
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
442
+ self.proj = nn.Linear(d, d, bias=False)
443
+
444
+ def forward(self, x, mask=None):
445
+ B, N, _ = x.shape
446
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
447
+ q, k, v = qkv[0], qkv[1], qkv[2]
448
+
449
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
450
+ att = att + alibi_bias(self.h, N)
451
+
452
+ if mask is not None:
453
+ att = att + mask.unsqueeze(0).unsqueeze(0)
454
+
455
+ # ReLU instead of softmax
456
+ att = F.relu(att)
457
+ att = att / (att.sum(dim=-1, keepdim=True) + 1e-6)
458
+
459
+ z = (att @ v).transpose(1, 2).reshape(B, N, -1)
460
+ return self.proj(z)
461
+
462
+
463
+ # ═══════════════════════════════════════════════════════════════
464
+ # 10. SIGMOID ATTENTION - Bounded
465
+ # ══���════════════════════════════════════════════════════════════
466
+ class SigmoidAttention(nn.Module):
467
+ """
468
+ Sigmoid attention: each position independently decides attention weight.
469
+ Not normalized to sum to 1 - allows variable "total attention".
470
+ """
471
+ def __init__(self, d: int, h: int):
472
+ super().__init__()
473
+ self.h, self.dk = h, d // h
474
+ self.qkv = nn.Linear(d, 3 * d, bias=False)
475
+ self.proj = nn.Linear(d, d, bias=False)
476
+ self.bias = nn.Parameter(torch.zeros(h, 1, 1))
477
+
478
+ def forward(self, x, mask=None):
479
+ B, N, _ = x.shape
480
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
481
+ q, k, v = qkv[0], qkv[1], qkv[2]
482
+
483
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) + self.bias
484
+
485
+ if mask is not None:
486
+ att = att + mask.unsqueeze(0).unsqueeze(0)
487
+
488
+ # Sigmoid instead of softmax - each weight independent
489
+ att = torch.sigmoid(att)
490
+
491
+ # Optional: mask out future for AR
492
+ if mask is not None:
493
+ att = att * (mask == 0).float().unsqueeze(0).unsqueeze(0)
494
+
495
+ z = (att @ v).transpose(1, 2).reshape(B, N, -1)
496
+ return self.proj(z)
497
+
498
+
499
+ # ═══════════════════════════════════════════════════════════════
500
+ # Block and Model
501
+ # ═══════════════════════════════════════════════════════════════
502
+ ATTN_REGISTRY = {
503
+ "standard": StandardAttention,
504
+ "linear": LinearAttention,
505
+ "cosine": CosineAttention,
506
+ "differential": DifferentialAttention,
507
+ "mqa": MultiQueryAttention,
508
+ "gqa": GroupedQueryAttention,
509
+ "retention": RetentionAttention,
510
+ "gated_linear": GatedLinearAttention,
511
+ "relu": ReLUAttention,
512
+ "sigmoid": SigmoidAttention,
513
+ }
514
+
515
+
516
+ class Block(nn.Module):
517
+ def __init__(self, d: int, h: int, attn_type: str = "standard"):
518
+ super().__init__()
519
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
520
+ self.attn = ATTN_REGISTRY[attn_type](d, h)
521
+ self.ff = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d))
522
+
523
+ def forward(self, x, mask=None):
524
+ x = x + self.attn(self.ln1(x), mask)
525
+ return x + self.ff(self.ln2(x))
526
+
527
+
528
+ class FlexModel(nn.Module):
529
+ def __init__(self, d: int, layers: int, h: int, attn_type: str = "standard"):
530
+ super().__init__()
531
+ self.emb = nn.Embedding(VOCAB, d)
532
+ self.blocks = nn.ModuleList([Block(d, h, attn_type) for _ in range(layers)])
533
+ self.ln = nn.LayerNorm(d)
534
+ self.head = nn.Linear(d, VOCAB, bias=False)
535
+ self.head.weight = self.emb.weight
536
+
537
+ def forward(self, x, mask=None):
538
+ x = self.emb(x)
539
+ for b in self.blocks:
540
+ x = b(x, mask)
541
+ return self.head(self.ln(x))
542
+
543
+ def count_params(self):
544
+ return sum(p.numel() for p in self.parameters())
545
+
546
+
547
+ # ═══════════════════════════════════════════════════════════════
548
+ # Training with AR/SAT/NAR modes
549
+ # ═══════════════════════════════════════════════════════════════
550
+ def train(attn_type: str, mode: str, d: int, layers: int, h: int,
551
+ batch: int, seq: int, steps: int, block_size: int = 4):
552
+
553
+ print(f"\n{'='*60}")
554
+ print(f"ATTENTION: {attn_type.upper()} | MODE: {mode.upper()}")
555
+ print(f"{'='*60}")
556
+
557
+ model = FlexModel(d, layers, h, attn_type).to(DEV)
558
+ print(f"Parameters: {model.count_params():,}")
559
+
560
+ opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
561
+
562
+ losses, times = [], []
563
+
564
+ for step in range(steps):
565
+ ids = torch.randint(0, VOCAB, (batch, seq), device=DEV)
566
+
567
+ if mode == "ar":
568
+ # Standard AR: predict next token
569
+ target = ids[:, 1:]
570
+ input_ids = ids[:, :-1]
571
+ mask = get_mask(seq - 1, "ar")
572
+ elif mode == "sat":
573
+ # SAT: predict within blocks
574
+ target = ids[:, 1:]
575
+ input_ids = ids[:, :-1]
576
+ mask = get_mask(seq - 1, "sat", block_size)
577
+ else: # nar
578
+ # NAR: predict all from [MASK] or noisy input
579
+ target = ids
580
+ # Add noise to input for NAR (simple version)
581
+ noise_mask = torch.rand(batch, seq, device=DEV) < 0.15
582
+ input_ids = ids.clone()
583
+ input_ids[noise_mask] = torch.randint(0, VOCAB, (noise_mask.sum().item(),), device=DEV)
584
+ mask = get_mask(seq, "nar")
585
+
586
+ start = time.time()
587
+ opt.zero_grad()
588
+
589
+ try:
590
+ logits = model(input_ids, mask)
591
+ loss = F.cross_entropy(logits.view(-1, VOCAB), target.reshape(-1))
592
+ loss.backward()
593
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
594
+ opt.step()
595
+ except Exception as e:
596
+ print(f"Step {step} failed: {e}")
597
+ return None
598
+
599
+ elapsed = time.time() - start
600
+ losses.append(loss.item())
601
+ times.append(elapsed)
602
+
603
+ if step % 20 == 0 or step == steps - 1:
604
+ tok_s = batch * seq / elapsed
605
+ print(f"Step {step:3d} | Loss {loss.item():.4f} | {tok_s:.0f} tok/s")
606
+
607
+ avg_loss = sum(losses[-20:]) / min(20, len(losses))
608
+ avg_toks = batch * seq / (sum(times[-20:]) / min(20, len(times)))
609
+
610
+ return {"attn": attn_type, "mode": mode, "loss": avg_loss, "tok_s": avg_toks}
611
+
612
+
613
+ def main():
614
+ parser = argparse.ArgumentParser()
615
+ parser.add_argument("--d", type=int, default=256)
616
+ parser.add_argument("--layers", type=int, default=4)
617
+ parser.add_argument("--heads", type=int, default=8)
618
+ parser.add_argument("--batch", type=int, default=16)
619
+ parser.add_argument("--seq", type=int, default=128)
620
+ parser.add_argument("--steps", type=int, default=100)
621
+ parser.add_argument("--mode", type=str, default="ar", choices=["ar", "sat", "nar", "all"])
622
+ parser.add_argument("--types", type=str, default="all")
623
+ args = parser.parse_args()
624
+
625
+ print(f"Device: {DEV}")
626
+ if torch.cuda.is_available():
627
+ print(f"GPU: {torch.cuda.get_device_name()}")
628
+
629
+ if args.types == "all":
630
+ types = list(ATTN_REGISTRY.keys())
631
+ else:
632
+ types = [t.strip() for t in args.types.split(",")]
633
+
634
+ modes = ["ar", "sat", "nar"] if args.mode == "all" else [args.mode]
635
+
636
+ results = []
637
+ for mode in modes:
638
+ for attn_type in types:
639
+ r = train(attn_type, mode, args.d, args.layers, args.heads,
640
+ args.batch, args.seq, args.steps)
641
+ if r:
642
+ results.append(r)
643
+ torch.cuda.empty_cache()
644
+
645
+ # Summary
646
+ print(f"\n{'='*60}")
647
+ print("SUMMARY")
648
+ print(f"{'='*60}")
649
+
650
+ for mode in modes:
651
+ print(f"\n--- MODE: {mode.upper()} ---")
652
+ mode_results = [r for r in results if r['mode'] == mode]
653
+ baseline = next((r for r in mode_results if r['attn'] == 'standard'), None)
654
+
655
+ for r in sorted(mode_results, key=lambda x: x['loss']):
656
+ rel = ""
657
+ if baseline and r['attn'] != 'standard':
658
+ loss_diff = (baseline['loss'] - r['loss']) / baseline['loss'] * 100
659
+ speed_ratio = r['tok_s'] / baseline['tok_s']
660
+ rel = f" | vs std: {loss_diff:+.1f}%, {speed_ratio:.2f}x"
661
+ print(f"{r['attn']:15s} | Loss {r['loss']:.4f} | {r['tok_s']:6.0f} tok/s{rel}")
662
+
663
+
664
+ if __name__ == "__main__":
665
+ main()