AbstractPhil commited on
Commit
7cf3b75
Β·
verified Β·
1 Parent(s): 1d7b73e

Create constellation_vs_rope.py

Browse files
Files changed (1) hide show
  1. constellation_vs_rope.py +477 -0
constellation_vs_rope.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ RoPE Attention vs Constellation Relay
4
+ ========================================
5
+ Two RoPE variants:
6
+ 1. Standard RoPE (Su et al.) β€” fixed base frequency 10000
7
+ 2. NTK-aware RoPE β€” scaled base frequency for longer contexts
8
+
9
+ Same battery of tests: single pass, depth stability, interleaved.
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+ import math
17
+ import time
18
+
19
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
+ torch.manual_seed(42)
21
+ torch.backends.cuda.matmul.allow_tf32 = True
22
+ torch.backends.cudnn.allow_tf32 = True
23
+
24
+ HAS_FP8 = hasattr(torch, 'float8_e4m3fn')
25
+
26
+
27
+ def compute_cv(points, n_samples=2000, n_points=5):
28
+ N = points.shape[0]
29
+ if N < n_points: return float('nan')
30
+ points = F.normalize(points.to(DEVICE).float(), dim=-1)
31
+ vols = []
32
+ for _ in range(n_samples):
33
+ idx = torch.randperm(min(N, 10000), device=DEVICE)[:n_points]
34
+ pts = points[idx].unsqueeze(0)
35
+ gram = torch.bmm(pts, pts.transpose(1, 2))
36
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
37
+ d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
38
+ d2 = F.relu(d2)
39
+ cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32)
40
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
41
+ v2 = -torch.linalg.det(cm) / 9216
42
+ if v2[0].item() > 1e-20:
43
+ vols.append(v2[0].sqrt().cpu())
44
+ if len(vols) < 50: return float('nan')
45
+ vt = torch.stack(vols)
46
+ return (vt.std() / (vt.mean() + 1e-8)).item()
47
+
48
+
49
+ def eff_dim(x):
50
+ x_c = x - x.mean(0, keepdim=True)
51
+ _, S, _ = torch.linalg.svd(x_c[:512].float(), full_matrices=False)
52
+ p = S / S.sum()
53
+ return p.pow(2).sum().reciprocal().item()
54
+
55
+
56
+ def uniform_sphere(n, d):
57
+ return F.normalize(torch.randn(n, d), dim=-1)
58
+
59
+
60
+ # ══════════════════════════════════════════════════════════════════
61
+ # RoPE IMPLEMENTATIONS
62
+ # ══════════════════════════════════════════════════════════════════
63
+
64
+ class RotaryEmbedding(nn.Module):
65
+ """Standard RoPE β€” fixed sinusoidal rotation frequencies."""
66
+ def __init__(self, dim, base=10000.0):
67
+ super().__init__()
68
+ self.dim = dim
69
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
70
+ self.register_buffer('inv_freq', inv_freq)
71
+
72
+ def forward(self, seq_len, device):
73
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
74
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq) # (S, dim/2)
75
+ emb = torch.cat([freqs, freqs], dim=-1) # (S, dim)
76
+ return emb.cos(), emb.sin()
77
+
78
+
79
+ class NTKRotaryEmbedding(nn.Module):
80
+ """NTK-aware RoPE β€” scaled base for extended context."""
81
+ def __init__(self, dim, base=10000.0, scale_factor=4.0):
82
+ super().__init__()
83
+ self.dim = dim
84
+ # NTK scaling: base^(dim/(dim-2)) * scale_factor
85
+ scaled_base = base * (scale_factor ** (dim / (dim - 2)))
86
+ inv_freq = 1.0 / (scaled_base ** (torch.arange(0, dim, 2).float() / dim))
87
+ self.register_buffer('inv_freq', inv_freq)
88
+
89
+ def forward(self, seq_len, device):
90
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
91
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
92
+ emb = torch.cat([freqs, freqs], dim=-1)
93
+ return emb.cos(), emb.sin()
94
+
95
+
96
+ def apply_rotary(x, cos, sin):
97
+ """Apply rotary embeddings to Q or K: (B, H, S, d)."""
98
+ d = x.shape[-1]
99
+ x1 = x[..., :d//2]
100
+ x2 = x[..., d//2:]
101
+ cos = cos[:x.shape[-2], :d//2].unsqueeze(0).unsqueeze(0) # (1, 1, S, d/2)
102
+ sin = sin[:x.shape[-2], :d//2].unsqueeze(0).unsqueeze(0)
103
+ return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
104
+
105
+
106
+ # ══════════════════════════════════════════════════════════════════
107
+ # ATTENTION BLOCKS
108
+ # ══════════════════════════════════════════════════════════════════
109
+
110
+ class VanillaAttnBlock(nn.Module):
111
+ """Standard self-attention β€” no position encoding."""
112
+ def __init__(self, dim, n_heads=4):
113
+ super().__init__()
114
+ self.n_heads = n_heads
115
+ self.head_dim = dim // n_heads
116
+ self.qkv = nn.Linear(dim, 3 * dim, bias=False)
117
+ self.out_proj = nn.Linear(dim, dim, bias=False)
118
+ self.norm = nn.LayerNorm(dim)
119
+
120
+ def forward(self, x):
121
+ B, S, D = x.shape
122
+ x_n = self.norm(x)
123
+ qkv = self.qkv(x_n).reshape(B, S, 3, self.n_heads, self.head_dim)
124
+ qkv = qkv.permute(2, 0, 3, 1, 4)
125
+ q, k, v = qkv[0], qkv[1], qkv[2]
126
+ attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
127
+ attn = attn.softmax(dim=-1)
128
+ out = (attn @ v).transpose(1, 2).reshape(B, S, D)
129
+ return x + self.out_proj(out)
130
+
131
+
132
+ class RoPEAttnBlock(nn.Module):
133
+ """Self-attention with Rotary Position Embeddings."""
134
+ def __init__(self, dim, n_heads=4, rope_type='standard', rope_base=10000.0,
135
+ ntk_scale=4.0):
136
+ super().__init__()
137
+ self.n_heads = n_heads
138
+ self.head_dim = dim // n_heads
139
+ self.qkv = nn.Linear(dim, 3 * dim, bias=False)
140
+ self.out_proj = nn.Linear(dim, dim, bias=False)
141
+ self.norm = nn.LayerNorm(dim)
142
+
143
+ if rope_type == 'standard':
144
+ self.rope = RotaryEmbedding(self.head_dim, base=rope_base)
145
+ elif rope_type == 'ntk':
146
+ self.rope = NTKRotaryEmbedding(self.head_dim, base=rope_base,
147
+ scale_factor=ntk_scale)
148
+ self.rope_type = rope_type
149
+
150
+ def forward(self, x):
151
+ B, S, D = x.shape
152
+ x_n = self.norm(x)
153
+ qkv = self.qkv(x_n).reshape(B, S, 3, self.n_heads, self.head_dim)
154
+ qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, S, hd)
155
+ q, k, v = qkv[0], qkv[1], qkv[2]
156
+
157
+ # Apply RoPE to Q and K
158
+ cos, sin = self.rope(S, x.device)
159
+ q = apply_rotary(q, cos, sin)
160
+ k = apply_rotary(k, cos, sin)
161
+
162
+ attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
163
+ attn = attn.softmax(dim=-1)
164
+ out = (attn @ v).transpose(1, 2).reshape(B, S, D)
165
+ return x + self.out_proj(out)
166
+
167
+
168
+ # ══════════════════════════════════════════════════════════════════
169
+ # CONSTELLATION RELAY (copy from v2)
170
+ # ══════════════════════════════════════════════════════════════════
171
+
172
+ class ConstellationRelay(nn.Module):
173
+ def __init__(self, input_dim, patch_dim=16, n_anchors=16, n_phases=3,
174
+ pw_hidden=32, gate_init=-3.0):
175
+ super().__init__()
176
+ assert input_dim % patch_dim == 0
177
+ self.input_dim = input_dim
178
+ self.patch_dim = patch_dim
179
+ self.n_patches = input_dim // patch_dim
180
+ self.n_anchors = n_anchors
181
+ self.n_phases = n_phases
182
+ P, A, d = self.n_patches, n_anchors, patch_dim
183
+
184
+ home = torch.empty(P, A, d)
185
+ nn.init.xavier_normal_(home.view(P * A, d))
186
+ home = F.normalize(home.view(P, A, d), dim=-1)
187
+ self.register_buffer('home', home)
188
+ self.anchors = nn.Parameter(home.clone())
189
+
190
+ tri_dim = n_phases * A
191
+ self.pw_w1 = nn.Parameter(torch.empty(P, tri_dim, pw_hidden))
192
+ self.pw_b1 = nn.Parameter(torch.zeros(1, P, pw_hidden))
193
+ self.pw_w2 = nn.Parameter(torch.empty(P, pw_hidden, d))
194
+ self.pw_b2 = nn.Parameter(torch.zeros(1, P, d))
195
+ for p in range(P):
196
+ nn.init.xavier_normal_(self.pw_w1.data[p])
197
+ nn.init.xavier_normal_(self.pw_w2.data[p])
198
+ self.pw_norm = nn.LayerNorm(d)
199
+ self.gates = nn.Parameter(torch.full((P,), gate_init))
200
+ self.norm = nn.LayerNorm(input_dim)
201
+
202
+ def drift(self):
203
+ h = F.normalize(self.home, dim=-1)
204
+ c = F.normalize(self.anchors, dim=-1)
205
+ cos = (h * c).sum(dim=-1).clamp(-1 + 1e-7, 1 - 1e-7)
206
+ return torch.acos(cos)
207
+
208
+ def at_phase(self, t):
209
+ h = F.normalize(self.home, dim=-1)
210
+ c = F.normalize(self.anchors, dim=-1)
211
+ omega = self.drift().unsqueeze(-1)
212
+ sin_omega = omega.sin().clamp(min=1e-7)
213
+ return (torch.sin((1 - t) * omega) / sin_omega * h +
214
+ torch.sin(t * omega) / sin_omega * c)
215
+
216
+ def forward(self, x):
217
+ B, D = x.shape
218
+ P, A, d = self.n_patches, self.n_anchors, self.patch_dim
219
+ x_n = self.norm(x)
220
+ patches = x_n.reshape(B, P, d)
221
+ patches_n = F.normalize(patches, dim=-1)
222
+
223
+ # Multi-phase triangulation
224
+ phases = torch.linspace(0, 1, self.n_phases).tolist()
225
+ tris = []
226
+ for t in phases:
227
+ anchors_t = F.normalize(self.at_phase(t), dim=-1)
228
+ cos = torch.einsum('bpd,pad->bpa', patches_n, anchors_t)
229
+ tris.append(1.0 - cos)
230
+ tri = torch.cat(tris, dim=-1)
231
+
232
+ # Patchwork
233
+ h = torch.einsum('bpt,pth->bph', tri, self.pw_w1) + self.pw_b1
234
+ h = F.gelu(h)
235
+ pw_out = torch.einsum('bph,phd->bpd', h, self.pw_w2) + self.pw_b2
236
+ pw_out = self.pw_norm(pw_out)
237
+
238
+ gate = self.gates.sigmoid().unsqueeze(0).unsqueeze(-1)
239
+ blended = gate * pw_out + (1 - gate) * patches
240
+ out = blended.reshape(B, D)
241
+ return x + out
242
+
243
+
244
+ # ══════════════════════════════════════════════════════════════════
245
+ # TEST SUITE
246
+ # ══════════════════════════════════════════════════════════════════
247
+
248
+ N = 2000
249
+ D = 128
250
+ N_CV = 2000
251
+
252
+ print("=" * 90)
253
+ print("RoPE ATTENTION vs CONSTELLATION RELAY")
254
+ print(f" Input dim: {D}, Sequence length: {N}")
255
+ print(f" Device: {DEVICE}")
256
+ print("=" * 90)
257
+
258
+ pts = uniform_sphere(N, D).to(DEVICE)
259
+ cv_base = compute_cv(pts, N_CV)
260
+ ed_base = eff_dim(pts)
261
+ print(f" Baseline: CV={cv_base:.4f} eff_dim={ed_base:.1f}")
262
+
263
+ # Build all architectures
264
+ configs = {
265
+ 'vanilla': lambda: VanillaAttnBlock(D, 8).to(DEVICE),
266
+ 'rope_std': lambda: RoPEAttnBlock(D, 8, 'standard', 10000).to(DEVICE),
267
+ 'rope_ntk': lambda: RoPEAttnBlock(D, 8, 'ntk', 10000, 4.0).to(DEVICE),
268
+ 'relay': lambda: ConstellationRelay(D, 16, 16, 3, 32).to(DEVICE),
269
+ }
270
+
271
+
272
+ # ── TEST 1: Single pass comparison ──
273
+ print(f"\n{'━'*90}")
274
+ print("TEST 1: Single pass β€” all architectures")
275
+ print(f"{'━'*90}")
276
+ print(f" {'arch':>12} {'params':>8} {'CV_out':>8} {'CV_norm':>8} "
277
+ f"{'cos_orig':>10} {'eff_dim':>8}")
278
+
279
+ for name, builder in configs.items():
280
+ module = builder()
281
+ np_ = sum(p.numel() for p in module.parameters())
282
+ with torch.no_grad():
283
+ if name == 'relay':
284
+ out = module(pts)
285
+ else:
286
+ out = module(pts.unsqueeze(0)).squeeze(0)
287
+ cv = compute_cv(out, N_CV)
288
+ cv_n = compute_cv(F.normalize(out, dim=-1), N_CV)
289
+ cos = (F.normalize(pts, dim=-1) * F.normalize(out, dim=-1)).sum(-1).mean().item()
290
+ ed = eff_dim(out)
291
+ print(f" {name:>12} {np_:>8,} {cv:>8.4f} {cv_n:>8.4f} {cos:>10.6f} {ed:>8.1f}")
292
+
293
+
294
+ # ── TEST 2: Depth sweep β€” 16 layers each ──
295
+ print(f"\n{'━'*90}")
296
+ print("TEST 2: Depth sweep β€” 16 layers, all architectures")
297
+ print(f"{'━'*90}")
298
+
299
+ checkpoints = [1, 2, 4, 8, 12, 16]
300
+
301
+ for name, builder in configs.items():
302
+ print(f"\n {name}:")
303
+ print(f" {'depth':>6} {'CV':>8} {'CV_n':>8} {'eff_d':>8} {'cos_orig':>10}")
304
+
305
+ stack = nn.ModuleList([builder() for _ in range(16)])
306
+ x = pts.clone()
307
+ for i, layer in enumerate(stack):
308
+ with torch.no_grad():
309
+ if name == 'relay':
310
+ x = layer(x)
311
+ else:
312
+ x = layer(x.unsqueeze(0)).squeeze(0)
313
+ if (i + 1) in checkpoints:
314
+ cv = compute_cv(x, N_CV)
315
+ cv_n = compute_cv(F.normalize(x, dim=-1), N_CV)
316
+ ed = eff_dim(x)
317
+ cos = (F.normalize(pts, dim=-1) * F.normalize(x, dim=-1)).sum(-1).mean().item()
318
+ print(f" {i+1:>6} {cv:>8.4f} {cv_n:>8.4f} {ed:>8.1f} {cos:>10.6f}")
319
+
320
+
321
+ # ── TEST 3: Interleaved β€” each attention type + relay ──
322
+ print(f"\n{'━'*90}")
323
+ print("TEST 3: Interleaved β€” [attn type] β†’ relay β†’ [attn type] β†’ relay β†’ ...")
324
+ print(f"{'━'*90}")
325
+
326
+ for attn_name in ['vanilla', 'rope_std', 'rope_ntk']:
327
+ print(f"\n {attn_name} + relay interleaved:")
328
+ print(f" {'step':>6} {'type':>8} {'CV_n':>8} {'eff_d':>8} {'cos_orig':>10}")
329
+
330
+ attn_builder = configs[attn_name]
331
+ attn_layers = nn.ModuleList([attn_builder() for _ in range(8)])
332
+ relay_layers = nn.ModuleList([
333
+ ConstellationRelay(D, 16, 16, 3, 32).to(DEVICE) for _ in range(8)])
334
+
335
+ x = pts.clone()
336
+ step = 0
337
+ for i in range(8):
338
+ # Attention step
339
+ with torch.no_grad():
340
+ x = attn_layers[i](x.unsqueeze(0)).squeeze(0)
341
+ step += 1
342
+ if step in checkpoints:
343
+ cv_n = compute_cv(F.normalize(x, dim=-1), N_CV)
344
+ ed = eff_dim(x)
345
+ cos = (F.normalize(pts, dim=-1) * F.normalize(x, dim=-1)).sum(-1).mean().item()
346
+ print(f" {step:>6} {'attn':>8} {cv_n:>8.4f} {ed:>8.1f} {cos:>10.6f}")
347
+
348
+ # Relay step
349
+ with torch.no_grad():
350
+ x = relay_layers[i](x)
351
+ step += 1
352
+ if step in checkpoints:
353
+ cv_n = compute_cv(F.normalize(x, dim=-1), N_CV)
354
+ ed = eff_dim(x)
355
+ cos = (F.normalize(pts, dim=-1) * F.normalize(x, dim=-1)).sum(-1).mean().item()
356
+ print(f" {step:>6} {'relay':>8} {cv_n:>8.4f} {ed:>8.1f} {cos:>10.6f}")
357
+
358
+
359
+ # ── TEST 4: Throughput comparison ──
360
+ print(f"\n{'━'*90}")
361
+ print("TEST 4: Throughput")
362
+ print(f"{'━'*90}")
363
+
364
+ print(f" {'arch':>12} {'ms':>8} {'params':>10}")
365
+
366
+ for name, builder in configs.items():
367
+ module = builder()
368
+ np_ = sum(p.numel() for p in module.parameters())
369
+
370
+ # Warmup
371
+ for _ in range(10):
372
+ with torch.no_grad():
373
+ if name == 'relay':
374
+ _ = module(pts)
375
+ else:
376
+ _ = module(pts.unsqueeze(0))
377
+ torch.cuda.synchronize()
378
+
379
+ t0 = time.time()
380
+ for _ in range(200):
381
+ with torch.no_grad():
382
+ if name == 'relay':
383
+ _ = module(pts)
384
+ else:
385
+ _ = module(pts.unsqueeze(0))
386
+ torch.cuda.synchronize()
387
+ ms = (time.time() - t0) / 200 * 1000
388
+ print(f" {name:>12} {ms:>8.2f} {np_:>10,}")
389
+
390
+
391
+ # ── TEST 5: Clustered input β€” all architectures ──
392
+ print(f"\n{'━'*90}")
393
+ print("TEST 5: Clustered input (10 clusters, d=128)")
394
+ print(f"{'━'*90}")
395
+
396
+ centroids = F.normalize(torch.randn(10, D), dim=-1).to(DEVICE)
397
+ assignments = torch.randint(0, 10, (N,))
398
+
399
+ print(f" {'spread':>8} {'CV_base':>8} {'vanilla':>8} {'rope_std':>8} "
400
+ f"{'rope_ntk':>8} {'relay':>8}")
401
+
402
+ for spread in [0.1, 0.3, 0.5, 1.0]:
403
+ pts_c = F.normalize(centroids[assignments] +
404
+ torch.randn(N, D, device=DEVICE) * spread, dim=-1)
405
+ cv_b = compute_cv(pts_c, N_CV)
406
+
407
+ row = f" {spread:>8.1f} {cv_b:>8.4f}"
408
+ for name, builder in configs.items():
409
+ module = builder()
410
+ with torch.no_grad():
411
+ if name == 'relay':
412
+ out = module(pts_c)
413
+ else:
414
+ out = module(pts_c.unsqueeze(0)).squeeze(0)
415
+ cv = compute_cv(F.normalize(out, dim=-1), N_CV)
416
+ row += f" {cv:>8.4f}"
417
+ print(row)
418
+
419
+
420
+ # ── TEST 6: RoPE frequency analysis ──
421
+ print(f"\n{'━'*90}")
422
+ print("TEST 6: RoPE base frequency sweep")
423
+ print(f" Does the rotation frequency affect geometric preservation?")
424
+ print(f"{'━'*90}")
425
+
426
+ print(f" {'base':>10} {'CV_n':>8} {'cos_orig':>10} {'eff_dim':>8}")
427
+
428
+ for base in [100, 500, 1000, 5000, 10000, 50000, 100000, 500000]:
429
+ module = RoPEAttnBlock(D, 8, 'standard', base).to(DEVICE)
430
+ with torch.no_grad():
431
+ out = module(pts.unsqueeze(0)).squeeze(0)
432
+ cv_n = compute_cv(F.normalize(out, dim=-1), N_CV)
433
+ cos = (F.normalize(pts, dim=-1) * F.normalize(out, dim=-1)).sum(-1).mean().item()
434
+ ed = eff_dim(out)
435
+ print(f" {base:>10} {cv_n:>8.4f} {cos:>10.6f} {ed:>8.1f}")
436
+
437
+
438
+ # ── TEST 7: NTK scale factor sweep ──
439
+ print(f"\n{'━'*90}")
440
+ print("TEST 7: NTK scale factor sweep (base=10000)")
441
+ print(f"{'━'*90}")
442
+
443
+ print(f" {'scale':>8} {'CV_n':>8} {'cos_orig':>10} {'eff_dim':>8}")
444
+
445
+ for scale in [1.0, 2.0, 4.0, 8.0, 16.0, 32.0, 64.0]:
446
+ module = RoPEAttnBlock(D, 8, 'ntk', 10000, scale).to(DEVICE)
447
+ with torch.no_grad():
448
+ out = module(pts.unsqueeze(0)).squeeze(0)
449
+ cv_n = compute_cv(F.normalize(out, dim=-1), N_CV)
450
+ cos = (F.normalize(pts, dim=-1) * F.normalize(out, dim=-1)).sum(-1).mean().item()
451
+ ed = eff_dim(out)
452
+ print(f" {scale:>8.1f} {cv_n:>8.4f} {cos:>10.6f} {ed:>8.1f}")
453
+
454
+
455
+ # ══════════════════════════════════════════════════════════════════
456
+ # SUMMARY
457
+ # ══════════════════════════════════════════════════════════════════
458
+
459
+ print(f"\n{'='*90}")
460
+ print("SUMMARY β€” cos_to_orig at depth 16")
461
+ print(f"{'='*90}")
462
+ print(f"""
463
+ Compare the depth-16 cos_to_orig from Test 2 across all architectures:
464
+
465
+ vanilla attention: (see Test 2)
466
+ RoPE standard: (see Test 2)
467
+ RoPE NTK: (see Test 2)
468
+ constellation relay: (see Test 2)
469
+
470
+ And the interleaved results from Test 3:
471
+ vanilla + relay: (see Test 3)
472
+ rope_std + relay: (see Test 3)
473
+ rope_ntk + relay: (see Test 3)
474
+ """)
475
+ print(f"{'='*90}")
476
+ print("DONE")
477
+ print(f"{'='*90}")