| |
| """ |
| GeoLIP Core β Back to Basics |
| ============================== |
| Conv encoder β sphere β constellation β patchwork β classifier. |
| No streams. No GAL. No Procrustes. No mastery queue. |
| Just the geometric classification pipeline. |
| |
| Two augmented views β InfoNCE + CE + CV. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| import os, time |
| import numpy as np |
| from itertools import combinations |
| from tqdm import tqdm |
| from torchvision import datasets, transforms |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
| |
| |
| |
|
|
| def uniform_hypersphere_init(n, d): |
| if n <= d: |
| M = torch.randn(d, n) |
| Q, _ = torch.linalg.qr(M) |
| return Q.T.contiguous() |
| else: |
| M = torch.randn(d, d) |
| Q, _ = torch.linalg.qr(M) |
| basis = Q.T |
| extra = F.normalize(torch.randn(n - d, d), dim=-1) |
| vecs = torch.cat([basis, extra], dim=0) |
| for _ in range(200): |
| sim = vecs @ vecs.T |
| sim.fill_diagonal_(-2.0) |
| nn_idx = sim.argmax(dim=1) |
| vecs = F.normalize(vecs - 0.05 * vecs[nn_idx], dim=-1) |
| return vecs |
|
|
|
|
| |
| |
| |
|
|
| class Constellation(nn.Module): |
| def __init__(self, n_anchors, dim, anchor_drop=0.0): |
| super().__init__() |
| self.anchors = nn.Parameter(uniform_hypersphere_init(n_anchors, dim)) |
| self.anchor_drop = anchor_drop |
|
|
| def triangulate(self, emb, training=False): |
| anchors = F.normalize(self.anchors, dim=-1) |
| if training and self.anchor_drop > 0: |
| mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop |
| if mask.sum() < 2: mask[:2] = True |
| anchors = anchors[mask] |
| cos = emb @ anchors.T |
| tri = 1.0 - cos |
| _, nearest_local = cos.max(dim=-1) |
| nearest = mask.nonzero(as_tuple=True)[0][nearest_local] |
| else: |
| cos = emb @ anchors.T |
| tri = 1.0 - cos |
| _, nearest = cos.max(dim=-1) |
| return tri, nearest |
|
|
|
|
| class Patchwork(nn.Module): |
| def __init__(self, n_anchors, n_comp, d_comp): |
| super().__init__() |
| self.n_comp = n_comp |
| self.register_buffer('asgn', torch.arange(n_anchors) % n_comp) |
| anchors_per = n_anchors // n_comp |
| self.comps = nn.ModuleList([nn.Sequential( |
| nn.Linear(anchors_per, d_comp * 2), nn.GELU(), |
| nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp)) |
| for _ in range(n_comp)]) |
|
|
| def forward(self, tri): |
| return torch.cat([self.comps[k](tri[:, self.asgn == k]) |
| for k in range(self.n_comp)], -1) |
|
|
|
|
| |
| |
| |
|
|
| class ConvEncoder(nn.Module): |
| """ |
| Simple conv backbone. No attention, no geometric layers. |
| Just feature extraction into a flat vector. |
| """ |
| def __init__(self, output_dim=128): |
| super().__init__() |
| self.features = nn.Sequential( |
| |
| nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(), |
| nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(), |
| nn.MaxPool2d(2), |
|
|
| |
| nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(), |
| nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(), |
| nn.MaxPool2d(2), |
|
|
| |
| nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(), |
| nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(), |
| nn.MaxPool2d(2), |
|
|
| |
| nn.AdaptiveAvgPool2d(1), |
| nn.Flatten(), |
| ) |
| self.proj = nn.Sequential( |
| nn.Linear(256, output_dim), |
| nn.LayerNorm(output_dim), |
| ) |
|
|
| def forward(self, x): |
| return self.proj(self.features(x)) |
|
|
|
|
| |
| |
| |
|
|
| class GeoLIPCore(nn.Module): |
| def __init__( |
| self, |
| num_classes=10, |
| output_dim=128, |
| n_anchors=64, |
| n_comp=8, |
| d_comp=64, |
| anchor_drop=0.15, |
| cv_target=0.22, |
| infonce_temp=0.07, |
| ): |
| super().__init__() |
| self.num_classes = num_classes |
| self.output_dim = output_dim |
| self.cv_target = cv_target |
| self.infonce_temp = infonce_temp |
|
|
| self.config = {k: v for k, v in locals().items() |
| if k != 'self' and not k.startswith('_')} |
|
|
| self.encoder = ConvEncoder(output_dim) |
| self.constellation = Constellation(n_anchors, output_dim, anchor_drop) |
| self.patchwork = Patchwork(n_anchors, n_comp, d_comp) |
| pw_dim = n_comp * d_comp |
|
|
| self.classifier = nn.Sequential( |
| nn.Linear(pw_dim + output_dim, pw_dim), nn.GELU(), |
| nn.LayerNorm(pw_dim), nn.Dropout(0.1), |
| nn.Linear(pw_dim, num_classes)) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out') |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)): |
| nn.init.ones_(m.weight) |
| nn.init.zeros_(m.bias) |
|
|
| def forward(self, x): |
| feat = self.encoder(x) |
| emb = F.normalize(feat, dim=-1) |
|
|
| |
| tri, nearest = self.constellation.triangulate(emb, training=False) |
| pw = self.patchwork(tri) |
|
|
| |
| if self.training: |
| _, nearest = self.constellation.triangulate(emb, training=True) |
|
|
| logits = self.classifier(torch.cat([pw, emb], dim=-1)) |
|
|
| return { |
| 'logits': logits, |
| 'embedding': emb, |
| 'triangulation': tri, |
| 'nearest': nearest, |
| } |
|
|
| def compute_loss(self, output, targets, output_aug=None): |
| ld = {} |
| emb = output['embedding'] |
| B = emb.shape[0] |
|
|
| |
| l_ce = F.cross_entropy(output['logits'], targets) |
| ld['ce'] = l_ce |
| ld['acc'] = (output['logits'].argmax(-1) == targets).float().mean().item() |
|
|
| |
| if output_aug is not None: |
| emb_aug = output_aug['embedding'] |
| labels_nce = torch.arange(B, device=emb.device) |
| sim = emb @ emb_aug.T / self.infonce_temp |
| l_nce = F.cross_entropy(sim, labels_nce) |
| nce_acc = (sim.argmax(1) == labels_nce).float().mean().item() |
| ld['nce'] = l_nce |
| ld['nce_acc'] = nce_acc |
|
|
| |
| anchors_n = F.normalize(self.constellation.anchors, dim=-1) |
| cos_to_anchors = emb @ anchors_n.T |
| nearest_cos = cos_to_anchors.max(dim=1).values |
| l_attract = (1.0 - nearest_cos).mean() |
| ld['attract'] = l_attract |
| ld['nearest_cos'] = nearest_cos.mean().item() |
|
|
| |
| l_cv = self._cv_loss(emb) |
| ld['cv'] = l_cv |
|
|
| |
| sim_a = anchors_n @ anchors_n.T |
| mask = ~torch.eye(anchors_n.shape[0], dtype=torch.bool, device=anchors_n.device) |
| l_spread = F.relu(sim_a[mask]).mean() |
| ld['spread'] = l_spread |
|
|
| |
| loss = (l_ce |
| + ld.get('nce', 0.0) * 1.0 |
| + l_attract * 0.5 |
| + l_cv * 0.01 |
| + l_spread * 0.001) |
| ld['total'] = loss |
| return loss, ld |
|
|
| @torch.no_grad() |
| def push_anchors_to_centroids(self, emb_buffer, label_buffer, lr=0.1): |
| """ |
| Push anchors toward CLASS centroids, not nearest-anchor centroids. |
| |
| Phase 1: Compute class centroids from labels |
| Phase 2: Each class owns (n_anchors / n_classes) anchors |
| Phase 3: Assigned anchors blend toward their class centroid |
| with small angular offsets so they don't all collapse |
| |
| This works even when anchors start bunched at origin. |
| """ |
| anchors = self.constellation.anchors.data |
| n_a = anchors.shape[0] |
| emb_n = F.normalize(emb_buffer, dim=-1) |
| device = anchors.device |
|
|
| |
| classes = label_buffer.unique() |
| n_cls = classes.shape[0] |
| centroids = [] |
| for c in classes: |
| mask = label_buffer == c |
| if mask.sum() > 0: |
| centroids.append(F.normalize(emb_n[mask].mean(0, keepdim=True), dim=-1)) |
| if len(centroids) == 0: |
| return 0 |
| centroids = torch.cat(centroids, dim=0) |
|
|
| |
| |
| anchors_n = F.normalize(anchors, dim=-1) |
| cos = anchors_n @ centroids.T |
| anchors_per_class = n_a // n_cls |
| assigned_class = torch.full((n_a,), -1, dtype=torch.long, device=device) |
| class_count = torch.zeros(n_cls, dtype=torch.long, device=device) |
|
|
| |
| _, flat_idx = cos.flatten().sort(descending=True) |
| for idx in flat_idx: |
| a = (idx // n_cls).item() |
| c = (idx % n_cls).item() |
| if assigned_class[a] >= 0: |
| continue |
| if class_count[c] >= anchors_per_class + 1: |
| continue |
| assigned_class[a] = c |
| class_count[c] += 1 |
| if (assigned_class >= 0).all(): |
| break |
|
|
| |
| unassigned = (assigned_class < 0).nonzero(as_tuple=True)[0] |
| if len(unassigned) > 0: |
| leftover_cos = anchors_n[unassigned] @ centroids.T |
| assigned_class[unassigned] = leftover_cos.argmax(dim=1) |
|
|
| |
| moved = 0 |
| for a in range(n_a): |
| c = assigned_class[a].item() |
| target = centroids[c] |
| |
| rank_in_class = (assigned_class[:a] == c).sum().item() |
| if anchors_per_class > 1 and rank_in_class > 0: |
| |
| noise = torch.randn_like(target) * 0.05 |
| noise = noise - (noise * target).sum() * target |
| target = F.normalize((target + noise).unsqueeze(0), dim=-1).squeeze(0) |
|
|
| anchors[a] = F.normalize( |
| (anchors_n[a] + lr * (target - anchors_n[a])).unsqueeze(0), |
| dim=-1).squeeze(0) |
| moved += 1 |
|
|
| return moved |
|
|
| def _cv_loss(self, emb, n_samples=64, n_points=5): |
| B = emb.shape[0] |
| if B < n_points: return torch.tensor(0.0, device=emb.device) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(min(B, 512), device=emb.device)[:n_points] |
| pts = emb[idx].unsqueeze(0) |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram |
| d2 = F.relu(d2) |
| N = n_points |
| cm = torch.zeros(1, N+1, N+1, device=emb.device, dtype=emb.dtype) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| k = N - 1 |
| pf = ((-1.0)**(k+1)) / ((2.0**k) * (math.factorial(k)**2)) |
| v2 = pf * torch.linalg.det(cm.float()) |
| if v2[0].item() > 1e-20: |
| vols.append(v2[0].to(emb.dtype).sqrt()) |
| if len(vols) < 5: |
| return torch.tensor(0.0, device=emb.device) |
| vt = torch.stack(vols) |
| cv = vt.std() / (vt.mean() + 1e-8) |
| return (cv - self.cv_target).pow(2) |
|
|
|
|
| |
| |
| |
|
|
| CIFAR_MEAN = (0.4914, 0.4822, 0.4465) |
| CIFAR_STD = (0.2470, 0.2435, 0.2616) |
|
|
| class TwoViewDataset(torch.utils.data.Dataset): |
| def __init__(self, base_ds, transform): |
| self.base = base_ds; self.transform = transform |
| def __len__(self): return len(self.base) |
| def __getitem__(self, i): |
| img, label = self.base[i] |
| return self.transform(img), self.transform(img), label |
|
|
|
|
| |
| |
| |
|
|
| |
| NUM_CLASSES = 10 |
| OUTPUT_DIM = 128 |
| N_ANCHORS = 64 |
| N_COMP = 8 |
| D_COMP = 64 |
| BATCH = 256 |
| EPOCHS = 100 |
| LR = 3e-4 |
|
|
| print("=" * 60) |
| print("GeoLIP Core β Conv + Constellation + Patchwork") |
| print(f" Encoder: 6-layer conv β {OUTPUT_DIM}-d sphere") |
| print(f" Constellation: {N_ANCHORS} anchors, {N_COMP}Γ{D_COMP} patchwork") |
| print(f" Loss: CE + InfoNCE + CV(0.22)") |
| print(f" Batch: {BATCH}, LR: {LR}, Epochs: {EPOCHS}") |
| print(f" Device: {DEVICE}") |
| print("=" * 60) |
|
|
| aug_transform = transforms.Compose([ |
| transforms.RandomCrop(32, padding=4), |
| transforms.RandomHorizontalFlip(), |
| transforms.ColorJitter(0.2, 0.2, 0.2, 0.05), |
| transforms.ToTensor(), |
| transforms.Normalize(CIFAR_MEAN, CIFAR_STD), |
| ]) |
| val_transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(CIFAR_MEAN, CIFAR_STD), |
| ]) |
|
|
| raw_train = datasets.CIFAR10(root='./data', train=True, download=True) |
| train_ds = TwoViewDataset(raw_train, aug_transform) |
| val_ds = datasets.CIFAR10(root='./data', train=False, |
| download=True, transform=val_transform) |
|
|
| train_loader = torch.utils.data.DataLoader( |
| train_ds, batch_size=BATCH, shuffle=True, |
| num_workers=2, pin_memory=True, drop_last=True) |
| val_loader = torch.utils.data.DataLoader( |
| val_ds, batch_size=BATCH, shuffle=False, |
| num_workers=2, pin_memory=True) |
|
|
| print(f" Train: {len(train_ds):,} Val: {len(val_ds):,}") |
|
|
| |
| model = GeoLIPCore( |
| num_classes=NUM_CLASSES, output_dim=OUTPUT_DIM, |
| n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP, |
| ).to(DEVICE) |
|
|
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f" Parameters: {n_params:,}") |
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=LR) |
| total_steps = len(train_loader) * EPOCHS |
| warmup_steps = len(train_loader) * 3 |
| scheduler = torch.optim.lr_scheduler.SequentialLR( |
| optimizer, |
| [torch.optim.lr_scheduler.LinearLR( |
| optimizer, start_factor=0.01, total_iters=warmup_steps), |
| torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=max(total_steps - warmup_steps, 1), eta_min=1e-6)], |
| milestones=[warmup_steps]) |
|
|
| scaler = torch.amp.GradScaler("cuda") |
| os.makedirs("checkpoints", exist_ok=True) |
| writer = SummaryWriter("runs/geolip_core") |
| best_acc = 0.0 |
| gs = 0 |
|
|
| |
| PUSH_INTERVAL = 50 |
| PUSH_LR = 0.1 |
| PUSH_BUFFER_SIZE = 5000 |
| emb_buffer = None |
| lbl_buffer = None |
| push_count = 0 |
|
|
| print(f"\n{'='*60}") |
| print(f"TRAINING β {EPOCHS} epochs") |
| print(f" Anchor push: every {PUSH_INTERVAL} batches, lr={PUSH_LR}") |
| print(f"{'='*60}") |
|
|
| for epoch in range(EPOCHS): |
| model.train() |
| t0 = time.time() |
| tot_loss, tot_ce, tot_nce, tot_cv = 0, 0, 0, 0 |
| tot_acc, tot_nce_acc, tot_nearest_cos, n = 0, 0, 0, 0 |
| correct, total = 0, 0 |
|
|
| pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b") |
| for v1, v2, targets in pbar: |
| v1 = v1.to(DEVICE, non_blocking=True) |
| v2 = v2.to(DEVICE, non_blocking=True) |
| targets = targets.to(DEVICE, non_blocking=True) |
|
|
| with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| out1 = model(v1) |
| out2 = model(v2) |
| loss, ld = model.compute_loss(out1, targets, output_aug=out2) |
|
|
| optimizer.zero_grad(set_to_none=True) |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer); scaler.update() |
| scheduler.step() |
| gs += 1 |
|
|
| |
| with torch.no_grad(): |
| batch_emb = out1['embedding'].detach().float() |
| if emb_buffer is None: |
| emb_buffer = batch_emb |
| lbl_buffer = targets.detach() |
| else: |
| emb_buffer = torch.cat([emb_buffer, batch_emb])[-PUSH_BUFFER_SIZE:] |
| lbl_buffer = torch.cat([lbl_buffer, targets.detach()])[-PUSH_BUFFER_SIZE:] |
|
|
| |
| if gs % PUSH_INTERVAL == 0 and emb_buffer is not None and emb_buffer.shape[0] > 500: |
| moved = model.push_anchors_to_centroids( |
| emb_buffer, lbl_buffer, lr=PUSH_LR) |
| push_count += 1 |
| writer.add_scalar("step/anchors_moved", moved, gs) |
|
|
| preds = out1['logits'].argmax(-1) |
| correct += (preds == targets).sum().item() |
| total += targets.shape[0] |
| tot_loss += loss.item() |
| tot_nce_acc += ld.get('nce_acc', 0) |
| tot_nearest_cos += ld.get('nearest_cos', 0) |
| n += 1 |
|
|
| if n % 10 == 0: |
| pbar.set_postfix( |
| loss=f"{tot_loss/n:.4f}", |
| acc=f"{100*correct/total:.0f}%", |
| nce=f"{tot_nce_acc/n:.2f}", |
| cos=f"{ld.get('nearest_cos', 0):.3f}", |
| push=push_count, |
| ordered=True) |
|
|
| elapsed = time.time() - t0 |
| train_acc = 100 * correct / total |
|
|
| |
| model.eval() |
| vc, vt_n, vl = 0, 0, 0 |
| all_embs = [] |
| with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| for imgs, lbls in val_loader: |
| imgs = imgs.to(DEVICE) |
| lbls = lbls.to(DEVICE) |
| out = model(imgs) |
| vc += (out['logits'].argmax(-1) == lbls).sum().item() |
| vt_n += lbls.shape[0] |
| vl += F.cross_entropy(out['logits'], lbls).item() |
| all_embs.append(out['embedding'].float().cpu()) |
|
|
| val_acc = 100 * vc / vt_n |
|
|
| |
| embs = torch.cat(all_embs)[:2000].to(DEVICE) |
| with torch.no_grad(): |
| vols = [] |
| for _ in range(200): |
| idx = torch.randperm(2000)[:5] |
| pts = embs[idx].unsqueeze(0).float() |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram |
| d2 = F.relu(d2) |
| cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| v2 = -torch.linalg.det(cm) / 9216 |
| if v2[0].item() > 1e-20: |
| vols.append(v2[0].sqrt()) |
| v_cv = (torch.stack(vols).std() / (torch.stack(vols).mean() + 1e-8)).item() if len(vols) > 10 else 0 |
|
|
| |
| with torch.no_grad(): |
| _, vnp = model.constellation.triangulate(embs, training=False) |
| n_active = vnp.cpu().unique().numel() |
|
|
| writer.add_scalar("epoch/train_acc", train_acc, epoch+1) |
| writer.add_scalar("epoch/val_acc", val_acc, epoch+1) |
| writer.add_scalar("epoch/val_cv", v_cv, epoch+1) |
| writer.add_scalar("epoch/anchors", n_active, epoch+1) |
| writer.add_scalar("epoch/nearest_cos", tot_nearest_cos / n, epoch+1) |
| writer.add_scalar("epoch/push_count", push_count, epoch+1) |
|
|
| mk = "" |
| if val_acc > best_acc: |
| best_acc = val_acc |
| torch.save({ |
| "state_dict": model.state_dict(), |
| "config": model.config, |
| "epoch": epoch + 1, |
| "val_acc": val_acc, |
| }, "checkpoints/geolip_core_best.pt") |
| mk = " β
" |
|
|
| nce_m = tot_nce_acc / n |
| cos_m = tot_nearest_cos / n |
| cv_band = "β" if 0.18 <= v_cv <= 0.25 else "β" |
| print(f" E{epoch+1:3d}: train={train_acc:.1f}% val={val_acc:.1f}% " |
| f"loss={tot_loss/n:.4f} nce={nce_m:.2f} cos={cos_m:.3f} " |
| f"cv={v_cv:.4f}({cv_band}) anch={n_active}/{N_ANCHORS} " |
| f"push={push_count} ({elapsed:.0f}s){mk}") |
|
|
| writer.close() |
| print(f"\n Best val accuracy: {best_acc:.1f}%") |
| print(f" Parameters: {n_params:,}") |
| print(f"\n{'='*60}") |
| print("DONE") |
| print(f"{'='*60}") |