geolip-autograd-induction-experiments / experiment_1 /experiment_4_30shape_autograd.py
AbstractPhil's picture
Rename experiment_4_30shape_autograd.py to experiment_1/experiment_4_30shape_autograd.py
f214af0 verified
# ============================================================================
# GEOMETRIC CONSTELLATION
#
# Pure abstract coordinate system on the unit hypersphere.
# No BERT. No semantics. No labels until assigned.
#
# Each anchor is a 4D local sphere (tangent frame at that point).
# The full constellation has 5D pentachoral structure.
# Conv4d ingests raw data β†’ triangulation position β†’ rigidity accumulation.
#
# The optimizer protects the constellation.
# The patchwork learns to navigate it.
# Rigidity crystallizes from the data itself.
# ============================================================================
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_float32_matmul_precision('high')
# ══════════════════════════════════════════════════════════════════
# GEOMETRIC PRIMITIVES
# ══════════════════════════════════════════════════════════════════
@torch.no_grad()
def cayley_menger_vol_sq(pts):
"""Pentachoron volumeΒ² from 5 points. (B, 5, D) β†’ (B,)"""
pts = pts.float()
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff * diff).sum(-1)
B = pts.shape[0]
cm = torch.zeros(B, 6, 6, device=pts.device, dtype=torch.float32)
cm[:, 0, 1:] = 1.0; cm[:, 1:, 0] = 1.0; cm[:, 1:, 1:] = d2
return -torch.linalg.det(cm) / 9216.0
@torch.no_grad()
def pentachoron_cv(emb, n_samples=100):
B = emb.shape[0]
if B < 5: return 0.0
emb_f = emb.detach().float()
vols = []
for _ in range(n_samples):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol_sq(emb_f[idx].unsqueeze(0))[0]
v = math.sqrt(max(v2.item(), 0.0))
if v > 0: vols.append(v)
if len(vols) < 10: return 0.0
vols_t = torch.tensor(vols)
return float(vols_t.std() / (vols_t.mean() + 1e-8))
def tangential_projection(grad, embedding):
emb_n = F.normalize(embedding.detach().float(), dim=-1)
grad_f = grad.float()
radial = (grad_f * emb_n).sum(dim=-1, keepdim=True) * emb_n
return (grad_f - radial).to(grad.dtype), radial.to(grad.dtype)
class TangentialGradFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, emb, gate):
ctx.save_for_backward(emb)
ctx.gate = gate
return x
@staticmethod
def backward(ctx, grad):
emb, = ctx.saved_tensors
tang, norm = tangential_projection(grad, emb)
return tang + ctx.gate * norm, None, None
# ══════════════════════════════════════════════════════════════════
# CONSTELLATION: 30 abstract anchors, each a 4D local sphere
# ══════════════════════════════════════════════════════════════════
class Constellation(nn.Module):
"""
N anchor points on the D-dim unit hypersphere.
Each anchor carries a 4D local tangent frame (orthonormal basis
of its tangent space). The frame defines a local 4-sphere at
that point β€” a submanifold where nearby structure is measured.
The full constellation has 5D pentachoral regularity:
any 5 anchors form a pentachoron whose volume is monitored.
Xavier initialization guarantees near-orthogonality in high-D.
Expected pairwise cosine β‰ˆ 0 Β± 1/√D.
"""
def __init__(self, n_anchors=30, d_embed=768, d_local=4):
super().__init__()
self.n_anchors = n_anchors
self.d_embed = d_embed
self.d_local = d_local
# Anchor positions on hypersphere (Xavier β†’ normalize)
anchor_init = torch.randn(n_anchors, d_embed)
anchor_init = F.normalize(anchor_init, dim=-1)
self.anchors = nn.Parameter(anchor_init)
# Local 4D tangent frames at each anchor
# Each is (d_local, d_embed) β€” 4 orthonormal basis vectors
# tangent to the sphere at the anchor point
frames = torch.randn(n_anchors, d_local, d_embed)
self.local_frames = nn.Parameter(frames)
# Rigidity accumulator: running stats per anchor
# How rigid/crystalline the local geometry has become
self.register_buffer("rigidity", torch.zeros(n_anchors))
self.register_buffer("visit_count", torch.zeros(n_anchors))
self.register_buffer("local_cv", torch.zeros(n_anchors))
def orthogonalize_frames(self):
"""Project frames tangential to sphere and orthonormalize."""
with torch.no_grad():
anchors_n = F.normalize(self.anchors.data, dim=-1)
for i in range(self.n_anchors):
frame = self.local_frames.data[i] # (4, D)
a = anchors_n[i] # (D,)
# Remove radial component (make tangential)
radial = (frame @ a).unsqueeze(-1) * a.unsqueeze(0) # (4, D)
frame = frame - radial
# Gram-Schmidt orthonormalize the 4 tangent vectors
ortho = []
for j in range(self.d_local):
v = frame[j]
for u in ortho:
v = v - (v @ u) * u
v = F.normalize(v, dim=-1)
ortho.append(v)
self.local_frames.data[i] = torch.stack(ortho)
def triangulate(self, emb):
"""
Compute triangulation coordinates: angular distances to all anchors.
Args:
emb: (B, D) L2-normalized embeddings
Returns:
tri_coords: (B, N) cosine distances to each anchor
local_coords: (B, N, 4) projection onto each anchor's local frame
nearest: (B,) index of nearest anchor
"""
anchors_n = F.normalize(self.anchors, dim=-1) # (N, D)
# Global triangulation: cosine to each anchor
cos_sim = emb @ anchors_n.T # (B, N)
tri_coords = 1.0 - cos_sim # (B, N) distances
# Local coordinates: project onto each anchor's 4D tangent frame
# For each anchor, project the residual (emb - anchor projection) into local frame
# emb_centered = emb - cos_sim * anchor gives the tangential displacement
B = emb.shape[0]
local_coords = torch.zeros(B, self.n_anchors, self.d_local,
device=emb.device, dtype=emb.dtype)
for i in range(self.n_anchors):
# Tangential displacement from anchor i
displacement = emb - cos_sim[:, i:i+1] * anchors_n[i:i+1] # (B, D)
# Project into local 4D frame
frame = self.local_frames[i] # (4, D)
local_coords[:, i] = displacement @ frame.T # (B, 4)
nearest = cos_sim.argmax(dim=-1) # (B,)
return tri_coords, local_coords, nearest
@torch.no_grad()
def update_rigidity(self, emb, labels):
"""
Accumulate rigidity from training data.
Rigidity = how consistent the local geometry is around each anchor.
"""
anchors_n = F.normalize(self.anchors, dim=-1)
for i in range(self.n_anchors):
mask = labels == i
if mask.sum() < 5:
continue
cluster = emb[mask]
self.visit_count[i] += mask.sum().float()
# Local CV: pentachoron regularity within this cluster
cv = pentachoron_cv(cluster, n_samples=50)
# Exponential moving average
alpha = min(0.1, 10.0 / (self.visit_count[i] + 1))
self.local_cv[i] = (1 - alpha) * self.local_cv[i] + alpha * cv
# Rigidity: inverse of CV (more regular = more rigid)
self.rigidity[i] = 1.0 / (self.local_cv[i] + 0.01)
def constellation_health(self):
"""Global pentachoral regularity of the anchor constellation."""
anchors_n = F.normalize(self.anchors.detach(), dim=-1)
cos = anchors_n @ anchors_n.T
mask = ~torch.eye(self.n_anchors, dtype=bool, device=anchors_n.device)
return {
"mean_cos": cos[mask].mean().item(),
"std_cos": cos[mask].std().item(),
"min_cos": cos[mask].min().item(),
"max_cos": cos[mask].max().item(),
"cv": pentachoron_cv(anchors_n, n_samples=200),
"mean_rigidity": self.rigidity.mean().item(),
"max_rigidity": self.rigidity.max().item(),
}
# ══════════════════════════════════════════════════════════════════
# CONV4D INGESTION
# ══════════════════════════════════════════════════════════════════
class Conv4dBlock(nn.Module):
"""
Process triangulation coordinates as a 4D structure.
Input: (B, N, 4) local coordinates at N anchors
Reshape to (B, 1, N, 4) β€” treat as 1-channel 2D image where
height=N (anchors), width=4 (local frame dims).
Conv2d operates on this β€” spatial correlations across anchors
and across local frame dimensions ARE the geometric signal.
"""
def __init__(self, n_anchors=30, d_local=4, d_out=256):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, (3, 3), padding=(1, 1)), nn.GELU(),
nn.Conv2d(32, 64, (3, 3), padding=(1, 1)), nn.GELU(),
nn.Conv2d(64, 128, (3, 1), padding=(1, 0)), nn.GELU(),
nn.AdaptiveAvgPool2d((1, 1)),
)
self.proj = nn.Linear(128, d_out)
def forward(self, local_coords):
"""
Args:
local_coords: (B, N, 4) β€” local frame projections at each anchor
Returns:
features: (B, d_out) β€” geometric features from 4D structure
"""
x = local_coords.unsqueeze(1) # (B, 1, N, 4)
x = self.conv(x).flatten(1) # (B, 128)
return self.proj(x) # (B, d_out)
# ══════════════════════════════════════════════════════════════════
# FULL MODEL
# ══════════════════════════════════════════════════════════════════
class ConstellationClassifier(nn.Module):
"""
Image β†’ conv backbone β†’ hypersphere embedding β†’ triangulation
β†’ conv4d on local coords β†’ prototype logits
The constellation is the coordinate system.
The conv4d reads the geometric structure.
The prototypes are points in the triangulation space.
"""
def __init__(self, n_classes=30, n_anchors=30, d_embed=768,
d_local=4, d_hidden=256):
super().__init__()
self.n_classes = n_classes
# Image backbone (simple, 1-channel input)
self.backbone = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.GELU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1), nn.GELU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.GELU(),
nn.AdaptiveAvgPool2d(1),
)
self.embed_proj = nn.Sequential(
nn.Linear(128, d_embed),
nn.LayerNorm(d_embed),
)
# Constellation (abstract coordinate system)
self.constellation = Constellation(n_anchors, d_embed, d_local)
# Conv4d: read geometric structure from local coordinates
self.conv4d = Conv4dBlock(n_anchors, d_local, d_hidden)
# Global triangulation path
self.tri_proj = nn.Sequential(
nn.Linear(n_anchors, d_hidden),
nn.GELU(),
nn.LayerNorm(d_hidden),
)
# Combine local (conv4d) + global (triangulation) β†’ classify
self.classifier = nn.Sequential(
nn.Linear(d_hidden * 2, d_hidden),
nn.GELU(),
nn.LayerNorm(d_hidden),
nn.Linear(d_hidden, n_classes),
)
def forward(self, x):
"""
Args:
x: (B, 1, 32, 32) grayscale images
Returns:
logits: (B, n_classes)
emb: (B, d_embed) L2-normalized hypersphere embedding
tri_coords: (B, n_anchors) triangulation distances
local_coords: (B, n_anchors, 4) local frame projections
nearest: (B,) nearest anchor index
"""
# Backbone β†’ hypersphere
feat = self.backbone(x).flatten(1)
emb = F.normalize(self.embed_proj(feat), dim=-1)
# Triangulate against constellation
tri_coords, local_coords, nearest = self.constellation.triangulate(emb)
# Conv4d on local structure
local_feat = self.conv4d(local_coords) # (B, d_hidden)
# Global triangulation features
global_feat = self.tri_proj(tri_coords) # (B, d_hidden)
# Combine and classify
combined = torch.cat([local_feat, global_feat], dim=-1) # (B, 2*d_hidden)
logits = self.classifier(combined)
return logits, emb, tri_coords, local_coords, nearest
# ══════════════════════════════════════════════════════════════════
# SHAPE RENDERERS (reuse from trainer)
# ══════════════════════════════════════════════════════════════════
def _draw(img, x0, y0, x1, y1, t=1):
n = max(int(max(abs(x1-x0), abs(y1-y0))*2), 1); sz = img.shape[0]
for s in np.linspace(0, 1, n):
px, py = int(x0+s*(x1-x0)), int(y0+s*(y1-y0))
for dx in range(-t, t+1):
for dy in range(-t, t+1):
nx, ny = px+dx, py+dy
if 0 <= nx < sz and 0 <= ny < sz: img[ny, nx] = 1.0
def render_poly(nv, sz=32, p=0.15):
img = np.zeros((sz,sz), dtype=np.float32); cx,cy,r = sz/2,sz/2,sz*0.35
a = np.linspace(0,2*np.pi,nv,endpoint=False)+np.random.uniform(0,2*np.pi)
ri = r*(1+np.random.normal(0,p,nv))
pts = [(cx+ri[i]*np.cos(a[i]),cy+ri[i]*np.sin(a[i])) for i in range(nv)]
for i in range(nv): _draw(img,*pts[i],*pts[(i+1)%nv])
return img
def render_star(np_, sz=32, p=0.12):
img = np.zeros((sz,sz), dtype=np.float32); cx,cy = sz/2,sz/2
ro,ri_ = sz*0.38,sz*0.15
a = np.linspace(0,2*np.pi,np_*2,endpoint=False)+np.random.uniform(0,2*np.pi)
pts = [(cx+(ro if i%2==0 else ri_)*(1+np.random.normal(0,p))*np.cos(a[i]),
cy+(ro if i%2==0 else ri_)*(1+np.random.normal(0,p))*np.sin(a[i])) for i in range(len(a))]
for i in range(len(pts)): _draw(img,*pts[i],*pts[(i+1)%len(pts)])
return img
def render_cross(sz=32, p=0.15):
img = np.zeros((sz,sz), dtype=np.float32); cx,cy,arm = sz/2,sz/2,sz*0.3
for ab in [0,np.pi/2,np.pi,3*np.pi/2]:
a = ab+np.random.normal(0,p*0.3); r = arm*(1+np.random.normal(0,p))
_draw(img,cx,cy,cx+r*np.cos(a),cy+r*np.sin(a),2)
return img
def render_spiral(sz=32, p=0.1):
img = np.zeros((sz,sz), dtype=np.float32); cx,cy = sz/2,sz/2
for t in np.linspace(0,5*np.pi,200):
r = sz*0.015*t*(1+np.random.normal(0,p*0.3))
x,y = int(cx+r*np.cos(t)),int(cy+r*np.sin(t))
if 0<=x<sz and 0<=y<sz: img[y,x]=1.0
return img
def render_wave(sz=32, p=0.1):
img = np.zeros((sz,sz), dtype=np.float32)
f = 2+np.random.normal(0,0.3); amp = sz*0.15*(1+np.random.normal(0,p))
for x in range(sz):
y = int(sz/2+amp*np.sin(2*np.pi*f*x/sz))
if 0<=y<sz: img[y,x]=1.0
return img
def render_heart(sz=32, p=0.1):
img = np.zeros((sz,sz), dtype=np.float32); cx,cy = sz/2,sz*0.45
s = sz*0.017*(1+np.random.normal(0,p))
for t in np.linspace(0,2*np.pi,300):
x = 16*np.sin(t)**3; y = -(13*np.cos(t)-5*np.cos(2*t)-2*np.cos(3*t)-np.cos(4*t))
ix,iy = int(cx+x*s),int(cy+y*s)
if 0<=ix<sz and 0<=iy<sz: img[iy,ix]=1.0
return img
def render_crescent(sz=32, p=0.1):
img = np.zeros((sz,sz), dtype=np.float32); cx,cy,r = sz/2,sz/2,sz*0.35
r2 = r*0.7; off = r*0.3
for a in np.linspace(0,2*np.pi,300):
x1,y1 = cx+r*np.cos(a),cy+r*np.sin(a)
d2 = math.sqrt((x1-cx-off)**2+(y1-cy)**2)
if d2 >= r2*0.9:
ix,iy = int(x1),int(y1)
if 0<=ix<sz and 0<=iy<sz: img[iy,ix]=1.0
return img
def render_ellipse(sz=32, p=0.1):
img = np.zeros((sz,sz), dtype=np.float32); cx,cy = sz/2,sz/2
a,b = sz*0.38*(1+np.random.normal(0,p)),sz*0.22*(1+np.random.normal(0,p))
rot = np.random.uniform(0,np.pi)
for t in np.linspace(0,2*np.pi,200):
x,y = a*np.cos(t),b*np.sin(t)
ix,iy = int(cx+x*np.cos(rot)-y*np.sin(rot)),int(cy+x*np.sin(rot)+y*np.cos(rot))
if 0<=ix<sz and 0<=iy<sz: img[iy,ix]=1.0
return img
def render_ring(sz=32, p=0.1):
img = np.zeros((sz,sz), dtype=np.float32); cx,cy = sz/2,sz/2
r1,r2 = sz*0.35*(1+np.random.normal(0,p)),sz*0.22*(1+np.random.normal(0,p))
for a in np.linspace(0,2*np.pi,300):
for r in [r1,r2]:
x,y = int(cx+r*np.cos(a)),int(cy+r*np.sin(a))
if 0<=x<sz and 0<=y<sz: img[y,x]=1.0
return img
def render_arrow(sz=32, p=0.12):
img = np.zeros((sz,sz), dtype=np.float32); cx,cy = sz/2,sz/2
l = sz*0.35*(1+np.random.normal(0,p)); h = l*0.35
a = np.random.uniform(0,2*np.pi)
x1,y1 = cx-l*np.cos(a),cy-l*np.sin(a); x2,y2 = cx+l*np.cos(a),cy+l*np.sin(a)
_draw(img,x1,y1,x2,y2)
for da in [0.7,-0.7]: _draw(img,x2,y2,x2-h*np.cos(a+da),y2-h*np.sin(a+da))
return img
def render_chevron(sz=32, p=0.12):
img = np.zeros((sz,sz), dtype=np.float32); cx,cy = sz/2,sz/2
w,h = sz*0.3*(1+np.random.normal(0,p)),sz*0.25*(1+np.random.normal(0,p))
_draw(img,cx-w,cy+h,cx,cy-h); _draw(img,cx,cy-h,cx+w,cy+h)
return img
def render_semicircle(sz=32, p=0.1):
img = np.zeros((sz,sz), dtype=np.float32); cx,cy,r = sz/2,sz*0.6,sz*0.35
for a in np.linspace(np.pi,2*np.pi,150):
x,y = int(cx+r*np.cos(a)),int(cy+r*np.sin(a))
if 0<=x<sz and 0<=y<sz: img[y,x]=1.0
_draw(img,cx-r,cy,cx+r,cy)
return img
SHAPE_NAMES = [
"triangle","square","pentagon","hexagon","heptagon",
"octagon","nonagon","decagon","dodecagon",
"circle","ellipse","spiral","wave","crescent",
"star3","star4","star5","star6","star7","star8",
"cross","diamond","arrow","heart","ring",
"semicircle","trapezoid","parallelogram","rhombus","chevron",
]
def gen_one(c, sz=32):
if c==0: return render_poly(3,sz,0.20)
if c==1: return render_poly(4,sz,0.12)
if c==2: return render_poly(5,sz,0.15)
if c==3: return render_poly(6,sz,0.10)
if c==4: return render_poly(7,sz,0.10)
if c==5: return render_poly(8,sz,0.08)
if c==6: return render_poly(9,sz,0.08)
if c==7: return render_poly(10,sz,0.07)
if c==8: return render_poly(12,sz,0.06)
if c==9: return render_poly(32,sz,0.03)
if c==10: return render_ellipse(sz,0.10)
if c==11: return render_spiral(sz,0.10)
if c==12: return render_wave(sz,0.10)
if c==13: return render_crescent(sz,0.10)
if c==14: return render_star(3,sz,0.12)
if c==15: return render_star(4,sz,0.12)
if c==16: return render_star(5,sz,0.12)
if c==17: return render_star(6,sz,0.10)
if c==18: return render_star(7,sz,0.10)
if c==19: return render_star(8,sz,0.08)
if c==20: return render_cross(sz,0.15)
if c==21: return render_poly(4,sz,0.10)
if c==22: return render_arrow(sz,0.12)
if c==23: return render_heart(sz,0.10)
if c==24: return render_ring(sz,0.10)
if c==25: return render_semicircle(sz,0.10)
if c==26: return render_poly(4,sz,0.15)
if c==27: return render_poly(4,sz,0.18)
if c==28: return render_poly(4,sz,0.10)
if c==29: return render_chevron(sz,0.12)
return render_poly(3,sz,0.15)
def gen_dataset(n_per=500, sz=32):
imgs, labels = [], []
for _ in range(n_per):
for c in range(30):
imgs.append(gen_one(c, sz)); labels.append(c)
imgs = torch.tensor(np.array(imgs)).unsqueeze(1)
labels = torch.tensor(labels, dtype=torch.long)
perm = torch.randperm(len(labels))
return imgs[perm], labels[perm]
# ══════════════════════════════════════════════════════════════════
# TRAINING
# ══════════════════════════════════════════════════════════════════
def train():
torch.manual_seed(42); np.random.seed(42)
print(f"\n{'='*65}")
print("CONSTELLATION CLASSIFIER: Pure Geometric Coordinates")
print(f"{'='*65}")
print(f" Device: {DEVICE}")
print(f"\n Generating data...")
train_imgs, train_labels = gen_dataset(n_per=500)
val_imgs, val_labels = gen_dataset(n_per=100)
train_imgs, train_labels = train_imgs.to(DEVICE), train_labels.to(DEVICE)
val_imgs, val_labels = val_imgs.to(DEVICE), val_labels.to(DEVICE)
n_train, n_val = len(train_labels), len(val_labels)
print(f" Train: {n_train:,} Val: {n_val:,} Classes: 30")
model = ConstellationClassifier(
n_classes=30, n_anchors=30, d_embed=768,
d_local=4, d_hidden=256,
).to(DEVICE)
model = torch.compile(model, mode="default")
n_params = sum(p.numel() for p in model.parameters())
n_constellation = sum(p.numel() for p in model.constellation.parameters())
print(f" Total params: {n_params:,}")
print(f" Constellation params: {n_constellation:,}")
# Check initial constellation health
health = model.constellation.constellation_health()
print(f"\n Initial constellation:")
print(f" Mean cos: {health['mean_cos']:.4f} (want β‰ˆ0)")
print(f" Std cos: {health['std_cos']:.4f}")
print(f" Min cos: {health['min_cos']:.4f}")
print(f" Max cos: {health['max_cos']:.4f}")
print(f" CV: {health['cv']:.4f}")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
BATCH, EPOCHS = 256, 40
cv_target = None
gate_val = 0.0
for epoch in range(EPOCHS):
model.train()
# Orthogonalize frames periodically
if epoch % 5 == 0:
model.constellation.orthogonalize_frames()
perm = torch.randperm(n_train, device=DEVICE)
total_loss, total_correct, n = 0, 0, 0
for i in range(0, n_train, BATCH):
idx = perm[i:i+BATCH]
if len(idx) < 4: continue
logits, emb, tri, local, nearest = model(train_imgs[idx])
labels = train_labels[idx]
# CV gate
if n % 10 == 0:
cv_now = pentachoron_cv(emb, n_samples=30)
if cv_target is None and cv_now > 0:
cv_target = cv_now
if cv_target:
delta = cv_now - cv_target
if abs(delta) <= 0.02: gate_val = 0.0
elif delta < 0: gate_val = min(abs(delta)/(cv_target+1e-8), 1.0) * 0.3
else: gate_val = max(0.0, 0.1*(1-min(delta/(cv_target+1e-8),1.0)))
# Apply tangential gate
emb_gated = TangentialGradFn.apply(emb, emb, gate_val)
# Recompute logits through gated embedding
tri_g, local_g, _ = model.constellation.triangulate(emb_gated)
local_feat = model.conv4d(local_g)
global_feat = model.tri_proj(tri_g)
logits = model.classifier(torch.cat([local_feat, global_feat], dim=-1))
loss = F.cross_entropy(logits, labels)
loss.backward()
optimizer.step(); optimizer.zero_grad(set_to_none=True)
# Update rigidity
model.constellation.update_rigidity(emb.detach(), labels)
total_correct += (logits.argmax(-1) == labels).sum().item()
total_loss += loss.item()
n += 1
train_acc = total_correct / n_train
# Validation
model.eval()
with torch.no_grad():
v_logits, v_emb, v_tri, v_local, v_nearest = model(val_imgs)
v_acc = (v_logits.argmax(-1) == val_labels).float().mean().item()
v_cv = pentachoron_cv(v_emb, n_samples=100)
health = model.constellation.constellation_health()
# Per-type accuracy
types = {"polygon": list(range(9)), "curve": list(range(9,14)),
"star": list(range(14,20)), "structure": list(range(20,30))}
ta = {}
for tname, tids in types.items():
tmask = torch.zeros(n_val, dtype=bool, device=DEVICE)
for tid in tids: tmask |= (val_labels == tid)
if tmask.sum() > 0:
ta[tname] = (v_logits.argmax(-1)[tmask] == val_labels[tmask]).float().mean().item()
if (epoch + 1) % 5 == 0 or epoch == 0:
ta_str = " ".join(f"{t}={a:.2f}" for t, a in ta.items())
rig = model.constellation.rigidity
print(f" E{epoch+1:2d}: t_acc={train_acc:.3f} v_acc={v_acc:.3f} "
f"cv={v_cv:.4f} gate={gate_val:.3f} "
f"cos={health['mean_cos']:.4f} "
f"rig={rig.mean():.1f}/{rig.max():.1f} "
f"[{ta_str}]")
# Final constellation state
health = model.constellation.constellation_health()
print(f"\n Final constellation:")
print(f" Mean cos: {health['mean_cos']:.4f}")
print(f" CV: {health['cv']:.4f}")
print(f" Rigidity: mean={health['mean_rigidity']:.1f} max={health['max_rigidity']:.1f}")
# Rigidity per class
print(f"\n Per-anchor rigidity:")
rig = model.constellation.rigidity.cpu()
for i in range(30):
bar = "β–ˆ" * int(rig[i].item())
print(f" {SHAPE_NAMES[i]:15s}: {rig[i]:.1f} {bar}")
print(f"\n{'='*65}")
print("DONE")
print(f"{'='*65}")
if __name__ == "__main__":
train()