""" Constellation — Geometric Observer + Interpreter =================================================== Aligned to the proven GeoLIP Core trainer (91.2% CIFAR-10 @ 1.65M params). Architecture: emb @ anchors.T → 64 distances → 8 round-robin compartments → cat(pw, emb) → classifier Key mechanisms: - Round-robin compartments: 8 groups of 8 anchors, diverse measurements per group - cat(patchwork, embedding): classifier sees both interpreted distances AND raw position - Anchor push: direct centroid placement every N batches (self-distillation across time) - Attraction loss: pulls embeddings toward nearest anchor - InfoNCE on two views: alignment force - Simple triangulation: emb @ anchors.T, no SLERP, no phases Classes: Constellation — triangulation against anchors on S^(d-1) Patchwork — round-robin compartmentalized interpretation ConstellationCore — full pipeline: constellation + patchwork + classifier GeometricOps — CV, spread, Cayley-Menger utilities GeometricAutograd — Form 12 manifold-aware gradient correction Usage: from constellation import ConstellationCore model = ConstellationCore(num_classes=10, dim=192, n_anchors=64) out = model(images) # dict: logits, embedding, triangulation, nearest, patchwork loss, ld = model.compute_loss(out, targets, output_aug=out2) """ import torch import torch.nn as nn import torch.nn.functional as F import math from dataclasses import dataclass from typing import Optional, Dict, Any # ══════════════════════════════════════════════════════════════════ # ACTIVATIONS # ══════════════════════════════════════════════════════════════════ class SquaredReLU(nn.Module): """x → ReLU(x)². Proven #1 in bulk activation tests.""" def forward(self, x): return F.relu(x) ** 2 class StarReLU(nn.Module): """x → (ReLU(x))² * scale + bias. Runner-up in bulk tests.""" def __init__(self): super().__init__() self.scale = nn.Parameter(torch.ones(1) * 0.8944) self.bias = nn.Parameter(torch.zeros(1) - 0.4472) def forward(self, x): return F.relu(x) ** 2 * self.scale + self.bias ACTIVATIONS = { 'squared_relu': SquaredReLU, 'star_relu': StarReLU, 'gelu': lambda: nn.GELU(), 'relu': lambda: nn.ReLU(), 'sigmoid': lambda: nn.Sigmoid(), } def make_activation(name='squared_relu'): """Create activation by name.""" if name not in ACTIVATIONS: raise ValueError(f"Unknown activation '{name}'. Choose from: {list(ACTIVATIONS.keys())}") return ACTIVATIONS[name]() # ══════════════════════════════════════════════════════════════════ # ANCHOR INITIALIZATION # ══════════════════════════════════════════════════════════════════ def init_anchors_xavier(n, d): """Xavier normal → normalize. Near-orthogonal in high-d.""" w = torch.empty(n, d) nn.init.xavier_normal_(w) return F.normalize(w, dim=-1) def init_anchors_orthogonal(n, d): """QR decomposition → exact orthonormal basis when n <= d.""" if n <= d: M = torch.randn(d, n) Q, _ = torch.linalg.qr(M) return Q.T.contiguous() else: M = torch.randn(d, d) Q, _ = torch.linalg.qr(M) basis = Q.T extra = F.normalize(torch.randn(n - d, d), dim=-1) return torch.cat([basis, extra], dim=0) def init_anchors_repulsion(n, d, iters=200, lr=0.05): """QR + iterative repulsion for even coverage. Used in proven Core.""" vecs = init_anchors_orthogonal(n, d) vecs = F.normalize(vecs, dim=-1) for _ in range(iters): sim = vecs @ vecs.T sim.fill_diagonal_(-2.0) nn_idx = sim.argmax(dim=1) vecs = F.normalize(vecs - lr * vecs[nn_idx], dim=-1) return vecs INIT_METHODS = { 'xavier': init_anchors_xavier, 'orthogonal': init_anchors_orthogonal, 'repulsion': init_anchors_repulsion, } # ══════════════════════════════════════════════════════════════════ # CONSTELLATION — triangulation on S^(d-1) # ══════════════════════════════════════════════════════════════════ class Constellation(nn.Module): """Anchors on S^(d-1). Triangulates input embeddings. Simple: emb @ anchors.T → cosines → distances. No SLERP, no phases, no home/learned split. Args: n_anchors: number of reference points on S^(d-1) dim: dimensionality of the sphere anchor_drop: fraction to drop during training (0.15 proven) anchor_init: 'repulsion', 'xavier', or 'orthogonal' """ def __init__(self, n_anchors, dim, anchor_drop=0.0, anchor_init='repulsion'): super().__init__() init_fn = INIT_METHODS[anchor_init] self.anchors = nn.Parameter(init_fn(n_anchors, dim)) self.anchor_drop = anchor_drop self.n_anchors = n_anchors self.dim = dim def triangulate(self, emb, training=False): """emb: (B, D) L2-normalized → (tri, nearest). tri: (B, A) angular distances to all anchors nearest: (B,) index of closest anchor """ anchors = F.normalize(self.anchors, dim=-1) if training and self.anchor_drop > 0: mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop if mask.sum() < 2: mask[:2] = True anchors_drop = anchors[mask] cos = emb @ anchors_drop.T tri = 1.0 - cos _, nearest_local = cos.max(dim=-1) nearest = mask.nonzero(as_tuple=True)[0][nearest_local] else: cos = emb @ anchors.T tri = 1.0 - cos _, nearest = cos.max(dim=-1) return tri, nearest def forward(self, emb, training=False): return self.triangulate(emb, training=training) # ══════════════════════════════════════════════════════════════════ # PATCHWORK — round-robin compartmentalized interpretation # ══════════════════════════════════════════════════════════════════ class Patchwork(nn.Module): """Round-robin compartments reading diverse anchor subsets. 64 anchors, 8 compartments → each reads 8 anchors. Assignment: anchor k goes to compartment (k % n_comp). Each compartment: Linear(anchors_per, d_comp*2) → act → Linear → LN → d_comp Args: n_anchors: total anchors (must be divisible by n_comp) n_comp: number of compartments d_comp: output dim per compartment activation: activation function name """ def __init__(self, n_anchors, n_comp=8, d_comp=64, activation='squared_relu'): super().__init__() self.n_comp = n_comp self.d_comp = d_comp self.output_dim = n_comp * d_comp # Round-robin assignment: anchor k → compartment (k % n_comp) self.register_buffer('asgn', torch.arange(n_anchors) % n_comp) anchors_per = n_anchors // n_comp self.comps = nn.ModuleList([ nn.Sequential( nn.Linear(anchors_per, d_comp * 2), make_activation(activation), nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp), ) for _ in range(n_comp) ]) def forward(self, tri): """tri: (B, n_anchors) → (B, n_comp * d_comp)""" return torch.cat([ self.comps[k](tri[:, self.asgn == k]) for k in range(self.n_comp) ], dim=-1) # ══════════════════════════════════════════════════════════════════ # CONSTELLATION CORE — full pipeline # ══════════════════════════════════════════════════════════════════ class ConstellationCore(nn.Module): """Constellation + Patchwork + Classifier. Forward returns dict with all outputs for downstream consumers. Classifier reads cat(patchwork, embedding). Args: num_classes: classification targets dim: embedding dimension (encoder output) n_anchors: anchors on S^(dim-1) n_comp: patchwork compartments d_comp: hidden dim per compartment anchor_drop: training dropout rate for anchors anchor_init: initialization method activation: activation for patchwork compartments cv_target: target CV for geometric loss infonce_temp: temperature for InfoNCE """ def __init__( self, num_classes=10, dim=192, n_anchors=64, n_comp=8, d_comp=64, anchor_drop=0.15, anchor_init='repulsion', activation='squared_relu', cv_target=0.22, infonce_temp=0.07, ): super().__init__() self.num_classes = num_classes self.dim = dim self.cv_target = cv_target self.infonce_temp = infonce_temp self.config = {k: v for k, v in locals().items() if k != 'self' and not k.startswith('_')} self.constellation = Constellation( n_anchors, dim, anchor_drop, anchor_init) self.patchwork = Patchwork( n_anchors, n_comp, d_comp, activation) pw_dim = self.patchwork.output_dim # Classifier reads cat(patchwork, embedding) self.classifier = nn.Sequential( nn.Linear(pw_dim + dim, pw_dim), make_activation(activation), nn.LayerNorm(pw_dim), nn.Dropout(0.1), nn.Linear(pw_dim, num_classes), ) def forward(self, emb_normalized): """Forward pass on L2-normalized embeddings. Args: emb_normalized: (B, D) already on S^(d-1) Returns: dict with: logits, embedding, triangulation, nearest, patchwork """ emb = emb_normalized # Full triangulation for patchwork tri, nearest = self.constellation.triangulate(emb, training=False) pw = self.patchwork(tri) # Dropout version for nearest tracking only if self.training: _, nearest = self.constellation.triangulate(emb, training=True) # Classifier sees BOTH patchwork interpretation AND raw position logits = self.classifier(torch.cat([pw, emb], dim=-1)) return { 'logits': logits, 'embedding': emb, 'triangulation': tri, 'nearest': nearest, 'patchwork': pw, } def compute_loss(self, output, targets, output_aug=None): """Compute all losses. Args: output: dict from forward() targets: (B,) class indices output_aug: optional dict from forward() on second view Returns: (total_loss, loss_dict) """ ld = {} emb = output['embedding'] B = emb.shape[0] # CE classification l_ce = F.cross_entropy(output['logits'], targets) ld['ce'] = l_ce ld['acc'] = (output['logits'].argmax(-1) == targets).float().mean().item() # InfoNCE between augmented views if output_aug is not None: emb_aug = output_aug['embedding'] labels_nce = torch.arange(B, device=emb.device) sim = emb @ emb_aug.T / self.infonce_temp l_nce = F.cross_entropy(sim, labels_nce) nce_acc = (sim.argmax(1) == labels_nce).float().mean().item() ld['nce'] = l_nce ld['nce_acc'] = nce_acc # Anchor attraction: pull embeddings toward nearest anchor anchors_n = F.normalize(self.constellation.anchors, dim=-1) cos_to_anchors = emb @ anchors_n.T nearest_cos = cos_to_anchors.max(dim=1).values l_attract = (1.0 - nearest_cos).mean() ld['attract'] = l_attract ld['nearest_cos'] = nearest_cos.mean().item() # CV on embeddings l_cv = GeometricOps.cv_loss(emb, target=self.cv_target) ld['cv'] = l_cv # Anchor spread l_spread = GeometricOps.anchor_spread_loss(self.constellation.anchors) ld['spread'] = l_spread # Total loss = (l_ce + ld.get('nce', 0.0) * 1.0 + l_attract * 0.5 + l_cv * 0.01 + l_spread * 0.001) ld['total'] = loss return loss, ld @torch.no_grad() def push_anchors_to_centroids(self, emb_buffer, label_buffer, lr=0.1): """Push anchors toward class centroids — self-distillation across time. Phase 1: Compute class centroids from labels Phase 2: Greedy-assign anchors to classes (round-robin capacity) Phase 3: SLERP each anchor toward its class centroid with perpendicular perturbation so co-class anchors don't collapse Args: emb_buffer: (N, D) accumulated embeddings label_buffer: (N,) class labels lr: blend rate toward centroid Returns: number of anchors moved """ anchors = self.constellation.anchors.data n_a = anchors.shape[0] emb_n = F.normalize(emb_buffer, dim=-1) device = anchors.device # Phase 1: class centroids classes = label_buffer.unique() n_cls = classes.shape[0] centroids = [] for c in classes: mask = label_buffer == c if mask.sum() > 0: centroids.append( F.normalize(emb_n[mask].mean(0, keepdim=True), dim=-1)) if len(centroids) == 0: return 0 centroids = torch.cat(centroids, dim=0) # Phase 2: greedy anchor-to-class assignment anchors_n = F.normalize(anchors, dim=-1) cos = anchors_n @ centroids.T anchors_per_class = n_a // n_cls assigned_class = torch.full((n_a,), -1, dtype=torch.long, device=device) class_count = torch.zeros(n_cls, dtype=torch.long, device=device) _, flat_idx = cos.flatten().sort(descending=True) for idx in flat_idx: a = (idx // n_cls).item() c = (idx % n_cls).item() if assigned_class[a] >= 0: continue if class_count[c] >= anchors_per_class + 1: continue assigned_class[a] = c class_count[c] += 1 if (assigned_class >= 0).all(): break # Unassigned leftovers unassigned = (assigned_class < 0).nonzero(as_tuple=True)[0] if len(unassigned) > 0: leftover_cos = anchors_n[unassigned] @ centroids.T assigned_class[unassigned] = leftover_cos.argmax(dim=1) # Phase 3: push with perpendicular perturbation moved = 0 for a in range(n_a): c = assigned_class[a].item() target = centroids[c] rank_in_class = (assigned_class[:a] == c).sum().item() if anchors_per_class > 1 and rank_in_class > 0: noise = torch.randn_like(target) * 0.05 noise = noise - (noise * target).sum() * target target = F.normalize( (target + noise).unsqueeze(0), dim=-1).squeeze(0) anchors[a] = F.normalize( (anchors_n[a] + lr * (target - anchors_n[a])).unsqueeze(0), dim=-1).squeeze(0) moved += 1 return moved # ══════════════════════════════════════════════════════════════════ # CONSTELLATION RELAY — Form 5 (per-token geometric layer) # ══════════════════════════════════════════════════════════════════ class ConstellationRelay(nn.Module): """Per-token geometric processing with gated residual. O(S) complexity. Preserves 99.4% cos similarity at depth 16. Pipeline: LayerNorm → L2 normalize → triangulate → patchwork → project → gated residual Args: dim: token dimension n_anchors: anchors on S^(dim-1) n_comp: patchwork compartments d_comp: hidden dim per compartment gate_init: initial gate bias (-3.0 → sigmoid ≈ 0.047) anchor_init: initialization method activation: activation function name """ def __init__( self, dim, n_anchors=16, n_comp=8, d_comp=64, gate_init=-3.0, anchor_init='repulsion', activation='squared_relu', ): super().__init__() self.dim = dim self.norm = nn.LayerNorm(dim) self.constellation = Constellation( n_anchors, dim, anchor_init=anchor_init) self.patchwork = Patchwork( n_anchors, n_comp, d_comp, activation) # Project patchwork back to token dim self.proj = nn.Linear(self.patchwork.output_dim, dim) # Gated residual self.gate = nn.Parameter(torch.full((dim,), gate_init)) def forward(self, x): """x: (B, S, D) or (B, D) → same shape.""" squeeze = False if x.dim() == 2: x = x.unsqueeze(1) squeeze = True B, S, D = x.shape residual = x h = self.norm(x) h_flat = h.reshape(B * S, D) h_flat = F.normalize(h_flat, dim=-1) tri, _ = self.constellation.triangulate(h_flat) pw = self.patchwork(tri) update = self.proj(pw).reshape(B, S, D) g = torch.sigmoid(self.gate) out = residual + g * update if squeeze: out = out.squeeze(1) return out # ══════════════════════════════════════════════════════════════════ # GEOMETRIC OPS # ══════════════════════════════════════════════════════════════════ class GeometricOps: """Static geometric utilities.""" @staticmethod def cayley_menger_vol2(points): """Squared simplex volume. points: (B, N, D) → (B,).""" B, N, D = points.shape gram = torch.bmm(points, points.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram d2 = F.relu(d2) cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 k = N - 1 sign = (-1.0) ** (k + 1) fact = math.factorial(k) return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2)) @staticmethod @torch.no_grad() def cv_metric(emb, n_samples=200, n_points=5): """Non-differentiable CV for monitoring. Target band: 0.20–0.23.""" vols = [] for _ in range(n_samples): idx = torch.randperm(emb.shape[0])[:n_points] v2 = GeometricOps.cayley_menger_vol2(emb[idx].unsqueeze(0)) if v2[0] > 1e-20: vols.append(v2[0].sqrt()) if len(vols) < 10: return 0.0 vols_t = torch.stack(vols) return (vols_t.std() / (vols_t.mean() + 1e-8)).item() @staticmethod def cv_loss(emb, target=0.22, n_samples=64, n_points=5): """Differentiable CV loss. Weight: 0.01 or below.""" B = emb.shape[0] if B < n_points: return torch.tensor(0.0, device=emb.device) vols = [] for _ in range(n_samples): idx = torch.randperm(min(B, 512), device=emb.device)[:n_points] pts = emb[idx].unsqueeze(0) gram = torch.bmm(pts, pts.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram d2 = F.relu(d2) N = n_points cm = torch.zeros(1, N + 1, N + 1, device=emb.device, dtype=emb.dtype) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 k = N - 1 pf = ((-1.0) ** (k + 1)) / ((2.0 ** k) * (math.factorial(k) ** 2)) v2 = pf * torch.linalg.det(cm.float()) if v2[0].item() > 1e-20: vols.append(v2[0].to(emb.dtype).sqrt()) if len(vols) < 5: return torch.tensor(0.0, device=emb.device) vt = torch.stack(vols) cv = vt.std() / (vt.mean() + 1e-8) return (cv - target).pow(2) @staticmethod def anchor_spread_loss(anchors, target_cos=0.0): """Repulsion loss keeping anchors spread.""" a = F.normalize(anchors, dim=-1) sim = a @ a.T mask = ~torch.eye(a.shape[0], dtype=torch.bool, device=a.device) return F.relu(sim[mask] - target_cos).mean() @staticmethod def diagnostics(constellation, emb): """Compute health metrics from a constellation and embeddings.""" tri, nearest = constellation.triangulate(emb, training=False) n_active = nearest.unique().numel() anchors_n = F.normalize(constellation.anchors, dim=-1) cos_to_anchors = emb @ anchors_n.T nearest_cos = cos_to_anchors.max(dim=1).values.mean().item() counts = torch.bincount(nearest, minlength=constellation.n_anchors).float() return { 'n_active': n_active, 'nearest_cos': nearest_cos, 'anchor_util_std': counts.std().item(), 'anchor_util_min': counts.min().item(), 'anchor_util_max': counts.max().item(), } # ══════════════════════════════════════════════════════════════════ # GEOMETRIC AUTOGRAD — Form 12 # ══════════════════════════════════════════════════════════════════ class GeometricAutograd(torch.autograd.Function): """Manifold-aware gradient correction on S^(D-1). Forward: identity. Backward: tangential projection + separation from nearest anchor. Proven settings: tang=0.01, sep=1.0 """ @staticmethod def forward(ctx, emb, anchors, tang_strength, sep_strength): ctx.save_for_backward(emb, anchors) ctx.tang = tang_strength ctx.sep = sep_strength return emb @staticmethod def backward(ctx, grad): emb, anchors = ctx.saved_tensors tang = ctx.tang sep = ctx.sep dot = (grad * emb).sum(dim=-1, keepdim=True) radial = dot * emb tangential = grad - radial corrected = tangential + (1.0 - tang) * radial if sep > 0: anchors_n = F.normalize(anchors.detach(), dim=-1) cos_to_anchors = emb @ anchors_n.T nearest_idx = cos_to_anchors.argmax(dim=-1) nearest = anchors_n[nearest_idx] toward = (corrected * nearest).sum(dim=-1, keepdim=True) corrected = corrected - sep * F.relu(toward) * nearest return corrected, None, None, None