AbstractPhil commited on
Commit
9243884
Β·
verified Β·
1 Parent(s): d4d0a5d

Create constellation_cantor_routing.py

Browse files
Files changed (1) hide show
  1. constellation_cantor_routing.py +823 -0
constellation_cantor_routing.py ADDED
@@ -0,0 +1,823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Constellation-Cantor Relay β€” O(S) Cross-Token Routing
4
+
5
+ This is likely one of the most powerful routing mechanisms that can exist in current spectrum
6
+ until more formulas are resolved.
7
+
8
+ =======================================================
9
+ Replaces attention entirely with triangulation-mediated hierarchical routing.
10
+
11
+ Architecture:
12
+ per-token: constellation relay (triangulate β†’ patchwork β†’ gated residual)
13
+ cross-token: Cantor router (hierarchical scatter/gather through anchor tree)
14
+
15
+ The triangulation profile IS the routing key. Tokens near the same anchor
16
+ on S^(d-1) share information at level 0. Anchor pairs share at level 1.
17
+ Quads at level 2. Full global at level log2(A).
18
+
19
+ Total cross-token cost: O(S Γ— n_levels) = O(S Γ— 4) for 16 anchors.
20
+ Total per-token cost: O(S Γ— tri_dim Γ— pw_hidden).
21
+ No attention anywhere. Fully O(S).
22
+
23
+ Benchmarks:
24
+ 1. Throughput: cantor-relay vs hybrid vs pure relay vs attention
25
+ 2. Cross-token causal intervention at scale
26
+ 3. Geometric preservation
27
+ 4. Trained task requiring cross-token routing
28
+ """
29
+
30
+ import os
31
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
32
+
33
+ import torch
34
+ import torch.nn as nn
35
+ import torch.nn.functional as F
36
+ import numpy as np
37
+ import math
38
+ import time
39
+ import gc
40
+ from collections import OrderedDict
41
+
42
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
43
+ torch.backends.cuda.matmul.allow_tf32 = True
44
+ torch.backends.cudnn.allow_tf32 = True
45
+
46
+
47
+ # ══════════════════════════════════════════════════════════════════
48
+ # ACTIVATIONS
49
+ # ══════════════════════════════════════════════════════════════════
50
+
51
+ class SquaredReLU(nn.Module):
52
+ def forward(self, x): return F.relu(x) ** 2
53
+
54
+
55
+ # ══════════════════════════════════════════════════════════════════
56
+ # CONSTELLATION RELAY β€” per-token geometric layer
57
+ # ══════════════════════════════════════════════════════════════════
58
+
59
+ class ConstellationRelay(nn.Module):
60
+ """Per-token constellation triangulation + patchwork. O(S)."""
61
+
62
+ def __init__(self, dim=256, patch_dim=16, n_anchors=16, n_phases=3):
63
+ super().__init__()
64
+ self.dim = dim
65
+ self.patch_dim = patch_dim
66
+ self.n_patches = dim // patch_dim
67
+ self.n_anchors = n_anchors
68
+ self.n_phases = n_phases
69
+ P, A, d = self.n_patches, n_anchors, patch_dim
70
+
71
+ self.ln = nn.LayerNorm(dim)
72
+
73
+ home = torch.empty(P, A, d)
74
+ nn.init.xavier_normal_(home.view(P * A, d))
75
+ home = F.normalize(home.view(P, A, d), dim=-1)
76
+ self.register_buffer('home', home)
77
+ self.anchors = nn.Parameter(home.clone())
78
+
79
+ tri_dim = P * A * n_phases
80
+ self.tri_dim = tri_dim
81
+ pw_hidden = tri_dim * 2
82
+
83
+ self.patchwork = nn.Sequential(
84
+ nn.Linear(tri_dim, pw_hidden),
85
+ SquaredReLU(),
86
+ nn.LayerNorm(pw_hidden),
87
+ nn.Linear(pw_hidden, dim),
88
+ )
89
+ self.gate = nn.Parameter(torch.tensor(-3.0))
90
+
91
+ def drift(self):
92
+ h = F.normalize(self.home.float(), dim=-1)
93
+ c = F.normalize(self.anchors.float(), dim=-1)
94
+ return torch.acos((h * c).sum(-1).clamp(-1 + 1e-6, 1 - 1e-6))
95
+
96
+ def at_phase(self, t):
97
+ h = F.normalize(self.home.float(), dim=-1)
98
+ c = F.normalize(self.anchors.float(), dim=-1)
99
+ omega = self.drift().unsqueeze(-1)
100
+ so = omega.sin().clamp(min=1e-6)
101
+ return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c
102
+
103
+ def triangulate(self, patches_n):
104
+ phases = torch.linspace(0, 1, self.n_phases, device=patches_n.device).tolist()
105
+ tris = []
106
+ for t in phases:
107
+ at = F.normalize(self.at_phase(t), dim=-1).to(patches_n.dtype)
108
+ tris.append(1.0 - torch.einsum('bpd,pad->bpa', patches_n, at))
109
+ return torch.cat(tris, dim=-1).reshape(patches_n.shape[0], -1)
110
+
111
+ def forward(self, x):
112
+ """x: (B*S, D) or (B, S, D)"""
113
+ is_seq = x.dim() == 3
114
+ if is_seq:
115
+ B, S, D = x.shape
116
+ x_flat = x.reshape(B * S, D)
117
+ else:
118
+ x_flat = x
119
+
120
+ residual = x_flat
121
+ h = self.ln(x_flat)
122
+ patches = h.reshape(-1, self.n_patches, self.patch_dim)
123
+ patches_n = F.normalize(patches, dim=-1)
124
+ tri = self.triangulate(patches_n)
125
+ pw_out = self.patchwork(tri)
126
+ g = self.gate.sigmoid()
127
+ out = residual + g * pw_out
128
+
129
+ if is_seq:
130
+ return out.reshape(B, S, D), tri.reshape(B, S, -1)
131
+ return out, tri
132
+
133
+ def forward_no_tri(self, x):
134
+ """Original forward without returning tri β€” for compatibility."""
135
+ out, _ = self.forward(x)
136
+ return out
137
+
138
+
139
+ # ══════════════════════════════════════════════════════════════════
140
+ # CANTOR CONSTELLATION ROUTER β€” hierarchical cross-token, O(S)
141
+ # ══════════════════════════════════════════════════════════════════
142
+
143
+ class CantorConstellationRouter(nn.Module):
144
+ """
145
+ Hierarchical cross-token routing through the constellation anchor tree.
146
+
147
+ The triangulation profile assigns each token to a region on S^(d-1).
148
+ A binary tree over anchors defines the routing hierarchy:
149
+
150
+ Level 0: A groups (per-anchor, local neighbors)
151
+ Level 1: A/2 groups (anchor pairs, nearby interaction)
152
+ Level 2: A/4 groups (quads, medium range)
153
+ ...
154
+ Level L: 1 group (global summary)
155
+
156
+ At each level:
157
+ 1. Soft-assign tokens to groups via triangulation weights
158
+ 2. Weighted scatter: aggregate token representations per group
159
+ 3. Transform: per-level MLP on group summaries
160
+ 4. Weighted gather: distribute transformed summaries back to tokens
161
+ 5. Gated residual addition
162
+
163
+ Cost: O(S Γ— L Γ— D) where L = log2(A) + 1 = 5 for A=16.
164
+ Memory: O(S Γ— D + A Γ— D) β€” no SΒ² term anywhere.
165
+ """
166
+
167
+ def __init__(self, dim=256, n_anchors=16, n_patches=16):
168
+ super().__init__()
169
+ self.dim = dim
170
+ self.n_anchors = n_anchors
171
+ self.n_patches = n_patches
172
+ self.n_levels = int(math.log2(n_anchors)) + 1 # 5 for A=16
173
+
174
+ # Build anchor hierarchy β€” which anchors merge at each level
175
+ # Level l: anchors are grouped into bins of size 2^l
176
+ # The ordering is determined at init from anchor geometry
177
+
178
+ # Per-level transforms: group_dim β†’ dim
179
+ self.level_mlps = nn.ModuleList()
180
+ self.level_gates = nn.ParameterList()
181
+ self.level_lns = nn.ModuleList()
182
+
183
+ for l in range(self.n_levels):
184
+ n_groups = max(1, n_anchors // (2 ** l))
185
+ self.level_mlps.append(nn.Sequential(
186
+ nn.Linear(dim, dim * 2),
187
+ SquaredReLU(),
188
+ nn.Linear(dim * 2, dim),
189
+ ))
190
+ self.level_lns.append(nn.LayerNorm(dim))
191
+ self.level_gates.append(nn.Parameter(torch.tensor(-3.0)))
192
+
193
+ # Projection from triangulation distances to routing weights
194
+ # Input: per-token distances to each anchor (n_patches Γ— n_anchors)
195
+ self.weight_proj = nn.Linear(n_patches * n_anchors, n_anchors)
196
+
197
+ def compute_routing_weights(self, tri, n_anchors):
198
+ """
199
+ Extract soft anchor assignment weights from triangulation profile.
200
+
201
+ tri: (BS, tri_dim) β€” full triangulation (n_patches Γ— n_anchors Γ— n_phases)
202
+ Returns: (BS, n_anchors) β€” soft assignment weights (sum to 1)
203
+ """
204
+ BS = tri.shape[0]
205
+ # Extract phase-0 distances: first n_patches * n_anchors values
206
+ # These are 1 - cos(token, anchor) for each patch Γ— anchor
207
+ phase0 = tri[:, :self.n_patches * n_anchors]
208
+
209
+ # Average over patches to get per-anchor proximity
210
+ # phase0: (BS, n_patches * n_anchors) β†’ reshape β†’ mean over patches
211
+ dists = phase0.reshape(BS, self.n_patches, n_anchors).mean(dim=1) # (BS, A)
212
+
213
+ # Convert distances to weights: closer = higher weight
214
+ # dists are in [0, 2] (1 - cos), so proximity = 2 - dists
215
+ proximity = (2.0 - dists).clamp(min=0)
216
+ weights = F.softmax(proximity * 5.0, dim=-1) # temperature-scaled
217
+ return weights
218
+
219
+ def forward(self, x, tri):
220
+ """
221
+ x: (B, S, D) token representations
222
+ tri: (B, S, tri_dim) triangulation profiles from constellation
223
+
224
+ Returns: (B, S, D) with cross-token information routed through anchor hierarchy
225
+ """
226
+ B, S, D = x.shape
227
+ x_flat = x.reshape(B * S, D)
228
+ tri_flat = tri.reshape(B * S, -1)
229
+
230
+ # Compute soft routing weights: (BS, A)
231
+ weights = self.compute_routing_weights(tri_flat, self.n_anchors)
232
+
233
+ h = x_flat # working copy
234
+
235
+ for level in range(self.n_levels):
236
+ group_size = 2 ** level
237
+ n_groups = max(1, self.n_anchors // group_size)
238
+
239
+ # Merge anchor weights into group weights
240
+ # Reshape weights (BS, A) β†’ (BS, n_groups, group_size) β†’ sum over group
241
+ if n_groups > 1:
242
+ group_weights = weights.reshape(B * S, n_groups, group_size).sum(dim=-1)
243
+ else:
244
+ group_weights = weights.sum(dim=-1, keepdim=True) # (BS, 1)
245
+
246
+ # Normalize group weights
247
+ group_weights = group_weights / (group_weights.sum(dim=-1, keepdim=True) + 1e-8)
248
+
249
+ # Weighted scatter: aggregate tokens into groups
250
+ # group_sums[g] = sum_s(group_weights[s, g] * h[s])
251
+ # Shape: (BS, n_groups, 1) Γ— (BS, 1, D) summed over BS
252
+ # But we need per-batch grouping. Reshape to (B, S, ...) for batched ops.
253
+
254
+ gw = group_weights.reshape(B, S, n_groups) # (B, S, G)
255
+ hh = h.reshape(B, S, D) # (B, S, D)
256
+
257
+ # Weighted sum: (B, G, S) @ (B, S, D) β†’ (B, G, D)
258
+ group_summary = torch.bmm(gw.transpose(1, 2), hh) # (B, G, D)
259
+
260
+ # Normalize by total weight per group
261
+ weight_sums = gw.sum(dim=1).unsqueeze(-1).clamp(min=1e-8) # (B, G, 1)
262
+ group_summary = group_summary / weight_sums
263
+
264
+ # Transform
265
+ gs_flat = group_summary.reshape(B * n_groups, D)
266
+ gs_flat = self.level_lns[level](gs_flat)
267
+ gs_transformed = self.level_mlps[level](gs_flat)
268
+ gs_transformed = gs_transformed.reshape(B, n_groups, D)
269
+
270
+ # Weighted gather: distribute back to tokens
271
+ # update[s] = sum_g(group_weights[s, g] * gs_transformed[g])
272
+ # (B, S, G) @ (B, G, D) β†’ (B, S, D)
273
+ token_update = torch.bmm(gw, gs_transformed).reshape(B * S, D)
274
+
275
+ # Gated residual
276
+ g = self.level_gates[level].sigmoid()
277
+ h = h + g * token_update
278
+
279
+ return h.reshape(B, S, D)
280
+
281
+
282
+ # ══════════════════════════════════════════════════════════════════
283
+ # CONSTELLATION-CANTOR RELAY β€” FULL O(S) TRANSFORMER LAYER
284
+ # ══════════════════════════════════════════════════════════════════
285
+
286
+ class ConstellationCantorRelay(nn.Module):
287
+ """
288
+ Complete O(S) transformer layer. No attention.
289
+
290
+ per-token: ConstellationRelay (triangulate β†’ patchwork β†’ gated residual)
291
+ cross-token: CantorConstellationRouter (hierarchical scatter/gather through anchors)
292
+
293
+ The triangulation from the per-token relay is reused as routing keys
294
+ for the cross-token path β€” no redundant computation.
295
+ """
296
+
297
+ def __init__(self, dim=256, patch_dim=16, n_anchors=16, n_phases=3):
298
+ super().__init__()
299
+ self.relay = ConstellationRelay(
300
+ dim=dim, patch_dim=patch_dim, n_anchors=n_anchors, n_phases=n_phases)
301
+ self.router = CantorConstellationRouter(
302
+ dim=dim, n_anchors=n_anchors, n_patches=dim // patch_dim)
303
+ self.gate_relay = nn.Parameter(torch.tensor(-2.0))
304
+ self.gate_router = nn.Parameter(torch.tensor(-2.0))
305
+
306
+ def forward(self, x):
307
+ """x: (B, S, D)"""
308
+ B, S, D = x.shape
309
+
310
+ # Per-token relay β€” returns delta + triangulation
311
+ relay_out, tri = self.relay(x) # (B, S, D), (B, S, tri_dim)
312
+ relay_delta = relay_out - x
313
+
314
+ # Cross-token routing using triangulation as routing key
315
+ routed = self.router(x, tri) # (B, S, D)
316
+ router_delta = routed - x
317
+
318
+ # Gated combination
319
+ gr = self.gate_relay.sigmoid()
320
+ gc = self.gate_router.sigmoid()
321
+ return x + gr * relay_delta + gc * router_delta
322
+
323
+
324
+ # ══════════════════════════════════════════════════════════════════
325
+ # COMPARISON ARCHITECTURES
326
+ # ══════════════════════════════════════════════════════════════════
327
+
328
+ class VanillaAttention(nn.Module):
329
+ """Standard attention layer for comparison. O(SΒ²)."""
330
+ def __init__(self, dim=256, n_heads=4):
331
+ super().__init__()
332
+ self.n_heads = n_heads
333
+ self.head_dim = dim // n_heads
334
+ self.ln = nn.LayerNorm(dim)
335
+ self.qkv = nn.Linear(dim, 3 * dim)
336
+ self.proj = nn.Linear(dim, dim)
337
+
338
+ def forward(self, x):
339
+ B, S, D = x.shape
340
+ h = self.ln(x)
341
+ qkv = self.qkv(h).reshape(B, S, 3, self.n_heads, self.head_dim)
342
+ q, k, v = qkv.unbind(2)
343
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
344
+ attn = F.scaled_dot_product_attention(q, k, v)
345
+ return x + self.proj(attn.transpose(1, 2).reshape(B, S, D))
346
+
347
+
348
+ class HybridRelay(nn.Module):
349
+ """Constellation relay + vanilla attention. For comparison."""
350
+ def __init__(self, dim=256, n_heads=4):
351
+ super().__init__()
352
+ self.relay = ConstellationRelay(dim=dim)
353
+ self.n_heads = n_heads
354
+ self.head_dim = dim // n_heads
355
+ self.qkv = nn.Linear(dim, 3 * dim)
356
+ self.attn_proj = nn.Linear(dim, dim)
357
+ self.attn_ln = nn.LayerNorm(dim)
358
+ self.gate_relay = nn.Parameter(torch.tensor(-2.0))
359
+ self.gate_attn = nn.Parameter(torch.tensor(-2.0))
360
+
361
+ def forward(self, x):
362
+ B, S, D = x.shape
363
+ relay_out = self.relay.forward_no_tri(x)
364
+ relay_delta = relay_out - x
365
+
366
+ h = self.attn_ln(x)
367
+ qkv = self.qkv(h).reshape(B, S, 3, self.n_heads, self.head_dim)
368
+ q, k, v = qkv.unbind(2)
369
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
370
+ attn = F.scaled_dot_product_attention(q, k, v)
371
+ attn_out = self.attn_proj(attn.transpose(1, 2).reshape(B, S, D))
372
+
373
+ gr = self.gate_relay.sigmoid()
374
+ ga = self.gate_attn.sigmoid()
375
+ return x + gr * relay_delta + ga * attn_out
376
+
377
+
378
+ class PureRelayLayer(nn.Module):
379
+ """Relay-only, no cross-token. For comparison."""
380
+ def __init__(self, dim=256):
381
+ super().__init__()
382
+ self.relay = ConstellationRelay(dim=dim)
383
+
384
+ def forward(self, x):
385
+ return self.relay.forward_no_tri(x)
386
+
387
+
388
+ # ══════════════════════════════════════════════════════════════════
389
+ # UTILITIES
390
+ # ══════════════════════════════════════════════════════════════════
391
+
392
+ def reset_vram():
393
+ gc.collect()
394
+ torch.cuda.empty_cache()
395
+ torch.cuda.reset_peak_memory_stats()
396
+
397
+ def peak_mb():
398
+ return torch.cuda.max_memory_allocated() / 1e6
399
+
400
+ D = 256
401
+
402
+ print("=" * 80)
403
+ print("CONSTELLATION-CANTOR RELAY β€” O(S) CROSS-TOKEN ROUTING BENCHMARK")
404
+ print(f" Device: {torch.cuda.get_device_name()}")
405
+ print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
406
+ print(f" Dimension: {D}")
407
+ print("=" * 80)
408
+
409
+
410
+ # ══════════════════════════════════════════════════════════════════
411
+ # TEST 1: THROUGHPUT β€” ALL FOUR ARCHITECTURES
412
+ # ══════════════════════════════════════════════════════════════════
413
+
414
+ print(f"\n{'━'*80}")
415
+ print("TEST 1: Throughput Scaling β€” 4 architectures, S=64 to 131K")
416
+ print(" Single layer, B=1, fp16")
417
+ print(f"{'━'*80}")
418
+
419
+ SEQ_LENGTHS = [64, 256, 1024, 4096, 16384, 32768, 65536, 131072]
420
+
421
+ print(f"\n {'S':>8} {'relay':>9} {'cantor':>9} {'hybrid':>9} {'attn':>9} "
422
+ f"{'c/r':>6} {'c/a':>6} {'c_MB':>7}")
423
+
424
+ for S in SEQ_LENGTHS:
425
+ results = {}
426
+
427
+ for name, make_layer in [
428
+ ("relay", lambda: PureRelayLayer(D)),
429
+ ("cantor", lambda: ConstellationCantorRelay(D)),
430
+ ("hybrid", lambda: HybridRelay(D)),
431
+ ("attn", lambda: VanillaAttention(D)),
432
+ ]:
433
+ try:
434
+ reset_vram()
435
+ layer = make_layer().to(DEVICE).half()
436
+ x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1)
437
+
438
+ # Warmup
439
+ with torch.no_grad():
440
+ for _ in range(3):
441
+ _ = layer(x)
442
+ torch.cuda.synchronize()
443
+
444
+ t0 = time.perf_counter()
445
+ with torch.no_grad():
446
+ for _ in range(10):
447
+ _ = layer(x)
448
+ torch.cuda.synchronize()
449
+ ms = (time.perf_counter() - t0) / 10 * 1000
450
+ mb = peak_mb()
451
+ results[name] = (ms, mb)
452
+
453
+ del layer, x
454
+ reset_vram()
455
+
456
+ except (torch.cuda.OutOfMemoryError, RuntimeError):
457
+ results[name] = (float('inf'), float('inf'))
458
+ reset_vram()
459
+
460
+ r = results.get("relay", (0, 0))[0]
461
+ c = results.get("cantor", (0, 0))[0]
462
+ h = results.get("hybrid", (0, 0))[0]
463
+ a = results.get("attn", (0, 0))[0]
464
+ c_mb = results.get("cantor", (0, 0))[1]
465
+
466
+ def fmt(v):
467
+ return f"{v:>8.2f}ms" if v < float('inf') else " OOM"
468
+
469
+ cr_ratio = f"{c/r:>5.1f}Γ—" if r > 0 and c < float('inf') else " -"
470
+ ca_ratio = f"{c/a:>5.1f}Γ—" if a > 0 and a < float('inf') and c < float('inf') else " -"
471
+
472
+ print(f" {S:>8} {fmt(r)} {fmt(c)} {fmt(h)} {fmt(a)} "
473
+ f"{cr_ratio} {ca_ratio} {c_mb:>7.0f}")
474
+
475
+
476
+ # ══════════════════════════════════════════════════════════════════
477
+ # TEST 2: CROSS-TOKEN CAUSAL INTERVENTION β€” CANTOR vs HYBRID
478
+ # ══════════════════════════════════════════════════════════════════
479
+
480
+ print(f"\n{'━'*80}")
481
+ print("TEST 2: Cross-Token Causal Intervention")
482
+ print(" Modify token 0, measure effect on token S//2")
483
+ print(" 4 layers deep. Compare: cantor relay vs hybrid vs pure relay")
484
+ print(f"{'━'*80}")
485
+
486
+ N_LAYERS = 4
487
+
488
+ print(f"\n {'S':>8} {'arch':>10} {'Ξ”_mid':>10} {'Ξ”_last':>10} "
489
+ f"{'cos_orig':>10} {'time_ms':>10}")
490
+
491
+ for S in [64, 256, 1024, 4096, 16384]:
492
+ for arch_name, make_stack in [
493
+ ("cantor", lambda: nn.ModuleList([ConstellationCantorRelay(D) for _ in range(N_LAYERS)])),
494
+ ("hybrid", lambda: nn.ModuleList([HybridRelay(D) for _ in range(N_LAYERS)])),
495
+ ("relay", lambda: nn.ModuleList([PureRelayLayer(D) for _ in range(N_LAYERS)])),
496
+ ]:
497
+ try:
498
+ reset_vram()
499
+ torch.manual_seed(42)
500
+ stack = make_stack().to(DEVICE).half()
501
+
502
+ x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1)
503
+ x_mod = x.clone()
504
+ x_mod[:, 0] = F.normalize(torch.randn(1, D, device=DEVICE, dtype=torch.float16), dim=-1)
505
+
506
+ torch.cuda.synchronize()
507
+ t0 = time.perf_counter()
508
+
509
+ with torch.no_grad():
510
+ h = x.clone()
511
+ h_mod = x_mod.clone()
512
+ for layer in stack:
513
+ h = layer(h)
514
+ h_mod = layer(h_mod)
515
+
516
+ torch.cuda.synchronize()
517
+ elapsed = (time.perf_counter() - t0) * 1000
518
+
519
+ mid = S // 2
520
+ delta_mid = (h[0, mid].float() - h_mod[0, mid].float()).norm().item()
521
+ delta_last = (h[0, -1].float() - h_mod[0, -1].float()).norm().item()
522
+ cos_orig = F.cosine_similarity(
523
+ x[0, mid:mid+1].float(), h[0, mid:mid+1].float()).item()
524
+
525
+ print(f" {S:>8} {arch_name:>10} {delta_mid:>10.4f} {delta_last:>10.4f} "
526
+ f"{cos_orig:>10.4f} {elapsed:>10.1f}")
527
+
528
+ del stack, x, x_mod, h, h_mod
529
+ reset_vram()
530
+
531
+ except (torch.cuda.OutOfMemoryError, RuntimeError):
532
+ print(f" {S:>8} {arch_name:>10} OOM")
533
+ reset_vram()
534
+
535
+ print()
536
+
537
+
538
+ # ══════════════════════════════════════════════════════════════════
539
+ # TEST 3: GEOMETRIC PRESERVATION WITH CROSS-TOKEN ROUTING
540
+ # ══════════════════════════════════════════════════════════════════
541
+
542
+ print(f"\n{'━'*80}")
543
+ print("TEST 3: Geometric Preservation β€” does Cantor routing hurt geometry?")
544
+ print(" 8 layers, S=4096. Compare cos_to_orig, CV, eff_dim.")
545
+ print(f"{'━'*80}")
546
+
547
+ def compute_cv(points, n_samples=500):
548
+ N = points.shape[0]
549
+ if N < 5: return float('nan')
550
+ points = F.normalize(points.float(), dim=-1)
551
+ vols = []
552
+ for _ in range(n_samples):
553
+ idx = torch.randperm(min(N, 2000), device=points.device)[:5]
554
+ pts = points[idx].unsqueeze(0)
555
+ gram = torch.bmm(pts, pts.transpose(1, 2))
556
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
557
+ d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
558
+ d2 = F.relu(d2)
559
+ cm = torch.zeros(1, 6, 6, device=points.device, dtype=torch.float32)
560
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
561
+ v2 = -torch.linalg.det(cm) / 9216
562
+ if v2[0].item() > 1e-20:
563
+ vols.append(v2[0].sqrt().cpu())
564
+ if len(vols) < 50: return float('nan')
565
+ vt = torch.stack(vols)
566
+ return (vt.std() / (vt.mean() + 1e-8)).item()
567
+
568
+ GEO_DEPTH = 8
569
+ GEO_S = 4096
570
+
571
+ print(f"\n {'arch':>10} {'cos_orig':>10} {'norm':>8} {'CV':>8} "
572
+ f"{'eff_dim':>8} {'self_sim':>10}")
573
+
574
+ for arch_name, make_stack in [
575
+ ("relay", lambda: nn.ModuleList([PureRelayLayer(D) for _ in range(GEO_DEPTH)])),
576
+ ("cantor", lambda: nn.ModuleList([ConstellationCantorRelay(D) for _ in range(GEO_DEPTH)])),
577
+ ("hybrid", lambda: nn.ModuleList([HybridRelay(D) for _ in range(GEO_DEPTH)])),
578
+ ("attn", lambda: nn.ModuleList([VanillaAttention(D) for _ in range(GEO_DEPTH)])),
579
+ ]:
580
+ try:
581
+ reset_vram()
582
+ torch.manual_seed(42)
583
+ stack = make_stack().to(DEVICE).half()
584
+
585
+ x = F.normalize(torch.randn(1, GEO_S, D, device=DEVICE, dtype=torch.float16), dim=-1)
586
+
587
+ with torch.no_grad():
588
+ h = x.clone()
589
+ for layer in stack:
590
+ h = layer(h)
591
+
592
+ x_s = x[0, :512].float()
593
+ h_s = h[0, :512].float()
594
+ cos = F.cosine_similarity(x_s, h_s).mean().item()
595
+ norm = h_s.norm(dim=-1).mean().item()
596
+ h_n = F.normalize(h_s, dim=-1)
597
+ sim = h_n @ h_n.T
598
+ mask = ~torch.eye(512, device=DEVICE, dtype=torch.bool)
599
+ self_sim = sim[mask].mean().item()
600
+ cv = compute_cv(h_n, 500)
601
+
602
+ _, S_vals, _ = torch.linalg.svd(h_n[:256], full_matrices=False)
603
+ p = S_vals / S_vals.sum()
604
+ ed = p.pow(2).sum().reciprocal().item()
605
+
606
+ print(f" {arch_name:>10} {cos:>10.4f} {norm:>8.4f} {cv:>8.4f} "
607
+ f"{ed:>8.1f} {self_sim:>10.6f}")
608
+
609
+ del stack, x, h
610
+ reset_vram()
611
+
612
+ except (torch.cuda.OutOfMemoryError, RuntimeError):
613
+ print(f" {arch_name:>10} OOM")
614
+ reset_vram()
615
+
616
+
617
+ # ═══��══════════════════════════════════════════════════════════════
618
+ # TEST 4: TRAINED CROSS-TOKEN TASK β€” ALL ARCHITECTURES
619
+ # ══════════════════════════════════════════════════════════════════
620
+
621
+ print(f"\n{'━'*80}")
622
+ print("TEST 4: Trained Cross-Token Task")
623
+ print(" Label = (token_0_class + token_1_class) % 10")
624
+ print(" Pure relay CANNOT solve this (zero cross-token info).")
625
+ print(" 4 layers, 500 steps, S=8.")
626
+ print(f"{'━'*80}")
627
+
628
+ S_TASK = 8
629
+ N_CLS = 10
630
+ N_SAMPLES = 4096
631
+ STEPS = 500
632
+
633
+ torch.manual_seed(777)
634
+ keys_a = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1)
635
+ keys_b = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1)
636
+
637
+ task_x = F.normalize(torch.randn(N_SAMPLES, S_TASK, D, device=DEVICE), dim=-1).clone()
638
+ label_a = torch.randint(0, N_CLS, (N_SAMPLES,), dtype=torch.long, device=DEVICE)
639
+ label_b = torch.randint(0, N_CLS, (N_SAMPLES,), dtype=torch.long, device=DEVICE)
640
+ task_x[:, 0] = keys_a[label_a] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2
641
+ task_x[:, 1] = keys_b[label_b] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2
642
+ task_x = F.normalize(task_x, dim=-1)
643
+ task_y = ((label_a + label_b) % N_CLS).long()
644
+
645
+ print(f"\n {'arch':>10} {'acc':>8} {'loss':>8} {'cross_Ξ”':>10} {'params':>10}")
646
+
647
+ for arch_name, make_stack in [
648
+ ("relay", lambda: nn.ModuleList([PureRelayLayer(D) for _ in range(4)])),
649
+ ("cantor", lambda: nn.ModuleList([ConstellationCantorRelay(D) for _ in range(4)])),
650
+ ("hybrid", lambda: nn.ModuleList([HybridRelay(D) for _ in range(4)])),
651
+ ("attn", lambda: nn.ModuleList([VanillaAttention(D) for _ in range(4)])),
652
+ ]:
653
+ torch.manual_seed(42)
654
+
655
+ class TaskModel(nn.Module):
656
+ def __init__(self, stack):
657
+ super().__init__()
658
+ self.layers = stack
659
+ self.pool = nn.Linear(D * S_TASK, D)
660
+ self.head = nn.Linear(D, N_CLS)
661
+
662
+ def forward(self, x):
663
+ for layer in self.layers:
664
+ x = layer(x)
665
+ return self.head(F.gelu(self.pool(x.reshape(x.shape[0], -1))))
666
+
667
+ model = TaskModel(make_stack()).to(DEVICE)
668
+ n_params = sum(p.numel() for p in model.parameters())
669
+ opt = torch.optim.Adam(model.parameters(), lr=3e-4)
670
+
671
+ for step in range(STEPS):
672
+ idx = torch.randint(0, N_SAMPLES, (128,))
673
+ logits = model(task_x[idx])
674
+ loss = F.cross_entropy(logits, task_y[idx])
675
+ if torch.isnan(loss) or torch.isinf(loss):
676
+ break
677
+ opt.zero_grad()
678
+ loss.backward()
679
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
680
+ opt.step()
681
+
682
+ model.eval()
683
+ with torch.no_grad():
684
+ logits = model(task_x[:1024])
685
+ acc = (logits.argmax(-1) == task_y[:1024]).float().mean().item()
686
+ final_loss = F.cross_entropy(logits, task_y[:1024]).item()
687
+
688
+ # Cross-token intervention
689
+ h1 = task_x[:64].clone()
690
+ for layer in model.layers:
691
+ h1 = layer(h1)
692
+ h2 = task_x[:64].clone()
693
+ h2[:, 0] = F.normalize(torch.randn(64, D, device=DEVICE), dim=-1)
694
+ for layer in model.layers:
695
+ h2 = layer(h2)
696
+ cross_delta = (h1[:, 1] - h2[:, 1]).norm(dim=-1).mean().item()
697
+
698
+ print(f" {arch_name:>10} {acc:>8.1%} {final_loss:>8.4f} {cross_delta:>10.4f} {n_params:>10,}")
699
+
700
+ del model
701
+ reset_vram()
702
+
703
+
704
+ # ══════════════════════════════════════════════════════════════════
705
+ # TEST 5: THE O(SΒ²) WALL β€” CANTOR vs ATTENTION at depth 8
706
+ # ══════════════════════════════════════════════════════════════════
707
+
708
+ print(f"\n{'━'*80}")
709
+ print("TEST 5: The O(SΒ²) Wall β€” Cantor vs Attention, 8 layers deep")
710
+ print(f"{'━'*80}")
711
+
712
+ WALL_DEPTH = 8
713
+
714
+ print(f"\n {'S':>8} {'cantor_ms':>10} {'attn_ms':>10} {'speedup':>8} "
715
+ f"{'c_cos':>8} {'a_cos':>8} {'c_MB':>8} {'a_MB':>8}")
716
+
717
+ for S in [1024, 4096, 8192, 16384, 32768, 65536, 131072]:
718
+ c_result = None
719
+ a_result = None
720
+
721
+ # Cantor
722
+ try:
723
+ reset_vram()
724
+ torch.manual_seed(42)
725
+ c_stack = nn.ModuleList([
726
+ ConstellationCantorRelay(D) for _ in range(WALL_DEPTH)
727
+ ]).to(DEVICE).half()
728
+
729
+ x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1)
730
+ with torch.no_grad():
731
+ h = x.clone()
732
+ for layer in c_stack:
733
+ h = layer(h)
734
+ torch.cuda.synchronize()
735
+
736
+ t0 = time.perf_counter()
737
+ with torch.no_grad():
738
+ h = x.clone()
739
+ for layer in c_stack:
740
+ h = layer(h)
741
+ torch.cuda.synchronize()
742
+ c_ms = (time.perf_counter() - t0) * 1000
743
+ c_mb = peak_mb()
744
+ c_cos = F.cosine_similarity(x[0, :256].float(), h[0, :256].float()).mean().item()
745
+ c_result = (c_ms, c_cos, c_mb)
746
+
747
+ del x, h, c_stack
748
+ reset_vram()
749
+ except (torch.cuda.OutOfMemoryError, RuntimeError):
750
+ reset_vram()
751
+
752
+ # Attention
753
+ try:
754
+ reset_vram()
755
+ torch.manual_seed(42)
756
+ a_stack = nn.ModuleList([
757
+ VanillaAttention(D) for _ in range(WALL_DEPTH)
758
+ ]).to(DEVICE).half()
759
+
760
+ x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1)
761
+ with torch.no_grad():
762
+ h = x.clone()
763
+ for layer in a_stack:
764
+ h = layer(h)
765
+ torch.cuda.synchronize()
766
+
767
+ t0 = time.perf_counter()
768
+ with torch.no_grad():
769
+ h = x.clone()
770
+ for layer in a_stack:
771
+ h = layer(h)
772
+ torch.cuda.synchronize()
773
+ a_ms = (time.perf_counter() - t0) * 1000
774
+ a_mb = peak_mb()
775
+ a_cos = F.cosine_similarity(x[0, :256].float(), h[0, :256].float()).mean().item()
776
+ a_result = (a_ms, a_cos, a_mb)
777
+
778
+ del x, h, a_stack
779
+ reset_vram()
780
+ except (torch.cuda.OutOfMemoryError, RuntimeError):
781
+ reset_vram()
782
+
783
+ c_str = f"{c_result[0]:>9.1f}ms" if c_result else " OOM"
784
+ a_str = f"{a_result[0]:>9.1f}ms" if a_result else " OOM"
785
+ sp = f"{a_result[0]/c_result[0]:>7.1f}Γ—" if c_result and a_result else " -"
786
+ cc = f"{c_result[1]:>8.4f}" if c_result else " ---"
787
+ ac = f"{a_result[1]:>8.4f}" if a_result else " ---"
788
+ cm = f"{c_result[2]:>8.0f}" if c_result else " OOM"
789
+ am = f"{a_result[2]:>8.0f}" if a_result else " OOM"
790
+
791
+ print(f" {S:>8} {c_str} {a_str} {sp} {cc} {ac} {cm} {am}")
792
+
793
+ if c_result is None:
794
+ print(f" β†’ Cantor OOM at S={S}, stopping")
795
+ break
796
+
797
+
798
+ # ══════════════════════════════════════════════════════════════════
799
+ # SUMMARY
800
+ # ══════════════════════════════════════════════════════════════════
801
+
802
+ print(f"\n{'='*80}")
803
+ print("CONSTELLATION-CANTOR RELAY β€” BENCHMARK COMPLETE")
804
+ print(f"{'='*80}")
805
+ print(f"""
806
+ Architecture:
807
+ per-token: constellation relay (triangulate β†’ patchwork β†’ gated residual)
808
+ cross-token: cantor router (hierarchical scatter/gather through anchor tree)
809
+ total: O(S) time, O(S) memory, no attention
810
+
811
+ 5 tests:
812
+ T1: Throughput β€” relay vs cantor vs hybrid vs attention, S to 131K
813
+ T2: Cross-token causal intervention β€” who routes strongest?
814
+ T3: Geometric preservation β€” does cross-token routing hurt geometry?
815
+ T4: Trained cross-token task β€” accuracy on interaction-dependent labels
816
+ T5: O(SΒ²) wall β€” cantor vs attention at 8 layers to OOM
817
+
818
+ Key questions answered:
819
+ β€’ Is the cantor router faster than attention at all sequence lengths?
820
+ β€’ Does it provide meaningful cross-token interaction?
821
+ β€’ Does the routing hurt per-token geometric preservation?
822
+ β€’ Can it solve tasks that require cross-token information?
823
+ """)