geolip-hypersphere-experiments / constellation_cantor_routing.py
AbstractPhil's picture
Create constellation_cantor_routing.py
9243884 verified
#!/usr/bin/env python3
"""
Constellation-Cantor Relay β€” O(S) Cross-Token Routing
This is likely one of the most powerful routing mechanisms that can exist in current spectrum
until more formulas are resolved.
=======================================================
Replaces attention entirely with triangulation-mediated hierarchical routing.
Architecture:
per-token: constellation relay (triangulate β†’ patchwork β†’ gated residual)
cross-token: Cantor router (hierarchical scatter/gather through anchor tree)
The triangulation profile IS the routing key. Tokens near the same anchor
on S^(d-1) share information at level 0. Anchor pairs share at level 1.
Quads at level 2. Full global at level log2(A).
Total cross-token cost: O(S Γ— n_levels) = O(S Γ— 4) for 16 anchors.
Total per-token cost: O(S Γ— tri_dim Γ— pw_hidden).
No attention anywhere. Fully O(S).
Benchmarks:
1. Throughput: cantor-relay vs hybrid vs pure relay vs attention
2. Cross-token causal intervention at scale
3. Geometric preservation
4. Trained task requiring cross-token routing
"""
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import time
import gc
from collections import OrderedDict
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# ══════════════════════════════════════════════════════════════════
# ACTIVATIONS
# ══════════════════════════════════════════════════════════════════
class SquaredReLU(nn.Module):
def forward(self, x): return F.relu(x) ** 2
# ══════════════════════════════════════════════════════════════════
# CONSTELLATION RELAY β€” per-token geometric layer
# ══════════════════════════════════════════════════════════════════
class ConstellationRelay(nn.Module):
"""Per-token constellation triangulation + patchwork. O(S)."""
def __init__(self, dim=256, patch_dim=16, n_anchors=16, n_phases=3):
super().__init__()
self.dim = dim
self.patch_dim = patch_dim
self.n_patches = dim // patch_dim
self.n_anchors = n_anchors
self.n_phases = n_phases
P, A, d = self.n_patches, n_anchors, patch_dim
self.ln = nn.LayerNorm(dim)
home = torch.empty(P, A, d)
nn.init.xavier_normal_(home.view(P * A, d))
home = F.normalize(home.view(P, A, d), dim=-1)
self.register_buffer('home', home)
self.anchors = nn.Parameter(home.clone())
tri_dim = P * A * n_phases
self.tri_dim = tri_dim
pw_hidden = tri_dim * 2
self.patchwork = nn.Sequential(
nn.Linear(tri_dim, pw_hidden),
SquaredReLU(),
nn.LayerNorm(pw_hidden),
nn.Linear(pw_hidden, dim),
)
self.gate = nn.Parameter(torch.tensor(-3.0))
def drift(self):
h = F.normalize(self.home.float(), dim=-1)
c = F.normalize(self.anchors.float(), dim=-1)
return torch.acos((h * c).sum(-1).clamp(-1 + 1e-6, 1 - 1e-6))
def at_phase(self, t):
h = F.normalize(self.home.float(), dim=-1)
c = F.normalize(self.anchors.float(), dim=-1)
omega = self.drift().unsqueeze(-1)
so = omega.sin().clamp(min=1e-6)
return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c
def triangulate(self, patches_n):
phases = torch.linspace(0, 1, self.n_phases, device=patches_n.device).tolist()
tris = []
for t in phases:
at = F.normalize(self.at_phase(t), dim=-1).to(patches_n.dtype)
tris.append(1.0 - torch.einsum('bpd,pad->bpa', patches_n, at))
return torch.cat(tris, dim=-1).reshape(patches_n.shape[0], -1)
def forward(self, x):
"""x: (B*S, D) or (B, S, D)"""
is_seq = x.dim() == 3
if is_seq:
B, S, D = x.shape
x_flat = x.reshape(B * S, D)
else:
x_flat = x
residual = x_flat
h = self.ln(x_flat)
patches = h.reshape(-1, self.n_patches, self.patch_dim)
patches_n = F.normalize(patches, dim=-1)
tri = self.triangulate(patches_n)
pw_out = self.patchwork(tri)
g = self.gate.sigmoid()
out = residual + g * pw_out
if is_seq:
return out.reshape(B, S, D), tri.reshape(B, S, -1)
return out, tri
def forward_no_tri(self, x):
"""Original forward without returning tri β€” for compatibility."""
out, _ = self.forward(x)
return out
# ══════════════════════════════════════════════════════════════════
# CANTOR CONSTELLATION ROUTER β€” hierarchical cross-token, O(S)
# ══════════════════════════════════════════════════════════════════
class CantorConstellationRouter(nn.Module):
"""
Hierarchical cross-token routing through the constellation anchor tree.
The triangulation profile assigns each token to a region on S^(d-1).
A binary tree over anchors defines the routing hierarchy:
Level 0: A groups (per-anchor, local neighbors)
Level 1: A/2 groups (anchor pairs, nearby interaction)
Level 2: A/4 groups (quads, medium range)
...
Level L: 1 group (global summary)
At each level:
1. Soft-assign tokens to groups via triangulation weights
2. Weighted scatter: aggregate token representations per group
3. Transform: per-level MLP on group summaries
4. Weighted gather: distribute transformed summaries back to tokens
5. Gated residual addition
Cost: O(S Γ— L Γ— D) where L = log2(A) + 1 = 5 for A=16.
Memory: O(S Γ— D + A Γ— D) β€” no SΒ² term anywhere.
"""
def __init__(self, dim=256, n_anchors=16, n_patches=16):
super().__init__()
self.dim = dim
self.n_anchors = n_anchors
self.n_patches = n_patches
self.n_levels = int(math.log2(n_anchors)) + 1 # 5 for A=16
# Build anchor hierarchy β€” which anchors merge at each level
# Level l: anchors are grouped into bins of size 2^l
# The ordering is determined at init from anchor geometry
# Per-level transforms: group_dim β†’ dim
self.level_mlps = nn.ModuleList()
self.level_gates = nn.ParameterList()
self.level_lns = nn.ModuleList()
for l in range(self.n_levels):
n_groups = max(1, n_anchors // (2 ** l))
self.level_mlps.append(nn.Sequential(
nn.Linear(dim, dim * 2),
SquaredReLU(),
nn.Linear(dim * 2, dim),
))
self.level_lns.append(nn.LayerNorm(dim))
self.level_gates.append(nn.Parameter(torch.tensor(-3.0)))
# Projection from triangulation distances to routing weights
# Input: per-token distances to each anchor (n_patches Γ— n_anchors)
self.weight_proj = nn.Linear(n_patches * n_anchors, n_anchors)
def compute_routing_weights(self, tri, n_anchors):
"""
Extract soft anchor assignment weights from triangulation profile.
tri: (BS, tri_dim) β€” full triangulation (n_patches Γ— n_anchors Γ— n_phases)
Returns: (BS, n_anchors) β€” soft assignment weights (sum to 1)
"""
BS = tri.shape[0]
# Extract phase-0 distances: first n_patches * n_anchors values
# These are 1 - cos(token, anchor) for each patch Γ— anchor
phase0 = tri[:, :self.n_patches * n_anchors]
# Average over patches to get per-anchor proximity
# phase0: (BS, n_patches * n_anchors) β†’ reshape β†’ mean over patches
dists = phase0.reshape(BS, self.n_patches, n_anchors).mean(dim=1) # (BS, A)
# Convert distances to weights: closer = higher weight
# dists are in [0, 2] (1 - cos), so proximity = 2 - dists
proximity = (2.0 - dists).clamp(min=0)
weights = F.softmax(proximity * 5.0, dim=-1) # temperature-scaled
return weights
def forward(self, x, tri):
"""
x: (B, S, D) token representations
tri: (B, S, tri_dim) triangulation profiles from constellation
Returns: (B, S, D) with cross-token information routed through anchor hierarchy
"""
B, S, D = x.shape
x_flat = x.reshape(B * S, D)
tri_flat = tri.reshape(B * S, -1)
# Compute soft routing weights: (BS, A)
weights = self.compute_routing_weights(tri_flat, self.n_anchors)
h = x_flat # working copy
for level in range(self.n_levels):
group_size = 2 ** level
n_groups = max(1, self.n_anchors // group_size)
# Merge anchor weights into group weights
# Reshape weights (BS, A) β†’ (BS, n_groups, group_size) β†’ sum over group
if n_groups > 1:
group_weights = weights.reshape(B * S, n_groups, group_size).sum(dim=-1)
else:
group_weights = weights.sum(dim=-1, keepdim=True) # (BS, 1)
# Normalize group weights
group_weights = group_weights / (group_weights.sum(dim=-1, keepdim=True) + 1e-8)
# Weighted scatter: aggregate tokens into groups
# group_sums[g] = sum_s(group_weights[s, g] * h[s])
# Shape: (BS, n_groups, 1) Γ— (BS, 1, D) summed over BS
# But we need per-batch grouping. Reshape to (B, S, ...) for batched ops.
gw = group_weights.reshape(B, S, n_groups) # (B, S, G)
hh = h.reshape(B, S, D) # (B, S, D)
# Weighted sum: (B, G, S) @ (B, S, D) β†’ (B, G, D)
group_summary = torch.bmm(gw.transpose(1, 2), hh) # (B, G, D)
# Normalize by total weight per group
weight_sums = gw.sum(dim=1).unsqueeze(-1).clamp(min=1e-8) # (B, G, 1)
group_summary = group_summary / weight_sums
# Transform
gs_flat = group_summary.reshape(B * n_groups, D)
gs_flat = self.level_lns[level](gs_flat)
gs_transformed = self.level_mlps[level](gs_flat)
gs_transformed = gs_transformed.reshape(B, n_groups, D)
# Weighted gather: distribute back to tokens
# update[s] = sum_g(group_weights[s, g] * gs_transformed[g])
# (B, S, G) @ (B, G, D) β†’ (B, S, D)
token_update = torch.bmm(gw, gs_transformed).reshape(B * S, D)
# Gated residual
g = self.level_gates[level].sigmoid()
h = h + g * token_update
return h.reshape(B, S, D)
# ══════════════════════════════════════════════════════════════════
# CONSTELLATION-CANTOR RELAY β€” FULL O(S) TRANSFORMER LAYER
# ══════════════════════════════════════════════════════════════════
class ConstellationCantorRelay(nn.Module):
"""
Complete O(S) transformer layer. No attention.
per-token: ConstellationRelay (triangulate β†’ patchwork β†’ gated residual)
cross-token: CantorConstellationRouter (hierarchical scatter/gather through anchors)
The triangulation from the per-token relay is reused as routing keys
for the cross-token path β€” no redundant computation.
"""
def __init__(self, dim=256, patch_dim=16, n_anchors=16, n_phases=3):
super().__init__()
self.relay = ConstellationRelay(
dim=dim, patch_dim=patch_dim, n_anchors=n_anchors, n_phases=n_phases)
self.router = CantorConstellationRouter(
dim=dim, n_anchors=n_anchors, n_patches=dim // patch_dim)
self.gate_relay = nn.Parameter(torch.tensor(-2.0))
self.gate_router = nn.Parameter(torch.tensor(-2.0))
def forward(self, x):
"""x: (B, S, D)"""
B, S, D = x.shape
# Per-token relay β€” returns delta + triangulation
relay_out, tri = self.relay(x) # (B, S, D), (B, S, tri_dim)
relay_delta = relay_out - x
# Cross-token routing using triangulation as routing key
routed = self.router(x, tri) # (B, S, D)
router_delta = routed - x
# Gated combination
gr = self.gate_relay.sigmoid()
gc = self.gate_router.sigmoid()
return x + gr * relay_delta + gc * router_delta
# ══════════════════════════════════════════════════════════════════
# COMPARISON ARCHITECTURES
# ══════════════════════════════════════════════════════════════════
class VanillaAttention(nn.Module):
"""Standard attention layer for comparison. O(SΒ²)."""
def __init__(self, dim=256, n_heads=4):
super().__init__()
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.ln = nn.LayerNorm(dim)
self.qkv = nn.Linear(dim, 3 * dim)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
B, S, D = x.shape
h = self.ln(x)
qkv = self.qkv(h).reshape(B, S, 3, self.n_heads, self.head_dim)
q, k, v = qkv.unbind(2)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v)
return x + self.proj(attn.transpose(1, 2).reshape(B, S, D))
class HybridRelay(nn.Module):
"""Constellation relay + vanilla attention. For comparison."""
def __init__(self, dim=256, n_heads=4):
super().__init__()
self.relay = ConstellationRelay(dim=dim)
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.qkv = nn.Linear(dim, 3 * dim)
self.attn_proj = nn.Linear(dim, dim)
self.attn_ln = nn.LayerNorm(dim)
self.gate_relay = nn.Parameter(torch.tensor(-2.0))
self.gate_attn = nn.Parameter(torch.tensor(-2.0))
def forward(self, x):
B, S, D = x.shape
relay_out = self.relay.forward_no_tri(x)
relay_delta = relay_out - x
h = self.attn_ln(x)
qkv = self.qkv(h).reshape(B, S, 3, self.n_heads, self.head_dim)
q, k, v = qkv.unbind(2)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v)
attn_out = self.attn_proj(attn.transpose(1, 2).reshape(B, S, D))
gr = self.gate_relay.sigmoid()
ga = self.gate_attn.sigmoid()
return x + gr * relay_delta + ga * attn_out
class PureRelayLayer(nn.Module):
"""Relay-only, no cross-token. For comparison."""
def __init__(self, dim=256):
super().__init__()
self.relay = ConstellationRelay(dim=dim)
def forward(self, x):
return self.relay.forward_no_tri(x)
# ══════════════════════════════════════════════════════════════════
# UTILITIES
# ══════════════════════════════════════════════════════════════════
def reset_vram():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
def peak_mb():
return torch.cuda.max_memory_allocated() / 1e6
D = 256
print("=" * 80)
print("CONSTELLATION-CANTOR RELAY β€” O(S) CROSS-TOKEN ROUTING BENCHMARK")
print(f" Device: {torch.cuda.get_device_name()}")
print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f" Dimension: {D}")
print("=" * 80)
# ══════════════════════════════════════════════════════════════════
# TEST 1: THROUGHPUT β€” ALL FOUR ARCHITECTURES
# ══════════════════════════════════════════════════════════════════
print(f"\n{'━'*80}")
print("TEST 1: Throughput Scaling β€” 4 architectures, S=64 to 131K")
print(" Single layer, B=1, fp16")
print(f"{'━'*80}")
SEQ_LENGTHS = [64, 256, 1024, 4096, 16384, 32768, 65536, 131072]
print(f"\n {'S':>8} {'relay':>9} {'cantor':>9} {'hybrid':>9} {'attn':>9} "
f"{'c/r':>6} {'c/a':>6} {'c_MB':>7}")
for S in SEQ_LENGTHS:
results = {}
for name, make_layer in [
("relay", lambda: PureRelayLayer(D)),
("cantor", lambda: ConstellationCantorRelay(D)),
("hybrid", lambda: HybridRelay(D)),
("attn", lambda: VanillaAttention(D)),
]:
try:
reset_vram()
layer = make_layer().to(DEVICE).half()
x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1)
# Warmup
with torch.no_grad():
for _ in range(3):
_ = layer(x)
torch.cuda.synchronize()
t0 = time.perf_counter()
with torch.no_grad():
for _ in range(10):
_ = layer(x)
torch.cuda.synchronize()
ms = (time.perf_counter() - t0) / 10 * 1000
mb = peak_mb()
results[name] = (ms, mb)
del layer, x
reset_vram()
except (torch.cuda.OutOfMemoryError, RuntimeError):
results[name] = (float('inf'), float('inf'))
reset_vram()
r = results.get("relay", (0, 0))[0]
c = results.get("cantor", (0, 0))[0]
h = results.get("hybrid", (0, 0))[0]
a = results.get("attn", (0, 0))[0]
c_mb = results.get("cantor", (0, 0))[1]
def fmt(v):
return f"{v:>8.2f}ms" if v < float('inf') else " OOM"
cr_ratio = f"{c/r:>5.1f}Γ—" if r > 0 and c < float('inf') else " -"
ca_ratio = f"{c/a:>5.1f}Γ—" if a > 0 and a < float('inf') and c < float('inf') else " -"
print(f" {S:>8} {fmt(r)} {fmt(c)} {fmt(h)} {fmt(a)} "
f"{cr_ratio} {ca_ratio} {c_mb:>7.0f}")
# ══════════════════════════════════════════════════════════════════
# TEST 2: CROSS-TOKEN CAUSAL INTERVENTION β€” CANTOR vs HYBRID
# ══════════════════════════════════════════════════════════════════
print(f"\n{'━'*80}")
print("TEST 2: Cross-Token Causal Intervention")
print(" Modify token 0, measure effect on token S//2")
print(" 4 layers deep. Compare: cantor relay vs hybrid vs pure relay")
print(f"{'━'*80}")
N_LAYERS = 4
print(f"\n {'S':>8} {'arch':>10} {'Ξ”_mid':>10} {'Ξ”_last':>10} "
f"{'cos_orig':>10} {'time_ms':>10}")
for S in [64, 256, 1024, 4096, 16384]:
for arch_name, make_stack in [
("cantor", lambda: nn.ModuleList([ConstellationCantorRelay(D) for _ in range(N_LAYERS)])),
("hybrid", lambda: nn.ModuleList([HybridRelay(D) for _ in range(N_LAYERS)])),
("relay", lambda: nn.ModuleList([PureRelayLayer(D) for _ in range(N_LAYERS)])),
]:
try:
reset_vram()
torch.manual_seed(42)
stack = make_stack().to(DEVICE).half()
x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1)
x_mod = x.clone()
x_mod[:, 0] = F.normalize(torch.randn(1, D, device=DEVICE, dtype=torch.float16), dim=-1)
torch.cuda.synchronize()
t0 = time.perf_counter()
with torch.no_grad():
h = x.clone()
h_mod = x_mod.clone()
for layer in stack:
h = layer(h)
h_mod = layer(h_mod)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - t0) * 1000
mid = S // 2
delta_mid = (h[0, mid].float() - h_mod[0, mid].float()).norm().item()
delta_last = (h[0, -1].float() - h_mod[0, -1].float()).norm().item()
cos_orig = F.cosine_similarity(
x[0, mid:mid+1].float(), h[0, mid:mid+1].float()).item()
print(f" {S:>8} {arch_name:>10} {delta_mid:>10.4f} {delta_last:>10.4f} "
f"{cos_orig:>10.4f} {elapsed:>10.1f}")
del stack, x, x_mod, h, h_mod
reset_vram()
except (torch.cuda.OutOfMemoryError, RuntimeError):
print(f" {S:>8} {arch_name:>10} OOM")
reset_vram()
print()
# ══════════════════════════════════════════════════════════════════
# TEST 3: GEOMETRIC PRESERVATION WITH CROSS-TOKEN ROUTING
# ══════════════════════════════════════════════════════════════════
print(f"\n{'━'*80}")
print("TEST 3: Geometric Preservation β€” does Cantor routing hurt geometry?")
print(" 8 layers, S=4096. Compare cos_to_orig, CV, eff_dim.")
print(f"{'━'*80}")
def compute_cv(points, n_samples=500):
N = points.shape[0]
if N < 5: return float('nan')
points = F.normalize(points.float(), dim=-1)
vols = []
for _ in range(n_samples):
idx = torch.randperm(min(N, 2000), device=points.device)[:5]
pts = points[idx].unsqueeze(0)
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
d2 = F.relu(d2)
cm = torch.zeros(1, 6, 6, device=points.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
v2 = -torch.linalg.det(cm) / 9216
if v2[0].item() > 1e-20:
vols.append(v2[0].sqrt().cpu())
if len(vols) < 50: return float('nan')
vt = torch.stack(vols)
return (vt.std() / (vt.mean() + 1e-8)).item()
GEO_DEPTH = 8
GEO_S = 4096
print(f"\n {'arch':>10} {'cos_orig':>10} {'norm':>8} {'CV':>8} "
f"{'eff_dim':>8} {'self_sim':>10}")
for arch_name, make_stack in [
("relay", lambda: nn.ModuleList([PureRelayLayer(D) for _ in range(GEO_DEPTH)])),
("cantor", lambda: nn.ModuleList([ConstellationCantorRelay(D) for _ in range(GEO_DEPTH)])),
("hybrid", lambda: nn.ModuleList([HybridRelay(D) for _ in range(GEO_DEPTH)])),
("attn", lambda: nn.ModuleList([VanillaAttention(D) for _ in range(GEO_DEPTH)])),
]:
try:
reset_vram()
torch.manual_seed(42)
stack = make_stack().to(DEVICE).half()
x = F.normalize(torch.randn(1, GEO_S, D, device=DEVICE, dtype=torch.float16), dim=-1)
with torch.no_grad():
h = x.clone()
for layer in stack:
h = layer(h)
x_s = x[0, :512].float()
h_s = h[0, :512].float()
cos = F.cosine_similarity(x_s, h_s).mean().item()
norm = h_s.norm(dim=-1).mean().item()
h_n = F.normalize(h_s, dim=-1)
sim = h_n @ h_n.T
mask = ~torch.eye(512, device=DEVICE, dtype=torch.bool)
self_sim = sim[mask].mean().item()
cv = compute_cv(h_n, 500)
_, S_vals, _ = torch.linalg.svd(h_n[:256], full_matrices=False)
p = S_vals / S_vals.sum()
ed = p.pow(2).sum().reciprocal().item()
print(f" {arch_name:>10} {cos:>10.4f} {norm:>8.4f} {cv:>8.4f} "
f"{ed:>8.1f} {self_sim:>10.6f}")
del stack, x, h
reset_vram()
except (torch.cuda.OutOfMemoryError, RuntimeError):
print(f" {arch_name:>10} OOM")
reset_vram()
# ══════════════════════════════════════════════════════════════════
# TEST 4: TRAINED CROSS-TOKEN TASK β€” ALL ARCHITECTURES
# ══════════════════════════════════════════════════════════════════
print(f"\n{'━'*80}")
print("TEST 4: Trained Cross-Token Task")
print(" Label = (token_0_class + token_1_class) % 10")
print(" Pure relay CANNOT solve this (zero cross-token info).")
print(" 4 layers, 500 steps, S=8.")
print(f"{'━'*80}")
S_TASK = 8
N_CLS = 10
N_SAMPLES = 4096
STEPS = 500
torch.manual_seed(777)
keys_a = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1)
keys_b = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1)
task_x = F.normalize(torch.randn(N_SAMPLES, S_TASK, D, device=DEVICE), dim=-1).clone()
label_a = torch.randint(0, N_CLS, (N_SAMPLES,), dtype=torch.long, device=DEVICE)
label_b = torch.randint(0, N_CLS, (N_SAMPLES,), dtype=torch.long, device=DEVICE)
task_x[:, 0] = keys_a[label_a] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2
task_x[:, 1] = keys_b[label_b] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2
task_x = F.normalize(task_x, dim=-1)
task_y = ((label_a + label_b) % N_CLS).long()
print(f"\n {'arch':>10} {'acc':>8} {'loss':>8} {'cross_Ξ”':>10} {'params':>10}")
for arch_name, make_stack in [
("relay", lambda: nn.ModuleList([PureRelayLayer(D) for _ in range(4)])),
("cantor", lambda: nn.ModuleList([ConstellationCantorRelay(D) for _ in range(4)])),
("hybrid", lambda: nn.ModuleList([HybridRelay(D) for _ in range(4)])),
("attn", lambda: nn.ModuleList([VanillaAttention(D) for _ in range(4)])),
]:
torch.manual_seed(42)
class TaskModel(nn.Module):
def __init__(self, stack):
super().__init__()
self.layers = stack
self.pool = nn.Linear(D * S_TASK, D)
self.head = nn.Linear(D, N_CLS)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return self.head(F.gelu(self.pool(x.reshape(x.shape[0], -1))))
model = TaskModel(make_stack()).to(DEVICE)
n_params = sum(p.numel() for p in model.parameters())
opt = torch.optim.Adam(model.parameters(), lr=3e-4)
for step in range(STEPS):
idx = torch.randint(0, N_SAMPLES, (128,))
logits = model(task_x[idx])
loss = F.cross_entropy(logits, task_y[idx])
if torch.isnan(loss) or torch.isinf(loss):
break
opt.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
model.eval()
with torch.no_grad():
logits = model(task_x[:1024])
acc = (logits.argmax(-1) == task_y[:1024]).float().mean().item()
final_loss = F.cross_entropy(logits, task_y[:1024]).item()
# Cross-token intervention
h1 = task_x[:64].clone()
for layer in model.layers:
h1 = layer(h1)
h2 = task_x[:64].clone()
h2[:, 0] = F.normalize(torch.randn(64, D, device=DEVICE), dim=-1)
for layer in model.layers:
h2 = layer(h2)
cross_delta = (h1[:, 1] - h2[:, 1]).norm(dim=-1).mean().item()
print(f" {arch_name:>10} {acc:>8.1%} {final_loss:>8.4f} {cross_delta:>10.4f} {n_params:>10,}")
del model
reset_vram()
# ══════════════════════════════════════════════════════════════════
# TEST 5: THE O(SΒ²) WALL β€” CANTOR vs ATTENTION at depth 8
# ══════════════════════════════════════════════════════════════════
print(f"\n{'━'*80}")
print("TEST 5: The O(SΒ²) Wall β€” Cantor vs Attention, 8 layers deep")
print(f"{'━'*80}")
WALL_DEPTH = 8
print(f"\n {'S':>8} {'cantor_ms':>10} {'attn_ms':>10} {'speedup':>8} "
f"{'c_cos':>8} {'a_cos':>8} {'c_MB':>8} {'a_MB':>8}")
for S in [1024, 4096, 8192, 16384, 32768, 65536, 131072]:
c_result = None
a_result = None
# Cantor
try:
reset_vram()
torch.manual_seed(42)
c_stack = nn.ModuleList([
ConstellationCantorRelay(D) for _ in range(WALL_DEPTH)
]).to(DEVICE).half()
x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1)
with torch.no_grad():
h = x.clone()
for layer in c_stack:
h = layer(h)
torch.cuda.synchronize()
t0 = time.perf_counter()
with torch.no_grad():
h = x.clone()
for layer in c_stack:
h = layer(h)
torch.cuda.synchronize()
c_ms = (time.perf_counter() - t0) * 1000
c_mb = peak_mb()
c_cos = F.cosine_similarity(x[0, :256].float(), h[0, :256].float()).mean().item()
c_result = (c_ms, c_cos, c_mb)
del x, h, c_stack
reset_vram()
except (torch.cuda.OutOfMemoryError, RuntimeError):
reset_vram()
# Attention
try:
reset_vram()
torch.manual_seed(42)
a_stack = nn.ModuleList([
VanillaAttention(D) for _ in range(WALL_DEPTH)
]).to(DEVICE).half()
x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1)
with torch.no_grad():
h = x.clone()
for layer in a_stack:
h = layer(h)
torch.cuda.synchronize()
t0 = time.perf_counter()
with torch.no_grad():
h = x.clone()
for layer in a_stack:
h = layer(h)
torch.cuda.synchronize()
a_ms = (time.perf_counter() - t0) * 1000
a_mb = peak_mb()
a_cos = F.cosine_similarity(x[0, :256].float(), h[0, :256].float()).mean().item()
a_result = (a_ms, a_cos, a_mb)
del x, h, a_stack
reset_vram()
except (torch.cuda.OutOfMemoryError, RuntimeError):
reset_vram()
c_str = f"{c_result[0]:>9.1f}ms" if c_result else " OOM"
a_str = f"{a_result[0]:>9.1f}ms" if a_result else " OOM"
sp = f"{a_result[0]/c_result[0]:>7.1f}Γ—" if c_result and a_result else " -"
cc = f"{c_result[1]:>8.4f}" if c_result else " ---"
ac = f"{a_result[1]:>8.4f}" if a_result else " ---"
cm = f"{c_result[2]:>8.0f}" if c_result else " OOM"
am = f"{a_result[2]:>8.0f}" if a_result else " OOM"
print(f" {S:>8} {c_str} {a_str} {sp} {cc} {ac} {cm} {am}")
if c_result is None:
print(f" β†’ Cantor OOM at S={S}, stopping")
break
# ══════════════════════════════════════════════════════════════════
# SUMMARY
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*80}")
print("CONSTELLATION-CANTOR RELAY β€” BENCHMARK COMPLETE")
print(f"{'='*80}")
print(f"""
Architecture:
per-token: constellation relay (triangulate β†’ patchwork β†’ gated residual)
cross-token: cantor router (hierarchical scatter/gather through anchor tree)
total: O(S) time, O(S) memory, no attention
5 tests:
T1: Throughput β€” relay vs cantor vs hybrid vs attention, S to 131K
T2: Cross-token causal intervention β€” who routes strongest?
T3: Geometric preservation β€” does cross-token routing hurt geometry?
T4: Trained cross-token task β€” accuracy on interaction-dependent labels
T5: O(SΒ²) wall β€” cantor vs attention at 8 layers to OOM
Key questions answered:
β€’ Is the cantor router faster than attention at all sequence lengths?
β€’ Does it provide meaningful cross-token interaction?
β€’ Does the routing hurt per-token geometric preservation?
β€’ Can it solve tasks that require cross-token information?
""")