geolip-hypersphere-experiments / constellation.py
AbstractPhil's picture
Update constellation.py
e50ad6c verified
"""
Constellation β€” Geometric Observer + Interpreter
===================================================
Aligned to the proven GeoLIP Core trainer (91.2% CIFAR-10 @ 1.65M params).
Architecture:
emb @ anchors.T β†’ 64 distances β†’ 8 round-robin compartments β†’ cat(pw, emb) β†’ classifier
Key mechanisms:
- Round-robin compartments: 8 groups of 8 anchors, diverse measurements per group
- cat(patchwork, embedding): classifier sees both interpreted distances AND raw position
- Anchor push: direct centroid placement every N batches (self-distillation across time)
- Attraction loss: pulls embeddings toward nearest anchor
- InfoNCE on two views: alignment force
- Simple triangulation: emb @ anchors.T, no SLERP, no phases
Classes:
Constellation β€” triangulation against anchors on S^(d-1)
Patchwork β€” round-robin compartmentalized interpretation
ConstellationCore β€” full pipeline: constellation + patchwork + classifier
GeometricOps β€” CV, spread, Cayley-Menger utilities
GeometricAutograd β€” Form 12 manifold-aware gradient correction
Usage:
from constellation import ConstellationCore
model = ConstellationCore(num_classes=10, dim=192, n_anchors=64)
out = model(images) # dict: logits, embedding, triangulation, nearest, patchwork
loss, ld = model.compute_loss(out, targets, output_aug=out2)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from typing import Optional, Dict, Any
# ══════════════════════════════════════════════════════════════════
# ACTIVATIONS
# ══════════════════════════════════════════════════════════════════
class SquaredReLU(nn.Module):
"""x β†’ ReLU(x)Β². Proven #1 in bulk activation tests."""
def forward(self, x):
return F.relu(x) ** 2
class StarReLU(nn.Module):
"""x β†’ (ReLU(x))Β² * scale + bias. Runner-up in bulk tests."""
def __init__(self):
super().__init__()
self.scale = nn.Parameter(torch.ones(1) * 0.8944)
self.bias = nn.Parameter(torch.zeros(1) - 0.4472)
def forward(self, x):
return F.relu(x) ** 2 * self.scale + self.bias
ACTIVATIONS = {
'squared_relu': SquaredReLU,
'star_relu': StarReLU,
'gelu': lambda: nn.GELU(),
'relu': lambda: nn.ReLU(),
'sigmoid': lambda: nn.Sigmoid(),
}
def make_activation(name='squared_relu'):
"""Create activation by name."""
if name not in ACTIVATIONS:
raise ValueError(f"Unknown activation '{name}'. Choose from: {list(ACTIVATIONS.keys())}")
return ACTIVATIONS[name]()
# ══════════════════════════════════════════════════════════════════
# ANCHOR INITIALIZATION
# ══════════════════════════════════════════════════════════════════
def init_anchors_xavier(n, d):
"""Xavier normal β†’ normalize. Near-orthogonal in high-d."""
w = torch.empty(n, d)
nn.init.xavier_normal_(w)
return F.normalize(w, dim=-1)
def init_anchors_orthogonal(n, d):
"""QR decomposition β†’ exact orthonormal basis when n <= d."""
if n <= d:
M = torch.randn(d, n)
Q, _ = torch.linalg.qr(M)
return Q.T.contiguous()
else:
M = torch.randn(d, d)
Q, _ = torch.linalg.qr(M)
basis = Q.T
extra = F.normalize(torch.randn(n - d, d), dim=-1)
return torch.cat([basis, extra], dim=0)
def init_anchors_repulsion(n, d, iters=200, lr=0.05):
"""QR + iterative repulsion for even coverage. Used in proven Core."""
vecs = init_anchors_orthogonal(n, d)
vecs = F.normalize(vecs, dim=-1)
for _ in range(iters):
sim = vecs @ vecs.T
sim.fill_diagonal_(-2.0)
nn_idx = sim.argmax(dim=1)
vecs = F.normalize(vecs - lr * vecs[nn_idx], dim=-1)
return vecs
INIT_METHODS = {
'xavier': init_anchors_xavier,
'orthogonal': init_anchors_orthogonal,
'repulsion': init_anchors_repulsion,
}
# ══════════════════════════════════════════════════════════════════
# CONSTELLATION β€” triangulation on S^(d-1)
# ══════════════════════════════════════════════════════════════════
class Constellation(nn.Module):
"""Anchors on S^(d-1). Triangulates input embeddings.
Simple: emb @ anchors.T β†’ cosines β†’ distances.
No SLERP, no phases, no home/learned split.
Args:
n_anchors: number of reference points on S^(d-1)
dim: dimensionality of the sphere
anchor_drop: fraction to drop during training (0.15 proven)
anchor_init: 'repulsion', 'xavier', or 'orthogonal'
"""
def __init__(self, n_anchors, dim, anchor_drop=0.0, anchor_init='repulsion'):
super().__init__()
init_fn = INIT_METHODS[anchor_init]
self.anchors = nn.Parameter(init_fn(n_anchors, dim))
self.anchor_drop = anchor_drop
self.n_anchors = n_anchors
self.dim = dim
def triangulate(self, emb, training=False):
"""emb: (B, D) L2-normalized β†’ (tri, nearest).
tri: (B, A) angular distances to all anchors
nearest: (B,) index of closest anchor
"""
anchors = F.normalize(self.anchors, dim=-1)
if training and self.anchor_drop > 0:
mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop
if mask.sum() < 2:
mask[:2] = True
anchors_drop = anchors[mask]
cos = emb @ anchors_drop.T
tri = 1.0 - cos
_, nearest_local = cos.max(dim=-1)
nearest = mask.nonzero(as_tuple=True)[0][nearest_local]
else:
cos = emb @ anchors.T
tri = 1.0 - cos
_, nearest = cos.max(dim=-1)
return tri, nearest
def forward(self, emb, training=False):
return self.triangulate(emb, training=training)
# ══════════════════════════════════════════════════════════════════
# PATCHWORK β€” round-robin compartmentalized interpretation
# ══════════════════════════════════════════════════════════════════
class Patchwork(nn.Module):
"""Round-robin compartments reading diverse anchor subsets.
64 anchors, 8 compartments β†’ each reads 8 anchors.
Assignment: anchor k goes to compartment (k % n_comp).
Each compartment: Linear(anchors_per, d_comp*2) β†’ act β†’ Linear β†’ LN β†’ d_comp
Args:
n_anchors: total anchors (must be divisible by n_comp)
n_comp: number of compartments
d_comp: output dim per compartment
activation: activation function name
"""
def __init__(self, n_anchors, n_comp=8, d_comp=64, activation='squared_relu'):
super().__init__()
self.n_comp = n_comp
self.d_comp = d_comp
self.output_dim = n_comp * d_comp
# Round-robin assignment: anchor k β†’ compartment (k % n_comp)
self.register_buffer('asgn', torch.arange(n_anchors) % n_comp)
anchors_per = n_anchors // n_comp
self.comps = nn.ModuleList([
nn.Sequential(
nn.Linear(anchors_per, d_comp * 2),
make_activation(activation),
nn.Linear(d_comp * 2, d_comp),
nn.LayerNorm(d_comp),
) for _ in range(n_comp)
])
def forward(self, tri):
"""tri: (B, n_anchors) β†’ (B, n_comp * d_comp)"""
return torch.cat([
self.comps[k](tri[:, self.asgn == k])
for k in range(self.n_comp)
], dim=-1)
# ══════════════════════════════════════════════════════════════════
# CONSTELLATION CORE β€” full pipeline
# ══════════════════════════════════════════════════════════════════
class ConstellationCore(nn.Module):
"""Constellation + Patchwork + Classifier.
Forward returns dict with all outputs for downstream consumers.
Classifier reads cat(patchwork, embedding).
Args:
num_classes: classification targets
dim: embedding dimension (encoder output)
n_anchors: anchors on S^(dim-1)
n_comp: patchwork compartments
d_comp: hidden dim per compartment
anchor_drop: training dropout rate for anchors
anchor_init: initialization method
activation: activation for patchwork compartments
cv_target: target CV for geometric loss
infonce_temp: temperature for InfoNCE
"""
def __init__(
self,
num_classes=10,
dim=192,
n_anchors=64,
n_comp=8,
d_comp=64,
anchor_drop=0.15,
anchor_init='repulsion',
activation='squared_relu',
cv_target=0.22,
infonce_temp=0.07,
):
super().__init__()
self.num_classes = num_classes
self.dim = dim
self.cv_target = cv_target
self.infonce_temp = infonce_temp
self.config = {k: v for k, v in locals().items()
if k != 'self' and not k.startswith('_')}
self.constellation = Constellation(
n_anchors, dim, anchor_drop, anchor_init)
self.patchwork = Patchwork(
n_anchors, n_comp, d_comp, activation)
pw_dim = self.patchwork.output_dim
# Classifier reads cat(patchwork, embedding)
self.classifier = nn.Sequential(
nn.Linear(pw_dim + dim, pw_dim),
make_activation(activation),
nn.LayerNorm(pw_dim),
nn.Dropout(0.1),
nn.Linear(pw_dim, num_classes),
)
def forward(self, emb_normalized):
"""Forward pass on L2-normalized embeddings.
Args:
emb_normalized: (B, D) already on S^(d-1)
Returns:
dict with: logits, embedding, triangulation, nearest, patchwork
"""
emb = emb_normalized
# Full triangulation for patchwork
tri, nearest = self.constellation.triangulate(emb, training=False)
pw = self.patchwork(tri)
# Dropout version for nearest tracking only
if self.training:
_, nearest = self.constellation.triangulate(emb, training=True)
# Classifier sees BOTH patchwork interpretation AND raw position
logits = self.classifier(torch.cat([pw, emb], dim=-1))
return {
'logits': logits,
'embedding': emb,
'triangulation': tri,
'nearest': nearest,
'patchwork': pw,
}
def compute_loss(self, output, targets, output_aug=None):
"""Compute all losses.
Args:
output: dict from forward()
targets: (B,) class indices
output_aug: optional dict from forward() on second view
Returns:
(total_loss, loss_dict)
"""
ld = {}
emb = output['embedding']
B = emb.shape[0]
# CE classification
l_ce = F.cross_entropy(output['logits'], targets)
ld['ce'] = l_ce
ld['acc'] = (output['logits'].argmax(-1) == targets).float().mean().item()
# InfoNCE between augmented views
if output_aug is not None:
emb_aug = output_aug['embedding']
labels_nce = torch.arange(B, device=emb.device)
sim = emb @ emb_aug.T / self.infonce_temp
l_nce = F.cross_entropy(sim, labels_nce)
nce_acc = (sim.argmax(1) == labels_nce).float().mean().item()
ld['nce'] = l_nce
ld['nce_acc'] = nce_acc
# Anchor attraction: pull embeddings toward nearest anchor
anchors_n = F.normalize(self.constellation.anchors, dim=-1)
cos_to_anchors = emb @ anchors_n.T
nearest_cos = cos_to_anchors.max(dim=1).values
l_attract = (1.0 - nearest_cos).mean()
ld['attract'] = l_attract
ld['nearest_cos'] = nearest_cos.mean().item()
# CV on embeddings
l_cv = GeometricOps.cv_loss(emb, target=self.cv_target)
ld['cv'] = l_cv
# Anchor spread
l_spread = GeometricOps.anchor_spread_loss(self.constellation.anchors)
ld['spread'] = l_spread
# Total
loss = (l_ce
+ ld.get('nce', 0.0) * 1.0
+ l_attract * 0.5
+ l_cv * 0.01
+ l_spread * 0.001)
ld['total'] = loss
return loss, ld
@torch.no_grad()
def push_anchors_to_centroids(self, emb_buffer, label_buffer, lr=0.1):
"""Push anchors toward class centroids β€” self-distillation across time.
Phase 1: Compute class centroids from labels
Phase 2: Greedy-assign anchors to classes (round-robin capacity)
Phase 3: SLERP each anchor toward its class centroid with perpendicular
perturbation so co-class anchors don't collapse
Args:
emb_buffer: (N, D) accumulated embeddings
label_buffer: (N,) class labels
lr: blend rate toward centroid
Returns:
number of anchors moved
"""
anchors = self.constellation.anchors.data
n_a = anchors.shape[0]
emb_n = F.normalize(emb_buffer, dim=-1)
device = anchors.device
# Phase 1: class centroids
classes = label_buffer.unique()
n_cls = classes.shape[0]
centroids = []
for c in classes:
mask = label_buffer == c
if mask.sum() > 0:
centroids.append(
F.normalize(emb_n[mask].mean(0, keepdim=True), dim=-1))
if len(centroids) == 0:
return 0
centroids = torch.cat(centroids, dim=0)
# Phase 2: greedy anchor-to-class assignment
anchors_n = F.normalize(anchors, dim=-1)
cos = anchors_n @ centroids.T
anchors_per_class = n_a // n_cls
assigned_class = torch.full((n_a,), -1, dtype=torch.long, device=device)
class_count = torch.zeros(n_cls, dtype=torch.long, device=device)
_, flat_idx = cos.flatten().sort(descending=True)
for idx in flat_idx:
a = (idx // n_cls).item()
c = (idx % n_cls).item()
if assigned_class[a] >= 0:
continue
if class_count[c] >= anchors_per_class + 1:
continue
assigned_class[a] = c
class_count[c] += 1
if (assigned_class >= 0).all():
break
# Unassigned leftovers
unassigned = (assigned_class < 0).nonzero(as_tuple=True)[0]
if len(unassigned) > 0:
leftover_cos = anchors_n[unassigned] @ centroids.T
assigned_class[unassigned] = leftover_cos.argmax(dim=1)
# Phase 3: push with perpendicular perturbation
moved = 0
for a in range(n_a):
c = assigned_class[a].item()
target = centroids[c]
rank_in_class = (assigned_class[:a] == c).sum().item()
if anchors_per_class > 1 and rank_in_class > 0:
noise = torch.randn_like(target) * 0.05
noise = noise - (noise * target).sum() * target
target = F.normalize(
(target + noise).unsqueeze(0), dim=-1).squeeze(0)
anchors[a] = F.normalize(
(anchors_n[a] + lr * (target - anchors_n[a])).unsqueeze(0),
dim=-1).squeeze(0)
moved += 1
return moved
# ══════════════════════════════════════════════════════════════════
# CONSTELLATION RELAY β€” Form 5 (per-token geometric layer)
# ══════════════════════════════════════════════════════════════════
class ConstellationRelay(nn.Module):
"""Per-token geometric processing with gated residual.
O(S) complexity. Preserves 99.4% cos similarity at depth 16.
Pipeline:
LayerNorm β†’ L2 normalize β†’ triangulate β†’ patchwork β†’ project β†’ gated residual
Args:
dim: token dimension
n_anchors: anchors on S^(dim-1)
n_comp: patchwork compartments
d_comp: hidden dim per compartment
gate_init: initial gate bias (-3.0 β†’ sigmoid β‰ˆ 0.047)
anchor_init: initialization method
activation: activation function name
"""
def __init__(
self,
dim,
n_anchors=16,
n_comp=8,
d_comp=64,
gate_init=-3.0,
anchor_init='repulsion',
activation='squared_relu',
):
super().__init__()
self.dim = dim
self.norm = nn.LayerNorm(dim)
self.constellation = Constellation(
n_anchors, dim, anchor_init=anchor_init)
self.patchwork = Patchwork(
n_anchors, n_comp, d_comp, activation)
# Project patchwork back to token dim
self.proj = nn.Linear(self.patchwork.output_dim, dim)
# Gated residual
self.gate = nn.Parameter(torch.full((dim,), gate_init))
def forward(self, x):
"""x: (B, S, D) or (B, D) β†’ same shape."""
squeeze = False
if x.dim() == 2:
x = x.unsqueeze(1)
squeeze = True
B, S, D = x.shape
residual = x
h = self.norm(x)
h_flat = h.reshape(B * S, D)
h_flat = F.normalize(h_flat, dim=-1)
tri, _ = self.constellation.triangulate(h_flat)
pw = self.patchwork(tri)
update = self.proj(pw).reshape(B, S, D)
g = torch.sigmoid(self.gate)
out = residual + g * update
if squeeze:
out = out.squeeze(1)
return out
# ══════════════════════════════════════════════════════════════════
# GEOMETRIC OPS
# ══════════════════════════════════════════════════════════════════
class GeometricOps:
"""Static geometric utilities."""
@staticmethod
def cayley_menger_vol2(points):
"""Squared simplex volume. points: (B, N, D) β†’ (B,)."""
B, N, D = points.shape
gram = torch.bmm(points, points.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
d2 = F.relu(d2)
cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
k = N - 1
sign = (-1.0) ** (k + 1)
fact = math.factorial(k)
return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2))
@staticmethod
@torch.no_grad()
def cv_metric(emb, n_samples=200, n_points=5):
"""Non-differentiable CV for monitoring. Target band: 0.20–0.23."""
vols = []
for _ in range(n_samples):
idx = torch.randperm(emb.shape[0])[:n_points]
v2 = GeometricOps.cayley_menger_vol2(emb[idx].unsqueeze(0))
if v2[0] > 1e-20:
vols.append(v2[0].sqrt())
if len(vols) < 10:
return 0.0
vols_t = torch.stack(vols)
return (vols_t.std() / (vols_t.mean() + 1e-8)).item()
@staticmethod
def cv_loss(emb, target=0.22, n_samples=64, n_points=5):
"""Differentiable CV loss. Weight: 0.01 or below."""
B = emb.shape[0]
if B < n_points:
return torch.tensor(0.0, device=emb.device)
vols = []
for _ in range(n_samples):
idx = torch.randperm(min(B, 512), device=emb.device)[:n_points]
pts = emb[idx].unsqueeze(0)
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
d2 = F.relu(d2)
N = n_points
cm = torch.zeros(1, N + 1, N + 1, device=emb.device, dtype=emb.dtype)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
k = N - 1
pf = ((-1.0) ** (k + 1)) / ((2.0 ** k) * (math.factorial(k) ** 2))
v2 = pf * torch.linalg.det(cm.float())
if v2[0].item() > 1e-20:
vols.append(v2[0].to(emb.dtype).sqrt())
if len(vols) < 5:
return torch.tensor(0.0, device=emb.device)
vt = torch.stack(vols)
cv = vt.std() / (vt.mean() + 1e-8)
return (cv - target).pow(2)
@staticmethod
def anchor_spread_loss(anchors, target_cos=0.0):
"""Repulsion loss keeping anchors spread."""
a = F.normalize(anchors, dim=-1)
sim = a @ a.T
mask = ~torch.eye(a.shape[0], dtype=torch.bool, device=a.device)
return F.relu(sim[mask] - target_cos).mean()
@staticmethod
def diagnostics(constellation, emb):
"""Compute health metrics from a constellation and embeddings."""
tri, nearest = constellation.triangulate(emb, training=False)
n_active = nearest.unique().numel()
anchors_n = F.normalize(constellation.anchors, dim=-1)
cos_to_anchors = emb @ anchors_n.T
nearest_cos = cos_to_anchors.max(dim=1).values.mean().item()
counts = torch.bincount(nearest, minlength=constellation.n_anchors).float()
return {
'n_active': n_active,
'nearest_cos': nearest_cos,
'anchor_util_std': counts.std().item(),
'anchor_util_min': counts.min().item(),
'anchor_util_max': counts.max().item(),
}
# ══════════════════════════════════════════════════════════════════
# GEOMETRIC AUTOGRAD β€” Form 12
# ══════════════════════════════════════════════════════════════════
class GeometricAutograd(torch.autograd.Function):
"""Manifold-aware gradient correction on S^(D-1).
Forward: identity.
Backward: tangential projection + separation from nearest anchor.
Proven settings: tang=0.01, sep=1.0
"""
@staticmethod
def forward(ctx, emb, anchors, tang_strength, sep_strength):
ctx.save_for_backward(emb, anchors)
ctx.tang = tang_strength
ctx.sep = sep_strength
return emb
@staticmethod
def backward(ctx, grad):
emb, anchors = ctx.saved_tensors
tang = ctx.tang
sep = ctx.sep
dot = (grad * emb).sum(dim=-1, keepdim=True)
radial = dot * emb
tangential = grad - radial
corrected = tangential + (1.0 - tang) * radial
if sep > 0:
anchors_n = F.normalize(anchors.detach(), dim=-1)
cos_to_anchors = emb @ anchors_n.T
nearest_idx = cos_to_anchors.argmax(dim=-1)
nearest = anchors_n[nearest_idx]
toward = (corrected * nearest).sum(dim=-1, keepdim=True)
corrected = corrected - sep * F.relu(toward) * nearest
return corrected, None, None, None