Rename experiment_4_30shape_autograd.py to experiment_1/experiment_4_30shape_autograd.py
f214af0 verified | # ============================================================================ | |
| # GEOMETRIC CONSTELLATION | |
| # | |
| # Pure abstract coordinate system on the unit hypersphere. | |
| # No BERT. No semantics. No labels until assigned. | |
| # | |
| # Each anchor is a 4D local sphere (tangent frame at that point). | |
| # The full constellation has 5D pentachoral structure. | |
| # Conv4d ingests raw data β triangulation position β rigidity accumulation. | |
| # | |
| # The optimizer protects the constellation. | |
| # The patchwork learns to navigate it. | |
| # Rigidity crystallizes from the data itself. | |
| # ============================================================================ | |
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from dataclasses import dataclass | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch.set_float32_matmul_precision('high') | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GEOMETRIC PRIMITIVES | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def cayley_menger_vol_sq(pts): | |
| """Pentachoron volumeΒ² from 5 points. (B, 5, D) β (B,)""" | |
| pts = pts.float() | |
| diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) | |
| d2 = (diff * diff).sum(-1) | |
| B = pts.shape[0] | |
| cm = torch.zeros(B, 6, 6, device=pts.device, dtype=torch.float32) | |
| cm[:, 0, 1:] = 1.0; cm[:, 1:, 0] = 1.0; cm[:, 1:, 1:] = d2 | |
| return -torch.linalg.det(cm) / 9216.0 | |
| def pentachoron_cv(emb, n_samples=100): | |
| B = emb.shape[0] | |
| if B < 5: return 0.0 | |
| emb_f = emb.detach().float() | |
| vols = [] | |
| for _ in range(n_samples): | |
| idx = torch.randperm(B, device=emb.device)[:5] | |
| v2 = cayley_menger_vol_sq(emb_f[idx].unsqueeze(0))[0] | |
| v = math.sqrt(max(v2.item(), 0.0)) | |
| if v > 0: vols.append(v) | |
| if len(vols) < 10: return 0.0 | |
| vols_t = torch.tensor(vols) | |
| return float(vols_t.std() / (vols_t.mean() + 1e-8)) | |
| def tangential_projection(grad, embedding): | |
| emb_n = F.normalize(embedding.detach().float(), dim=-1) | |
| grad_f = grad.float() | |
| radial = (grad_f * emb_n).sum(dim=-1, keepdim=True) * emb_n | |
| return (grad_f - radial).to(grad.dtype), radial.to(grad.dtype) | |
| class TangentialGradFn(torch.autograd.Function): | |
| def forward(ctx, x, emb, gate): | |
| ctx.save_for_backward(emb) | |
| ctx.gate = gate | |
| return x | |
| def backward(ctx, grad): | |
| emb, = ctx.saved_tensors | |
| tang, norm = tangential_projection(grad, emb) | |
| return tang + ctx.gate * norm, None, None | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONSTELLATION: 30 abstract anchors, each a 4D local sphere | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Constellation(nn.Module): | |
| """ | |
| N anchor points on the D-dim unit hypersphere. | |
| Each anchor carries a 4D local tangent frame (orthonormal basis | |
| of its tangent space). The frame defines a local 4-sphere at | |
| that point β a submanifold where nearby structure is measured. | |
| The full constellation has 5D pentachoral regularity: | |
| any 5 anchors form a pentachoron whose volume is monitored. | |
| Xavier initialization guarantees near-orthogonality in high-D. | |
| Expected pairwise cosine β 0 Β± 1/βD. | |
| """ | |
| def __init__(self, n_anchors=30, d_embed=768, d_local=4): | |
| super().__init__() | |
| self.n_anchors = n_anchors | |
| self.d_embed = d_embed | |
| self.d_local = d_local | |
| # Anchor positions on hypersphere (Xavier β normalize) | |
| anchor_init = torch.randn(n_anchors, d_embed) | |
| anchor_init = F.normalize(anchor_init, dim=-1) | |
| self.anchors = nn.Parameter(anchor_init) | |
| # Local 4D tangent frames at each anchor | |
| # Each is (d_local, d_embed) β 4 orthonormal basis vectors | |
| # tangent to the sphere at the anchor point | |
| frames = torch.randn(n_anchors, d_local, d_embed) | |
| self.local_frames = nn.Parameter(frames) | |
| # Rigidity accumulator: running stats per anchor | |
| # How rigid/crystalline the local geometry has become | |
| self.register_buffer("rigidity", torch.zeros(n_anchors)) | |
| self.register_buffer("visit_count", torch.zeros(n_anchors)) | |
| self.register_buffer("local_cv", torch.zeros(n_anchors)) | |
| def orthogonalize_frames(self): | |
| """Project frames tangential to sphere and orthonormalize.""" | |
| with torch.no_grad(): | |
| anchors_n = F.normalize(self.anchors.data, dim=-1) | |
| for i in range(self.n_anchors): | |
| frame = self.local_frames.data[i] # (4, D) | |
| a = anchors_n[i] # (D,) | |
| # Remove radial component (make tangential) | |
| radial = (frame @ a).unsqueeze(-1) * a.unsqueeze(0) # (4, D) | |
| frame = frame - radial | |
| # Gram-Schmidt orthonormalize the 4 tangent vectors | |
| ortho = [] | |
| for j in range(self.d_local): | |
| v = frame[j] | |
| for u in ortho: | |
| v = v - (v @ u) * u | |
| v = F.normalize(v, dim=-1) | |
| ortho.append(v) | |
| self.local_frames.data[i] = torch.stack(ortho) | |
| def triangulate(self, emb): | |
| """ | |
| Compute triangulation coordinates: angular distances to all anchors. | |
| Args: | |
| emb: (B, D) L2-normalized embeddings | |
| Returns: | |
| tri_coords: (B, N) cosine distances to each anchor | |
| local_coords: (B, N, 4) projection onto each anchor's local frame | |
| nearest: (B,) index of nearest anchor | |
| """ | |
| anchors_n = F.normalize(self.anchors, dim=-1) # (N, D) | |
| # Global triangulation: cosine to each anchor | |
| cos_sim = emb @ anchors_n.T # (B, N) | |
| tri_coords = 1.0 - cos_sim # (B, N) distances | |
| # Local coordinates: project onto each anchor's 4D tangent frame | |
| # For each anchor, project the residual (emb - anchor projection) into local frame | |
| # emb_centered = emb - cos_sim * anchor gives the tangential displacement | |
| B = emb.shape[0] | |
| local_coords = torch.zeros(B, self.n_anchors, self.d_local, | |
| device=emb.device, dtype=emb.dtype) | |
| for i in range(self.n_anchors): | |
| # Tangential displacement from anchor i | |
| displacement = emb - cos_sim[:, i:i+1] * anchors_n[i:i+1] # (B, D) | |
| # Project into local 4D frame | |
| frame = self.local_frames[i] # (4, D) | |
| local_coords[:, i] = displacement @ frame.T # (B, 4) | |
| nearest = cos_sim.argmax(dim=-1) # (B,) | |
| return tri_coords, local_coords, nearest | |
| def update_rigidity(self, emb, labels): | |
| """ | |
| Accumulate rigidity from training data. | |
| Rigidity = how consistent the local geometry is around each anchor. | |
| """ | |
| anchors_n = F.normalize(self.anchors, dim=-1) | |
| for i in range(self.n_anchors): | |
| mask = labels == i | |
| if mask.sum() < 5: | |
| continue | |
| cluster = emb[mask] | |
| self.visit_count[i] += mask.sum().float() | |
| # Local CV: pentachoron regularity within this cluster | |
| cv = pentachoron_cv(cluster, n_samples=50) | |
| # Exponential moving average | |
| alpha = min(0.1, 10.0 / (self.visit_count[i] + 1)) | |
| self.local_cv[i] = (1 - alpha) * self.local_cv[i] + alpha * cv | |
| # Rigidity: inverse of CV (more regular = more rigid) | |
| self.rigidity[i] = 1.0 / (self.local_cv[i] + 0.01) | |
| def constellation_health(self): | |
| """Global pentachoral regularity of the anchor constellation.""" | |
| anchors_n = F.normalize(self.anchors.detach(), dim=-1) | |
| cos = anchors_n @ anchors_n.T | |
| mask = ~torch.eye(self.n_anchors, dtype=bool, device=anchors_n.device) | |
| return { | |
| "mean_cos": cos[mask].mean().item(), | |
| "std_cos": cos[mask].std().item(), | |
| "min_cos": cos[mask].min().item(), | |
| "max_cos": cos[mask].max().item(), | |
| "cv": pentachoron_cv(anchors_n, n_samples=200), | |
| "mean_rigidity": self.rigidity.mean().item(), | |
| "max_rigidity": self.rigidity.max().item(), | |
| } | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONV4D INGESTION | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Conv4dBlock(nn.Module): | |
| """ | |
| Process triangulation coordinates as a 4D structure. | |
| Input: (B, N, 4) local coordinates at N anchors | |
| Reshape to (B, 1, N, 4) β treat as 1-channel 2D image where | |
| height=N (anchors), width=4 (local frame dims). | |
| Conv2d operates on this β spatial correlations across anchors | |
| and across local frame dimensions ARE the geometric signal. | |
| """ | |
| def __init__(self, n_anchors=30, d_local=4, d_out=256): | |
| super().__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(1, 32, (3, 3), padding=(1, 1)), nn.GELU(), | |
| nn.Conv2d(32, 64, (3, 3), padding=(1, 1)), nn.GELU(), | |
| nn.Conv2d(64, 128, (3, 1), padding=(1, 0)), nn.GELU(), | |
| nn.AdaptiveAvgPool2d((1, 1)), | |
| ) | |
| self.proj = nn.Linear(128, d_out) | |
| def forward(self, local_coords): | |
| """ | |
| Args: | |
| local_coords: (B, N, 4) β local frame projections at each anchor | |
| Returns: | |
| features: (B, d_out) β geometric features from 4D structure | |
| """ | |
| x = local_coords.unsqueeze(1) # (B, 1, N, 4) | |
| x = self.conv(x).flatten(1) # (B, 128) | |
| return self.proj(x) # (B, d_out) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FULL MODEL | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ConstellationClassifier(nn.Module): | |
| """ | |
| Image β conv backbone β hypersphere embedding β triangulation | |
| β conv4d on local coords β prototype logits | |
| The constellation is the coordinate system. | |
| The conv4d reads the geometric structure. | |
| The prototypes are points in the triangulation space. | |
| """ | |
| def __init__(self, n_classes=30, n_anchors=30, d_embed=768, | |
| d_local=4, d_hidden=256): | |
| super().__init__() | |
| self.n_classes = n_classes | |
| # Image backbone (simple, 1-channel input) | |
| self.backbone = nn.Sequential( | |
| nn.Conv2d(1, 32, 3, padding=1), nn.GELU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(32, 64, 3, padding=1), nn.GELU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(64, 128, 3, padding=1), nn.GELU(), | |
| nn.AdaptiveAvgPool2d(1), | |
| ) | |
| self.embed_proj = nn.Sequential( | |
| nn.Linear(128, d_embed), | |
| nn.LayerNorm(d_embed), | |
| ) | |
| # Constellation (abstract coordinate system) | |
| self.constellation = Constellation(n_anchors, d_embed, d_local) | |
| # Conv4d: read geometric structure from local coordinates | |
| self.conv4d = Conv4dBlock(n_anchors, d_local, d_hidden) | |
| # Global triangulation path | |
| self.tri_proj = nn.Sequential( | |
| nn.Linear(n_anchors, d_hidden), | |
| nn.GELU(), | |
| nn.LayerNorm(d_hidden), | |
| ) | |
| # Combine local (conv4d) + global (triangulation) β classify | |
| self.classifier = nn.Sequential( | |
| nn.Linear(d_hidden * 2, d_hidden), | |
| nn.GELU(), | |
| nn.LayerNorm(d_hidden), | |
| nn.Linear(d_hidden, n_classes), | |
| ) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: (B, 1, 32, 32) grayscale images | |
| Returns: | |
| logits: (B, n_classes) | |
| emb: (B, d_embed) L2-normalized hypersphere embedding | |
| tri_coords: (B, n_anchors) triangulation distances | |
| local_coords: (B, n_anchors, 4) local frame projections | |
| nearest: (B,) nearest anchor index | |
| """ | |
| # Backbone β hypersphere | |
| feat = self.backbone(x).flatten(1) | |
| emb = F.normalize(self.embed_proj(feat), dim=-1) | |
| # Triangulate against constellation | |
| tri_coords, local_coords, nearest = self.constellation.triangulate(emb) | |
| # Conv4d on local structure | |
| local_feat = self.conv4d(local_coords) # (B, d_hidden) | |
| # Global triangulation features | |
| global_feat = self.tri_proj(tri_coords) # (B, d_hidden) | |
| # Combine and classify | |
| combined = torch.cat([local_feat, global_feat], dim=-1) # (B, 2*d_hidden) | |
| logits = self.classifier(combined) | |
| return logits, emb, tri_coords, local_coords, nearest | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SHAPE RENDERERS (reuse from trainer) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _draw(img, x0, y0, x1, y1, t=1): | |
| n = max(int(max(abs(x1-x0), abs(y1-y0))*2), 1); sz = img.shape[0] | |
| for s in np.linspace(0, 1, n): | |
| px, py = int(x0+s*(x1-x0)), int(y0+s*(y1-y0)) | |
| for dx in range(-t, t+1): | |
| for dy in range(-t, t+1): | |
| nx, ny = px+dx, py+dy | |
| if 0 <= nx < sz and 0 <= ny < sz: img[ny, nx] = 1.0 | |
| def render_poly(nv, sz=32, p=0.15): | |
| img = np.zeros((sz,sz), dtype=np.float32); cx,cy,r = sz/2,sz/2,sz*0.35 | |
| a = np.linspace(0,2*np.pi,nv,endpoint=False)+np.random.uniform(0,2*np.pi) | |
| ri = r*(1+np.random.normal(0,p,nv)) | |
| pts = [(cx+ri[i]*np.cos(a[i]),cy+ri[i]*np.sin(a[i])) for i in range(nv)] | |
| for i in range(nv): _draw(img,*pts[i],*pts[(i+1)%nv]) | |
| return img | |
| def render_star(np_, sz=32, p=0.12): | |
| img = np.zeros((sz,sz), dtype=np.float32); cx,cy = sz/2,sz/2 | |
| ro,ri_ = sz*0.38,sz*0.15 | |
| a = np.linspace(0,2*np.pi,np_*2,endpoint=False)+np.random.uniform(0,2*np.pi) | |
| pts = [(cx+(ro if i%2==0 else ri_)*(1+np.random.normal(0,p))*np.cos(a[i]), | |
| cy+(ro if i%2==0 else ri_)*(1+np.random.normal(0,p))*np.sin(a[i])) for i in range(len(a))] | |
| for i in range(len(pts)): _draw(img,*pts[i],*pts[(i+1)%len(pts)]) | |
| return img | |
| def render_cross(sz=32, p=0.15): | |
| img = np.zeros((sz,sz), dtype=np.float32); cx,cy,arm = sz/2,sz/2,sz*0.3 | |
| for ab in [0,np.pi/2,np.pi,3*np.pi/2]: | |
| a = ab+np.random.normal(0,p*0.3); r = arm*(1+np.random.normal(0,p)) | |
| _draw(img,cx,cy,cx+r*np.cos(a),cy+r*np.sin(a),2) | |
| return img | |
| def render_spiral(sz=32, p=0.1): | |
| img = np.zeros((sz,sz), dtype=np.float32); cx,cy = sz/2,sz/2 | |
| for t in np.linspace(0,5*np.pi,200): | |
| r = sz*0.015*t*(1+np.random.normal(0,p*0.3)) | |
| x,y = int(cx+r*np.cos(t)),int(cy+r*np.sin(t)) | |
| if 0<=x<sz and 0<=y<sz: img[y,x]=1.0 | |
| return img | |
| def render_wave(sz=32, p=0.1): | |
| img = np.zeros((sz,sz), dtype=np.float32) | |
| f = 2+np.random.normal(0,0.3); amp = sz*0.15*(1+np.random.normal(0,p)) | |
| for x in range(sz): | |
| y = int(sz/2+amp*np.sin(2*np.pi*f*x/sz)) | |
| if 0<=y<sz: img[y,x]=1.0 | |
| return img | |
| def render_heart(sz=32, p=0.1): | |
| img = np.zeros((sz,sz), dtype=np.float32); cx,cy = sz/2,sz*0.45 | |
| s = sz*0.017*(1+np.random.normal(0,p)) | |
| for t in np.linspace(0,2*np.pi,300): | |
| x = 16*np.sin(t)**3; y = -(13*np.cos(t)-5*np.cos(2*t)-2*np.cos(3*t)-np.cos(4*t)) | |
| ix,iy = int(cx+x*s),int(cy+y*s) | |
| if 0<=ix<sz and 0<=iy<sz: img[iy,ix]=1.0 | |
| return img | |
| def render_crescent(sz=32, p=0.1): | |
| img = np.zeros((sz,sz), dtype=np.float32); cx,cy,r = sz/2,sz/2,sz*0.35 | |
| r2 = r*0.7; off = r*0.3 | |
| for a in np.linspace(0,2*np.pi,300): | |
| x1,y1 = cx+r*np.cos(a),cy+r*np.sin(a) | |
| d2 = math.sqrt((x1-cx-off)**2+(y1-cy)**2) | |
| if d2 >= r2*0.9: | |
| ix,iy = int(x1),int(y1) | |
| if 0<=ix<sz and 0<=iy<sz: img[iy,ix]=1.0 | |
| return img | |
| def render_ellipse(sz=32, p=0.1): | |
| img = np.zeros((sz,sz), dtype=np.float32); cx,cy = sz/2,sz/2 | |
| a,b = sz*0.38*(1+np.random.normal(0,p)),sz*0.22*(1+np.random.normal(0,p)) | |
| rot = np.random.uniform(0,np.pi) | |
| for t in np.linspace(0,2*np.pi,200): | |
| x,y = a*np.cos(t),b*np.sin(t) | |
| ix,iy = int(cx+x*np.cos(rot)-y*np.sin(rot)),int(cy+x*np.sin(rot)+y*np.cos(rot)) | |
| if 0<=ix<sz and 0<=iy<sz: img[iy,ix]=1.0 | |
| return img | |
| def render_ring(sz=32, p=0.1): | |
| img = np.zeros((sz,sz), dtype=np.float32); cx,cy = sz/2,sz/2 | |
| r1,r2 = sz*0.35*(1+np.random.normal(0,p)),sz*0.22*(1+np.random.normal(0,p)) | |
| for a in np.linspace(0,2*np.pi,300): | |
| for r in [r1,r2]: | |
| x,y = int(cx+r*np.cos(a)),int(cy+r*np.sin(a)) | |
| if 0<=x<sz and 0<=y<sz: img[y,x]=1.0 | |
| return img | |
| def render_arrow(sz=32, p=0.12): | |
| img = np.zeros((sz,sz), dtype=np.float32); cx,cy = sz/2,sz/2 | |
| l = sz*0.35*(1+np.random.normal(0,p)); h = l*0.35 | |
| a = np.random.uniform(0,2*np.pi) | |
| x1,y1 = cx-l*np.cos(a),cy-l*np.sin(a); x2,y2 = cx+l*np.cos(a),cy+l*np.sin(a) | |
| _draw(img,x1,y1,x2,y2) | |
| for da in [0.7,-0.7]: _draw(img,x2,y2,x2-h*np.cos(a+da),y2-h*np.sin(a+da)) | |
| return img | |
| def render_chevron(sz=32, p=0.12): | |
| img = np.zeros((sz,sz), dtype=np.float32); cx,cy = sz/2,sz/2 | |
| w,h = sz*0.3*(1+np.random.normal(0,p)),sz*0.25*(1+np.random.normal(0,p)) | |
| _draw(img,cx-w,cy+h,cx,cy-h); _draw(img,cx,cy-h,cx+w,cy+h) | |
| return img | |
| def render_semicircle(sz=32, p=0.1): | |
| img = np.zeros((sz,sz), dtype=np.float32); cx,cy,r = sz/2,sz*0.6,sz*0.35 | |
| for a in np.linspace(np.pi,2*np.pi,150): | |
| x,y = int(cx+r*np.cos(a)),int(cy+r*np.sin(a)) | |
| if 0<=x<sz and 0<=y<sz: img[y,x]=1.0 | |
| _draw(img,cx-r,cy,cx+r,cy) | |
| return img | |
| SHAPE_NAMES = [ | |
| "triangle","square","pentagon","hexagon","heptagon", | |
| "octagon","nonagon","decagon","dodecagon", | |
| "circle","ellipse","spiral","wave","crescent", | |
| "star3","star4","star5","star6","star7","star8", | |
| "cross","diamond","arrow","heart","ring", | |
| "semicircle","trapezoid","parallelogram","rhombus","chevron", | |
| ] | |
| def gen_one(c, sz=32): | |
| if c==0: return render_poly(3,sz,0.20) | |
| if c==1: return render_poly(4,sz,0.12) | |
| if c==2: return render_poly(5,sz,0.15) | |
| if c==3: return render_poly(6,sz,0.10) | |
| if c==4: return render_poly(7,sz,0.10) | |
| if c==5: return render_poly(8,sz,0.08) | |
| if c==6: return render_poly(9,sz,0.08) | |
| if c==7: return render_poly(10,sz,0.07) | |
| if c==8: return render_poly(12,sz,0.06) | |
| if c==9: return render_poly(32,sz,0.03) | |
| if c==10: return render_ellipse(sz,0.10) | |
| if c==11: return render_spiral(sz,0.10) | |
| if c==12: return render_wave(sz,0.10) | |
| if c==13: return render_crescent(sz,0.10) | |
| if c==14: return render_star(3,sz,0.12) | |
| if c==15: return render_star(4,sz,0.12) | |
| if c==16: return render_star(5,sz,0.12) | |
| if c==17: return render_star(6,sz,0.10) | |
| if c==18: return render_star(7,sz,0.10) | |
| if c==19: return render_star(8,sz,0.08) | |
| if c==20: return render_cross(sz,0.15) | |
| if c==21: return render_poly(4,sz,0.10) | |
| if c==22: return render_arrow(sz,0.12) | |
| if c==23: return render_heart(sz,0.10) | |
| if c==24: return render_ring(sz,0.10) | |
| if c==25: return render_semicircle(sz,0.10) | |
| if c==26: return render_poly(4,sz,0.15) | |
| if c==27: return render_poly(4,sz,0.18) | |
| if c==28: return render_poly(4,sz,0.10) | |
| if c==29: return render_chevron(sz,0.12) | |
| return render_poly(3,sz,0.15) | |
| def gen_dataset(n_per=500, sz=32): | |
| imgs, labels = [], [] | |
| for _ in range(n_per): | |
| for c in range(30): | |
| imgs.append(gen_one(c, sz)); labels.append(c) | |
| imgs = torch.tensor(np.array(imgs)).unsqueeze(1) | |
| labels = torch.tensor(labels, dtype=torch.long) | |
| perm = torch.randperm(len(labels)) | |
| return imgs[perm], labels[perm] | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TRAINING | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train(): | |
| torch.manual_seed(42); np.random.seed(42) | |
| print(f"\n{'='*65}") | |
| print("CONSTELLATION CLASSIFIER: Pure Geometric Coordinates") | |
| print(f"{'='*65}") | |
| print(f" Device: {DEVICE}") | |
| print(f"\n Generating data...") | |
| train_imgs, train_labels = gen_dataset(n_per=500) | |
| val_imgs, val_labels = gen_dataset(n_per=100) | |
| train_imgs, train_labels = train_imgs.to(DEVICE), train_labels.to(DEVICE) | |
| val_imgs, val_labels = val_imgs.to(DEVICE), val_labels.to(DEVICE) | |
| n_train, n_val = len(train_labels), len(val_labels) | |
| print(f" Train: {n_train:,} Val: {n_val:,} Classes: 30") | |
| model = ConstellationClassifier( | |
| n_classes=30, n_anchors=30, d_embed=768, | |
| d_local=4, d_hidden=256, | |
| ).to(DEVICE) | |
| model = torch.compile(model, mode="default") | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| n_constellation = sum(p.numel() for p in model.constellation.parameters()) | |
| print(f" Total params: {n_params:,}") | |
| print(f" Constellation params: {n_constellation:,}") | |
| # Check initial constellation health | |
| health = model.constellation.constellation_health() | |
| print(f"\n Initial constellation:") | |
| print(f" Mean cos: {health['mean_cos']:.4f} (want β0)") | |
| print(f" Std cos: {health['std_cos']:.4f}") | |
| print(f" Min cos: {health['min_cos']:.4f}") | |
| print(f" Max cos: {health['max_cos']:.4f}") | |
| print(f" CV: {health['cv']:.4f}") | |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | |
| BATCH, EPOCHS = 256, 40 | |
| cv_target = None | |
| gate_val = 0.0 | |
| for epoch in range(EPOCHS): | |
| model.train() | |
| # Orthogonalize frames periodically | |
| if epoch % 5 == 0: | |
| model.constellation.orthogonalize_frames() | |
| perm = torch.randperm(n_train, device=DEVICE) | |
| total_loss, total_correct, n = 0, 0, 0 | |
| for i in range(0, n_train, BATCH): | |
| idx = perm[i:i+BATCH] | |
| if len(idx) < 4: continue | |
| logits, emb, tri, local, nearest = model(train_imgs[idx]) | |
| labels = train_labels[idx] | |
| # CV gate | |
| if n % 10 == 0: | |
| cv_now = pentachoron_cv(emb, n_samples=30) | |
| if cv_target is None and cv_now > 0: | |
| cv_target = cv_now | |
| if cv_target: | |
| delta = cv_now - cv_target | |
| if abs(delta) <= 0.02: gate_val = 0.0 | |
| elif delta < 0: gate_val = min(abs(delta)/(cv_target+1e-8), 1.0) * 0.3 | |
| else: gate_val = max(0.0, 0.1*(1-min(delta/(cv_target+1e-8),1.0))) | |
| # Apply tangential gate | |
| emb_gated = TangentialGradFn.apply(emb, emb, gate_val) | |
| # Recompute logits through gated embedding | |
| tri_g, local_g, _ = model.constellation.triangulate(emb_gated) | |
| local_feat = model.conv4d(local_g) | |
| global_feat = model.tri_proj(tri_g) | |
| logits = model.classifier(torch.cat([local_feat, global_feat], dim=-1)) | |
| loss = F.cross_entropy(logits, labels) | |
| loss.backward() | |
| optimizer.step(); optimizer.zero_grad(set_to_none=True) | |
| # Update rigidity | |
| model.constellation.update_rigidity(emb.detach(), labels) | |
| total_correct += (logits.argmax(-1) == labels).sum().item() | |
| total_loss += loss.item() | |
| n += 1 | |
| train_acc = total_correct / n_train | |
| # Validation | |
| model.eval() | |
| with torch.no_grad(): | |
| v_logits, v_emb, v_tri, v_local, v_nearest = model(val_imgs) | |
| v_acc = (v_logits.argmax(-1) == val_labels).float().mean().item() | |
| v_cv = pentachoron_cv(v_emb, n_samples=100) | |
| health = model.constellation.constellation_health() | |
| # Per-type accuracy | |
| types = {"polygon": list(range(9)), "curve": list(range(9,14)), | |
| "star": list(range(14,20)), "structure": list(range(20,30))} | |
| ta = {} | |
| for tname, tids in types.items(): | |
| tmask = torch.zeros(n_val, dtype=bool, device=DEVICE) | |
| for tid in tids: tmask |= (val_labels == tid) | |
| if tmask.sum() > 0: | |
| ta[tname] = (v_logits.argmax(-1)[tmask] == val_labels[tmask]).float().mean().item() | |
| if (epoch + 1) % 5 == 0 or epoch == 0: | |
| ta_str = " ".join(f"{t}={a:.2f}" for t, a in ta.items()) | |
| rig = model.constellation.rigidity | |
| print(f" E{epoch+1:2d}: t_acc={train_acc:.3f} v_acc={v_acc:.3f} " | |
| f"cv={v_cv:.4f} gate={gate_val:.3f} " | |
| f"cos={health['mean_cos']:.4f} " | |
| f"rig={rig.mean():.1f}/{rig.max():.1f} " | |
| f"[{ta_str}]") | |
| # Final constellation state | |
| health = model.constellation.constellation_health() | |
| print(f"\n Final constellation:") | |
| print(f" Mean cos: {health['mean_cos']:.4f}") | |
| print(f" CV: {health['cv']:.4f}") | |
| print(f" Rigidity: mean={health['mean_rigidity']:.1f} max={health['max_rigidity']:.1f}") | |
| # Rigidity per class | |
| print(f"\n Per-anchor rigidity:") | |
| rig = model.constellation.rigidity.cpu() | |
| for i in range(30): | |
| bar = "β" * int(rig[i].item()) | |
| print(f" {SHAPE_NAMES[i]:15s}: {rig[i]:.1f} {bar}") | |
| print(f"\n{'='*65}") | |
| print("DONE") | |
| print(f"{'='*65}") | |
| if __name__ == "__main__": | |
| train() |