AbstractPhil commited on
Commit
09ae007
Β·
verified Β·
1 Parent(s): 5885f2d

Create constellation_cantor_routing_relay.py

Browse files
Files changed (1) hide show
  1. constellation_cantor_routing_relay.py +819 -0
constellation_cantor_routing_relay.py ADDED
@@ -0,0 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Constellation-Cantor Relay β€” O(S) Cross-Token Routing
4
+ =======================================================
5
+ Replaces attention entirely with triangulation-mediated hierarchical routing.
6
+
7
+ Architecture:
8
+ per-token: constellation relay (triangulate β†’ patchwork β†’ gated residual)
9
+ cross-token: Cantor router (hierarchical scatter/gather through anchor tree)
10
+
11
+ The triangulation profile IS the routing key. Tokens near the same anchor
12
+ on S^(d-1) share information at level 0. Anchor pairs share at level 1.
13
+ Quads at level 2. Full global at level log2(A).
14
+
15
+ Total cross-token cost: O(S Γ— n_levels) = O(S Γ— 4) for 16 anchors.
16
+ Total per-token cost: O(S Γ— tri_dim Γ— pw_hidden).
17
+ No attention anywhere. Fully O(S).
18
+
19
+ Benchmarks:
20
+ 1. Throughput: cantor-relay vs hybrid vs pure relay vs attention
21
+ 2. Cross-token causal intervention at scale
22
+ 3. Geometric preservation
23
+ 4. Trained task requiring cross-token routing
24
+ """
25
+
26
+ import os
27
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ import numpy as np
33
+ import math
34
+ import time
35
+ import gc
36
+ from collections import OrderedDict
37
+
38
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
39
+ torch.backends.cuda.matmul.allow_tf32 = True
40
+ torch.backends.cudnn.allow_tf32 = True
41
+
42
+
43
+ # ══════════════════════════════════════════════════════════════════
44
+ # ACTIVATIONS
45
+ # ══════════════════════════════════════════════════════════════════
46
+
47
+ class SquaredReLU(nn.Module):
48
+ def forward(self, x): return F.relu(x) ** 2
49
+
50
+
51
+ # ══════════════════════════════════════════════════════════════════
52
+ # CONSTELLATION RELAY β€” per-token geometric layer
53
+ # ══════════════════════════════════════════════════════════════════
54
+
55
+ class ConstellationRelay(nn.Module):
56
+ """Per-token constellation triangulation + patchwork. O(S)."""
57
+
58
+ def __init__(self, dim=256, patch_dim=16, n_anchors=16, n_phases=3):
59
+ super().__init__()
60
+ self.dim = dim
61
+ self.patch_dim = patch_dim
62
+ self.n_patches = dim // patch_dim
63
+ self.n_anchors = n_anchors
64
+ self.n_phases = n_phases
65
+ P, A, d = self.n_patches, n_anchors, patch_dim
66
+
67
+ self.ln = nn.LayerNorm(dim)
68
+
69
+ home = torch.empty(P, A, d)
70
+ nn.init.xavier_normal_(home.view(P * A, d))
71
+ home = F.normalize(home.view(P, A, d), dim=-1)
72
+ self.register_buffer('home', home)
73
+ self.anchors = nn.Parameter(home.clone())
74
+
75
+ tri_dim = P * A * n_phases
76
+ self.tri_dim = tri_dim
77
+ pw_hidden = tri_dim * 2
78
+
79
+ self.patchwork = nn.Sequential(
80
+ nn.Linear(tri_dim, pw_hidden),
81
+ SquaredReLU(),
82
+ nn.LayerNorm(pw_hidden),
83
+ nn.Linear(pw_hidden, dim),
84
+ )
85
+ self.gate = nn.Parameter(torch.tensor(-3.0))
86
+
87
+ def drift(self):
88
+ h = F.normalize(self.home.float(), dim=-1)
89
+ c = F.normalize(self.anchors.float(), dim=-1)
90
+ return torch.acos((h * c).sum(-1).clamp(-1 + 1e-6, 1 - 1e-6))
91
+
92
+ def at_phase(self, t):
93
+ h = F.normalize(self.home.float(), dim=-1)
94
+ c = F.normalize(self.anchors.float(), dim=-1)
95
+ omega = self.drift().unsqueeze(-1)
96
+ so = omega.sin().clamp(min=1e-6)
97
+ return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c
98
+
99
+ def triangulate(self, patches_n):
100
+ phases = torch.linspace(0, 1, self.n_phases, device=patches_n.device).tolist()
101
+ tris = []
102
+ for t in phases:
103
+ at = F.normalize(self.at_phase(t), dim=-1).to(patches_n.dtype)
104
+ tris.append(1.0 - torch.einsum('bpd,pad->bpa', patches_n, at))
105
+ return torch.cat(tris, dim=-1).reshape(patches_n.shape[0], -1)
106
+
107
+ def forward(self, x):
108
+ """x: (B*S, D) or (B, S, D)"""
109
+ is_seq = x.dim() == 3
110
+ if is_seq:
111
+ B, S, D = x.shape
112
+ x_flat = x.reshape(B * S, D)
113
+ else:
114
+ x_flat = x
115
+
116
+ residual = x_flat
117
+ h = self.ln(x_flat)
118
+ patches = h.reshape(-1, self.n_patches, self.patch_dim)
119
+ patches_n = F.normalize(patches, dim=-1)
120
+ tri = self.triangulate(patches_n)
121
+ pw_out = self.patchwork(tri)
122
+ g = self.gate.sigmoid()
123
+ out = residual + g * pw_out
124
+
125
+ if is_seq:
126
+ return out.reshape(B, S, D), tri.reshape(B, S, -1)
127
+ return out, tri
128
+
129
+ def forward_no_tri(self, x):
130
+ """Original forward without returning tri β€” for compatibility."""
131
+ out, _ = self.forward(x)
132
+ return out
133
+
134
+
135
+ # ══════════════════════════════════════════════════════════════════
136
+ # CANTOR CONSTELLATION ROUTER β€” hierarchical cross-token, O(S)
137
+ # ══════════════════════════════════════════════════════════════════
138
+
139
+ class CantorConstellationRouter(nn.Module):
140
+ """
141
+ Hierarchical cross-token routing through the constellation anchor tree.
142
+
143
+ The triangulation profile assigns each token to a region on S^(d-1).
144
+ A binary tree over anchors defines the routing hierarchy:
145
+
146
+ Level 0: A groups (per-anchor, local neighbors)
147
+ Level 1: A/2 groups (anchor pairs, nearby interaction)
148
+ Level 2: A/4 groups (quads, medium range)
149
+ ...
150
+ Level L: 1 group (global summary)
151
+
152
+ At each level:
153
+ 1. Soft-assign tokens to groups via triangulation weights
154
+ 2. Weighted scatter: aggregate token representations per group
155
+ 3. Transform: per-level MLP on group summaries
156
+ 4. Weighted gather: distribute transformed summaries back to tokens
157
+ 5. Gated residual addition
158
+
159
+ Cost: O(S Γ— L Γ— D) where L = log2(A) + 1 = 5 for A=16.
160
+ Memory: O(S Γ— D + A Γ— D) β€” no SΒ² term anywhere.
161
+ """
162
+
163
+ def __init__(self, dim=256, n_anchors=16, n_patches=16):
164
+ super().__init__()
165
+ self.dim = dim
166
+ self.n_anchors = n_anchors
167
+ self.n_patches = n_patches
168
+ self.n_levels = int(math.log2(n_anchors)) + 1 # 5 for A=16
169
+
170
+ # Build anchor hierarchy β€” which anchors merge at each level
171
+ # Level l: anchors are grouped into bins of size 2^l
172
+ # The ordering is determined at init from anchor geometry
173
+
174
+ # Per-level transforms: group_dim β†’ dim
175
+ self.level_mlps = nn.ModuleList()
176
+ self.level_gates = nn.ParameterList()
177
+ self.level_lns = nn.ModuleList()
178
+
179
+ for l in range(self.n_levels):
180
+ n_groups = max(1, n_anchors // (2 ** l))
181
+ self.level_mlps.append(nn.Sequential(
182
+ nn.Linear(dim, dim * 2),
183
+ SquaredReLU(),
184
+ nn.Linear(dim * 2, dim),
185
+ ))
186
+ self.level_lns.append(nn.LayerNorm(dim))
187
+ self.level_gates.append(nn.Parameter(torch.tensor(-3.0)))
188
+
189
+ # Projection from triangulation distances to routing weights
190
+ # Input: per-token distances to each anchor (n_patches Γ— n_anchors)
191
+ self.weight_proj = nn.Linear(n_patches * n_anchors, n_anchors)
192
+
193
+ def compute_routing_weights(self, tri, n_anchors):
194
+ """
195
+ Extract soft anchor assignment weights from triangulation profile.
196
+
197
+ tri: (BS, tri_dim) β€” full triangulation (n_patches Γ— n_anchors Γ— n_phases)
198
+ Returns: (BS, n_anchors) β€” soft assignment weights (sum to 1)
199
+ """
200
+ BS = tri.shape[0]
201
+ # Extract phase-0 distances: first n_patches * n_anchors values
202
+ # These are 1 - cos(token, anchor) for each patch Γ— anchor
203
+ phase0 = tri[:, :self.n_patches * n_anchors]
204
+
205
+ # Average over patches to get per-anchor proximity
206
+ # phase0: (BS, n_patches * n_anchors) β†’ reshape β†’ mean over patches
207
+ dists = phase0.reshape(BS, self.n_patches, n_anchors).mean(dim=1) # (BS, A)
208
+
209
+ # Convert distances to weights: closer = higher weight
210
+ # dists are in [0, 2] (1 - cos), so proximity = 2 - dists
211
+ proximity = (2.0 - dists).clamp(min=0)
212
+ weights = F.softmax(proximity * 5.0, dim=-1) # temperature-scaled
213
+ return weights
214
+
215
+ def forward(self, x, tri):
216
+ """
217
+ x: (B, S, D) token representations
218
+ tri: (B, S, tri_dim) triangulation profiles from constellation
219
+
220
+ Returns: (B, S, D) with cross-token information routed through anchor hierarchy
221
+ """
222
+ B, S, D = x.shape
223
+ x_flat = x.reshape(B * S, D)
224
+ tri_flat = tri.reshape(B * S, -1)
225
+
226
+ # Compute soft routing weights: (BS, A)
227
+ weights = self.compute_routing_weights(tri_flat, self.n_anchors)
228
+
229
+ h = x_flat # working copy
230
+
231
+ for level in range(self.n_levels):
232
+ group_size = 2 ** level
233
+ n_groups = max(1, self.n_anchors // group_size)
234
+
235
+ # Merge anchor weights into group weights
236
+ # Reshape weights (BS, A) β†’ (BS, n_groups, group_size) β†’ sum over group
237
+ if n_groups > 1:
238
+ group_weights = weights.reshape(B * S, n_groups, group_size).sum(dim=-1)
239
+ else:
240
+ group_weights = weights.sum(dim=-1, keepdim=True) # (BS, 1)
241
+
242
+ # Normalize group weights
243
+ group_weights = group_weights / (group_weights.sum(dim=-1, keepdim=True) + 1e-8)
244
+
245
+ # Weighted scatter: aggregate tokens into groups
246
+ # group_sums[g] = sum_s(group_weights[s, g] * h[s])
247
+ # Shape: (BS, n_groups, 1) Γ— (BS, 1, D) summed over BS
248
+ # But we need per-batch grouping. Reshape to (B, S, ...) for batched ops.
249
+
250
+ gw = group_weights.reshape(B, S, n_groups) # (B, S, G)
251
+ hh = h.reshape(B, S, D) # (B, S, D)
252
+
253
+ # Weighted sum: (B, G, S) @ (B, S, D) β†’ (B, G, D)
254
+ group_summary = torch.bmm(gw.transpose(1, 2), hh) # (B, G, D)
255
+
256
+ # Normalize by total weight per group
257
+ weight_sums = gw.sum(dim=1).unsqueeze(-1).clamp(min=1e-8) # (B, G, 1)
258
+ group_summary = group_summary / weight_sums
259
+
260
+ # Transform
261
+ gs_flat = group_summary.reshape(B * n_groups, D)
262
+ gs_flat = self.level_lns[level](gs_flat)
263
+ gs_transformed = self.level_mlps[level](gs_flat)
264
+ gs_transformed = gs_transformed.reshape(B, n_groups, D)
265
+
266
+ # Weighted gather: distribute back to tokens
267
+ # update[s] = sum_g(group_weights[s, g] * gs_transformed[g])
268
+ # (B, S, G) @ (B, G, D) β†’ (B, S, D)
269
+ token_update = torch.bmm(gw, gs_transformed).reshape(B * S, D)
270
+
271
+ # Gated residual
272
+ g = self.level_gates[level].sigmoid()
273
+ h = h + g * token_update
274
+
275
+ return h.reshape(B, S, D)
276
+
277
+
278
+ # ══════════════════════════════════════════════════════════════════
279
+ # CONSTELLATION-CANTOR RELAY β€” FULL O(S) TRANSFORMER LAYER
280
+ # ══════════════════════════════════════════════════════════════════
281
+
282
+ class ConstellationCantorRelay(nn.Module):
283
+ """
284
+ Complete O(S) transformer layer. No attention.
285
+
286
+ per-token: ConstellationRelay (triangulate β†’ patchwork β†’ gated residual)
287
+ cross-token: CantorConstellationRouter (hierarchical scatter/gather through anchors)
288
+
289
+ The triangulation from the per-token relay is reused as routing keys
290
+ for the cross-token path β€” no redundant computation.
291
+ """
292
+
293
+ def __init__(self, dim=256, patch_dim=16, n_anchors=16, n_phases=3):
294
+ super().__init__()
295
+ self.relay = ConstellationRelay(
296
+ dim=dim, patch_dim=patch_dim, n_anchors=n_anchors, n_phases=n_phases)
297
+ self.router = CantorConstellationRouter(
298
+ dim=dim, n_anchors=n_anchors, n_patches=dim // patch_dim)
299
+ self.gate_relay = nn.Parameter(torch.tensor(-2.0))
300
+ self.gate_router = nn.Parameter(torch.tensor(-2.0))
301
+
302
+ def forward(self, x):
303
+ """x: (B, S, D)"""
304
+ B, S, D = x.shape
305
+
306
+ # Per-token relay β€” returns delta + triangulation
307
+ relay_out, tri = self.relay(x) # (B, S, D), (B, S, tri_dim)
308
+ relay_delta = relay_out - x
309
+
310
+ # Cross-token routing using triangulation as routing key
311
+ routed = self.router(x, tri) # (B, S, D)
312
+ router_delta = routed - x
313
+
314
+ # Gated combination
315
+ gr = self.gate_relay.sigmoid()
316
+ gc = self.gate_router.sigmoid()
317
+ return x + gr * relay_delta + gc * router_delta
318
+
319
+
320
+ # ══════════════════════════════════════════════════════════════════
321
+ # COMPARISON ARCHITECTURES
322
+ # ══════════════════════════════════════════════════════════════════
323
+
324
+ class VanillaAttention(nn.Module):
325
+ """Standard attention layer for comparison. O(SΒ²)."""
326
+ def __init__(self, dim=256, n_heads=4):
327
+ super().__init__()
328
+ self.n_heads = n_heads
329
+ self.head_dim = dim // n_heads
330
+ self.ln = nn.LayerNorm(dim)
331
+ self.qkv = nn.Linear(dim, 3 * dim)
332
+ self.proj = nn.Linear(dim, dim)
333
+
334
+ def forward(self, x):
335
+ B, S, D = x.shape
336
+ h = self.ln(x)
337
+ qkv = self.qkv(h).reshape(B, S, 3, self.n_heads, self.head_dim)
338
+ q, k, v = qkv.unbind(2)
339
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
340
+ attn = F.scaled_dot_product_attention(q, k, v)
341
+ return x + self.proj(attn.transpose(1, 2).reshape(B, S, D))
342
+
343
+
344
+ class HybridRelay(nn.Module):
345
+ """Constellation relay + vanilla attention. For comparison."""
346
+ def __init__(self, dim=256, n_heads=4):
347
+ super().__init__()
348
+ self.relay = ConstellationRelay(dim=dim)
349
+ self.n_heads = n_heads
350
+ self.head_dim = dim // n_heads
351
+ self.qkv = nn.Linear(dim, 3 * dim)
352
+ self.attn_proj = nn.Linear(dim, dim)
353
+ self.attn_ln = nn.LayerNorm(dim)
354
+ self.gate_relay = nn.Parameter(torch.tensor(-2.0))
355
+ self.gate_attn = nn.Parameter(torch.tensor(-2.0))
356
+
357
+ def forward(self, x):
358
+ B, S, D = x.shape
359
+ relay_out = self.relay.forward_no_tri(x)
360
+ relay_delta = relay_out - x
361
+
362
+ h = self.attn_ln(x)
363
+ qkv = self.qkv(h).reshape(B, S, 3, self.n_heads, self.head_dim)
364
+ q, k, v = qkv.unbind(2)
365
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
366
+ attn = F.scaled_dot_product_attention(q, k, v)
367
+ attn_out = self.attn_proj(attn.transpose(1, 2).reshape(B, S, D))
368
+
369
+ gr = self.gate_relay.sigmoid()
370
+ ga = self.gate_attn.sigmoid()
371
+ return x + gr * relay_delta + ga * attn_out
372
+
373
+
374
+ class PureRelayLayer(nn.Module):
375
+ """Relay-only, no cross-token. For comparison."""
376
+ def __init__(self, dim=256):
377
+ super().__init__()
378
+ self.relay = ConstellationRelay(dim=dim)
379
+
380
+ def forward(self, x):
381
+ return self.relay.forward_no_tri(x)
382
+
383
+
384
+ # ══════════════════════════════════════════════════════════════════
385
+ # UTILITIES
386
+ # ══════════════════════════════════════════════════════════════════
387
+
388
+ def reset_vram():
389
+ gc.collect()
390
+ torch.cuda.empty_cache()
391
+ torch.cuda.reset_peak_memory_stats()
392
+
393
+ def peak_mb():
394
+ return torch.cuda.max_memory_allocated() / 1e6
395
+
396
+ D = 256
397
+
398
+ print("=" * 80)
399
+ print("CONSTELLATION-CANTOR RELAY β€” O(S) CROSS-TOKEN ROUTING BENCHMARK")
400
+ print(f" Device: {torch.cuda.get_device_name()}")
401
+ print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
402
+ print(f" Dimension: {D}")
403
+ print("=" * 80)
404
+
405
+
406
+ # ══════════════════════════════════════════════════════════════════
407
+ # TEST 1: THROUGHPUT β€” ALL FOUR ARCHITECTURES
408
+ # ══════════════════════════════════════════════════════════════════
409
+
410
+ print(f"\n{'━'*80}")
411
+ print("TEST 1: Throughput Scaling β€” 4 architectures, S=64 to 131K")
412
+ print(" Single layer, B=1, fp16")
413
+ print(f"{'━'*80}")
414
+
415
+ SEQ_LENGTHS = [64, 256, 1024, 4096, 16384, 32768, 65536, 131072]
416
+
417
+ print(f"\n {'S':>8} {'relay':>9} {'cantor':>9} {'hybrid':>9} {'attn':>9} "
418
+ f"{'c/r':>6} {'c/a':>6} {'c_MB':>7}")
419
+
420
+ for S in SEQ_LENGTHS:
421
+ results = {}
422
+
423
+ for name, make_layer in [
424
+ ("relay", lambda: PureRelayLayer(D)),
425
+ ("cantor", lambda: ConstellationCantorRelay(D)),
426
+ ("hybrid", lambda: HybridRelay(D)),
427
+ ("attn", lambda: VanillaAttention(D)),
428
+ ]:
429
+ try:
430
+ reset_vram()
431
+ layer = make_layer().to(DEVICE).half()
432
+ x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1)
433
+
434
+ # Warmup
435
+ with torch.no_grad():
436
+ for _ in range(3):
437
+ _ = layer(x)
438
+ torch.cuda.synchronize()
439
+
440
+ t0 = time.perf_counter()
441
+ with torch.no_grad():
442
+ for _ in range(10):
443
+ _ = layer(x)
444
+ torch.cuda.synchronize()
445
+ ms = (time.perf_counter() - t0) / 10 * 1000
446
+ mb = peak_mb()
447
+ results[name] = (ms, mb)
448
+
449
+ del layer, x
450
+ reset_vram()
451
+
452
+ except (torch.cuda.OutOfMemoryError, RuntimeError):
453
+ results[name] = (float('inf'), float('inf'))
454
+ reset_vram()
455
+
456
+ r = results.get("relay", (0, 0))[0]
457
+ c = results.get("cantor", (0, 0))[0]
458
+ h = results.get("hybrid", (0, 0))[0]
459
+ a = results.get("attn", (0, 0))[0]
460
+ c_mb = results.get("cantor", (0, 0))[1]
461
+
462
+ def fmt(v):
463
+ return f"{v:>8.2f}ms" if v < float('inf') else " OOM"
464
+
465
+ cr_ratio = f"{c/r:>5.1f}Γ—" if r > 0 and c < float('inf') else " -"
466
+ ca_ratio = f"{c/a:>5.1f}Γ—" if a > 0 and a < float('inf') and c < float('inf') else " -"
467
+
468
+ print(f" {S:>8} {fmt(r)} {fmt(c)} {fmt(h)} {fmt(a)} "
469
+ f"{cr_ratio} {ca_ratio} {c_mb:>7.0f}")
470
+
471
+
472
+ # ══════════════════════════════════════════════════════════════════
473
+ # TEST 2: CROSS-TOKEN CAUSAL INTERVENTION β€” CANTOR vs HYBRID
474
+ # ══════════════════════════════════════════════════════════════════
475
+
476
+ print(f"\n{'━'*80}")
477
+ print("TEST 2: Cross-Token Causal Intervention")
478
+ print(" Modify token 0, measure effect on token S//2")
479
+ print(" 4 layers deep. Compare: cantor relay vs hybrid vs pure relay")
480
+ print(f"{'━'*80}")
481
+
482
+ N_LAYERS = 4
483
+
484
+ print(f"\n {'S':>8} {'arch':>10} {'Ξ”_mid':>10} {'Ξ”_last':>10} "
485
+ f"{'cos_orig':>10} {'time_ms':>10}")
486
+
487
+ for S in [64, 256, 1024, 4096, 16384]:
488
+ for arch_name, make_stack in [
489
+ ("cantor", lambda: nn.ModuleList([ConstellationCantorRelay(D) for _ in range(N_LAYERS)])),
490
+ ("hybrid", lambda: nn.ModuleList([HybridRelay(D) for _ in range(N_LAYERS)])),
491
+ ("relay", lambda: nn.ModuleList([PureRelayLayer(D) for _ in range(N_LAYERS)])),
492
+ ]:
493
+ try:
494
+ reset_vram()
495
+ torch.manual_seed(42)
496
+ stack = make_stack().to(DEVICE).half()
497
+
498
+ x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1)
499
+ x_mod = x.clone()
500
+ x_mod[:, 0] = F.normalize(torch.randn(1, D, device=DEVICE, dtype=torch.float16), dim=-1)
501
+
502
+ torch.cuda.synchronize()
503
+ t0 = time.perf_counter()
504
+
505
+ with torch.no_grad():
506
+ h = x.clone()
507
+ h_mod = x_mod.clone()
508
+ for layer in stack:
509
+ h = layer(h)
510
+ h_mod = layer(h_mod)
511
+
512
+ torch.cuda.synchronize()
513
+ elapsed = (time.perf_counter() - t0) * 1000
514
+
515
+ mid = S // 2
516
+ delta_mid = (h[0, mid].float() - h_mod[0, mid].float()).norm().item()
517
+ delta_last = (h[0, -1].float() - h_mod[0, -1].float()).norm().item()
518
+ cos_orig = F.cosine_similarity(
519
+ x[0, mid:mid+1].float(), h[0, mid:mid+1].float()).item()
520
+
521
+ print(f" {S:>8} {arch_name:>10} {delta_mid:>10.4f} {delta_last:>10.4f} "
522
+ f"{cos_orig:>10.4f} {elapsed:>10.1f}")
523
+
524
+ del stack, x, x_mod, h, h_mod
525
+ reset_vram()
526
+
527
+ except (torch.cuda.OutOfMemoryError, RuntimeError):
528
+ print(f" {S:>8} {arch_name:>10} OOM")
529
+ reset_vram()
530
+
531
+ print()
532
+
533
+
534
+ # ══════════════════════════════════════════════════════════════════
535
+ # TEST 3: GEOMETRIC PRESERVATION WITH CROSS-TOKEN ROUTING
536
+ # ══════════════════════════════════════════════════════════════════
537
+
538
+ print(f"\n{'━'*80}")
539
+ print("TEST 3: Geometric Preservation β€” does Cantor routing hurt geometry?")
540
+ print(" 8 layers, S=4096. Compare cos_to_orig, CV, eff_dim.")
541
+ print(f"{'━'*80}")
542
+
543
+ def compute_cv(points, n_samples=500):
544
+ N = points.shape[0]
545
+ if N < 5: return float('nan')
546
+ points = F.normalize(points.float(), dim=-1)
547
+ vols = []
548
+ for _ in range(n_samples):
549
+ idx = torch.randperm(min(N, 2000), device=points.device)[:5]
550
+ pts = points[idx].unsqueeze(0)
551
+ gram = torch.bmm(pts, pts.transpose(1, 2))
552
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
553
+ d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
554
+ d2 = F.relu(d2)
555
+ cm = torch.zeros(1, 6, 6, device=points.device, dtype=torch.float32)
556
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
557
+ v2 = -torch.linalg.det(cm) / 9216
558
+ if v2[0].item() > 1e-20:
559
+ vols.append(v2[0].sqrt().cpu())
560
+ if len(vols) < 50: return float('nan')
561
+ vt = torch.stack(vols)
562
+ return (vt.std() / (vt.mean() + 1e-8)).item()
563
+
564
+ GEO_DEPTH = 8
565
+ GEO_S = 4096
566
+
567
+ print(f"\n {'arch':>10} {'cos_orig':>10} {'norm':>8} {'CV':>8} "
568
+ f"{'eff_dim':>8} {'self_sim':>10}")
569
+
570
+ for arch_name, make_stack in [
571
+ ("relay", lambda: nn.ModuleList([PureRelayLayer(D) for _ in range(GEO_DEPTH)])),
572
+ ("cantor", lambda: nn.ModuleList([ConstellationCantorRelay(D) for _ in range(GEO_DEPTH)])),
573
+ ("hybrid", lambda: nn.ModuleList([HybridRelay(D) for _ in range(GEO_DEPTH)])),
574
+ ("attn", lambda: nn.ModuleList([VanillaAttention(D) for _ in range(GEO_DEPTH)])),
575
+ ]:
576
+ try:
577
+ reset_vram()
578
+ torch.manual_seed(42)
579
+ stack = make_stack().to(DEVICE).half()
580
+
581
+ x = F.normalize(torch.randn(1, GEO_S, D, device=DEVICE, dtype=torch.float16), dim=-1)
582
+
583
+ with torch.no_grad():
584
+ h = x.clone()
585
+ for layer in stack:
586
+ h = layer(h)
587
+
588
+ x_s = x[0, :512].float()
589
+ h_s = h[0, :512].float()
590
+ cos = F.cosine_similarity(x_s, h_s).mean().item()
591
+ norm = h_s.norm(dim=-1).mean().item()
592
+ h_n = F.normalize(h_s, dim=-1)
593
+ sim = h_n @ h_n.T
594
+ mask = ~torch.eye(512, device=DEVICE, dtype=torch.bool)
595
+ self_sim = sim[mask].mean().item()
596
+ cv = compute_cv(h_n, 500)
597
+
598
+ _, S_vals, _ = torch.linalg.svd(h_n[:256], full_matrices=False)
599
+ p = S_vals / S_vals.sum()
600
+ ed = p.pow(2).sum().reciprocal().item()
601
+
602
+ print(f" {arch_name:>10} {cos:>10.4f} {norm:>8.4f} {cv:>8.4f} "
603
+ f"{ed:>8.1f} {self_sim:>10.6f}")
604
+
605
+ del stack, x, h
606
+ reset_vram()
607
+
608
+ except (torch.cuda.OutOfMemoryError, RuntimeError):
609
+ print(f" {arch_name:>10} OOM")
610
+ reset_vram()
611
+
612
+
613
+ # ════════════════════════════════════════════════���═════════════════
614
+ # TEST 4: TRAINED CROSS-TOKEN TASK β€” ALL ARCHITECTURES
615
+ # ══════════════════════════════════════════════════════════════════
616
+
617
+ print(f"\n{'━'*80}")
618
+ print("TEST 4: Trained Cross-Token Task")
619
+ print(" Label = (token_0_class + token_1_class) % 10")
620
+ print(" Pure relay CANNOT solve this (zero cross-token info).")
621
+ print(" 4 layers, 500 steps, S=8.")
622
+ print(f"{'━'*80}")
623
+
624
+ S_TASK = 8
625
+ N_CLS = 10
626
+ N_SAMPLES = 4096
627
+ STEPS = 500
628
+
629
+ torch.manual_seed(777)
630
+ keys_a = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1)
631
+ keys_b = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1)
632
+
633
+ task_x = F.normalize(torch.randn(N_SAMPLES, S_TASK, D, device=DEVICE), dim=-1).clone()
634
+ label_a = torch.randint(0, N_CLS, (N_SAMPLES,), dtype=torch.long, device=DEVICE)
635
+ label_b = torch.randint(0, N_CLS, (N_SAMPLES,), dtype=torch.long, device=DEVICE)
636
+ task_x[:, 0] = keys_a[label_a] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2
637
+ task_x[:, 1] = keys_b[label_b] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2
638
+ task_x = F.normalize(task_x, dim=-1)
639
+ task_y = ((label_a + label_b) % N_CLS).long()
640
+
641
+ print(f"\n {'arch':>10} {'acc':>8} {'loss':>8} {'cross_Ξ”':>10} {'params':>10}")
642
+
643
+ for arch_name, make_stack in [
644
+ ("relay", lambda: nn.ModuleList([PureRelayLayer(D) for _ in range(4)])),
645
+ ("cantor", lambda: nn.ModuleList([ConstellationCantorRelay(D) for _ in range(4)])),
646
+ ("hybrid", lambda: nn.ModuleList([HybridRelay(D) for _ in range(4)])),
647
+ ("attn", lambda: nn.ModuleList([VanillaAttention(D) for _ in range(4)])),
648
+ ]:
649
+ torch.manual_seed(42)
650
+
651
+ class TaskModel(nn.Module):
652
+ def __init__(self, stack):
653
+ super().__init__()
654
+ self.layers = stack
655
+ self.pool = nn.Linear(D * S_TASK, D)
656
+ self.head = nn.Linear(D, N_CLS)
657
+
658
+ def forward(self, x):
659
+ for layer in self.layers:
660
+ x = layer(x)
661
+ return self.head(F.gelu(self.pool(x.reshape(x.shape[0], -1))))
662
+
663
+ model = TaskModel(make_stack()).to(DEVICE)
664
+ n_params = sum(p.numel() for p in model.parameters())
665
+ opt = torch.optim.Adam(model.parameters(), lr=3e-4)
666
+
667
+ for step in range(STEPS):
668
+ idx = torch.randint(0, N_SAMPLES, (128,))
669
+ logits = model(task_x[idx])
670
+ loss = F.cross_entropy(logits, task_y[idx])
671
+ if torch.isnan(loss) or torch.isinf(loss):
672
+ break
673
+ opt.zero_grad()
674
+ loss.backward()
675
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
676
+ opt.step()
677
+
678
+ model.eval()
679
+ with torch.no_grad():
680
+ logits = model(task_x[:1024])
681
+ acc = (logits.argmax(-1) == task_y[:1024]).float().mean().item()
682
+ final_loss = F.cross_entropy(logits, task_y[:1024]).item()
683
+
684
+ # Cross-token intervention
685
+ h1 = task_x[:64].clone()
686
+ for layer in model.layers:
687
+ h1 = layer(h1)
688
+ h2 = task_x[:64].clone()
689
+ h2[:, 0] = F.normalize(torch.randn(64, D, device=DEVICE), dim=-1)
690
+ for layer in model.layers:
691
+ h2 = layer(h2)
692
+ cross_delta = (h1[:, 1] - h2[:, 1]).norm(dim=-1).mean().item()
693
+
694
+ print(f" {arch_name:>10} {acc:>8.1%} {final_loss:>8.4f} {cross_delta:>10.4f} {n_params:>10,}")
695
+
696
+ del model
697
+ reset_vram()
698
+
699
+
700
+ # ══════════════════════════════════════════════════════════════════
701
+ # TEST 5: THE O(SΒ²) WALL β€” CANTOR vs ATTENTION at depth 8
702
+ # ══════════════════════════════════════════════════════════════════
703
+
704
+ print(f"\n{'━'*80}")
705
+ print("TEST 5: The O(SΒ²) Wall β€” Cantor vs Attention, 8 layers deep")
706
+ print(f"{'━'*80}")
707
+
708
+ WALL_DEPTH = 8
709
+
710
+ print(f"\n {'S':>8} {'cantor_ms':>10} {'attn_ms':>10} {'speedup':>8} "
711
+ f"{'c_cos':>8} {'a_cos':>8} {'c_MB':>8} {'a_MB':>8}")
712
+
713
+ for S in [1024, 4096, 8192, 16384, 32768, 65536, 131072]:
714
+ c_result = None
715
+ a_result = None
716
+
717
+ # Cantor
718
+ try:
719
+ reset_vram()
720
+ torch.manual_seed(42)
721
+ c_stack = nn.ModuleList([
722
+ ConstellationCantorRelay(D) for _ in range(WALL_DEPTH)
723
+ ]).to(DEVICE).half()
724
+
725
+ x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1)
726
+ with torch.no_grad():
727
+ h = x.clone()
728
+ for layer in c_stack:
729
+ h = layer(h)
730
+ torch.cuda.synchronize()
731
+
732
+ t0 = time.perf_counter()
733
+ with torch.no_grad():
734
+ h = x.clone()
735
+ for layer in c_stack:
736
+ h = layer(h)
737
+ torch.cuda.synchronize()
738
+ c_ms = (time.perf_counter() - t0) * 1000
739
+ c_mb = peak_mb()
740
+ c_cos = F.cosine_similarity(x[0, :256].float(), h[0, :256].float()).mean().item()
741
+ c_result = (c_ms, c_cos, c_mb)
742
+
743
+ del x, h, c_stack
744
+ reset_vram()
745
+ except (torch.cuda.OutOfMemoryError, RuntimeError):
746
+ reset_vram()
747
+
748
+ # Attention
749
+ try:
750
+ reset_vram()
751
+ torch.manual_seed(42)
752
+ a_stack = nn.ModuleList([
753
+ VanillaAttention(D) for _ in range(WALL_DEPTH)
754
+ ]).to(DEVICE).half()
755
+
756
+ x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1)
757
+ with torch.no_grad():
758
+ h = x.clone()
759
+ for layer in a_stack:
760
+ h = layer(h)
761
+ torch.cuda.synchronize()
762
+
763
+ t0 = time.perf_counter()
764
+ with torch.no_grad():
765
+ h = x.clone()
766
+ for layer in a_stack:
767
+ h = layer(h)
768
+ torch.cuda.synchronize()
769
+ a_ms = (time.perf_counter() - t0) * 1000
770
+ a_mb = peak_mb()
771
+ a_cos = F.cosine_similarity(x[0, :256].float(), h[0, :256].float()).mean().item()
772
+ a_result = (a_ms, a_cos, a_mb)
773
+
774
+ del x, h, a_stack
775
+ reset_vram()
776
+ except (torch.cuda.OutOfMemoryError, RuntimeError):
777
+ reset_vram()
778
+
779
+ c_str = f"{c_result[0]:>9.1f}ms" if c_result else " OOM"
780
+ a_str = f"{a_result[0]:>9.1f}ms" if a_result else " OOM"
781
+ sp = f"{a_result[0]/c_result[0]:>7.1f}Γ—" if c_result and a_result else " -"
782
+ cc = f"{c_result[1]:>8.4f}" if c_result else " ---"
783
+ ac = f"{a_result[1]:>8.4f}" if a_result else " ---"
784
+ cm = f"{c_result[2]:>8.0f}" if c_result else " OOM"
785
+ am = f"{a_result[2]:>8.0f}" if a_result else " OOM"
786
+
787
+ print(f" {S:>8} {c_str} {a_str} {sp} {cc} {ac} {cm} {am}")
788
+
789
+ if c_result is None:
790
+ print(f" β†’ Cantor OOM at S={S}, stopping")
791
+ break
792
+
793
+
794
+ # ══════════════════════════════════════════════════════════════════
795
+ # SUMMARY
796
+ # ══════════════════════════════════════════════════════════════════
797
+
798
+ print(f"\n{'='*80}")
799
+ print("CONSTELLATION-CANTOR RELAY β€” BENCHMARK COMPLETE")
800
+ print(f"{'='*80}")
801
+ print(f"""
802
+ Architecture:
803
+ per-token: constellation relay (triangulate β†’ patchwork β†’ gated residual)
804
+ cross-token: cantor router (hierarchical scatter/gather through anchor tree)
805
+ total: O(S) time, O(S) memory, no attention
806
+
807
+ 5 tests:
808
+ T1: Throughput β€” relay vs cantor vs hybrid vs attention, S to 131K
809
+ T2: Cross-token causal intervention β€” who routes strongest?
810
+ T3: Geometric preservation β€” does cross-token routing hurt geometry?
811
+ T4: Trained cross-token task β€” accuracy on interaction-dependent labels
812
+ T5: O(SΒ²) wall β€” cantor vs attention at 8 layers to OOM
813
+
814
+ Key questions answered:
815
+ β€’ Is the cantor router faster than attention at all sequence lengths?
816
+ β€’ Does it provide meaningful cross-token interaction?
817
+ β€’ Does the routing hurt per-token geometric preservation?
818
+ β€’ Can it solve tasks that require cross-token information?
819
+ """)