| |
| """ |
| GeoLIP Core β Back to Basics |
| ============================== |
| Conv encoder β sphere β constellation β patchwork β classifier. |
| No streams. No GAL. No Procrustes. No mastery queue. |
| Just the geometric classification pipeline. |
| |
| Two augmented views β InfoNCE + CE + CV. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| import os, time |
| import numpy as np |
| from itertools import combinations |
| from tqdm import tqdm |
| from torchvision import datasets, transforms |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
| |
| |
| |
|
|
| def uniform_hypersphere_init(n, d): |
| if n <= d: |
| M = torch.randn(d, n) |
| Q, _ = torch.linalg.qr(M) |
| return Q.T.contiguous() |
| else: |
| M = torch.randn(d, d) |
| Q, _ = torch.linalg.qr(M) |
| basis = Q.T |
| extra = F.normalize(torch.randn(n - d, d), dim=-1) |
| vecs = torch.cat([basis, extra], dim=0) |
| for _ in range(200): |
| sim = vecs @ vecs.T |
| sim.fill_diagonal_(-2.0) |
| nn_idx = sim.argmax(dim=1) |
| vecs = F.normalize(vecs - 0.05 * vecs[nn_idx], dim=-1) |
| return vecs |
|
|
|
|
| |
| |
| |
|
|
| class Constellation(nn.Module): |
| def __init__(self, n_anchors, dim, anchor_drop=0.0): |
| super().__init__() |
| self.anchors = nn.Parameter(uniform_hypersphere_init(n_anchors, dim)) |
| self.anchor_drop = anchor_drop |
|
|
| def triangulate(self, emb, training=False): |
| anchors = F.normalize(self.anchors, dim=-1) |
| if training and self.anchor_drop > 0: |
| mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop |
| if mask.sum() < 2: mask[:2] = True |
| anchors = anchors[mask] |
| cos = emb @ anchors.T |
| tri = 1.0 - cos |
| _, nearest_local = cos.max(dim=-1) |
| nearest = mask.nonzero(as_tuple=True)[0][nearest_local] |
| else: |
| cos = emb @ anchors.T |
| tri = 1.0 - cos |
| _, nearest = cos.max(dim=-1) |
| return tri, nearest |
|
|
|
|
| class Patchwork(nn.Module): |
| def __init__(self, n_anchors, n_comp, d_comp): |
| super().__init__() |
| self.n_comp = n_comp |
| self.register_buffer('asgn', torch.arange(n_anchors) % n_comp) |
| anchors_per = n_anchors // n_comp |
| self.comps = nn.ModuleList([nn.Sequential( |
| nn.Linear(anchors_per, d_comp * 2), nn.GELU(), |
| nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp)) |
| for _ in range(n_comp)]) |
|
|
| def forward(self, tri): |
| return torch.cat([self.comps[k](tri[:, self.asgn == k]) |
| for k in range(self.n_comp)], -1) |
|
|
|
|
| |
| |
| |
|
|
| class ConvEncoder(nn.Module): |
| """ |
| Simple conv backbone. No attention, no geometric layers. |
| Just feature extraction into a flat vector. |
| """ |
| def __init__(self, output_dim=128): |
| super().__init__() |
| self.features = nn.Sequential( |
| |
| nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(), |
| nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(), |
| nn.MaxPool2d(2), |
|
|
| |
| nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(), |
| nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(), |
| nn.MaxPool2d(2), |
|
|
| |
| nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(), |
| nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(), |
| nn.MaxPool2d(2), |
|
|
| |
| nn.AdaptiveAvgPool2d(1), |
| nn.Flatten(), |
| ) |
| self.proj = nn.Sequential( |
| nn.Linear(256, output_dim), |
| nn.LayerNorm(output_dim), |
| ) |
|
|
| def forward(self, x): |
| return self.proj(self.features(x)) |
|
|
|
|
| |
| |
| |
|
|
| class GeoLIPCore(nn.Module): |
| def __init__( |
| self, |
| num_classes=10, |
| output_dim=128, |
| n_anchors=64, |
| n_comp=8, |
| d_comp=64, |
| anchor_drop=0.15, |
| cv_target=0.22, |
| infonce_temp=0.07, |
| ): |
| super().__init__() |
| self.num_classes = num_classes |
| self.output_dim = output_dim |
| self.cv_target = cv_target |
| self.infonce_temp = infonce_temp |
|
|
| self.config = {k: v for k, v in locals().items() |
| if k != 'self' and not k.startswith('_')} |
|
|
| self.encoder = ConvEncoder(output_dim) |
| self.constellation = Constellation(n_anchors, output_dim, anchor_drop) |
| self.patchwork = Patchwork(n_anchors, n_comp, d_comp) |
| pw_dim = n_comp * d_comp |
|
|
| self.classifier = nn.Sequential( |
| nn.Linear(pw_dim + output_dim, pw_dim), nn.GELU(), |
| nn.LayerNorm(pw_dim), nn.Dropout(0.1), |
| nn.Linear(pw_dim, num_classes)) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out') |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)): |
| nn.init.ones_(m.weight) |
| nn.init.zeros_(m.bias) |
|
|
| def forward(self, x): |
| feat = self.encoder(x) |
| emb = F.normalize(feat, dim=-1) |
|
|
| |
| tri, nearest = self.constellation.triangulate(emb, training=False) |
| pw = self.patchwork(tri) |
|
|
| |
| if self.training: |
| _, nearest = self.constellation.triangulate(emb, training=True) |
|
|
| logits = self.classifier(torch.cat([pw, emb], dim=-1)) |
|
|
| return { |
| 'logits': logits, |
| 'embedding': emb, |
| 'triangulation': tri, |
| 'nearest': nearest, |
| } |
|
|
| def compute_loss(self, output, targets, output_aug=None): |
| ld = {} |
| emb = output['embedding'] |
| B = emb.shape[0] |
|
|
| |
| l_ce = F.cross_entropy(output['logits'], targets) |
| ld['ce'] = l_ce |
| ld['acc'] = (output['logits'].argmax(-1) == targets).float().mean().item() |
|
|
| |
| if output_aug is not None: |
| emb_aug = output_aug['embedding'] |
| labels_nce = torch.arange(B, device=emb.device) |
| sim = emb @ emb_aug.T / self.infonce_temp |
| l_nce = F.cross_entropy(sim, labels_nce) |
| nce_acc = (sim.argmax(1) == labels_nce).float().mean().item() |
| ld['nce'] = l_nce |
| ld['nce_acc'] = nce_acc |
|
|
| |
| anchors_n = F.normalize(self.constellation.anchors, dim=-1) |
| cos_to_anchors = emb @ anchors_n.T |
| nearest_cos = cos_to_anchors.max(dim=1).values |
| l_attract = (1.0 - nearest_cos).mean() |
| ld['attract'] = l_attract |
| ld['nearest_cos'] = nearest_cos.mean().item() |
|
|
| |
| l_cv = self._cv_loss(emb) |
| ld['cv'] = l_cv |
|
|
| |
| sim_a = anchors_n @ anchors_n.T |
| mask = ~torch.eye(anchors_n.shape[0], dtype=torch.bool, device=anchors_n.device) |
| l_spread = F.relu(sim_a[mask]).mean() |
| ld['spread'] = l_spread |
|
|
| |
| loss = (l_ce |
| + ld.get('nce', 0.0) * 1.0 |
| + l_attract * 0.5 |
| + l_cv * 0.01 |
| + l_spread * 0.001) |
| ld['total'] = loss |
| return loss, ld |
|
|
| @torch.no_grad() |
| def push_anchors_to_centroids(self, emb_buffer, label_buffer, lr=0.1): |
| """ |
| Push anchors toward CLASS centroids, not nearest-anchor centroids. |
| |
| Phase 1: Compute class centroids from labels |
| Phase 2: Each class owns (n_anchors / n_classes) anchors |
| Phase 3: Assigned anchors blend toward their class centroid |
| with small angular offsets so they don't all collapse |
| |
| This works even when anchors start bunched at origin. |
| """ |
| anchors = self.constellation.anchors.data |
| n_a = anchors.shape[0] |
| emb_n = F.normalize(emb_buffer, dim=-1) |
| device = anchors.device |
|
|
| |
| classes = label_buffer.unique() |
| n_cls = classes.shape[0] |
| centroids = [] |
| for c in classes: |
| mask = label_buffer == c |
| if mask.sum() > 0: |
| centroids.append(F.normalize(emb_n[mask].mean(0, keepdim=True), dim=-1)) |
| if len(centroids) == 0: |
| return 0 |
| centroids = torch.cat(centroids, dim=0) |
|
|
| |
| |
| anchors_n = F.normalize(anchors, dim=-1) |
| cos = anchors_n @ centroids.T |
| anchors_per_class = n_a // n_cls |
| assigned_class = torch.full((n_a,), -1, dtype=torch.long, device=device) |
| class_count = torch.zeros(n_cls, dtype=torch.long, device=device) |
|
|
| |
| _, flat_idx = cos.flatten().sort(descending=True) |
| for idx in flat_idx: |
| a = (idx // n_cls).item() |
| c = (idx % n_cls).item() |
| if assigned_class[a] >= 0: |
| continue |
| if class_count[c] >= anchors_per_class + 1: |
| continue |
| assigned_class[a] = c |
| class_count[c] += 1 |
| if (assigned_class >= 0).all(): |
| break |
|
|
| |
| unassigned = (assigned_class < 0).nonzero(as_tuple=True)[0] |
| if len(unassigned) > 0: |
| leftover_cos = anchors_n[unassigned] @ centroids.T |
| assigned_class[unassigned] = leftover_cos.argmax(dim=1) |
|
|
| |
| moved = 0 |
| for a in range(n_a): |
| c = assigned_class[a].item() |
| target = centroids[c] |
| |
| rank_in_class = (assigned_class[:a] == c).sum().item() |
| if anchors_per_class > 1 and rank_in_class > 0: |
| |
| noise = torch.randn_like(target) * 0.05 |
| noise = noise - (noise * target).sum() * target |
| target = F.normalize((target + noise).unsqueeze(0), dim=-1).squeeze(0) |
|
|
| anchors[a] = F.normalize( |
| (anchors_n[a] + lr * (target - anchors_n[a])).unsqueeze(0), |
| dim=-1).squeeze(0) |
| moved += 1 |
|
|
| return moved |
|
|
| def _cv_loss(self, emb, n_samples=64, n_points=5): |
| B = emb.shape[0] |
| if B < n_points: return torch.tensor(0.0, device=emb.device) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(min(B, 512), device=emb.device)[:n_points] |
| pts = emb[idx].unsqueeze(0) |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram |
| d2 = F.relu(d2) |
| N = n_points |
| cm = torch.zeros(1, N+1, N+1, device=emb.device, dtype=emb.dtype) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| k = N - 1 |
| pf = ((-1.0)**(k+1)) / ((2.0**k) * (math.factorial(k)**2)) |
| v2 = pf * torch.linalg.det(cm.float()) |
| if v2[0].item() > 1e-20: |
| vols.append(v2[0].to(emb.dtype).sqrt()) |
| if len(vols) < 5: |
| return torch.tensor(0.0, device=emb.device) |
| vt = torch.stack(vols) |
| cv = vt.std() / (vt.mean() + 1e-8) |
| return (cv - self.cv_target).pow(2) |
|
|