| """ |
| Constellation β Geometric Observer + Interpreter |
| =================================================== |
| Aligned to the proven GeoLIP Core trainer (91.2% CIFAR-10 @ 1.65M params). |
| |
| Architecture: |
| emb @ anchors.T β 64 distances β 8 round-robin compartments β cat(pw, emb) β classifier |
| |
| Key mechanisms: |
| - Round-robin compartments: 8 groups of 8 anchors, diverse measurements per group |
| - cat(patchwork, embedding): classifier sees both interpreted distances AND raw position |
| - Anchor push: direct centroid placement every N batches (self-distillation across time) |
| - Attraction loss: pulls embeddings toward nearest anchor |
| - InfoNCE on two views: alignment force |
| - Simple triangulation: emb @ anchors.T, no SLERP, no phases |
| |
| Classes: |
| Constellation β triangulation against anchors on S^(d-1) |
| Patchwork β round-robin compartmentalized interpretation |
| ConstellationCore β full pipeline: constellation + patchwork + classifier |
| GeometricOps β CV, spread, Cayley-Menger utilities |
| GeometricAutograd β Form 12 manifold-aware gradient correction |
| |
| Usage: |
| from constellation import ConstellationCore |
| |
| model = ConstellationCore(num_classes=10, dim=192, n_anchors=64) |
| out = model(images) # dict: logits, embedding, triangulation, nearest, patchwork |
| loss, ld = model.compute_loss(out, targets, output_aug=out2) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from dataclasses import dataclass |
| from typing import Optional, Dict, Any |
|
|
|
|
| |
| |
| |
|
|
| class SquaredReLU(nn.Module): |
| """x β ReLU(x)Β². Proven #1 in bulk activation tests.""" |
| def forward(self, x): |
| return F.relu(x) ** 2 |
|
|
|
|
| class StarReLU(nn.Module): |
| """x β (ReLU(x))Β² * scale + bias. Runner-up in bulk tests.""" |
| def __init__(self): |
| super().__init__() |
| self.scale = nn.Parameter(torch.ones(1) * 0.8944) |
| self.bias = nn.Parameter(torch.zeros(1) - 0.4472) |
| def forward(self, x): |
| return F.relu(x) ** 2 * self.scale + self.bias |
|
|
|
|
| ACTIVATIONS = { |
| 'squared_relu': SquaredReLU, |
| 'star_relu': StarReLU, |
| 'gelu': lambda: nn.GELU(), |
| 'relu': lambda: nn.ReLU(), |
| 'sigmoid': lambda: nn.Sigmoid(), |
| } |
|
|
|
|
| def make_activation(name='squared_relu'): |
| """Create activation by name.""" |
| if name not in ACTIVATIONS: |
| raise ValueError(f"Unknown activation '{name}'. Choose from: {list(ACTIVATIONS.keys())}") |
| return ACTIVATIONS[name]() |
|
|
|
|
| |
| |
| |
|
|
| def init_anchors_xavier(n, d): |
| """Xavier normal β normalize. Near-orthogonal in high-d.""" |
| w = torch.empty(n, d) |
| nn.init.xavier_normal_(w) |
| return F.normalize(w, dim=-1) |
|
|
|
|
| def init_anchors_orthogonal(n, d): |
| """QR decomposition β exact orthonormal basis when n <= d.""" |
| if n <= d: |
| M = torch.randn(d, n) |
| Q, _ = torch.linalg.qr(M) |
| return Q.T.contiguous() |
| else: |
| M = torch.randn(d, d) |
| Q, _ = torch.linalg.qr(M) |
| basis = Q.T |
| extra = F.normalize(torch.randn(n - d, d), dim=-1) |
| return torch.cat([basis, extra], dim=0) |
|
|
|
|
| def init_anchors_repulsion(n, d, iters=200, lr=0.05): |
| """QR + iterative repulsion for even coverage. Used in proven Core.""" |
| vecs = init_anchors_orthogonal(n, d) |
| vecs = F.normalize(vecs, dim=-1) |
| for _ in range(iters): |
| sim = vecs @ vecs.T |
| sim.fill_diagonal_(-2.0) |
| nn_idx = sim.argmax(dim=1) |
| vecs = F.normalize(vecs - lr * vecs[nn_idx], dim=-1) |
| return vecs |
|
|
|
|
| INIT_METHODS = { |
| 'xavier': init_anchors_xavier, |
| 'orthogonal': init_anchors_orthogonal, |
| 'repulsion': init_anchors_repulsion, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class Constellation(nn.Module): |
| """Anchors on S^(d-1). Triangulates input embeddings. |
| |
| Simple: emb @ anchors.T β cosines β distances. |
| No SLERP, no phases, no home/learned split. |
| |
| Args: |
| n_anchors: number of reference points on S^(d-1) |
| dim: dimensionality of the sphere |
| anchor_drop: fraction to drop during training (0.15 proven) |
| anchor_init: 'repulsion', 'xavier', or 'orthogonal' |
| """ |
|
|
| def __init__(self, n_anchors, dim, anchor_drop=0.0, anchor_init='repulsion'): |
| super().__init__() |
| init_fn = INIT_METHODS[anchor_init] |
| self.anchors = nn.Parameter(init_fn(n_anchors, dim)) |
| self.anchor_drop = anchor_drop |
| self.n_anchors = n_anchors |
| self.dim = dim |
|
|
| def triangulate(self, emb, training=False): |
| """emb: (B, D) L2-normalized β (tri, nearest). |
| |
| tri: (B, A) angular distances to all anchors |
| nearest: (B,) index of closest anchor |
| """ |
| anchors = F.normalize(self.anchors, dim=-1) |
|
|
| if training and self.anchor_drop > 0: |
| mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop |
| if mask.sum() < 2: |
| mask[:2] = True |
| anchors_drop = anchors[mask] |
| cos = emb @ anchors_drop.T |
| tri = 1.0 - cos |
| _, nearest_local = cos.max(dim=-1) |
| nearest = mask.nonzero(as_tuple=True)[0][nearest_local] |
| else: |
| cos = emb @ anchors.T |
| tri = 1.0 - cos |
| _, nearest = cos.max(dim=-1) |
|
|
| return tri, nearest |
|
|
| def forward(self, emb, training=False): |
| return self.triangulate(emb, training=training) |
|
|
|
|
| |
| |
| |
|
|
| class Patchwork(nn.Module): |
| """Round-robin compartments reading diverse anchor subsets. |
| |
| 64 anchors, 8 compartments β each reads 8 anchors. |
| Assignment: anchor k goes to compartment (k % n_comp). |
| Each compartment: Linear(anchors_per, d_comp*2) β act β Linear β LN β d_comp |
| |
| Args: |
| n_anchors: total anchors (must be divisible by n_comp) |
| n_comp: number of compartments |
| d_comp: output dim per compartment |
| activation: activation function name |
| """ |
|
|
| def __init__(self, n_anchors, n_comp=8, d_comp=64, activation='squared_relu'): |
| super().__init__() |
| self.n_comp = n_comp |
| self.d_comp = d_comp |
| self.output_dim = n_comp * d_comp |
|
|
| |
| self.register_buffer('asgn', torch.arange(n_anchors) % n_comp) |
| anchors_per = n_anchors // n_comp |
|
|
| self.comps = nn.ModuleList([ |
| nn.Sequential( |
| nn.Linear(anchors_per, d_comp * 2), |
| make_activation(activation), |
| nn.Linear(d_comp * 2, d_comp), |
| nn.LayerNorm(d_comp), |
| ) for _ in range(n_comp) |
| ]) |
|
|
| def forward(self, tri): |
| """tri: (B, n_anchors) β (B, n_comp * d_comp)""" |
| return torch.cat([ |
| self.comps[k](tri[:, self.asgn == k]) |
| for k in range(self.n_comp) |
| ], dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| class ConstellationCore(nn.Module): |
| """Constellation + Patchwork + Classifier. |
| |
| Forward returns dict with all outputs for downstream consumers. |
| Classifier reads cat(patchwork, embedding). |
| |
| Args: |
| num_classes: classification targets |
| dim: embedding dimension (encoder output) |
| n_anchors: anchors on S^(dim-1) |
| n_comp: patchwork compartments |
| d_comp: hidden dim per compartment |
| anchor_drop: training dropout rate for anchors |
| anchor_init: initialization method |
| activation: activation for patchwork compartments |
| cv_target: target CV for geometric loss |
| infonce_temp: temperature for InfoNCE |
| """ |
|
|
| def __init__( |
| self, |
| num_classes=10, |
| dim=192, |
| n_anchors=64, |
| n_comp=8, |
| d_comp=64, |
| anchor_drop=0.15, |
| anchor_init='repulsion', |
| activation='squared_relu', |
| cv_target=0.22, |
| infonce_temp=0.07, |
| ): |
| super().__init__() |
| self.num_classes = num_classes |
| self.dim = dim |
| self.cv_target = cv_target |
| self.infonce_temp = infonce_temp |
|
|
| self.config = {k: v for k, v in locals().items() |
| if k != 'self' and not k.startswith('_')} |
|
|
| self.constellation = Constellation( |
| n_anchors, dim, anchor_drop, anchor_init) |
|
|
| self.patchwork = Patchwork( |
| n_anchors, n_comp, d_comp, activation) |
|
|
| pw_dim = self.patchwork.output_dim |
|
|
| |
| self.classifier = nn.Sequential( |
| nn.Linear(pw_dim + dim, pw_dim), |
| make_activation(activation), |
| nn.LayerNorm(pw_dim), |
| nn.Dropout(0.1), |
| nn.Linear(pw_dim, num_classes), |
| ) |
|
|
| def forward(self, emb_normalized): |
| """Forward pass on L2-normalized embeddings. |
| |
| Args: |
| emb_normalized: (B, D) already on S^(d-1) |
| |
| Returns: |
| dict with: logits, embedding, triangulation, nearest, patchwork |
| """ |
| emb = emb_normalized |
|
|
| |
| tri, nearest = self.constellation.triangulate(emb, training=False) |
| pw = self.patchwork(tri) |
|
|
| |
| if self.training: |
| _, nearest = self.constellation.triangulate(emb, training=True) |
|
|
| |
| logits = self.classifier(torch.cat([pw, emb], dim=-1)) |
|
|
| return { |
| 'logits': logits, |
| 'embedding': emb, |
| 'triangulation': tri, |
| 'nearest': nearest, |
| 'patchwork': pw, |
| } |
|
|
| def compute_loss(self, output, targets, output_aug=None): |
| """Compute all losses. |
| |
| Args: |
| output: dict from forward() |
| targets: (B,) class indices |
| output_aug: optional dict from forward() on second view |
| |
| Returns: |
| (total_loss, loss_dict) |
| """ |
| ld = {} |
| emb = output['embedding'] |
| B = emb.shape[0] |
|
|
| |
| l_ce = F.cross_entropy(output['logits'], targets) |
| ld['ce'] = l_ce |
| ld['acc'] = (output['logits'].argmax(-1) == targets).float().mean().item() |
|
|
| |
| if output_aug is not None: |
| emb_aug = output_aug['embedding'] |
| labels_nce = torch.arange(B, device=emb.device) |
| sim = emb @ emb_aug.T / self.infonce_temp |
| l_nce = F.cross_entropy(sim, labels_nce) |
| nce_acc = (sim.argmax(1) == labels_nce).float().mean().item() |
| ld['nce'] = l_nce |
| ld['nce_acc'] = nce_acc |
|
|
| |
| anchors_n = F.normalize(self.constellation.anchors, dim=-1) |
| cos_to_anchors = emb @ anchors_n.T |
| nearest_cos = cos_to_anchors.max(dim=1).values |
| l_attract = (1.0 - nearest_cos).mean() |
| ld['attract'] = l_attract |
| ld['nearest_cos'] = nearest_cos.mean().item() |
|
|
| |
| l_cv = GeometricOps.cv_loss(emb, target=self.cv_target) |
| ld['cv'] = l_cv |
|
|
| |
| l_spread = GeometricOps.anchor_spread_loss(self.constellation.anchors) |
| ld['spread'] = l_spread |
|
|
| |
| loss = (l_ce |
| + ld.get('nce', 0.0) * 1.0 |
| + l_attract * 0.5 |
| + l_cv * 0.01 |
| + l_spread * 0.001) |
| ld['total'] = loss |
| return loss, ld |
|
|
| @torch.no_grad() |
| def push_anchors_to_centroids(self, emb_buffer, label_buffer, lr=0.1): |
| """Push anchors toward class centroids β self-distillation across time. |
| |
| Phase 1: Compute class centroids from labels |
| Phase 2: Greedy-assign anchors to classes (round-robin capacity) |
| Phase 3: SLERP each anchor toward its class centroid with perpendicular |
| perturbation so co-class anchors don't collapse |
| |
| Args: |
| emb_buffer: (N, D) accumulated embeddings |
| label_buffer: (N,) class labels |
| lr: blend rate toward centroid |
| |
| Returns: |
| number of anchors moved |
| """ |
| anchors = self.constellation.anchors.data |
| n_a = anchors.shape[0] |
| emb_n = F.normalize(emb_buffer, dim=-1) |
| device = anchors.device |
|
|
| |
| classes = label_buffer.unique() |
| n_cls = classes.shape[0] |
| centroids = [] |
| for c in classes: |
| mask = label_buffer == c |
| if mask.sum() > 0: |
| centroids.append( |
| F.normalize(emb_n[mask].mean(0, keepdim=True), dim=-1)) |
| if len(centroids) == 0: |
| return 0 |
| centroids = torch.cat(centroids, dim=0) |
|
|
| |
| anchors_n = F.normalize(anchors, dim=-1) |
| cos = anchors_n @ centroids.T |
| anchors_per_class = n_a // n_cls |
| assigned_class = torch.full((n_a,), -1, dtype=torch.long, device=device) |
| class_count = torch.zeros(n_cls, dtype=torch.long, device=device) |
|
|
| _, flat_idx = cos.flatten().sort(descending=True) |
| for idx in flat_idx: |
| a = (idx // n_cls).item() |
| c = (idx % n_cls).item() |
| if assigned_class[a] >= 0: |
| continue |
| if class_count[c] >= anchors_per_class + 1: |
| continue |
| assigned_class[a] = c |
| class_count[c] += 1 |
| if (assigned_class >= 0).all(): |
| break |
|
|
| |
| unassigned = (assigned_class < 0).nonzero(as_tuple=True)[0] |
| if len(unassigned) > 0: |
| leftover_cos = anchors_n[unassigned] @ centroids.T |
| assigned_class[unassigned] = leftover_cos.argmax(dim=1) |
|
|
| |
| moved = 0 |
| for a in range(n_a): |
| c = assigned_class[a].item() |
| target = centroids[c] |
|
|
| rank_in_class = (assigned_class[:a] == c).sum().item() |
| if anchors_per_class > 1 and rank_in_class > 0: |
| noise = torch.randn_like(target) * 0.05 |
| noise = noise - (noise * target).sum() * target |
| target = F.normalize( |
| (target + noise).unsqueeze(0), dim=-1).squeeze(0) |
|
|
| anchors[a] = F.normalize( |
| (anchors_n[a] + lr * (target - anchors_n[a])).unsqueeze(0), |
| dim=-1).squeeze(0) |
| moved += 1 |
|
|
| return moved |
|
|
|
|
| |
| |
| |
|
|
| class ConstellationRelay(nn.Module): |
| """Per-token geometric processing with gated residual. |
| |
| O(S) complexity. Preserves 99.4% cos similarity at depth 16. |
| |
| Pipeline: |
| LayerNorm β L2 normalize β triangulate β patchwork β project β gated residual |
| |
| Args: |
| dim: token dimension |
| n_anchors: anchors on S^(dim-1) |
| n_comp: patchwork compartments |
| d_comp: hidden dim per compartment |
| gate_init: initial gate bias (-3.0 β sigmoid β 0.047) |
| anchor_init: initialization method |
| activation: activation function name |
| """ |
|
|
| def __init__( |
| self, |
| dim, |
| n_anchors=16, |
| n_comp=8, |
| d_comp=64, |
| gate_init=-3.0, |
| anchor_init='repulsion', |
| activation='squared_relu', |
| ): |
| super().__init__() |
| self.dim = dim |
| self.norm = nn.LayerNorm(dim) |
|
|
| self.constellation = Constellation( |
| n_anchors, dim, anchor_init=anchor_init) |
|
|
| self.patchwork = Patchwork( |
| n_anchors, n_comp, d_comp, activation) |
|
|
| |
| self.proj = nn.Linear(self.patchwork.output_dim, dim) |
|
|
| |
| self.gate = nn.Parameter(torch.full((dim,), gate_init)) |
|
|
| def forward(self, x): |
| """x: (B, S, D) or (B, D) β same shape.""" |
| squeeze = False |
| if x.dim() == 2: |
| x = x.unsqueeze(1) |
| squeeze = True |
|
|
| B, S, D = x.shape |
| residual = x |
|
|
| h = self.norm(x) |
| h_flat = h.reshape(B * S, D) |
| h_flat = F.normalize(h_flat, dim=-1) |
|
|
| tri, _ = self.constellation.triangulate(h_flat) |
| pw = self.patchwork(tri) |
| update = self.proj(pw).reshape(B, S, D) |
|
|
| g = torch.sigmoid(self.gate) |
| out = residual + g * update |
|
|
| if squeeze: |
| out = out.squeeze(1) |
| return out |
|
|
|
|
| |
| |
| |
|
|
| class GeometricOps: |
| """Static geometric utilities.""" |
|
|
| @staticmethod |
| def cayley_menger_vol2(points): |
| """Squared simplex volume. points: (B, N, D) β (B,).""" |
| B, N, D = points.shape |
| gram = torch.bmm(points, points.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram |
| d2 = F.relu(d2) |
| cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| k = N - 1 |
| sign = (-1.0) ** (k + 1) |
| fact = math.factorial(k) |
| return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2)) |
|
|
| @staticmethod |
| @torch.no_grad() |
| def cv_metric(emb, n_samples=200, n_points=5): |
| """Non-differentiable CV for monitoring. Target band: 0.20β0.23.""" |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(emb.shape[0])[:n_points] |
| v2 = GeometricOps.cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| if v2[0] > 1e-20: |
| vols.append(v2[0].sqrt()) |
| if len(vols) < 10: |
| return 0.0 |
| vols_t = torch.stack(vols) |
| return (vols_t.std() / (vols_t.mean() + 1e-8)).item() |
|
|
| @staticmethod |
| def cv_loss(emb, target=0.22, n_samples=64, n_points=5): |
| """Differentiable CV loss. Weight: 0.01 or below.""" |
| B = emb.shape[0] |
| if B < n_points: |
| return torch.tensor(0.0, device=emb.device) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(min(B, 512), device=emb.device)[:n_points] |
| pts = emb[idx].unsqueeze(0) |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram |
| d2 = F.relu(d2) |
| N = n_points |
| cm = torch.zeros(1, N + 1, N + 1, device=emb.device, dtype=emb.dtype) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| k = N - 1 |
| pf = ((-1.0) ** (k + 1)) / ((2.0 ** k) * (math.factorial(k) ** 2)) |
| v2 = pf * torch.linalg.det(cm.float()) |
| if v2[0].item() > 1e-20: |
| vols.append(v2[0].to(emb.dtype).sqrt()) |
| if len(vols) < 5: |
| return torch.tensor(0.0, device=emb.device) |
| vt = torch.stack(vols) |
| cv = vt.std() / (vt.mean() + 1e-8) |
| return (cv - target).pow(2) |
|
|
| @staticmethod |
| def anchor_spread_loss(anchors, target_cos=0.0): |
| """Repulsion loss keeping anchors spread.""" |
| a = F.normalize(anchors, dim=-1) |
| sim = a @ a.T |
| mask = ~torch.eye(a.shape[0], dtype=torch.bool, device=a.device) |
| return F.relu(sim[mask] - target_cos).mean() |
|
|
| @staticmethod |
| def diagnostics(constellation, emb): |
| """Compute health metrics from a constellation and embeddings.""" |
| tri, nearest = constellation.triangulate(emb, training=False) |
| n_active = nearest.unique().numel() |
| anchors_n = F.normalize(constellation.anchors, dim=-1) |
| cos_to_anchors = emb @ anchors_n.T |
| nearest_cos = cos_to_anchors.max(dim=1).values.mean().item() |
| counts = torch.bincount(nearest, minlength=constellation.n_anchors).float() |
| return { |
| 'n_active': n_active, |
| 'nearest_cos': nearest_cos, |
| 'anchor_util_std': counts.std().item(), |
| 'anchor_util_min': counts.min().item(), |
| 'anchor_util_max': counts.max().item(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| class GeometricAutograd(torch.autograd.Function): |
| """Manifold-aware gradient correction on S^(D-1). |
| |
| Forward: identity. |
| Backward: tangential projection + separation from nearest anchor. |
| |
| Proven settings: tang=0.01, sep=1.0 |
| """ |
|
|
| @staticmethod |
| def forward(ctx, emb, anchors, tang_strength, sep_strength): |
| ctx.save_for_backward(emb, anchors) |
| ctx.tang = tang_strength |
| ctx.sep = sep_strength |
| return emb |
|
|
| @staticmethod |
| def backward(ctx, grad): |
| emb, anchors = ctx.saved_tensors |
| tang = ctx.tang |
| sep = ctx.sep |
|
|
| dot = (grad * emb).sum(dim=-1, keepdim=True) |
| radial = dot * emb |
| tangential = grad - radial |
| corrected = tangential + (1.0 - tang) * radial |
|
|
| if sep > 0: |
| anchors_n = F.normalize(anchors.detach(), dim=-1) |
| cos_to_anchors = emb @ anchors_n.T |
| nearest_idx = cos_to_anchors.argmax(dim=-1) |
| nearest = anchors_n[nearest_idx] |
| toward = (corrected * nearest).sum(dim=-1, keepdim=True) |
| corrected = corrected - sep * F.relu(toward) * nearest |
|
|
| return corrected, None, None, None |