#!/usr/bin/env python3 """ 34-EXPERT PATCHWORK MODEL ========================== Pre-extracted features from 34 vision models → learned projectors → cross-expert fusion → constellation triangulation → patchwork → COCO multi-label. Architecture: Per-expert: Linear(d_expert → d_shared) + LayerNorm Fusion: Cross-attention over 34 expert tokens → fused embedding Geometry: Constellation(n_anchors) → triangulation → Patchwork → MLP Output: 80-class multi-label (BCE) Training: Adam + geometric autograd (tang=0.01, sep=1.0, cv=0.001) """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math from datasets import load_dataset import gc DEVICE = "cuda" if torch.cuda.is_available() else "cpu" D_SHARED = 1024 N_ANCHORS = 256 N_CLASSES = 80 N_COMP = 8 D_COMP = 128 print("=" * 65) print("34-EXPERT PATCHWORK MODEL") print("=" * 65) print(f" Device: {DEVICE}") print(f" Shared dim: {D_SHARED}, Anchors: {N_ANCHORS}, Classes: {N_CLASSES}") # ══════════════════════════════════════════════════════════════════ # GEOMETRIC PRIMITIVES # ══════════════════════════════════════════════════════════════════ def tangential_projection(grad, embedding): emb_n = F.normalize(embedding.detach().float(), dim=-1) grad_f = grad.float() radial = (grad_f * emb_n).sum(dim=-1, keepdim=True) * emb_n return (grad_f - radial).to(grad.dtype), radial.to(grad.dtype) def cayley_menger_vol2(pts): pts = pts.float() diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) d2 = (diff * diff).sum(-1) B, V, _ = d2.shape cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 s = (-1.0)**V; f = math.factorial(V-1) return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) def cv_loss(emb, target=0.2, n_samples=16): B = emb.shape[0] if B < 5: return torch.tensor(0.0, device=emb.device) vols = [] for _ in range(n_samples): idx = torch.randperm(B, device=emb.device)[:5] v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) stacked = torch.stack(vols) return (stacked.std() / (stacked.mean() + 1e-8) - target).abs() @torch.no_grad() def cv_metric(emb, n_samples=200): B = emb.shape[0] if B < 5: return 0.0 vols = [] for _ in range(n_samples): idx = torch.randperm(B)[:5] v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() if v > 0: vols.append(v) if len(vols) < 10: return 0.0 a = torch.tensor(vols) return float(a.std() / (a.mean() + 1e-8)) def anchor_spread_loss(anchors): a = F.normalize(anchors, dim=-1) sim = a @ a.T - torch.diag(torch.ones(anchors.shape[0], device=anchors.device)) return sim.pow(2).mean() def anchor_entropy_loss(emb, anchors, sharpness=10.0): a = F.normalize(anchors, dim=-1) probs = F.softmax(emb @ a.T * sharpness, dim=-1) return -(probs * (probs + 1e-12).log()).sum(-1).mean() class EmbeddingAutograd(torch.autograd.Function): @staticmethod def forward(ctx, x, embedding, anchors, tang, sep): ctx.save_for_backward(embedding, anchors) ctx.tang = tang; ctx.sep = sep return x @staticmethod def backward(ctx, grad_output): embedding, anchors = ctx.saved_tensors emb_n = F.normalize(embedding.detach().float(), dim=-1) anchors_n = F.normalize(anchors.detach().float(), dim=-1) grad_f = grad_output.float() tang_grad, norm_grad = tangential_projection(grad_f, emb_n) corrected = tang_grad + (1.0 - ctx.tang) * norm_grad if ctx.sep > 0: cos_to = emb_n @ anchors_n.T nearest = anchors_n[cos_to.argmax(dim=-1)] toward = (corrected * nearest).sum(dim=-1, keepdim=True) collapse = toward * nearest corrected = corrected - ctx.sep * (toward > 0).float() * collapse return corrected.to(grad_output.dtype), None, None, None, None # ══════════════════════════════════════════════════════════════════ # MODEL COMPONENTS # ══════════════════════════════════════════════════════════════════ class ExpertProjector(nn.Module): """d_expert → d_shared with bottleneck.""" def __init__(self, d_in, d_out=D_SHARED): super().__init__() d_mid = min(d_in, d_out) self.net = nn.Sequential( nn.Linear(d_in, d_mid), nn.GELU(), nn.Linear(d_mid, d_out), nn.LayerNorm(d_out), ) def forward(self, x): return self.net(x) class ExpertFusion(nn.Module): """ Cross-attention fusion of N expert projections → single embedding. Uses a learned query token that attends to all expert outputs. """ def __init__(self, d_model=D_SHARED, n_heads=8, n_layers=2): super().__init__() self.query = nn.Parameter(torch.randn(1, 1, d_model) * 0.02) self.layers = nn.ModuleList([ nn.TransformerDecoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_model * 2, dropout=0.1, batch_first=True, norm_first=True, ) for _ in range(n_layers) ]) self.norm = nn.LayerNorm(d_model) def forward(self, expert_tokens): """ expert_tokens: (B, N_experts, d_model) returns: (B, d_model) """ B = expert_tokens.shape[0] q = self.query.expand(B, -1, -1) # (B, 1, d_model) for layer in self.layers: q = layer(q, expert_tokens) return self.norm(q.squeeze(1)) # (B, d_model) class Constellation(nn.Module): def __init__(self, n_anchors=N_ANCHORS, d_embed=D_SHARED, init_anchors=None): super().__init__() self.n_anchors = n_anchors if init_anchors is not None: self.anchors = nn.Parameter(init_anchors.clone()) else: self.anchors = nn.Parameter(F.normalize( torch.randn(n_anchors, d_embed), dim=-1)) self.register_buffer("rigidity", torch.zeros(n_anchors)) self.register_buffer("visit_count", torch.zeros(n_anchors)) def triangulate(self, emb): a = F.normalize(self.anchors, dim=-1) cos = emb @ a.T return 1.0 - cos, cos.argmax(dim=-1) @torch.no_grad() def update_rigidity(self, tri): nearest = tri.argmin(dim=-1) for i in range(self.n_anchors): m = nearest == i if m.sum() < 5: continue self.visit_count[i] += m.sum().float() sp = tri[m].std(dim=0).mean() alpha = min(0.1, 10.0 / (self.visit_count[i] + 1)) self.rigidity[i] = (1-alpha)*self.rigidity[i] + alpha/(sp+0.01) class Patchwork(nn.Module): def __init__(self, n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP): super().__init__() self.n_comp = n_comp asgn = torch.arange(n_anchors) % n_comp self.register_buffer("asgn", asgn) self.comps = nn.ModuleList([nn.Sequential( nn.Linear((asgn == k).sum().item(), d_comp * 2), nn.GELU(), nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp), ) for k in range(n_comp)]) def forward(self, tri): return torch.cat([ self.comps[k](tri[:, self.asgn == k]) for k in range(self.n_comp) ], dim=-1) class SoupModel(nn.Module): """ 34-expert → projectors → fusion → constellation → patchwork → classifier. """ def __init__(self, expert_dims_dict, n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP, n_classes=N_CLASSES, d_shared=D_SHARED, init_anchors=None): super().__init__() self.expert_names = sorted(expert_dims_dict.keys()) self.n_experts = len(self.expert_names) self.d_shared = d_shared # Per-expert projectors self.projectors = nn.ModuleDict({ name.replace(".", "_"): ExpertProjector(dim, d_shared) for name, dim in expert_dims_dict.items() }) self.name_to_key = {name: name.replace(".", "_") for name in expert_dims_dict} # Expert identity embeddings (learned, added to projected features) self.expert_ids = nn.Parameter( torch.randn(self.n_experts, d_shared) * 0.02) # Fusion: cross-attention over expert tokens self.fusion = ExpertFusion(d_shared, n_heads=8, n_layers=2) # Geometric pipeline self.constellation = Constellation(n_anchors, d_shared, init_anchors) self.patchwork = Patchwork(n_anchors, n_comp, d_comp) # Classifier: patchwork output + fused embedding → multi-label pw_dim = n_comp * d_comp self.classifier = nn.Sequential( nn.Linear(pw_dim + d_shared, d_shared), nn.GELU(), nn.LayerNorm(d_shared), nn.Dropout(0.1), nn.Linear(d_shared, d_shared // 2), nn.GELU(), nn.Linear(d_shared // 2, n_classes), ) def forward(self, expert_features_dict): """ expert_features_dict: {name: (B, d_expert)} for each expert """ B = next(iter(expert_features_dict.values())).shape[0] # Project each expert tokens = [] for i, name in enumerate(self.expert_names): key = self.name_to_key[name] feat = expert_features_dict[name] proj = self.projectors[key](feat) # (B, d_shared) proj = proj + self.expert_ids[i] # + identity tokens.append(proj) expert_stack = torch.stack(tokens, dim=1) # (B, N, d_shared) # Fuse fused = self.fusion(expert_stack) # (B, d_shared) emb = F.normalize(fused, dim=-1) # on hypersphere # Triangulate tri, nearest = self.constellation.triangulate(emb) # Patchwork pw = self.patchwork(tri) # (B, n_comp * d_comp) # Classify from patchwork + embedding combined = torch.cat([pw, emb], dim=-1) logits = self.classifier(combined) # (B, n_classes) return logits, emb, tri, nearest def count_params(self): total = sum(p.numel() for p in self.parameters()) proj = sum(p.numel() for p in self.projectors.parameters()) fuse = sum(p.numel() for p in self.fusion.parameters()) geo = sum(p.numel() for p in self.constellation.parameters()) pw = sum(p.numel() for p in self.patchwork.parameters()) cls = sum(p.numel() for p in self.classifier.parameters()) return {"total": total, "projectors": proj, "fusion": fuse, "constellation": geo, "patchwork": pw, "classifier": cls} # ══════════════════════════════════════════════════════════════════ # DATA LOADING # ══════════════════════════════════════════════════════════════════ SUBSETS = [ "clip_b16_laion2b", "clip_b16_openai", "clip_b32_datacomp", "clip_b32_laion2b", "clip_b32_openai", "clip_bigg14_laion2b", "clip_g14_laion2b", "clip_h14_laion2b", "clip_l14_336_openai", "clip_l14_datacomp", "clip_l14_laion2b", "clip_l14_openai", "dinov2_b14", "dinov2_b14_reg", "dinov2_g14", "dinov2_g14_reg", "dinov2_l14", "dinov2_l14_reg", "dinov2_s14", "dinov2_s14_reg", "mae_b16", "mae_h14", "mae_l16", "siglip2_b16_256", "siglip2_b16_512", "siglip2_l16_384", "siglip_b16_384", "siglip_b16_512", "siglip_l16_256", "siglip_l16_384", "siglip_so400m_384", "vit_b16_21k", "vit_l16_21k", "vit_s16_21k", ] print(f"\n Loading val features...") ref_ds = load_dataset("AbstractPhil/bulk-coco-features", SUBSETS[0], split="val") image_ids = ref_ds["image_id"] labels_raw = ref_ds["labels"] N = len(image_ids) id_to_idx = {iid: i for i, iid in enumerate(image_ids)} # Multi-label targets label_matrix = torch.zeros(N, N_CLASSES) for i, labs in enumerate(labels_raw): for l in labs: if l < N_CLASSES: label_matrix[i, l] = 1.0 expert_features = {} expert_dims = {} for name in SUBSETS: ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="val") dim = len(ds[0]["features"]) expert_dims[name] = dim feats = torch.zeros(N, dim) for row in ds: if row["image_id"] in id_to_idx: feats[id_to_idx[row["image_id"]]] = torch.tensor( row["features"], dtype=torch.float32) expert_features[name] = feats # NOT normalized — projector handles it print(f" {name:<30} dim={dim}", flush=True) print(f" Loaded {len(expert_features)} experts, N={N}") print(f" Labels: {N_CLASSES} classes, multi-label") print(f" Positive rate: {label_matrix.sum() / (N * N_CLASSES):.4f}") # ══════════════════════════════════════════════════════════════════ # BUILD MODEL # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("BUILDING MODEL") print(f"{'='*65}") model = SoupModel(expert_dims, n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP, n_classes=N_CLASSES, d_shared=D_SHARED).to(DEVICE) params = model.count_params() print(f" Parameters:") for k, v in params.items(): print(f" {k:<15}: {v:>10,}") # ══════════════════════════════════════════════════════════════════ # TRAINING # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("TRAINING") print(f"{'='*65}") # Split 80/20 n_train = int(N * 0.8) train_idx = torch.arange(n_train) val_idx = torch.arange(n_train, N) # Pre-stack features per expert on device train_feats = {name: expert_features[name][:n_train].to(DEVICE) for name in SUBSETS} val_feats = {name: expert_features[name][n_train:].to(DEVICE) for name in SUBSETS} train_labels = label_matrix[:n_train].to(DEVICE) val_labels = label_matrix[n_train:].to(DEVICE) optimizer = torch.optim.Adam(model.parameters(), lr=5e-4) BATCH = 128 EPOCHS = 20 TANG, SEP, CV_W = 0.01, 1.0, 0.001 for epoch in range(EPOCHS): model.train() perm = torch.randperm(n_train, device=DEVICE) total_loss, total_correct, n_batches = 0, 0, 0 for i in range(0, n_train, BATCH): idx = perm[i:i+BATCH] if len(idx) < 4: continue # Gather batch batch_feats = {name: train_feats[name][idx] for name in SUBSETS} batch_labels = train_labels[idx] logits, emb, tri, nearest = model(batch_feats) anchors = model.constellation.anchors # Geometric autograd emb_g = EmbeddingAutograd.apply(emb, emb, anchors, TANG, SEP) tri_g, _ = model.constellation.triangulate(emb_g) pw_g = model.patchwork(tri_g) combined_g = torch.cat([pw_g, emb_g], dim=-1) logits = model.classifier(combined_g) # Multi-label BCE l_cls = F.binary_cross_entropy_with_logits(logits, batch_labels) # Geometric losses l_cv = CV_W * cv_loss(emb) l_spread = 1e-3 * anchor_spread_loss(anchors) l_ent = 1e-4 * anchor_entropy_loss(emb, anchors) loss = l_cls + l_cv + l_spread + l_ent loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step(); optimizer.zero_grad(set_to_none=True) model.constellation.update_rigidity(tri.detach()) # Multi-label accuracy (threshold 0.5) preds = (logits.detach().sigmoid() > 0.5).float() correct = (preds == batch_labels).float().mean().item() total_correct += correct total_loss += loss.item() n_batches += 1 train_acc = total_correct / n_batches # Validation model.eval() with torch.no_grad(): # Process val in chunks all_logits, all_embs = [], [] for j in range(0, len(val_idx), BATCH): chunk_idx = torch.arange(j, min(j + BATCH, len(val_idx))) chunk_feats = {name: val_feats[name][chunk_idx] for name in SUBSETS} lo, em, _, _ = model(chunk_feats) all_logits.append(lo) all_embs.append(em) v_logits = torch.cat(all_logits, 0) v_embs = torch.cat(all_embs, 0) v_preds = (v_logits.sigmoid() > 0.5).float() v_acc = (v_preds == val_labels).float().mean().item() v_cv = cv_metric(v_embs.cpu()) # Per-class F1 (macro) tp = (v_preds * val_labels).sum(0) fp = (v_preds * (1 - val_labels)).sum(0) fn = ((1 - v_preds) * val_labels).sum(0) precision = tp / (tp + fp + 1e-8) recall = tp / (tp + fn + 1e-8) f1 = 2 * precision * recall / (precision + recall + 1e-8) macro_f1 = f1[f1 > 0].mean().item() # mAP ap_sum = 0 n_valid = 0 for c in range(N_CLASSES): if val_labels[:, c].sum() > 0: scores = v_logits[:, c].cpu() targets = val_labels[:, c].cpu() sorted_idx = scores.argsort(descending=True) sorted_tgt = targets[sorted_idx] tp_cumsum = sorted_tgt.cumsum(0) precision_at_k = tp_cumsum / torch.arange(1, len(sorted_tgt) + 1).float() ap = (precision_at_k * sorted_tgt).sum() / sorted_tgt.sum() ap_sum += ap.item() n_valid += 1 mAP = ap_sum / max(n_valid, 1) rig = model.constellation.rigidity if (epoch + 1) % 2 == 0 or epoch == 0: print(f" E{epoch+1:2d}: t_acc={train_acc:.3f} v_acc={v_acc:.3f} " f"mAP={mAP:.3f} F1={macro_f1:.3f} " f"cv={v_cv:.4f} rig={rig.mean():.1f}/{rig.max():.1f} " f"loss={total_loss/n_batches:.4f}") # ══════════════════════════════════════════════════════════════════ # FINAL REPORT # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*65}") print("FINAL REPORT") print(f"{'='*65}") model.eval() with torch.no_grad(): all_logits, all_embs = [], [] for j in range(0, len(val_idx), BATCH): chunk_idx = torch.arange(j, min(j + BATCH, len(val_idx))) chunk_feats = {name: val_feats[name][chunk_idx] for name in SUBSETS} lo, em, _, _ = model(chunk_feats) all_logits.append(lo) all_embs.append(em) v_logits = torch.cat(all_logits, 0) v_embs = torch.cat(all_embs, 0) # Top-5 and bottom-5 classes by AP class_aps = {} for c in range(N_CLASSES): if val_labels[:, c].sum() > 0: scores = v_logits[:, c].cpu() targets = val_labels[:, c].cpu() sorted_idx = scores.argsort(descending=True) sorted_tgt = targets[sorted_idx] tp_cumsum = sorted_tgt.cumsum(0) prec_at_k = tp_cumsum / torch.arange(1, len(sorted_tgt) + 1).float() class_aps[c] = (prec_at_k * sorted_tgt).sum().item() / sorted_tgt.sum().item() sorted_aps = sorted(class_aps.items(), key=lambda x: -x[1]) print(f"\n Top 5 classes by AP:") for c, ap in sorted_aps[:5]: n = val_labels[:, c].sum().int().item() print(f" class {c:>3}: AP={ap:.3f} (n={n})") print(f"\n Bottom 5 classes by AP:") for c, ap in sorted_aps[-5:]: n = val_labels[:, c].sum().int().item() print(f" class {c:>3}: AP={ap:.3f} (n={n})") final_cv = cv_metric(v_embs.cpu()) print(f"\n Final mAP: {sum(class_aps.values())/len(class_aps):.3f}") print(f" Final CV: {final_cv:.4f}") print(f" Embedding dim: {v_embs.shape[1]}") print(f" Anchors: {model.constellation.n_anchors}") # Expert contribution analysis print(f"\n Expert identity norms (learned importance):") norms = model.expert_ids.detach().cpu().norm(dim=-1) sorted_exp = sorted(zip(model.expert_names, norms.tolist()), key=lambda x: -x[1]) for name, norm in sorted_exp[:5]: print(f" {name:<30} norm={norm:.4f}") print(f" ...") for name, norm in sorted_exp[-3:]: print(f" {name:<30} norm={norm:.4f}") print(f"\n Parameters: {params['total']:,}") print(f"\n{'='*65}") print("DONE") print(f"{'='*65}")