#!/usr/bin/env python3 """ GeoLIP Core — Back to Basics ============================== Conv encoder → sphere → constellation → patchwork → classifier. No streams. No GAL. No Procrustes. No mastery queue. Just the geometric classification pipeline. Two augmented views → InfoNCE + CE + CV. """ import torch import torch.nn as nn import torch.nn.functional as F import math import os, time import numpy as np from itertools import combinations from tqdm import tqdm from torchvision import datasets, transforms from torch.utils.tensorboard import SummaryWriter DEVICE = "cuda" if torch.cuda.is_available() else "cpu" torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # ══════════════════════════════════════════════════════════════════ # UNIFORM HYPERSPHERE INIT # ══════════════════════════════════════════════════════════════════ def uniform_hypersphere_init(n, d): if n <= d: M = torch.randn(d, n) Q, _ = torch.linalg.qr(M) return Q.T.contiguous() else: M = torch.randn(d, d) Q, _ = torch.linalg.qr(M) basis = Q.T extra = F.normalize(torch.randn(n - d, d), dim=-1) vecs = torch.cat([basis, extra], dim=0) for _ in range(200): sim = vecs @ vecs.T sim.fill_diagonal_(-2.0) nn_idx = sim.argmax(dim=1) vecs = F.normalize(vecs - 0.05 * vecs[nn_idx], dim=-1) return vecs # ══════════════════════════════════════════════════════════════════ # CONSTELLATION + PATCHWORK # ══════════════════════════════════════════════════════════════════ class Constellation(nn.Module): def __init__(self, n_anchors, dim, anchor_drop=0.0): super().__init__() self.anchors = nn.Parameter(uniform_hypersphere_init(n_anchors, dim)) self.anchor_drop = anchor_drop def triangulate(self, emb, training=False): anchors = F.normalize(self.anchors, dim=-1) if training and self.anchor_drop > 0: mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop if mask.sum() < 2: mask[:2] = True anchors = anchors[mask] cos = emb @ anchors.T tri = 1.0 - cos _, nearest_local = cos.max(dim=-1) nearest = mask.nonzero(as_tuple=True)[0][nearest_local] else: cos = emb @ anchors.T tri = 1.0 - cos _, nearest = cos.max(dim=-1) return tri, nearest class Patchwork(nn.Module): def __init__(self, n_anchors, n_comp, d_comp): super().__init__() self.n_comp = n_comp self.register_buffer('asgn', torch.arange(n_anchors) % n_comp) anchors_per = n_anchors // n_comp self.comps = nn.ModuleList([nn.Sequential( nn.Linear(anchors_per, d_comp * 2), nn.GELU(), nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp)) for _ in range(n_comp)]) def forward(self, tri): return torch.cat([self.comps[k](tri[:, self.asgn == k]) for k in range(self.n_comp)], -1) # ══════════════════════════════════════════════════════════════════ # CONV ENCODER # ══════════════════════════════════════════════════════════════════ class ConvEncoder(nn.Module): """ Simple conv backbone. No attention, no geometric layers. Just feature extraction into a flat vector. """ def __init__(self, output_dim=128): super().__init__() self.features = nn.Sequential( # 32×32 → 16×16 nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(), nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(), nn.MaxPool2d(2), # 16×16 → 8×8 nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(), nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(), nn.MaxPool2d(2), # 8×8 → 4×4 nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(), nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(), nn.MaxPool2d(2), # 4×4 → global nn.AdaptiveAvgPool2d(1), nn.Flatten(), ) self.proj = nn.Sequential( nn.Linear(256, output_dim), nn.LayerNorm(output_dim), ) def forward(self, x): return self.proj(self.features(x)) # ══════════════════════════════════════════════════════════════════ # GEOLIP CORE # ══════════════════════════════════════════════════════════════════ class GeoLIPCore(nn.Module): def __init__( self, num_classes=10, output_dim=128, n_anchors=64, n_comp=8, d_comp=64, anchor_drop=0.15, cv_target=0.22, infonce_temp=0.07, ): super().__init__() self.num_classes = num_classes self.output_dim = output_dim self.cv_target = cv_target self.infonce_temp = infonce_temp self.config = {k: v for k, v in locals().items() if k != 'self' and not k.startswith('_')} self.encoder = ConvEncoder(output_dim) self.constellation = Constellation(n_anchors, output_dim, anchor_drop) self.patchwork = Patchwork(n_anchors, n_comp, d_comp) pw_dim = n_comp * d_comp self.classifier = nn.Sequential( nn.Linear(pw_dim + output_dim, pw_dim), nn.GELU(), nn.LayerNorm(pw_dim), nn.Dropout(0.1), nn.Linear(pw_dim, num_classes)) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out') if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) def forward(self, x): feat = self.encoder(x) emb = F.normalize(feat, dim=-1) # Full tri for patchwork (needs all anchor columns) tri, nearest = self.constellation.triangulate(emb, training=False) pw = self.patchwork(tri) # Dropout version for nearest tracking only if self.training: _, nearest = self.constellation.triangulate(emb, training=True) logits = self.classifier(torch.cat([pw, emb], dim=-1)) return { 'logits': logits, 'embedding': emb, 'triangulation': tri, 'nearest': nearest, } def compute_loss(self, output, targets, output_aug=None): ld = {} emb = output['embedding'] B = emb.shape[0] # CE l_ce = F.cross_entropy(output['logits'], targets) ld['ce'] = l_ce ld['acc'] = (output['logits'].argmax(-1) == targets).float().mean().item() # InfoNCE if output_aug is not None: emb_aug = output_aug['embedding'] labels_nce = torch.arange(B, device=emb.device) sim = emb @ emb_aug.T / self.infonce_temp l_nce = F.cross_entropy(sim, labels_nce) nce_acc = (sim.argmax(1) == labels_nce).float().mean().item() ld['nce'] = l_nce ld['nce_acc'] = nce_acc # ── Anchor attraction: pull each embedding toward its nearest anchor ── anchors_n = F.normalize(self.constellation.anchors, dim=-1) cos_to_anchors = emb @ anchors_n.T # (B, n_anchors) nearest_cos = cos_to_anchors.max(dim=1).values # (B,) l_attract = (1.0 - nearest_cos).mean() # 0 when on top of anchor ld['attract'] = l_attract ld['nearest_cos'] = nearest_cos.mean().item() # CV l_cv = self._cv_loss(emb) ld['cv'] = l_cv # Anchor spread sim_a = anchors_n @ anchors_n.T mask = ~torch.eye(anchors_n.shape[0], dtype=torch.bool, device=anchors_n.device) l_spread = F.relu(sim_a[mask]).mean() ld['spread'] = l_spread # Total loss = (l_ce + ld.get('nce', 0.0) * 1.0 + l_attract * 0.5 + l_cv * 0.01 + l_spread * 0.001) ld['total'] = loss return loss, ld @torch.no_grad() def push_anchors_to_centroids(self, emb_buffer, label_buffer, lr=0.1): """ Push anchors toward CLASS centroids, not nearest-anchor centroids. Phase 1: Compute class centroids from labels Phase 2: Each class owns (n_anchors / n_classes) anchors Phase 3: Assigned anchors blend toward their class centroid with small angular offsets so they don't all collapse This works even when anchors start bunched at origin. """ anchors = self.constellation.anchors.data # (A, D) n_a = anchors.shape[0] emb_n = F.normalize(emb_buffer, dim=-1) device = anchors.device # Phase 1: class centroids classes = label_buffer.unique() n_cls = classes.shape[0] centroids = [] for c in classes: mask = label_buffer == c if mask.sum() > 0: centroids.append(F.normalize(emb_n[mask].mean(0, keepdim=True), dim=-1)) if len(centroids) == 0: return 0 centroids = torch.cat(centroids, dim=0) # (C, D) # Phase 2: assign anchors to classes round-robin # Sort anchors by cosine to each centroid, greedily assign anchors_n = F.normalize(anchors, dim=-1) cos = anchors_n @ centroids.T # (A, C) anchors_per_class = n_a // n_cls assigned_class = torch.full((n_a,), -1, dtype=torch.long, device=device) class_count = torch.zeros(n_cls, dtype=torch.long, device=device) # Greedy: for each anchor, assign to its best class if that class has room _, flat_idx = cos.flatten().sort(descending=True) for idx in flat_idx: a = (idx // n_cls).item() c = (idx % n_cls).item() if assigned_class[a] >= 0: continue if class_count[c] >= anchors_per_class + 1: # +1 for remainder continue assigned_class[a] = c class_count[c] += 1 if (assigned_class >= 0).all(): break # Unassigned leftovers → nearest centroid unassigned = (assigned_class < 0).nonzero(as_tuple=True)[0] if len(unassigned) > 0: leftover_cos = anchors_n[unassigned] @ centroids.T assigned_class[unassigned] = leftover_cos.argmax(dim=1) # Phase 3: push each anchor toward its class centroid moved = 0 for a in range(n_a): c = assigned_class[a].item() target = centroids[c] # Add small angular offset so co-class anchors don't collapse rank_in_class = (assigned_class[:a] == c).sum().item() if anchors_per_class > 1 and rank_in_class > 0: # Tiny perpendicular perturbation noise = torch.randn_like(target) * 0.05 noise = noise - (noise * target).sum() * target # project out radial target = F.normalize((target + noise).unsqueeze(0), dim=-1).squeeze(0) anchors[a] = F.normalize( (anchors_n[a] + lr * (target - anchors_n[a])).unsqueeze(0), dim=-1).squeeze(0) moved += 1 return moved def _cv_loss(self, emb, n_samples=64, n_points=5): B = emb.shape[0] if B < n_points: return torch.tensor(0.0, device=emb.device) vols = [] for _ in range(n_samples): idx = torch.randperm(min(B, 512), device=emb.device)[:n_points] pts = emb[idx].unsqueeze(0) gram = torch.bmm(pts, pts.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram d2 = F.relu(d2) N = n_points cm = torch.zeros(1, N+1, N+1, device=emb.device, dtype=emb.dtype) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 k = N - 1 pf = ((-1.0)**(k+1)) / ((2.0**k) * (math.factorial(k)**2)) v2 = pf * torch.linalg.det(cm.float()) if v2[0].item() > 1e-20: vols.append(v2[0].to(emb.dtype).sqrt()) if len(vols) < 5: return torch.tensor(0.0, device=emb.device) vt = torch.stack(vols) cv = vt.std() / (vt.mean() + 1e-8) return (cv - self.cv_target).pow(2) # ══════════════════════════════════════════════════════════════════ # DATA # ══════════════════════════════════════════════════════════════════ CIFAR_MEAN = (0.4914, 0.4822, 0.4465) CIFAR_STD = (0.2470, 0.2435, 0.2616) class TwoViewDataset(torch.utils.data.Dataset): def __init__(self, base_ds, transform): self.base = base_ds; self.transform = transform def __len__(self): return len(self.base) def __getitem__(self, i): img, label = self.base[i] return self.transform(img), self.transform(img), label # ══════════════════════════════════════════════════════════════════ # TRAINING # ══════════════════════════════════════════════════════════════════ # Config NUM_CLASSES = 10 OUTPUT_DIM = 128 N_ANCHORS = 64 N_COMP = 8 D_COMP = 64 BATCH = 256 EPOCHS = 100 LR = 3e-4 print("=" * 60) print("GeoLIP Core — Conv + Constellation + Patchwork") print(f" Encoder: 6-layer conv → {OUTPUT_DIM}-d sphere") print(f" Constellation: {N_ANCHORS} anchors, {N_COMP}×{D_COMP} patchwork") print(f" Loss: CE + InfoNCE + CV(0.22)") print(f" Batch: {BATCH}, LR: {LR}, Epochs: {EPOCHS}") print(f" Device: {DEVICE}") print("=" * 60) aug_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ColorJitter(0.2, 0.2, 0.2, 0.05), transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD), ]) val_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD), ]) raw_train = datasets.CIFAR10(root='./data', train=True, download=True) train_ds = TwoViewDataset(raw_train, aug_transform) val_ds = datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transform) train_loader = torch.utils.data.DataLoader( train_ds, batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader( val_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True) print(f" Train: {len(train_ds):,} Val: {len(val_ds):,}") # Build model = GeoLIPCore( num_classes=NUM_CLASSES, output_dim=OUTPUT_DIM, n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP, ).to(DEVICE) n_params = sum(p.numel() for p in model.parameters()) print(f" Parameters: {n_params:,}") optimizer = torch.optim.Adam(model.parameters(), lr=LR) total_steps = len(train_loader) * EPOCHS warmup_steps = len(train_loader) * 3 scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, [torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=0.01, total_iters=warmup_steps), torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=max(total_steps - warmup_steps, 1), eta_min=1e-6)], milestones=[warmup_steps]) scaler = torch.amp.GradScaler("cuda") os.makedirs("checkpoints", exist_ok=True) writer = SummaryWriter("runs/geolip_core") best_acc = 0.0 gs = 0 # Anchor push config PUSH_INTERVAL = 50 # batches between centroid pushes PUSH_LR = 0.1 # blend rate toward centroid PUSH_BUFFER_SIZE = 5000 emb_buffer = None # (N, D) accumulated embeddings lbl_buffer = None # (N,) accumulated labels push_count = 0 print(f"\n{'='*60}") print(f"TRAINING — {EPOCHS} epochs") print(f" Anchor push: every {PUSH_INTERVAL} batches, lr={PUSH_LR}") print(f"{'='*60}") for epoch in range(EPOCHS): model.train() t0 = time.time() tot_loss, tot_ce, tot_nce, tot_cv = 0, 0, 0, 0 tot_acc, tot_nce_acc, tot_nearest_cos, n = 0, 0, 0, 0 correct, total = 0, 0 pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b") for v1, v2, targets in pbar: v1 = v1.to(DEVICE, non_blocking=True) v2 = v2.to(DEVICE, non_blocking=True) targets = targets.to(DEVICE, non_blocking=True) with torch.amp.autocast("cuda", dtype=torch.bfloat16): out1 = model(v1) out2 = model(v2) loss, ld = model.compute_loss(out1, targets, output_aug=out2) optimizer.zero_grad(set_to_none=True) scaler.scale(loss).backward() scaler.unscale_(optimizer) nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer); scaler.update() scheduler.step() gs += 1 # ── Accumulate embeddings for anchor push ── with torch.no_grad(): batch_emb = out1['embedding'].detach().float() if emb_buffer is None: emb_buffer = batch_emb lbl_buffer = targets.detach() else: emb_buffer = torch.cat([emb_buffer, batch_emb])[-PUSH_BUFFER_SIZE:] lbl_buffer = torch.cat([lbl_buffer, targets.detach()])[-PUSH_BUFFER_SIZE:] # ── Periodic anchor push toward class centroids ── if gs % PUSH_INTERVAL == 0 and emb_buffer is not None and emb_buffer.shape[0] > 500: moved = model.push_anchors_to_centroids( emb_buffer, lbl_buffer, lr=PUSH_LR) push_count += 1 writer.add_scalar("step/anchors_moved", moved, gs) preds = out1['logits'].argmax(-1) correct += (preds == targets).sum().item() total += targets.shape[0] tot_loss += loss.item() tot_nce_acc += ld.get('nce_acc', 0) tot_nearest_cos += ld.get('nearest_cos', 0) n += 1 if n % 10 == 0: pbar.set_postfix( loss=f"{tot_loss/n:.4f}", acc=f"{100*correct/total:.0f}%", nce=f"{tot_nce_acc/n:.2f}", cos=f"{ld.get('nearest_cos', 0):.3f}", push=push_count, ordered=True) elapsed = time.time() - t0 train_acc = 100 * correct / total # Val model.eval() vc, vt_n, vl = 0, 0, 0 all_embs = [] with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16): for imgs, lbls in val_loader: imgs = imgs.to(DEVICE) lbls = lbls.to(DEVICE) out = model(imgs) vc += (out['logits'].argmax(-1) == lbls).sum().item() vt_n += lbls.shape[0] vl += F.cross_entropy(out['logits'], lbls).item() all_embs.append(out['embedding'].float().cpu()) val_acc = 100 * vc / vt_n # CV embs = torch.cat(all_embs)[:2000].to(DEVICE) with torch.no_grad(): vols = [] for _ in range(200): idx = torch.randperm(2000)[:5] pts = embs[idx].unsqueeze(0).float() gram = torch.bmm(pts, pts.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram d2 = F.relu(d2) cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 v2 = -torch.linalg.det(cm) / 9216 if v2[0].item() > 1e-20: vols.append(v2[0].sqrt()) v_cv = (torch.stack(vols).std() / (torch.stack(vols).mean() + 1e-8)).item() if len(vols) > 10 else 0 # Anchors with torch.no_grad(): _, vnp = model.constellation.triangulate(embs, training=False) n_active = vnp.cpu().unique().numel() writer.add_scalar("epoch/train_acc", train_acc, epoch+1) writer.add_scalar("epoch/val_acc", val_acc, epoch+1) writer.add_scalar("epoch/val_cv", v_cv, epoch+1) writer.add_scalar("epoch/anchors", n_active, epoch+1) writer.add_scalar("epoch/nearest_cos", tot_nearest_cos / n, epoch+1) writer.add_scalar("epoch/push_count", push_count, epoch+1) mk = "" if val_acc > best_acc: best_acc = val_acc torch.save({ "state_dict": model.state_dict(), "config": model.config, "epoch": epoch + 1, "val_acc": val_acc, }, "checkpoints/geolip_core_best.pt") mk = " ★" nce_m = tot_nce_acc / n cos_m = tot_nearest_cos / n cv_band = "✓" if 0.18 <= v_cv <= 0.25 else "✗" print(f" E{epoch+1:3d}: train={train_acc:.1f}% val={val_acc:.1f}% " f"loss={tot_loss/n:.4f} nce={nce_m:.2f} cos={cos_m:.3f} " f"cv={v_cv:.4f}({cv_band}) anch={n_active}/{N_ANCHORS} " f"push={push_count} ({elapsed:.0f}s){mk}") writer.close() print(f"\n Best val accuracy: {best_acc:.1f}%") print(f" Parameters: {n_params:,}") print(f"\n{'='*60}") print("DONE") print(f"{'='*60}")