| |
| """ |
| Constellation-Cantor Relay β O(S) Cross-Token Routing |
| |
| This is likely one of the most powerful routing mechanisms that can exist in current spectrum |
| until more formulas are resolved. |
| |
| ======================================================= |
| Replaces attention entirely with triangulation-mediated hierarchical routing. |
| |
| Architecture: |
| per-token: constellation relay (triangulate β patchwork β gated residual) |
| cross-token: Cantor router (hierarchical scatter/gather through anchor tree) |
| |
| The triangulation profile IS the routing key. Tokens near the same anchor |
| on S^(d-1) share information at level 0. Anchor pairs share at level 1. |
| Quads at level 2. Full global at level log2(A). |
| |
| Total cross-token cost: O(S Γ n_levels) = O(S Γ 4) for 16 anchors. |
| Total per-token cost: O(S Γ tri_dim Γ pw_hidden). |
| No attention anywhere. Fully O(S). |
| |
| Benchmarks: |
| 1. Throughput: cantor-relay vs hybrid vs pure relay vs attention |
| 2. Cross-token causal intervention at scale |
| 3. Geometric preservation |
| 4. Trained task requiring cross-token routing |
| """ |
|
|
| import os |
| os.environ["CUDA_LAUNCH_BLOCKING"] = "1" |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import math |
| import time |
| import gc |
| from collections import OrderedDict |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
| |
| |
| |
|
|
| class SquaredReLU(nn.Module): |
| def forward(self, x): return F.relu(x) ** 2 |
|
|
|
|
| |
| |
| |
|
|
| class ConstellationRelay(nn.Module): |
| """Per-token constellation triangulation + patchwork. O(S).""" |
|
|
| def __init__(self, dim=256, patch_dim=16, n_anchors=16, n_phases=3): |
| super().__init__() |
| self.dim = dim |
| self.patch_dim = patch_dim |
| self.n_patches = dim // patch_dim |
| self.n_anchors = n_anchors |
| self.n_phases = n_phases |
| P, A, d = self.n_patches, n_anchors, patch_dim |
|
|
| self.ln = nn.LayerNorm(dim) |
|
|
| home = torch.empty(P, A, d) |
| nn.init.xavier_normal_(home.view(P * A, d)) |
| home = F.normalize(home.view(P, A, d), dim=-1) |
| self.register_buffer('home', home) |
| self.anchors = nn.Parameter(home.clone()) |
|
|
| tri_dim = P * A * n_phases |
| self.tri_dim = tri_dim |
| pw_hidden = tri_dim * 2 |
|
|
| self.patchwork = nn.Sequential( |
| nn.Linear(tri_dim, pw_hidden), |
| SquaredReLU(), |
| nn.LayerNorm(pw_hidden), |
| nn.Linear(pw_hidden, dim), |
| ) |
| self.gate = nn.Parameter(torch.tensor(-3.0)) |
|
|
| def drift(self): |
| h = F.normalize(self.home.float(), dim=-1) |
| c = F.normalize(self.anchors.float(), dim=-1) |
| return torch.acos((h * c).sum(-1).clamp(-1 + 1e-6, 1 - 1e-6)) |
|
|
| def at_phase(self, t): |
| h = F.normalize(self.home.float(), dim=-1) |
| c = F.normalize(self.anchors.float(), dim=-1) |
| omega = self.drift().unsqueeze(-1) |
| so = omega.sin().clamp(min=1e-6) |
| return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c |
|
|
| def triangulate(self, patches_n): |
| phases = torch.linspace(0, 1, self.n_phases, device=patches_n.device).tolist() |
| tris = [] |
| for t in phases: |
| at = F.normalize(self.at_phase(t), dim=-1).to(patches_n.dtype) |
| tris.append(1.0 - torch.einsum('bpd,pad->bpa', patches_n, at)) |
| return torch.cat(tris, dim=-1).reshape(patches_n.shape[0], -1) |
|
|
| def forward(self, x): |
| """x: (B*S, D) or (B, S, D)""" |
| is_seq = x.dim() == 3 |
| if is_seq: |
| B, S, D = x.shape |
| x_flat = x.reshape(B * S, D) |
| else: |
| x_flat = x |
|
|
| residual = x_flat |
| h = self.ln(x_flat) |
| patches = h.reshape(-1, self.n_patches, self.patch_dim) |
| patches_n = F.normalize(patches, dim=-1) |
| tri = self.triangulate(patches_n) |
| pw_out = self.patchwork(tri) |
| g = self.gate.sigmoid() |
| out = residual + g * pw_out |
|
|
| if is_seq: |
| return out.reshape(B, S, D), tri.reshape(B, S, -1) |
| return out, tri |
|
|
| def forward_no_tri(self, x): |
| """Original forward without returning tri β for compatibility.""" |
| out, _ = self.forward(x) |
| return out |
|
|
|
|
| |
| |
| |
|
|
| class CantorConstellationRouter(nn.Module): |
| """ |
| Hierarchical cross-token routing through the constellation anchor tree. |
| |
| The triangulation profile assigns each token to a region on S^(d-1). |
| A binary tree over anchors defines the routing hierarchy: |
| |
| Level 0: A groups (per-anchor, local neighbors) |
| Level 1: A/2 groups (anchor pairs, nearby interaction) |
| Level 2: A/4 groups (quads, medium range) |
| ... |
| Level L: 1 group (global summary) |
| |
| At each level: |
| 1. Soft-assign tokens to groups via triangulation weights |
| 2. Weighted scatter: aggregate token representations per group |
| 3. Transform: per-level MLP on group summaries |
| 4. Weighted gather: distribute transformed summaries back to tokens |
| 5. Gated residual addition |
| |
| Cost: O(S Γ L Γ D) where L = log2(A) + 1 = 5 for A=16. |
| Memory: O(S Γ D + A Γ D) β no SΒ² term anywhere. |
| """ |
|
|
| def __init__(self, dim=256, n_anchors=16, n_patches=16): |
| super().__init__() |
| self.dim = dim |
| self.n_anchors = n_anchors |
| self.n_patches = n_patches |
| self.n_levels = int(math.log2(n_anchors)) + 1 |
|
|
| |
| |
| |
|
|
| |
| self.level_mlps = nn.ModuleList() |
| self.level_gates = nn.ParameterList() |
| self.level_lns = nn.ModuleList() |
|
|
| for l in range(self.n_levels): |
| n_groups = max(1, n_anchors // (2 ** l)) |
| self.level_mlps.append(nn.Sequential( |
| nn.Linear(dim, dim * 2), |
| SquaredReLU(), |
| nn.Linear(dim * 2, dim), |
| )) |
| self.level_lns.append(nn.LayerNorm(dim)) |
| self.level_gates.append(nn.Parameter(torch.tensor(-3.0))) |
|
|
| |
| |
| self.weight_proj = nn.Linear(n_patches * n_anchors, n_anchors) |
|
|
| def compute_routing_weights(self, tri, n_anchors): |
| """ |
| Extract soft anchor assignment weights from triangulation profile. |
| |
| tri: (BS, tri_dim) β full triangulation (n_patches Γ n_anchors Γ n_phases) |
| Returns: (BS, n_anchors) β soft assignment weights (sum to 1) |
| """ |
| BS = tri.shape[0] |
| |
| |
| phase0 = tri[:, :self.n_patches * n_anchors] |
|
|
| |
| |
| dists = phase0.reshape(BS, self.n_patches, n_anchors).mean(dim=1) |
|
|
| |
| |
| proximity = (2.0 - dists).clamp(min=0) |
| weights = F.softmax(proximity * 5.0, dim=-1) |
| return weights |
|
|
| def forward(self, x, tri): |
| """ |
| x: (B, S, D) token representations |
| tri: (B, S, tri_dim) triangulation profiles from constellation |
| |
| Returns: (B, S, D) with cross-token information routed through anchor hierarchy |
| """ |
| B, S, D = x.shape |
| x_flat = x.reshape(B * S, D) |
| tri_flat = tri.reshape(B * S, -1) |
|
|
| |
| weights = self.compute_routing_weights(tri_flat, self.n_anchors) |
|
|
| h = x_flat |
|
|
| for level in range(self.n_levels): |
| group_size = 2 ** level |
| n_groups = max(1, self.n_anchors // group_size) |
|
|
| |
| |
| if n_groups > 1: |
| group_weights = weights.reshape(B * S, n_groups, group_size).sum(dim=-1) |
| else: |
| group_weights = weights.sum(dim=-1, keepdim=True) |
|
|
| |
| group_weights = group_weights / (group_weights.sum(dim=-1, keepdim=True) + 1e-8) |
|
|
| |
| |
| |
| |
|
|
| gw = group_weights.reshape(B, S, n_groups) |
| hh = h.reshape(B, S, D) |
|
|
| |
| group_summary = torch.bmm(gw.transpose(1, 2), hh) |
|
|
| |
| weight_sums = gw.sum(dim=1).unsqueeze(-1).clamp(min=1e-8) |
| group_summary = group_summary / weight_sums |
|
|
| |
| gs_flat = group_summary.reshape(B * n_groups, D) |
| gs_flat = self.level_lns[level](gs_flat) |
| gs_transformed = self.level_mlps[level](gs_flat) |
| gs_transformed = gs_transformed.reshape(B, n_groups, D) |
|
|
| |
| |
| |
| token_update = torch.bmm(gw, gs_transformed).reshape(B * S, D) |
|
|
| |
| g = self.level_gates[level].sigmoid() |
| h = h + g * token_update |
|
|
| return h.reshape(B, S, D) |
|
|
|
|
| |
| |
| |
|
|
| class ConstellationCantorRelay(nn.Module): |
| """ |
| Complete O(S) transformer layer. No attention. |
| |
| per-token: ConstellationRelay (triangulate β patchwork β gated residual) |
| cross-token: CantorConstellationRouter (hierarchical scatter/gather through anchors) |
| |
| The triangulation from the per-token relay is reused as routing keys |
| for the cross-token path β no redundant computation. |
| """ |
|
|
| def __init__(self, dim=256, patch_dim=16, n_anchors=16, n_phases=3): |
| super().__init__() |
| self.relay = ConstellationRelay( |
| dim=dim, patch_dim=patch_dim, n_anchors=n_anchors, n_phases=n_phases) |
| self.router = CantorConstellationRouter( |
| dim=dim, n_anchors=n_anchors, n_patches=dim // patch_dim) |
| self.gate_relay = nn.Parameter(torch.tensor(-2.0)) |
| self.gate_router = nn.Parameter(torch.tensor(-2.0)) |
|
|
| def forward(self, x): |
| """x: (B, S, D)""" |
| B, S, D = x.shape |
|
|
| |
| relay_out, tri = self.relay(x) |
| relay_delta = relay_out - x |
|
|
| |
| routed = self.router(x, tri) |
| router_delta = routed - x |
|
|
| |
| gr = self.gate_relay.sigmoid() |
| gc = self.gate_router.sigmoid() |
| return x + gr * relay_delta + gc * router_delta |
|
|
|
|
| |
| |
| |
|
|
| class VanillaAttention(nn.Module): |
| """Standard attention layer for comparison. O(SΒ²).""" |
| def __init__(self, dim=256, n_heads=4): |
| super().__init__() |
| self.n_heads = n_heads |
| self.head_dim = dim // n_heads |
| self.ln = nn.LayerNorm(dim) |
| self.qkv = nn.Linear(dim, 3 * dim) |
| self.proj = nn.Linear(dim, dim) |
|
|
| def forward(self, x): |
| B, S, D = x.shape |
| h = self.ln(x) |
| qkv = self.qkv(h).reshape(B, S, 3, self.n_heads, self.head_dim) |
| q, k, v = qkv.unbind(2) |
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
| attn = F.scaled_dot_product_attention(q, k, v) |
| return x + self.proj(attn.transpose(1, 2).reshape(B, S, D)) |
|
|
|
|
| class HybridRelay(nn.Module): |
| """Constellation relay + vanilla attention. For comparison.""" |
| def __init__(self, dim=256, n_heads=4): |
| super().__init__() |
| self.relay = ConstellationRelay(dim=dim) |
| self.n_heads = n_heads |
| self.head_dim = dim // n_heads |
| self.qkv = nn.Linear(dim, 3 * dim) |
| self.attn_proj = nn.Linear(dim, dim) |
| self.attn_ln = nn.LayerNorm(dim) |
| self.gate_relay = nn.Parameter(torch.tensor(-2.0)) |
| self.gate_attn = nn.Parameter(torch.tensor(-2.0)) |
|
|
| def forward(self, x): |
| B, S, D = x.shape |
| relay_out = self.relay.forward_no_tri(x) |
| relay_delta = relay_out - x |
|
|
| h = self.attn_ln(x) |
| qkv = self.qkv(h).reshape(B, S, 3, self.n_heads, self.head_dim) |
| q, k, v = qkv.unbind(2) |
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
| attn = F.scaled_dot_product_attention(q, k, v) |
| attn_out = self.attn_proj(attn.transpose(1, 2).reshape(B, S, D)) |
|
|
| gr = self.gate_relay.sigmoid() |
| ga = self.gate_attn.sigmoid() |
| return x + gr * relay_delta + ga * attn_out |
|
|
|
|
| class PureRelayLayer(nn.Module): |
| """Relay-only, no cross-token. For comparison.""" |
| def __init__(self, dim=256): |
| super().__init__() |
| self.relay = ConstellationRelay(dim=dim) |
|
|
| def forward(self, x): |
| return self.relay.forward_no_tri(x) |
|
|
|
|
| |
| |
| |
|
|
| def reset_vram(): |
| gc.collect() |
| torch.cuda.empty_cache() |
| torch.cuda.reset_peak_memory_stats() |
|
|
| def peak_mb(): |
| return torch.cuda.max_memory_allocated() / 1e6 |
|
|
| D = 256 |
|
|
| print("=" * 80) |
| print("CONSTELLATION-CANTOR RELAY β O(S) CROSS-TOKEN ROUTING BENCHMARK") |
| print(f" Device: {torch.cuda.get_device_name()}") |
| print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| print(f" Dimension: {D}") |
| print("=" * 80) |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'β'*80}") |
| print("TEST 1: Throughput Scaling β 4 architectures, S=64 to 131K") |
| print(" Single layer, B=1, fp16") |
| print(f"{'β'*80}") |
|
|
| SEQ_LENGTHS = [64, 256, 1024, 4096, 16384, 32768, 65536, 131072] |
|
|
| print(f"\n {'S':>8} {'relay':>9} {'cantor':>9} {'hybrid':>9} {'attn':>9} " |
| f"{'c/r':>6} {'c/a':>6} {'c_MB':>7}") |
|
|
| for S in SEQ_LENGTHS: |
| results = {} |
|
|
| for name, make_layer in [ |
| ("relay", lambda: PureRelayLayer(D)), |
| ("cantor", lambda: ConstellationCantorRelay(D)), |
| ("hybrid", lambda: HybridRelay(D)), |
| ("attn", lambda: VanillaAttention(D)), |
| ]: |
| try: |
| reset_vram() |
| layer = make_layer().to(DEVICE).half() |
| x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1) |
|
|
| |
| with torch.no_grad(): |
| for _ in range(3): |
| _ = layer(x) |
| torch.cuda.synchronize() |
|
|
| t0 = time.perf_counter() |
| with torch.no_grad(): |
| for _ in range(10): |
| _ = layer(x) |
| torch.cuda.synchronize() |
| ms = (time.perf_counter() - t0) / 10 * 1000 |
| mb = peak_mb() |
| results[name] = (ms, mb) |
|
|
| del layer, x |
| reset_vram() |
|
|
| except (torch.cuda.OutOfMemoryError, RuntimeError): |
| results[name] = (float('inf'), float('inf')) |
| reset_vram() |
|
|
| r = results.get("relay", (0, 0))[0] |
| c = results.get("cantor", (0, 0))[0] |
| h = results.get("hybrid", (0, 0))[0] |
| a = results.get("attn", (0, 0))[0] |
| c_mb = results.get("cantor", (0, 0))[1] |
|
|
| def fmt(v): |
| return f"{v:>8.2f}ms" if v < float('inf') else " OOM" |
|
|
| cr_ratio = f"{c/r:>5.1f}Γ" if r > 0 and c < float('inf') else " -" |
| ca_ratio = f"{c/a:>5.1f}Γ" if a > 0 and a < float('inf') and c < float('inf') else " -" |
|
|
| print(f" {S:>8} {fmt(r)} {fmt(c)} {fmt(h)} {fmt(a)} " |
| f"{cr_ratio} {ca_ratio} {c_mb:>7.0f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'β'*80}") |
| print("TEST 2: Cross-Token Causal Intervention") |
| print(" Modify token 0, measure effect on token S//2") |
| print(" 4 layers deep. Compare: cantor relay vs hybrid vs pure relay") |
| print(f"{'β'*80}") |
|
|
| N_LAYERS = 4 |
|
|
| print(f"\n {'S':>8} {'arch':>10} {'Ξ_mid':>10} {'Ξ_last':>10} " |
| f"{'cos_orig':>10} {'time_ms':>10}") |
|
|
| for S in [64, 256, 1024, 4096, 16384]: |
| for arch_name, make_stack in [ |
| ("cantor", lambda: nn.ModuleList([ConstellationCantorRelay(D) for _ in range(N_LAYERS)])), |
| ("hybrid", lambda: nn.ModuleList([HybridRelay(D) for _ in range(N_LAYERS)])), |
| ("relay", lambda: nn.ModuleList([PureRelayLayer(D) for _ in range(N_LAYERS)])), |
| ]: |
| try: |
| reset_vram() |
| torch.manual_seed(42) |
| stack = make_stack().to(DEVICE).half() |
|
|
| x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1) |
| x_mod = x.clone() |
| x_mod[:, 0] = F.normalize(torch.randn(1, D, device=DEVICE, dtype=torch.float16), dim=-1) |
|
|
| torch.cuda.synchronize() |
| t0 = time.perf_counter() |
|
|
| with torch.no_grad(): |
| h = x.clone() |
| h_mod = x_mod.clone() |
| for layer in stack: |
| h = layer(h) |
| h_mod = layer(h_mod) |
|
|
| torch.cuda.synchronize() |
| elapsed = (time.perf_counter() - t0) * 1000 |
|
|
| mid = S // 2 |
| delta_mid = (h[0, mid].float() - h_mod[0, mid].float()).norm().item() |
| delta_last = (h[0, -1].float() - h_mod[0, -1].float()).norm().item() |
| cos_orig = F.cosine_similarity( |
| x[0, mid:mid+1].float(), h[0, mid:mid+1].float()).item() |
|
|
| print(f" {S:>8} {arch_name:>10} {delta_mid:>10.4f} {delta_last:>10.4f} " |
| f"{cos_orig:>10.4f} {elapsed:>10.1f}") |
|
|
| del stack, x, x_mod, h, h_mod |
| reset_vram() |
|
|
| except (torch.cuda.OutOfMemoryError, RuntimeError): |
| print(f" {S:>8} {arch_name:>10} OOM") |
| reset_vram() |
|
|
| print() |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'β'*80}") |
| print("TEST 3: Geometric Preservation β does Cantor routing hurt geometry?") |
| print(" 8 layers, S=4096. Compare cos_to_orig, CV, eff_dim.") |
| print(f"{'β'*80}") |
|
|
| def compute_cv(points, n_samples=500): |
| N = points.shape[0] |
| if N < 5: return float('nan') |
| points = F.normalize(points.float(), dim=-1) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(min(N, 2000), device=points.device)[:5] |
| pts = points[idx].unsqueeze(0) |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram |
| d2 = F.relu(d2) |
| cm = torch.zeros(1, 6, 6, device=points.device, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| v2 = -torch.linalg.det(cm) / 9216 |
| if v2[0].item() > 1e-20: |
| vols.append(v2[0].sqrt().cpu()) |
| if len(vols) < 50: return float('nan') |
| vt = torch.stack(vols) |
| return (vt.std() / (vt.mean() + 1e-8)).item() |
|
|
| GEO_DEPTH = 8 |
| GEO_S = 4096 |
|
|
| print(f"\n {'arch':>10} {'cos_orig':>10} {'norm':>8} {'CV':>8} " |
| f"{'eff_dim':>8} {'self_sim':>10}") |
|
|
| for arch_name, make_stack in [ |
| ("relay", lambda: nn.ModuleList([PureRelayLayer(D) for _ in range(GEO_DEPTH)])), |
| ("cantor", lambda: nn.ModuleList([ConstellationCantorRelay(D) for _ in range(GEO_DEPTH)])), |
| ("hybrid", lambda: nn.ModuleList([HybridRelay(D) for _ in range(GEO_DEPTH)])), |
| ("attn", lambda: nn.ModuleList([VanillaAttention(D) for _ in range(GEO_DEPTH)])), |
| ]: |
| try: |
| reset_vram() |
| torch.manual_seed(42) |
| stack = make_stack().to(DEVICE).half() |
|
|
| x = F.normalize(torch.randn(1, GEO_S, D, device=DEVICE, dtype=torch.float16), dim=-1) |
|
|
| with torch.no_grad(): |
| h = x.clone() |
| for layer in stack: |
| h = layer(h) |
|
|
| x_s = x[0, :512].float() |
| h_s = h[0, :512].float() |
| cos = F.cosine_similarity(x_s, h_s).mean().item() |
| norm = h_s.norm(dim=-1).mean().item() |
| h_n = F.normalize(h_s, dim=-1) |
| sim = h_n @ h_n.T |
| mask = ~torch.eye(512, device=DEVICE, dtype=torch.bool) |
| self_sim = sim[mask].mean().item() |
| cv = compute_cv(h_n, 500) |
|
|
| _, S_vals, _ = torch.linalg.svd(h_n[:256], full_matrices=False) |
| p = S_vals / S_vals.sum() |
| ed = p.pow(2).sum().reciprocal().item() |
|
|
| print(f" {arch_name:>10} {cos:>10.4f} {norm:>8.4f} {cv:>8.4f} " |
| f"{ed:>8.1f} {self_sim:>10.6f}") |
|
|
| del stack, x, h |
| reset_vram() |
|
|
| except (torch.cuda.OutOfMemoryError, RuntimeError): |
| print(f" {arch_name:>10} OOM") |
| reset_vram() |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'β'*80}") |
| print("TEST 4: Trained Cross-Token Task") |
| print(" Label = (token_0_class + token_1_class) % 10") |
| print(" Pure relay CANNOT solve this (zero cross-token info).") |
| print(" 4 layers, 500 steps, S=8.") |
| print(f"{'β'*80}") |
|
|
| S_TASK = 8 |
| N_CLS = 10 |
| N_SAMPLES = 4096 |
| STEPS = 500 |
|
|
| torch.manual_seed(777) |
| keys_a = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1) |
| keys_b = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1) |
|
|
| task_x = F.normalize(torch.randn(N_SAMPLES, S_TASK, D, device=DEVICE), dim=-1).clone() |
| label_a = torch.randint(0, N_CLS, (N_SAMPLES,), dtype=torch.long, device=DEVICE) |
| label_b = torch.randint(0, N_CLS, (N_SAMPLES,), dtype=torch.long, device=DEVICE) |
| task_x[:, 0] = keys_a[label_a] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2 |
| task_x[:, 1] = keys_b[label_b] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2 |
| task_x = F.normalize(task_x, dim=-1) |
| task_y = ((label_a + label_b) % N_CLS).long() |
|
|
| print(f"\n {'arch':>10} {'acc':>8} {'loss':>8} {'cross_Ξ':>10} {'params':>10}") |
|
|
| for arch_name, make_stack in [ |
| ("relay", lambda: nn.ModuleList([PureRelayLayer(D) for _ in range(4)])), |
| ("cantor", lambda: nn.ModuleList([ConstellationCantorRelay(D) for _ in range(4)])), |
| ("hybrid", lambda: nn.ModuleList([HybridRelay(D) for _ in range(4)])), |
| ("attn", lambda: nn.ModuleList([VanillaAttention(D) for _ in range(4)])), |
| ]: |
| torch.manual_seed(42) |
|
|
| class TaskModel(nn.Module): |
| def __init__(self, stack): |
| super().__init__() |
| self.layers = stack |
| self.pool = nn.Linear(D * S_TASK, D) |
| self.head = nn.Linear(D, N_CLS) |
|
|
| def forward(self, x): |
| for layer in self.layers: |
| x = layer(x) |
| return self.head(F.gelu(self.pool(x.reshape(x.shape[0], -1)))) |
|
|
| model = TaskModel(make_stack()).to(DEVICE) |
| n_params = sum(p.numel() for p in model.parameters()) |
| opt = torch.optim.Adam(model.parameters(), lr=3e-4) |
|
|
| for step in range(STEPS): |
| idx = torch.randint(0, N_SAMPLES, (128,)) |
| logits = model(task_x[idx]) |
| loss = F.cross_entropy(logits, task_y[idx]) |
| if torch.isnan(loss) or torch.isinf(loss): |
| break |
| opt.zero_grad() |
| loss.backward() |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
|
|
| model.eval() |
| with torch.no_grad(): |
| logits = model(task_x[:1024]) |
| acc = (logits.argmax(-1) == task_y[:1024]).float().mean().item() |
| final_loss = F.cross_entropy(logits, task_y[:1024]).item() |
|
|
| |
| h1 = task_x[:64].clone() |
| for layer in model.layers: |
| h1 = layer(h1) |
| h2 = task_x[:64].clone() |
| h2[:, 0] = F.normalize(torch.randn(64, D, device=DEVICE), dim=-1) |
| for layer in model.layers: |
| h2 = layer(h2) |
| cross_delta = (h1[:, 1] - h2[:, 1]).norm(dim=-1).mean().item() |
|
|
| print(f" {arch_name:>10} {acc:>8.1%} {final_loss:>8.4f} {cross_delta:>10.4f} {n_params:>10,}") |
|
|
| del model |
| reset_vram() |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'β'*80}") |
| print("TEST 5: The O(SΒ²) Wall β Cantor vs Attention, 8 layers deep") |
| print(f"{'β'*80}") |
|
|
| WALL_DEPTH = 8 |
|
|
| print(f"\n {'S':>8} {'cantor_ms':>10} {'attn_ms':>10} {'speedup':>8} " |
| f"{'c_cos':>8} {'a_cos':>8} {'c_MB':>8} {'a_MB':>8}") |
|
|
| for S in [1024, 4096, 8192, 16384, 32768, 65536, 131072]: |
| c_result = None |
| a_result = None |
|
|
| |
| try: |
| reset_vram() |
| torch.manual_seed(42) |
| c_stack = nn.ModuleList([ |
| ConstellationCantorRelay(D) for _ in range(WALL_DEPTH) |
| ]).to(DEVICE).half() |
|
|
| x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1) |
| with torch.no_grad(): |
| h = x.clone() |
| for layer in c_stack: |
| h = layer(h) |
| torch.cuda.synchronize() |
|
|
| t0 = time.perf_counter() |
| with torch.no_grad(): |
| h = x.clone() |
| for layer in c_stack: |
| h = layer(h) |
| torch.cuda.synchronize() |
| c_ms = (time.perf_counter() - t0) * 1000 |
| c_mb = peak_mb() |
| c_cos = F.cosine_similarity(x[0, :256].float(), h[0, :256].float()).mean().item() |
| c_result = (c_ms, c_cos, c_mb) |
|
|
| del x, h, c_stack |
| reset_vram() |
| except (torch.cuda.OutOfMemoryError, RuntimeError): |
| reset_vram() |
|
|
| |
| try: |
| reset_vram() |
| torch.manual_seed(42) |
| a_stack = nn.ModuleList([ |
| VanillaAttention(D) for _ in range(WALL_DEPTH) |
| ]).to(DEVICE).half() |
|
|
| x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1) |
| with torch.no_grad(): |
| h = x.clone() |
| for layer in a_stack: |
| h = layer(h) |
| torch.cuda.synchronize() |
|
|
| t0 = time.perf_counter() |
| with torch.no_grad(): |
| h = x.clone() |
| for layer in a_stack: |
| h = layer(h) |
| torch.cuda.synchronize() |
| a_ms = (time.perf_counter() - t0) * 1000 |
| a_mb = peak_mb() |
| a_cos = F.cosine_similarity(x[0, :256].float(), h[0, :256].float()).mean().item() |
| a_result = (a_ms, a_cos, a_mb) |
|
|
| del x, h, a_stack |
| reset_vram() |
| except (torch.cuda.OutOfMemoryError, RuntimeError): |
| reset_vram() |
|
|
| c_str = f"{c_result[0]:>9.1f}ms" if c_result else " OOM" |
| a_str = f"{a_result[0]:>9.1f}ms" if a_result else " OOM" |
| sp = f"{a_result[0]/c_result[0]:>7.1f}Γ" if c_result and a_result else " -" |
| cc = f"{c_result[1]:>8.4f}" if c_result else " ---" |
| ac = f"{a_result[1]:>8.4f}" if a_result else " ---" |
| cm = f"{c_result[2]:>8.0f}" if c_result else " OOM" |
| am = f"{a_result[2]:>8.0f}" if a_result else " OOM" |
|
|
| print(f" {S:>8} {c_str} {a_str} {sp} {cc} {ac} {cm} {am}") |
|
|
| if c_result is None: |
| print(f" β Cantor OOM at S={S}, stopping") |
| break |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*80}") |
| print("CONSTELLATION-CANTOR RELAY β BENCHMARK COMPLETE") |
| print(f"{'='*80}") |
| print(f""" |
| Architecture: |
| per-token: constellation relay (triangulate β patchwork β gated residual) |
| cross-token: cantor router (hierarchical scatter/gather through anchor tree) |
| total: O(S) time, O(S) memory, no attention |
| |
| 5 tests: |
| T1: Throughput β relay vs cantor vs hybrid vs attention, S to 131K |
| T2: Cross-token causal intervention β who routes strongest? |
| T3: Geometric preservation β does cross-token routing hurt geometry? |
| T4: Trained cross-token task β accuracy on interaction-dependent labels |
| T5: O(SΒ²) wall β cantor vs attention at 8 layers to OOM |
| |
| Key questions answered: |
| β’ Is the cantor router faster than attention at all sequence lengths? |
| β’ Does it provide meaningful cross-token interaction? |
| β’ Does the routing hurt per-token geometric preservation? |
| β’ Can it solve tasks that require cross-token information? |
| """) |