| |
| """ |
| 34-EXPERT PATCHWORK MODEL |
| ========================== |
| Pre-extracted features from 34 vision models β learned projectors β |
| cross-expert fusion β constellation triangulation β patchwork β COCO multi-label. |
| |
| Architecture: |
| Per-expert: Linear(d_expert β d_shared) + LayerNorm |
| Fusion: Cross-attention over 34 expert tokens β fused embedding |
| Geometry: Constellation(n_anchors) β triangulation β Patchwork β MLP |
| Output: 80-class multi-label (BCE) |
| |
| Training: Adam + geometric autograd (tang=0.01, sep=1.0, cv=0.001) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import math |
| from datasets import load_dataset |
| import gc |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| D_SHARED = 1024 |
| N_ANCHORS = 256 |
| N_CLASSES = 80 |
| N_COMP = 8 |
| D_COMP = 128 |
|
|
| print("=" * 65) |
| print("34-EXPERT PATCHWORK MODEL") |
| print("=" * 65) |
| print(f" Device: {DEVICE}") |
| print(f" Shared dim: {D_SHARED}, Anchors: {N_ANCHORS}, Classes: {N_CLASSES}") |
|
|
|
|
| |
| |
| |
|
|
| def tangential_projection(grad, embedding): |
| emb_n = F.normalize(embedding.detach().float(), dim=-1) |
| grad_f = grad.float() |
| radial = (grad_f * emb_n).sum(dim=-1, keepdim=True) * emb_n |
| return (grad_f - radial).to(grad.dtype), radial.to(grad.dtype) |
|
|
| def cayley_menger_vol2(pts): |
| pts = pts.float() |
| diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) |
| d2 = (diff * diff).sum(-1) |
| B, V, _ = d2.shape |
| cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| s = (-1.0)**V; f = math.factorial(V-1) |
| return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) |
|
|
| def cv_loss(emb, target=0.2, n_samples=16): |
| B = emb.shape[0] |
| if B < 5: return torch.tensor(0.0, device=emb.device) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(B, device=emb.device)[:5] |
| v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) |
| stacked = torch.stack(vols) |
| return (stacked.std() / (stacked.mean() + 1e-8) - target).abs() |
|
|
| @torch.no_grad() |
| def cv_metric(emb, n_samples=200): |
| B = emb.shape[0] |
| if B < 5: return 0.0 |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(B)[:5] |
| v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() |
| if v > 0: vols.append(v) |
| if len(vols) < 10: return 0.0 |
| a = torch.tensor(vols) |
| return float(a.std() / (a.mean() + 1e-8)) |
|
|
| def anchor_spread_loss(anchors): |
| a = F.normalize(anchors, dim=-1) |
| sim = a @ a.T - torch.diag(torch.ones(anchors.shape[0], device=anchors.device)) |
| return sim.pow(2).mean() |
|
|
| def anchor_entropy_loss(emb, anchors, sharpness=10.0): |
| a = F.normalize(anchors, dim=-1) |
| probs = F.softmax(emb @ a.T * sharpness, dim=-1) |
| return -(probs * (probs + 1e-12).log()).sum(-1).mean() |
|
|
| class EmbeddingAutograd(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, embedding, anchors, tang, sep): |
| ctx.save_for_backward(embedding, anchors) |
| ctx.tang = tang; ctx.sep = sep |
| return x |
| @staticmethod |
| def backward(ctx, grad_output): |
| embedding, anchors = ctx.saved_tensors |
| emb_n = F.normalize(embedding.detach().float(), dim=-1) |
| anchors_n = F.normalize(anchors.detach().float(), dim=-1) |
| grad_f = grad_output.float() |
| tang_grad, norm_grad = tangential_projection(grad_f, emb_n) |
| corrected = tang_grad + (1.0 - ctx.tang) * norm_grad |
| if ctx.sep > 0: |
| cos_to = emb_n @ anchors_n.T |
| nearest = anchors_n[cos_to.argmax(dim=-1)] |
| toward = (corrected * nearest).sum(dim=-1, keepdim=True) |
| collapse = toward * nearest |
| corrected = corrected - ctx.sep * (toward > 0).float() * collapse |
| return corrected.to(grad_output.dtype), None, None, None, None |
|
|
|
|
| |
| |
| |
|
|
| class ExpertProjector(nn.Module): |
| """d_expert β d_shared with bottleneck.""" |
| def __init__(self, d_in, d_out=D_SHARED): |
| super().__init__() |
| d_mid = min(d_in, d_out) |
| self.net = nn.Sequential( |
| nn.Linear(d_in, d_mid), |
| nn.GELU(), |
| nn.Linear(d_mid, d_out), |
| nn.LayerNorm(d_out), |
| ) |
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| class ExpertFusion(nn.Module): |
| """ |
| Cross-attention fusion of N expert projections β single embedding. |
| Uses a learned query token that attends to all expert outputs. |
| """ |
| def __init__(self, d_model=D_SHARED, n_heads=8, n_layers=2): |
| super().__init__() |
| self.query = nn.Parameter(torch.randn(1, 1, d_model) * 0.02) |
| self.layers = nn.ModuleList([ |
| nn.TransformerDecoderLayer( |
| d_model=d_model, nhead=n_heads, |
| dim_feedforward=d_model * 2, |
| dropout=0.1, batch_first=True, |
| norm_first=True, |
| ) for _ in range(n_layers) |
| ]) |
| self.norm = nn.LayerNorm(d_model) |
|
|
| def forward(self, expert_tokens): |
| """ |
| expert_tokens: (B, N_experts, d_model) |
| returns: (B, d_model) |
| """ |
| B = expert_tokens.shape[0] |
| q = self.query.expand(B, -1, -1) |
| for layer in self.layers: |
| q = layer(q, expert_tokens) |
| return self.norm(q.squeeze(1)) |
|
|
|
|
| class Constellation(nn.Module): |
| def __init__(self, n_anchors=N_ANCHORS, d_embed=D_SHARED, init_anchors=None): |
| super().__init__() |
| self.n_anchors = n_anchors |
| if init_anchors is not None: |
| self.anchors = nn.Parameter(init_anchors.clone()) |
| else: |
| self.anchors = nn.Parameter(F.normalize( |
| torch.randn(n_anchors, d_embed), dim=-1)) |
| self.register_buffer("rigidity", torch.zeros(n_anchors)) |
| self.register_buffer("visit_count", torch.zeros(n_anchors)) |
|
|
| def triangulate(self, emb): |
| a = F.normalize(self.anchors, dim=-1) |
| cos = emb @ a.T |
| return 1.0 - cos, cos.argmax(dim=-1) |
|
|
| @torch.no_grad() |
| def update_rigidity(self, tri): |
| nearest = tri.argmin(dim=-1) |
| for i in range(self.n_anchors): |
| m = nearest == i |
| if m.sum() < 5: continue |
| self.visit_count[i] += m.sum().float() |
| sp = tri[m].std(dim=0).mean() |
| alpha = min(0.1, 10.0 / (self.visit_count[i] + 1)) |
| self.rigidity[i] = (1-alpha)*self.rigidity[i] + alpha/(sp+0.01) |
|
|
|
|
| class Patchwork(nn.Module): |
| def __init__(self, n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP): |
| super().__init__() |
| self.n_comp = n_comp |
| asgn = torch.arange(n_anchors) % n_comp |
| self.register_buffer("asgn", asgn) |
| self.comps = nn.ModuleList([nn.Sequential( |
| nn.Linear((asgn == k).sum().item(), d_comp * 2), nn.GELU(), |
| nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp), |
| ) for k in range(n_comp)]) |
|
|
| def forward(self, tri): |
| return torch.cat([ |
| self.comps[k](tri[:, self.asgn == k]) |
| for k in range(self.n_comp) |
| ], dim=-1) |
|
|
|
|
| class SoupModel(nn.Module): |
| """ |
| 34-expert β projectors β fusion β constellation β patchwork β classifier. |
| """ |
| def __init__(self, expert_dims_dict, n_anchors=N_ANCHORS, |
| n_comp=N_COMP, d_comp=D_COMP, n_classes=N_CLASSES, |
| d_shared=D_SHARED, init_anchors=None): |
| super().__init__() |
| self.expert_names = sorted(expert_dims_dict.keys()) |
| self.n_experts = len(self.expert_names) |
| self.d_shared = d_shared |
|
|
| |
| self.projectors = nn.ModuleDict({ |
| name.replace(".", "_"): ExpertProjector(dim, d_shared) |
| for name, dim in expert_dims_dict.items() |
| }) |
| self.name_to_key = {name: name.replace(".", "_") |
| for name in expert_dims_dict} |
|
|
| |
| self.expert_ids = nn.Parameter( |
| torch.randn(self.n_experts, d_shared) * 0.02) |
|
|
| |
| self.fusion = ExpertFusion(d_shared, n_heads=8, n_layers=2) |
|
|
| |
| self.constellation = Constellation(n_anchors, d_shared, init_anchors) |
| self.patchwork = Patchwork(n_anchors, n_comp, d_comp) |
|
|
| |
| pw_dim = n_comp * d_comp |
| self.classifier = nn.Sequential( |
| nn.Linear(pw_dim + d_shared, d_shared), |
| nn.GELU(), |
| nn.LayerNorm(d_shared), |
| nn.Dropout(0.1), |
| nn.Linear(d_shared, d_shared // 2), |
| nn.GELU(), |
| nn.Linear(d_shared // 2, n_classes), |
| ) |
|
|
| def forward(self, expert_features_dict): |
| """ |
| expert_features_dict: {name: (B, d_expert)} for each expert |
| """ |
| B = next(iter(expert_features_dict.values())).shape[0] |
|
|
| |
| tokens = [] |
| for i, name in enumerate(self.expert_names): |
| key = self.name_to_key[name] |
| feat = expert_features_dict[name] |
| proj = self.projectors[key](feat) |
| proj = proj + self.expert_ids[i] |
| tokens.append(proj) |
|
|
| expert_stack = torch.stack(tokens, dim=1) |
|
|
| |
| fused = self.fusion(expert_stack) |
| emb = F.normalize(fused, dim=-1) |
|
|
| |
| tri, nearest = self.constellation.triangulate(emb) |
|
|
| |
| pw = self.patchwork(tri) |
|
|
| |
| combined = torch.cat([pw, emb], dim=-1) |
| logits = self.classifier(combined) |
|
|
| return logits, emb, tri, nearest |
|
|
| def count_params(self): |
| total = sum(p.numel() for p in self.parameters()) |
| proj = sum(p.numel() for p in self.projectors.parameters()) |
| fuse = sum(p.numel() for p in self.fusion.parameters()) |
| geo = sum(p.numel() for p in self.constellation.parameters()) |
| pw = sum(p.numel() for p in self.patchwork.parameters()) |
| cls = sum(p.numel() for p in self.classifier.parameters()) |
| return {"total": total, "projectors": proj, "fusion": fuse, |
| "constellation": geo, "patchwork": pw, "classifier": cls} |
|
|
|
|
| |
| |
| |
|
|
| SUBSETS = [ |
| "clip_b16_laion2b", "clip_b16_openai", "clip_b32_datacomp", |
| "clip_b32_laion2b", "clip_b32_openai", "clip_bigg14_laion2b", |
| "clip_g14_laion2b", "clip_h14_laion2b", "clip_l14_336_openai", |
| "clip_l14_datacomp", "clip_l14_laion2b", "clip_l14_openai", |
| "dinov2_b14", "dinov2_b14_reg", "dinov2_g14", "dinov2_g14_reg", |
| "dinov2_l14", "dinov2_l14_reg", "dinov2_s14", "dinov2_s14_reg", |
| "mae_b16", "mae_h14", "mae_l16", |
| "siglip2_b16_256", "siglip2_b16_512", "siglip2_l16_384", |
| "siglip_b16_384", "siglip_b16_512", "siglip_l16_256", |
| "siglip_l16_384", "siglip_so400m_384", |
| "vit_b16_21k", "vit_l16_21k", "vit_s16_21k", |
| ] |
|
|
| print(f"\n Loading val features...") |
| ref_ds = load_dataset("AbstractPhil/bulk-coco-features", SUBSETS[0], split="val") |
| image_ids = ref_ds["image_id"] |
| labels_raw = ref_ds["labels"] |
| N = len(image_ids) |
| id_to_idx = {iid: i for i, iid in enumerate(image_ids)} |
|
|
| |
| label_matrix = torch.zeros(N, N_CLASSES) |
| for i, labs in enumerate(labels_raw): |
| for l in labs: |
| if l < N_CLASSES: |
| label_matrix[i, l] = 1.0 |
|
|
| expert_features = {} |
| expert_dims = {} |
| for name in SUBSETS: |
| ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="val") |
| dim = len(ds[0]["features"]) |
| expert_dims[name] = dim |
| feats = torch.zeros(N, dim) |
| for row in ds: |
| if row["image_id"] in id_to_idx: |
| feats[id_to_idx[row["image_id"]]] = torch.tensor( |
| row["features"], dtype=torch.float32) |
| expert_features[name] = feats |
| print(f" {name:<30} dim={dim}", flush=True) |
|
|
| print(f" Loaded {len(expert_features)} experts, N={N}") |
| print(f" Labels: {N_CLASSES} classes, multi-label") |
| print(f" Positive rate: {label_matrix.sum() / (N * N_CLASSES):.4f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("BUILDING MODEL") |
| print(f"{'='*65}") |
|
|
| model = SoupModel(expert_dims, n_anchors=N_ANCHORS, |
| n_comp=N_COMP, d_comp=D_COMP, |
| n_classes=N_CLASSES, d_shared=D_SHARED).to(DEVICE) |
|
|
| params = model.count_params() |
| print(f" Parameters:") |
| for k, v in params.items(): |
| print(f" {k:<15}: {v:>10,}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("TRAINING") |
| print(f"{'='*65}") |
|
|
| |
| n_train = int(N * 0.8) |
| train_idx = torch.arange(n_train) |
| val_idx = torch.arange(n_train, N) |
|
|
| |
| train_feats = {name: expert_features[name][:n_train].to(DEVICE) for name in SUBSETS} |
| val_feats = {name: expert_features[name][n_train:].to(DEVICE) for name in SUBSETS} |
| train_labels = label_matrix[:n_train].to(DEVICE) |
| val_labels = label_matrix[n_train:].to(DEVICE) |
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=5e-4) |
| BATCH = 128 |
| EPOCHS = 20 |
| TANG, SEP, CV_W = 0.01, 1.0, 0.001 |
|
|
| for epoch in range(EPOCHS): |
| model.train() |
| perm = torch.randperm(n_train, device=DEVICE) |
| total_loss, total_correct, n_batches = 0, 0, 0 |
|
|
| for i in range(0, n_train, BATCH): |
| idx = perm[i:i+BATCH] |
| if len(idx) < 4: continue |
|
|
| |
| batch_feats = {name: train_feats[name][idx] for name in SUBSETS} |
| batch_labels = train_labels[idx] |
|
|
| logits, emb, tri, nearest = model(batch_feats) |
| anchors = model.constellation.anchors |
|
|
| |
| emb_g = EmbeddingAutograd.apply(emb, emb, anchors, TANG, SEP) |
| tri_g, _ = model.constellation.triangulate(emb_g) |
| pw_g = model.patchwork(tri_g) |
| combined_g = torch.cat([pw_g, emb_g], dim=-1) |
| logits = model.classifier(combined_g) |
|
|
| |
| l_cls = F.binary_cross_entropy_with_logits(logits, batch_labels) |
|
|
| |
| l_cv = CV_W * cv_loss(emb) |
| l_spread = 1e-3 * anchor_spread_loss(anchors) |
| l_ent = 1e-4 * anchor_entropy_loss(emb, anchors) |
|
|
| loss = l_cls + l_cv + l_spread + l_ent |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step(); optimizer.zero_grad(set_to_none=True) |
|
|
| model.constellation.update_rigidity(tri.detach()) |
|
|
| |
| preds = (logits.detach().sigmoid() > 0.5).float() |
| correct = (preds == batch_labels).float().mean().item() |
| total_correct += correct |
| total_loss += loss.item() |
| n_batches += 1 |
|
|
| train_acc = total_correct / n_batches |
|
|
| |
| model.eval() |
| with torch.no_grad(): |
| |
| all_logits, all_embs = [], [] |
| for j in range(0, len(val_idx), BATCH): |
| chunk_idx = torch.arange(j, min(j + BATCH, len(val_idx))) |
| chunk_feats = {name: val_feats[name][chunk_idx] for name in SUBSETS} |
| lo, em, _, _ = model(chunk_feats) |
| all_logits.append(lo) |
| all_embs.append(em) |
|
|
| v_logits = torch.cat(all_logits, 0) |
| v_embs = torch.cat(all_embs, 0) |
|
|
| v_preds = (v_logits.sigmoid() > 0.5).float() |
| v_acc = (v_preds == val_labels).float().mean().item() |
| v_cv = cv_metric(v_embs.cpu()) |
|
|
| |
| tp = (v_preds * val_labels).sum(0) |
| fp = (v_preds * (1 - val_labels)).sum(0) |
| fn = ((1 - v_preds) * val_labels).sum(0) |
| precision = tp / (tp + fp + 1e-8) |
| recall = tp / (tp + fn + 1e-8) |
| f1 = 2 * precision * recall / (precision + recall + 1e-8) |
| macro_f1 = f1[f1 > 0].mean().item() |
|
|
| |
| ap_sum = 0 |
| n_valid = 0 |
| for c in range(N_CLASSES): |
| if val_labels[:, c].sum() > 0: |
| scores = v_logits[:, c].cpu() |
| targets = val_labels[:, c].cpu() |
| sorted_idx = scores.argsort(descending=True) |
| sorted_tgt = targets[sorted_idx] |
| tp_cumsum = sorted_tgt.cumsum(0) |
| precision_at_k = tp_cumsum / torch.arange(1, len(sorted_tgt) + 1).float() |
| ap = (precision_at_k * sorted_tgt).sum() / sorted_tgt.sum() |
| ap_sum += ap.item() |
| n_valid += 1 |
| mAP = ap_sum / max(n_valid, 1) |
|
|
| rig = model.constellation.rigidity |
| if (epoch + 1) % 2 == 0 or epoch == 0: |
| print(f" E{epoch+1:2d}: t_acc={train_acc:.3f} v_acc={v_acc:.3f} " |
| f"mAP={mAP:.3f} F1={macro_f1:.3f} " |
| f"cv={v_cv:.4f} rig={rig.mean():.1f}/{rig.max():.1f} " |
| f"loss={total_loss/n_batches:.4f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("FINAL REPORT") |
| print(f"{'='*65}") |
|
|
| model.eval() |
| with torch.no_grad(): |
| all_logits, all_embs = [], [] |
| for j in range(0, len(val_idx), BATCH): |
| chunk_idx = torch.arange(j, min(j + BATCH, len(val_idx))) |
| chunk_feats = {name: val_feats[name][chunk_idx] for name in SUBSETS} |
| lo, em, _, _ = model(chunk_feats) |
| all_logits.append(lo) |
| all_embs.append(em) |
|
|
| v_logits = torch.cat(all_logits, 0) |
| v_embs = torch.cat(all_embs, 0) |
|
|
| |
| class_aps = {} |
| for c in range(N_CLASSES): |
| if val_labels[:, c].sum() > 0: |
| scores = v_logits[:, c].cpu() |
| targets = val_labels[:, c].cpu() |
| sorted_idx = scores.argsort(descending=True) |
| sorted_tgt = targets[sorted_idx] |
| tp_cumsum = sorted_tgt.cumsum(0) |
| prec_at_k = tp_cumsum / torch.arange(1, len(sorted_tgt) + 1).float() |
| class_aps[c] = (prec_at_k * sorted_tgt).sum().item() / sorted_tgt.sum().item() |
|
|
| sorted_aps = sorted(class_aps.items(), key=lambda x: -x[1]) |
| print(f"\n Top 5 classes by AP:") |
| for c, ap in sorted_aps[:5]: |
| n = val_labels[:, c].sum().int().item() |
| print(f" class {c:>3}: AP={ap:.3f} (n={n})") |
|
|
| print(f"\n Bottom 5 classes by AP:") |
| for c, ap in sorted_aps[-5:]: |
| n = val_labels[:, c].sum().int().item() |
| print(f" class {c:>3}: AP={ap:.3f} (n={n})") |
|
|
| final_cv = cv_metric(v_embs.cpu()) |
| print(f"\n Final mAP: {sum(class_aps.values())/len(class_aps):.3f}") |
| print(f" Final CV: {final_cv:.4f}") |
| print(f" Embedding dim: {v_embs.shape[1]}") |
| print(f" Anchors: {model.constellation.n_anchors}") |
|
|
| |
| print(f"\n Expert identity norms (learned importance):") |
| norms = model.expert_ids.detach().cpu().norm(dim=-1) |
| sorted_exp = sorted(zip(model.expert_names, norms.tolist()), |
| key=lambda x: -x[1]) |
| for name, norm in sorted_exp[:5]: |
| print(f" {name:<30} norm={norm:.4f}") |
| print(f" ...") |
| for name, norm in sorted_exp[-3:]: |
| print(f" {name:<30} norm={norm:.4f}") |
|
|
| print(f"\n Parameters: {params['total']:,}") |
|
|
| print(f"\n{'='*65}") |
| print("DONE") |
| print(f"{'='*65}") |