geolip-constellation-activations / constellation_relays_activation_effects_analysis.py
AbstractPhil's picture
a couple slow ones
84eb1a3 verified
#!/usr/bin/env python3
"""
Activation Effects on Constellation Relays
============================================
Systematic test of:
1. Activation function effects on geometric preservation through relay stacks
2. Activation placement: before triangulation, in patchwork, after patchwork
3. Hybrid relay information retention with different activations
4. Activation effects on anchor drift and CV stability
5. Cross-token routing strength under different activations
Each test uses the same random seed and input for fair comparison.
"""
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import time
from collections import defaultdict, OrderedDict
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# ══════════════════════════════════════════════════════════════════
# ACTIVATION REGISTRY
# ══════════════════════════════════════════════════════════════════
class Mish(nn.Module):
def forward(self, x): return x * torch.tanh(F.softplus(x))
class SquaredReLU(nn.Module):
def forward(self, x): return F.relu(x) ** 2
class StarReLU(nn.Module):
"""From MetaFormer β€” scaled squared ReLU with learnable params."""
def __init__(self):
super().__init__()
self.scale = nn.Parameter(torch.tensor(0.8944))
self.bias = nn.Parameter(torch.tensor(-0.4472))
def forward(self, x): return self.scale * F.relu(x) ** 2 + self.bias
class SoftSign(nn.Module):
def forward(self, x): return x / (1 + x.abs())
class Identity(nn.Module):
def forward(self, x): return x
ACTIVATIONS = OrderedDict([
("none", lambda: Identity()),
("relu", lambda: nn.ReLU()),
("gelu", lambda: nn.GELU()),
("silu", lambda: nn.SiLU()),
("mish", lambda: Mish()),
("tanh", lambda: nn.Tanh()),
("softplus", lambda: nn.Softplus()),
("softsign", lambda: SoftSign()),
("squared_relu", lambda: SquaredReLU()),
("star_relu", lambda: StarReLU()),
("leaky_relu", lambda: nn.LeakyReLU(0.1)),
("elu", lambda: nn.ELU()),
("prelu", lambda: nn.PReLU()),
])
# ══════════════════════════════════════════════════════════════════
# CONSTELLATION RELAY β€” CONFIGURABLE ACTIVATIONS
# ══════════════════════════════════════════════════════════════════
class ConstellationRelay(nn.Module):
"""
Relay with configurable activation at three positions:
- pre_act: applied to embedding BEFORE L2 norm + triangulation
- pw_act: activation inside patchwork MLP
- post_act: applied to patchwork output BEFORE gated residual
"""
def __init__(
self,
dim=256,
patch_dim=16,
n_anchors=16,
n_phases=3,
pre_act="none",
pw_act="gelu",
post_act="none",
pw_hidden_mult=2,
):
super().__init__()
self.dim = dim
self.patch_dim = patch_dim
self.n_patches = dim // patch_dim
self.n_anchors = n_anchors
self.n_phases = n_phases
P, A, d = self.n_patches, n_anchors, patch_dim
self.ln = nn.LayerNorm(dim)
# Activations at each position
self.pre_act = ACTIVATIONS[pre_act]()
self.post_act = ACTIVATIONS[post_act]()
# Constellation anchors
home = torch.empty(P, A, d)
nn.init.xavier_normal_(home.view(P * A, d))
home = F.normalize(home.view(P, A, d), dim=-1)
self.register_buffer('home', home)
self.anchors = nn.Parameter(home.clone())
# Patchwork
tri_dim = P * A * n_phases
pw_hidden = int(tri_dim * pw_hidden_mult)
pw_act_fn = ACTIVATIONS[pw_act]()
self.patchwork = nn.Sequential(
nn.Linear(tri_dim, pw_hidden),
pw_act_fn,
nn.LayerNorm(pw_hidden),
nn.Linear(pw_hidden, dim),
)
# Cold-init gate
self.gate = nn.Parameter(torch.tensor(-3.0))
def drift(self):
h = F.normalize(self.home, dim=-1)
c = F.normalize(self.anchors, dim=-1)
return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7))
def at_phase(self, t):
h = F.normalize(self.home, dim=-1)
c = F.normalize(self.anchors, dim=-1)
omega = self.drift().unsqueeze(-1)
so = omega.sin().clamp(min=1e-7)
return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c
def triangulate(self, patches_n):
phases = torch.linspace(0, 1, self.n_phases, device=patches_n.device).tolist()
tris = []
for t in phases:
at = F.normalize(self.at_phase(t), dim=-1)
tris.append(1.0 - torch.einsum('bpd,pad->bpa', patches_n, at))
return torch.cat(tris, dim=-1).reshape(patches_n.shape[0], -1)
def forward(self, x):
B = x.shape[0]
residual = x
h = self.ln(x)
# Pre-activation (before sphere projection)
h = self.pre_act(h)
# Project to sphere
patches = h.reshape(B, self.n_patches, self.patch_dim)
patches_n = F.normalize(patches, dim=-1)
# Triangulate
tri = self.triangulate(patches_n)
# Patchwork
pw_out = self.patchwork(tri)
# Post-activation (after patchwork, before residual)
pw_out = self.post_act(pw_out)
# Gated residual
g = self.gate.sigmoid()
return residual + g * pw_out
# ══════════════════════════════════════════════════════════════════
# HYBRID RELAY β€” CONFIGURABLE ACTIVATIONS
# ══════════════════════════════════════════════════════════════════
class HybridRelay(nn.Module):
"""
Cross-token constellation relay + lightweight attention.
Configurable activations at relay and attention paths.
"""
def __init__(
self,
dim=256,
n_heads=4,
patch_dim=16,
n_anchors=16,
n_phases=3,
relay_pw_act="gelu",
attn_act="none", # activation on attention output
):
super().__init__()
self.dim = dim
# Relay path
self.relay = ConstellationRelay(
dim=dim, patch_dim=patch_dim, n_anchors=n_anchors,
n_phases=n_phases, pre_act="none", pw_act=relay_pw_act,
post_act="none")
# Attention path
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.qkv = nn.Linear(dim, 3 * dim)
self.attn_proj = nn.Linear(dim, dim)
self.attn_ln = nn.LayerNorm(dim)
self.attn_act = ACTIVATIONS[attn_act]()
# Split gates
self.gate_relay = nn.Parameter(torch.tensor(-2.0))
self.gate_attn = nn.Parameter(torch.tensor(-2.0))
def forward(self, x):
B, S, D = x.shape
residual = x
# Relay path β€” per-token
relay_out = torch.zeros_like(x)
for s in range(S):
relay_out[:, s] = self.relay(x[:, s]) - x[:, s] # delta only
# Attention path
h = self.attn_ln(x)
qkv = self.qkv(h).reshape(B, S, 3, self.n_heads, self.head_dim)
q, k, v = qkv.unbind(2)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn = F.scaled_dot_product_attention(q, k, v)
attn = attn.transpose(1, 2).reshape(B, S, D)
attn_out = self.attn_act(self.attn_proj(attn))
gr = self.gate_relay.sigmoid()
ga = self.gate_attn.sigmoid()
return residual + gr * relay_out + ga * attn_out
# ══════════════════════════════════════════════════════════════════
# MEASUREMENT UTILITIES
# ══════════════════════════════════════════════════════════════════
def compute_cv(points, n_samples=1000, n_points=5):
N = points.shape[0]
if N < n_points: return float('nan')
points = F.normalize(points.float(), dim=-1)
vols = []
for _ in range(n_samples):
idx = torch.randperm(min(N, 2000), device=points.device)[:n_points]
pts = points[idx].unsqueeze(0)
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
d2 = F.relu(d2)
cm = torch.zeros(1, 6, 6, device=points.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
v2 = -torch.linalg.det(cm) / 9216
if v2[0].item() > 1e-20:
vols.append(v2[0].sqrt().cpu())
if len(vols) < 50: return float('nan')
vt = torch.stack(vols)
return (vt.std() / (vt.mean() + 1e-8)).item()
def eff_dim(x):
x_c = x - x.mean(0, keepdim=True)
n = min(256, x.shape[0])
_, S, _ = torch.linalg.svd(x_c[:n].float(), full_matrices=False)
p = S / S.sum()
return p.pow(2).sum().reciprocal().item()
def measure_geometry(x):
"""Comprehensive geometric measurement of a batch of vectors."""
x = x.detach().float()
x_n = F.normalize(x, dim=-1)
return {
'norm_mean': x.norm(dim=-1).mean().item(),
'norm_std': x.norm(dim=-1).std().item(),
'cv': compute_cv(x_n.to(DEVICE), 500),
'eff_dim': eff_dim(x_n),
'self_sim': (x_n @ x_n.T).fill_diagonal_(0).mean().item(),
}
# ══════════════════════════════════════════════════════════════════
# TEST 1: ACTIVATION EFFECTS ON GEOMETRIC PRESERVATION AT DEPTH
# ══════════════════════════════════════════════════════════════════
print("=" * 80)
print("ACTIVATION EFFECTS ON CONSTELLATION RELAYS")
print(f" Device: {DEVICE}")
print("=" * 80)
D = 256
DEPTH = 16
B = 512
# Fixed input
torch.manual_seed(42)
x_input = F.normalize(torch.randn(B, D, device=DEVICE), dim=-1)
print(f"\n{'━'*80}")
print(f"TEST 1: Patchwork Activation β€” Geometric Preservation at Depth {DEPTH}")
print(f" Testing: which activation in the patchwork MLP best preserves geometry?")
print(f" Setup: {DEPTH} stacked relay layers, pre_act=none, post_act=none")
print(f" Input: {B} unit vectors in {D}d")
print(f"{'━'*80}")
print(f"\n {'activation':>14} {'cos_orig':>10} {'norm':>8} {'CV':>8} "
f"{'eff_dim':>8} {'self_sim':>10} {'gate':>8} {'drift':>8}")
for act_name in ACTIVATIONS:
torch.manual_seed(42)
layers = nn.ModuleList([
ConstellationRelay(dim=D, pw_act=act_name)
for _ in range(DEPTH)
]).to(DEVICE)
with torch.no_grad():
h = x_input.clone()
for layer in layers:
h = layer(h)
cos = F.cosine_similarity(x_input, h).mean().item()
geo = measure_geometry(h.cpu())
gate = layers[0].gate.sigmoid().item()
drift = layers[-1].drift().mean().item()
print(f" {act_name:>14} {cos:>10.4f} {geo['norm_mean']:>8.4f} {geo['cv']:>8.4f} "
f"{geo['eff_dim']:>8.1f} {geo['self_sim']:>10.6f} {gate:>8.4f} {drift:>8.4f}")
# ══════════════════════════════════════════════════════════════════
# TEST 2: PRE-ACTIVATION β€” BEFORE THE SPHERE
# ══════════════════════════════════════════════════════════════════
print(f"\n{'━'*80}")
print(f"TEST 2: Pre-Activation β€” what happens before L2 normalization?")
print(f" Setup: activation applied AFTER LayerNorm, BEFORE sphere projection")
print(f" This tests whether activations distort the pre-sphere distribution")
print(f"{'━'*80}")
print(f"\n {'pre_act':>14} {'cos_orig':>10} {'norm':>8} {'CV':>8} "
f"{'eff_dim':>8} {'self_sim':>10}")
for act_name in ACTIVATIONS:
torch.manual_seed(42)
layers = nn.ModuleList([
ConstellationRelay(dim=D, pre_act=act_name, pw_act="gelu")
for _ in range(DEPTH)
]).to(DEVICE)
with torch.no_grad():
h = x_input.clone()
for layer in layers:
h = layer(h)
cos = F.cosine_similarity(x_input, h).mean().item()
geo = measure_geometry(h.cpu())
print(f" {act_name:>14} {cos:>10.4f} {geo['norm_mean']:>8.4f} {geo['cv']:>8.4f} "
f"{geo['eff_dim']:>8.1f} {geo['self_sim']:>10.6f}")
# ══════════════════════════════════════════════════════════════════
# TEST 3: POST-ACTIVATION β€” AFTER THE PATCHWORK
# ══════════════════════════════════════════════════════════════════
print(f"\n{'━'*80}")
print(f"TEST 3: Post-Activation β€” applied to patchwork output before residual")
print(f" This tests whether activations on the relay's contribution help or hurt")
print(f"{'━'*80}")
print(f"\n {'post_act':>14} {'cos_orig':>10} {'norm':>8} {'CV':>8} "
f"{'eff_dim':>8} {'self_sim':>10}")
for act_name in ACTIVATIONS:
torch.manual_seed(42)
layers = nn.ModuleList([
ConstellationRelay(dim=D, pw_act="gelu", post_act=act_name)
for _ in range(DEPTH)
]).to(DEVICE)
with torch.no_grad():
h = x_input.clone()
for layer in layers:
h = layer(h)
cos = F.cosine_similarity(x_input, h).mean().item()
geo = measure_geometry(h.cpu())
print(f" {act_name:>14} {cos:>10.4f} {geo['norm_mean']:>8.4f} {geo['cv']:>8.4f} "
f"{geo['eff_dim']:>8.1f} {geo['self_sim']:>10.6f}")
# ══════════════════════════════════════════════════════════════════
# TEST 4: DEPTH PROFILE β€” HOW FAST DOES EACH ACTIVATION DEGRADE?
# ══════════════════════════════════════════════════════════════════
print(f"\n{'━'*80}")
print(f"TEST 4: Depth Profile β€” cos_to_orig at each depth for key activations")
print(f"{'━'*80}")
key_acts = ["none", "relu", "gelu", "silu", "tanh", "squared_relu", "star_relu"]
depths_to_check = [1, 2, 4, 8, 12, 16, 24, 32]
max_depth = max(depths_to_check)
print(f"\n {'depth':>6}", end="")
for act in key_acts:
print(f" {act:>12}", end="")
print()
results_by_act = {}
for act_name in key_acts:
torch.manual_seed(42)
layers = nn.ModuleList([
ConstellationRelay(dim=D, pw_act=act_name)
for _ in range(max_depth)
]).to(DEVICE)
cos_at_depth = []
with torch.no_grad():
h = x_input.clone()
for d in range(max_depth):
h = layers[d](h)
if (d + 1) in depths_to_check:
cos_at_depth.append(
(d + 1, F.cosine_similarity(x_input, h).mean().item()))
results_by_act[act_name] = cos_at_depth
for i, depth in enumerate(depths_to_check):
print(f" {depth:>6}", end="")
for act in key_acts:
val = results_by_act[act][i][1]
print(f" {val:>12.4f}", end="")
print()
# ══════════════════════════════════════════════════════════════════
# TEST 5: TRAINED RELAY β€” ACTIVATION EFFECT ON LEARNING
# ══════════════════════════════════════════════════════════════════
# Cleanup from Tests 1-4
torch.cuda.empty_cache()
print(f"\n{'━'*80}")
print(f"TEST 5: Trained Relay β€” does activation choice affect what the relay LEARNS?")
print(f" Setup: 4-layer relay trained to classify 256d embeddings into 10 classes")
print(f" 500 steps SGD, measure final accuracy + geometric health")
print(f"{'━'*80}")
N_TRAIN = 2048
N_CLASSES = 10
TRAIN_STEPS = 500
RELAY_DEPTH = 4
# Generate clustered data on the sphere (10 classes)
torch.manual_seed(123)
class_centers = F.normalize(torch.randn(N_CLASSES, D, device=DEVICE), dim=-1)
train_x = []
train_y = []
for c in range(N_CLASSES):
noise = torch.randn(N_TRAIN // N_CLASSES, D, device=DEVICE) * 0.3
pts = F.normalize(class_centers[c].unsqueeze(0) + noise, dim=-1)
train_x.append(pts)
train_y.append(torch.full((N_TRAIN // N_CLASSES,), c, dtype=torch.long, device=DEVICE))
train_x = torch.cat(train_x)
train_y = torch.cat(train_y)
assert train_y.max() < N_CLASSES, f"Label OOB: max={train_y.max()}, n_classes={N_CLASSES}"
assert train_y.min() >= 0, f"Negative label: min={train_y.min()}"
torch.cuda.synchronize()
print(f"\n {'pw_act':>14} {'acc':>8} {'loss':>8} {'cos_orig':>10} "
f"{'CV':>8} {'eff_dim':>8} {'drift':>8} {'gate':>8}")
for act_name in ["none", "relu", "gelu", "silu", "tanh", "squared_relu", "star_relu"]:
torch.manual_seed(42)
class RelayClassifier(nn.Module):
def __init__(self):
super().__init__()
self.relays = nn.ModuleList([
ConstellationRelay(dim=D, pw_act=act_name)
for _ in range(RELAY_DEPTH)])
self.head = nn.Linear(D, N_CLASSES)
def forward(self, x):
for r in self.relays:
x = r(x)
return self.head(x)
model = RelayClassifier().to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
for step in range(TRAIN_STEPS):
idx = torch.randint(0, len(train_x), (128,))
logits = model(train_x[idx])
loss = F.cross_entropy(logits, train_y[idx])
if torch.isnan(loss) or torch.isinf(loss):
print(f" ⚠ Bad loss at step {step}, act={act_name}")
break
opt.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
# Evaluate
model.eval()
with torch.no_grad():
logits = model(train_x)
acc = (logits.argmax(-1) == train_y).float().mean().item()
final_loss = F.cross_entropy(logits, train_y).item()
# Pass through relays only (no head) to measure geometry
h = train_x.clone()
for r in model.relays:
h = r(h)
cos = F.cosine_similarity(train_x, h).mean().item()
cv = compute_cv(F.normalize(h, dim=-1), 500)
ed = eff_dim(F.normalize(h, dim=-1))
drift = model.relays[-1].drift().mean().item()
gate = model.relays[-1].gate.sigmoid().item()
print(f" {act_name:>14} {acc:>8.1%} {final_loss:>8.4f} {cos:>10.4f} "
f"{cv:>8.4f} {ed:>8.1f} {drift:>8.4f} {gate:>8.4f}")
# ══════════════════════════════════════════════════════════════════
# TEST 6: HYBRID RELAY β€” INFORMATION RETENTION
# ══════════════════════════════════════════════════════════════════
torch.cuda.empty_cache()
print(f"\n{'━'*80}")
print(f"TEST 6: Hybrid Relay β€” Information Retention")
print(f" Setup: 8 layers of hybrid relay (attention + constellation)")
print(f" Sequence length 32, measure cos_to_orig, cross-token Ξ”, gate split")
print(f"{'━'*80}")
S = 32
HYBRID_DEPTH = 8
torch.manual_seed(42)
x_seq = F.normalize(torch.randn(64, S, D, device=DEVICE), dim=-1)
# Also prepare a causal intervention: modify token 0
x_mod = x_seq.clone()
x_mod[:, 0] = F.normalize(torch.randn(64, D, device=DEVICE), dim=-1)
print(f"\n {'relay_act':>14} {'cos_orig':>10} {'norm':>8} {'CV':>8} "
f"{'cross_Ξ”':>10} {'g_relay':>8} {'g_attn':>8}")
for act_name in ["none", "relu", "gelu", "silu", "tanh", "squared_relu"]:
torch.manual_seed(42)
layers = nn.ModuleList([
HybridRelay(dim=D, relay_pw_act=act_name)
for _ in range(HYBRID_DEPTH)
]).to(DEVICE)
with torch.no_grad():
# Forward original
h = x_seq.clone()
for layer in layers:
h = layer(h)
# Forward modified (token 0 changed)
h_mod = x_mod.clone()
for layer in layers:
h_mod = layer(h_mod)
# Geometry of token 1 (unmodified input)
cos = F.cosine_similarity(
x_seq[:, 1], h[:, 1]).mean().item()
geo = measure_geometry(h[:, 1].cpu())
# Cross-token effect: how much did modifying token 0 change token 1?
cross_delta = (h[:, 1] - h_mod[:, 1]).norm(dim=-1).mean().item()
gr = layers[0].gate_relay.sigmoid().item()
ga = layers[0].gate_attn.sigmoid().item()
print(f" {act_name:>14} {cos:>10.4f} {geo['norm_mean']:>8.4f} {geo['cv']:>8.4f} "
f"{cross_delta:>10.4f} {gr:>8.4f} {ga:>8.4f}")
# ══════════════════════════════════════════════════════════════════
# TEST 7: HYBRID β€” TRAINED CROSS-TOKEN ROUTING
# ══════════════════════════════════════════════════════════════════
print(f"\n{'━'*80}")
print(f"TEST 7: Hybrid Relay β€” Trained Cross-Token Routing")
print(f" Setup: classify sequences where the label depends on token interactions")
print(f" Token 0 = class signal, Token 1 = modifier, class = f(tok0, tok1)")
print(f" If hybrid can't route cross-token, it can't solve this")
print(f"{'━'*80}")
S_TASK = 8
N_CLS = 10
N_SAMPLES = 4096
STEPS = 300
# Generate task: class depends on interaction of first two tokens
torch.manual_seed(777)
keys_a = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1)
keys_b = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1)
task_x = F.normalize(torch.randn(N_SAMPLES, S_TASK, D, device=DEVICE), dim=-1).clone()
label_a = torch.randint(0, N_CLS, (N_SAMPLES,), dtype=torch.long, device=DEVICE)
label_b = torch.randint(0, N_CLS, (N_SAMPLES,), dtype=torch.long, device=DEVICE)
task_x[:, 0] = keys_a[label_a] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2
task_x[:, 1] = keys_b[label_b] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2
task_x = F.normalize(task_x, dim=-1)
task_y = ((label_a + label_b) % N_CLS).long()
assert task_y.max() < N_CLS and task_y.min() >= 0
torch.cuda.synchronize()
print(f"\n {'relay_act':>14} {'acc':>8} {'loss':>8} {'g_relay':>8} "
f"{'g_attn':>8} {'cross_Ξ”':>10}")
for act_name in ["none", "relu", "gelu", "silu", "tanh", "squared_relu"]:
torch.manual_seed(42)
class HybridClassifier(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
HybridRelay(dim=D, relay_pw_act=act_name)
for _ in range(4)])
self.pool = nn.Linear(D * S_TASK, D)
self.head = nn.Linear(D, N_CLS)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return self.head(F.gelu(self.pool(x.reshape(x.shape[0], -1))))
model = HybridClassifier().to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=3e-4)
for step in range(STEPS):
idx = torch.randint(0, N_SAMPLES, (128,))
logits = model(task_x[idx])
loss = F.cross_entropy(logits, task_y[idx])
if torch.isnan(loss) or torch.isinf(loss):
print(f" ⚠ Bad loss at step {step}, act={act_name}")
break
opt.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
model.eval()
with torch.no_grad():
logits = model(task_x[:1024])
acc = (logits.argmax(-1) == task_y[:1024]).float().mean().item()
final_loss = F.cross_entropy(logits, task_y[:1024]).item()
gr = model.layers[-1].gate_relay.sigmoid().item()
ga = model.layers[-1].gate_attn.sigmoid().item()
# Cross-token effect
h1 = task_x[:64].clone()
for layer in model.layers:
h1 = layer(h1)
h2 = task_x[:64].clone()
h2[:, 0] = F.normalize(torch.randn(64, D, device=DEVICE), dim=-1)
for layer in model.layers:
h2 = layer(h2)
cross_delta = (h1[:, 1] - h2[:, 1]).norm(dim=-1).mean().item()
print(f" {act_name:>14} {acc:>8.1%} {final_loss:>8.4f} {gr:>8.4f} "
f"{ga:>8.4f} {cross_delta:>10.4f}")
# ══════════════════════════════════════════════════════════════════
# TEST 8: ACTIVATION EFFECT ON DRIFT UNDER TRAINING
# ══════════════════════════════════════════════════════════════════
print(f"\n{'━'*80}")
print(f"TEST 8: Anchor Drift Under Training β€” does activation affect convergence?")
print(f" Same classification task as Test 5, track drift trajectory")
print(f"{'━'*80}")
print(f"\n {'pw_act':>14}", end="")
for step in [50, 100, 200, 300, 500]:
print(f" d@{step:>3}", end="")
print(f" {'final_drift':>12} {'near_0.29':>10}")
for act_name in ["none", "relu", "gelu", "silu", "tanh", "squared_relu", "star_relu"]:
torch.manual_seed(42)
class DriftTracker(nn.Module):
def __init__(self):
super().__init__()
self.relays = nn.ModuleList([
ConstellationRelay(dim=D, pw_act=act_name)
for _ in range(RELAY_DEPTH)])
self.head = nn.Linear(D, N_CLASSES)
def forward(self, x):
for r in self.relays:
x = r(x)
return self.head(x)
model = DriftTracker().to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
drift_log = {}
for step in range(TRAIN_STEPS):
idx = torch.randint(0, len(train_x), (128,))
logits = model(train_x[idx])
loss = F.cross_entropy(logits, train_y[idx])
if torch.isnan(loss) or torch.isinf(loss):
print(f" ⚠ Bad loss at step {step}, act={act_name}")
break
opt.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
if (step + 1) in [50, 100, 200, 300, 500]:
with torch.no_grad():
d = model.relays[-1].drift().mean().item()
drift_log[step + 1] = d
# Final stats
with torch.no_grad():
all_drift = torch.cat([r.drift().flatten() for r in model.relays])
near_029 = (all_drift - 0.29154).abs().lt(0.05).float().mean().item()
print(f" {act_name:>14}", end="")
for step in [50, 100, 200, 300, 500]:
print(f" {drift_log[step]:>5.3f}", end="")
print(f" {all_drift.mean().item():>12.6f} {near_029:>10.1%}")
# ══════════════════════════════════════════════════════════════════
# TEST 9: ACTIVATION GRADIENT MAGNITUDE THROUGH RELAY
# ══════════════════════════════════════════════════════════════════
print(f"\n{'━'*80}")
print(f"TEST 9: Gradient Magnitude Through Relay Stack")
print(f" How does each activation affect gradient flow through {DEPTH} layers?")
print(f"{'━'*80}")
print(f"\n {'pw_act':>14} {'grad_in':>10} {'grad_out':>10} {'ratio':>10} "
f"{'anchor_grad':>12} {'gate_grad':>12}")
for act_name in ["none", "relu", "gelu", "silu", "tanh", "squared_relu", "star_relu"]:
torch.manual_seed(42)
layers = nn.ModuleList([
ConstellationRelay(dim=D, pw_act=act_name)
for _ in range(DEPTH)
]).to(DEVICE)
x = x_input.clone().requires_grad_(True)
h = x
for layer in layers:
h = layer(h)
h.retain_grad()
loss = h.sum()
loss.backward()
grad_in = x.grad.norm().item()
anchor_grads = [l.anchors.grad.norm().item() for l in layers if l.anchors.grad is not None]
gate_grads = [l.gate.grad.item() for l in layers if l.gate.grad is not None]
grad_out = h.grad.norm().item() if h.grad is not None else 0
print(f" {act_name:>14} {grad_in:>10.4f} {grad_out:>10.4f} "
f"{grad_in / (grad_out + 1e-8):>10.4f} "
f"{np.mean(anchor_grads) if anchor_grads else 0:>12.6f} "
f"{np.mean(gate_grads) if gate_grads else 0:>12.6f}")
# ══════════════════════════════════════════════════════════════════
# SUMMARY
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*80}")
print("EXPERIMENT COMPLETE")
print(f"{'='*80}")
print(f"""
9 tests covering:
T1: Patchwork activation β†’ geometric preservation at depth 16
T2: Pre-activation (before sphere) β†’ distribution distortion
T3: Post-activation (after patchwork) β†’ residual contribution
T4: Depth profile β†’ degradation curves per activation
T5: Trained relay β†’ classification accuracy + geometry
T6: Hybrid relay β†’ information retention + cross-token effect
T7: Hybrid trained β†’ cross-token routing task accuracy
T8: Drift trajectory β†’ activation effect on 0.29154 convergence
T9: Gradient flow β†’ magnitude through relay stack
Key questions answered:
β€’ Which activation preserves geometry best in the patchwork?
β€’ Does pre-activation (before sphere) help or hurt?
β€’ Does post-activation (after patchwork) affect the residual?
β€’ How fast does each activation degrade with depth?
β€’ Does activation choice affect the binding constant convergence?
β€’ Does activation choice affect cross-token routing in hybrids?
""")