| |
| """ |
| GEOLIP VISION ALIGNMENT BANK |
| ============================== |
| CaptionBERT architecture applied to 34 vision experts. |
| |
| CaptionBERT: |
| 5 BERT experts β GPA consensus β per-expert whitened Procrustes |
| β AlignmentBank(rotations, whiteners, means, anchors, geo_proj) |
| β compute_bank_loss(agreement, ortho, spread, entropy, cross_var, disagree, CV) |
| β student losses: InfoNCE + MSE against consensus |
| |
| This file: |
| 34 vision experts β GPA consensus β per-expert whitened Procrustes |
| β VisionAlignmentBank(34 rotations, whiteners, means, anchors, geo_proj) |
| β same compute_bank_loss |
| β same student losses against consensus |
| β classification through constellation + patchwork (transferred from soup) |
| |
| Data: AbstractPhil/bulk-coco-features (118K train + 5K val, pre-extracted) |
| """ |
|
|
| import gc |
| import math |
| import os |
| import time |
| import json |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.tensorboard import SummaryWriter |
| from datasets import load_dataset |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| REPO_ID = "AbstractPhil/geolip-vit-x34" |
| SOUP_PATH = "soup_patchwork.pt" |
|
|
| |
| D_SHARED = 1024 |
| N_ANCHORS = 256 |
| N_CLASSES = 80 |
| N_COMP = 8 |
| D_COMP = 128 |
| D_BANK = 128 |
|
|
| |
| BATCH = 128 |
| EPOCHS = 20 |
| LR = 5e-4 |
| W_NCE = 1.0 |
| W_MSE = 0.5 |
| W_CV = 0.001 |
| W_BANK = 1.0 |
| W_CLS = 0.3 |
| GRAD_CLIP = 1.0 |
|
|
| SUBSETS = [ |
| "clip_b16_laion2b", "clip_b16_openai", "clip_b32_datacomp", |
| "clip_b32_laion2b", "clip_b32_openai", "clip_bigg14_laion2b", |
| "clip_g14_laion2b", "clip_h14_laion2b", "clip_l14_336_openai", |
| "clip_l14_datacomp", "clip_l14_laion2b", "clip_l14_openai", |
| "dinov2_b14", "dinov2_b14_reg", "dinov2_g14", "dinov2_g14_reg", |
| "dinov2_l14", "dinov2_l14_reg", "dinov2_s14", "dinov2_s14_reg", |
| "mae_b16", "mae_h14", "mae_l16", |
| "siglip2_b16_256", "siglip2_b16_512", "siglip2_l16_384", |
| "siglip_b16_384", "siglip_b16_512", "siglip_l16_256", |
| "siglip_l16_384", "siglip_so400m_384", |
| "vit_b16_21k", "vit_l16_21k", "vit_s16_21k", |
| ] |
|
|
| print("=" * 65) |
| print("GEOLIP VISION ALIGNMENT BANK") |
| print(f" {len(SUBSETS)} experts β CaptionBERT AlignmentBank") |
| print(f" Device: {DEVICE}") |
| print("=" * 65) |
|
|
|
|
| |
| |
| |
|
|
| def cayley_menger_vol2(pts): |
| pts = pts.float() |
| diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) |
| d2 = (diff * diff).sum(-1) |
| B, V, _ = d2.shape |
| cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| s = (-1.0)**V; f = math.factorial(V-1) |
| return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) |
|
|
| def cv_loss(emb, target=0.2, n_samples=16): |
| B = emb.shape[0] |
| if B < 5: return torch.tensor(0.0, device=emb.device) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(B, device=emb.device)[:5] |
| v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) |
| stacked = torch.stack(vols) |
| return (stacked.std() / (stacked.mean() + 1e-8) - target).abs() |
|
|
| def cv_metric(emb, n=200): |
| B = emb.shape[0] |
| if B < 5: return 0.0 |
| vols = [] |
| for _ in range(n): |
| idx = torch.randperm(B, device=emb.device)[:5] |
| v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() |
| if v > 0: vols.append(v) |
| if len(vols) < 10: return 0.0 |
| a = np.array(vols) |
| return float(a.std() / (a.mean() + 1e-8)) |
|
|
| def infonce(a, b, temperature=0.07): |
| a = F.normalize(a, dim=-1); b = F.normalize(b, dim=-1) |
| logits = (a @ b.T) / temperature |
| labels = torch.arange(logits.shape[0], device=logits.device) |
| loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 |
| with torch.no_grad(): |
| acc = (logits.argmax(-1) == labels).float().mean().item() |
| return loss, acc |
|
|
|
|
| |
| |
| |
|
|
| def compute_bank_loss(bank, embedding): |
| B = embedding.shape[0] |
| emb = embedding.float() |
|
|
| expert_cos_list = [] |
| expert_projected = [] |
| for i in range(bank.n_experts): |
| R = bank.expert_rotations[i] |
| W = bank.expert_whiteners[i] |
| mu = bank.expert_means[i] |
| centered = emb - mu |
| whitened = centered @ W |
| whitened_n = F.normalize(whitened, dim=-1) |
| in_expert = whitened_n @ R.T |
| back = in_expert @ R |
| cos = F.cosine_similarity(whitened_n, back, dim=-1) |
| expert_cos_list.append(cos) |
| expert_projected.append(in_expert) |
|
|
| expert_cos = torch.stack(expert_cos_list, dim=-1) |
|
|
| |
| expert_mean = expert_cos.mean(dim=-1, keepdim=True) |
| l_agreement = (expert_cos - expert_mean).pow(2).mean() |
|
|
| |
| l_ortho = 0.0 |
| for i in range(bank.n_experts): |
| R = bank.expert_rotations[i] |
| l_ortho += (R @ R.T - torch.eye(bank.d_embed, device=R.device)).pow(2).mean() |
| l_ortho = l_ortho / bank.n_experts |
|
|
| |
| anchors_n = F.normalize(bank.anchors, dim=-1) |
| anchor_sim = anchors_n @ anchors_n.T |
| anchor_sim.fill_diagonal_(0) |
| l_spread = anchor_sim.pow(2).mean() |
|
|
| |
| anchor_cos = emb @ anchors_n.T |
| anchor_probs = F.softmax(anchor_cos * 10, dim=-1) |
| l_entropy = -(anchor_probs * (anchor_probs + 1e-12).log()).sum(-1).mean() |
|
|
| |
| cross_cos = [] |
| for i in range(bank.n_experts): |
| for j in range(i + 1, bank.n_experts): |
| cc = F.cosine_similarity(expert_projected[i], expert_projected[j], dim=-1) |
| cross_cos.append(cc) |
|
|
| if cross_cos: |
| cross_features = torch.stack(cross_cos, dim=-1) |
| l_cross_var = cross_features.var(dim=0).mean() |
|
|
| batch_cross_mean = cross_features.mean() |
| batch_cross_std = cross_features.std() |
| per_sample_agreement = expert_cos.mean(dim=-1) |
| per_sample_disagreement = expert_cos.std(dim=-1) |
| batch_disagree_ratio = (per_sample_disagreement / (per_sample_agreement + 1e-8)).mean() |
| l_disagree = ( |
| (batch_cross_mean - bank.target_cross_cos_mean).pow(2) + |
| (batch_cross_std - bank.target_cross_cos_std).pow(2) + |
| (batch_disagree_ratio - bank.target_disagreement_ratio).pow(2)) |
| else: |
| l_cross_var = torch.tensor(0.0, device=emb.device) |
| l_disagree = torch.tensor(0.0, device=emb.device) |
|
|
| |
| l_emb_cv = torch.tensor(0.0, device=emb.device) |
| if B >= 10: |
| emb_n = F.normalize(emb, dim=-1) |
| vols = [] |
| for _ in range(16): |
| idx = torch.randperm(B, device=emb.device)[:5] |
| pts = emb_n[idx].unsqueeze(0) |
| diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) |
| d2 = (diff*diff).sum(-1) |
| Bv, V, _ = d2.shape |
| cm = torch.zeros(Bv, V+1, V+1, device=d2.device, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| s = (-1.0)**V; f = math.factorial(V-1) |
| v2 = s / ((2.0**(V-1))*f*f) * torch.linalg.det(cm) |
| vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) |
| stacked = torch.stack(vols) |
| emb_cv = stacked.std() / (stacked.mean() + 1e-8) |
| l_emb_cv = (emb_cv - bank.target_cv).abs() |
|
|
| total = (1.0*l_agreement + 1.0*l_ortho + 0.5*l_spread + |
| 0.1*l_entropy + 0.3*l_cross_var + 0.3*l_emb_cv + 0.5*l_disagree) |
|
|
| diagnostics = { |
| "agreement": l_agreement.item(), |
| "ortho": l_ortho.item() if torch.is_tensor(l_ortho) else l_ortho, |
| "spread": l_spread.item(), "entropy": l_entropy.item(), |
| "cross_var": l_cross_var.item(), "disagree": l_disagree.item(), |
| "emb_cv": emb_cv.item() if B >= 10 else 0.0, |
| "expert_cos_mean": expert_cos.mean().item(), |
| "expert_cos_std": expert_cos.std().item(), |
| } |
| return total, diagnostics |
|
|
|
|
| |
| |
| |
|
|
| def symmetric_inv_sqrt(cov, eps=1e-6): |
| evals, evecs = torch.linalg.eigh(cov) |
| return evecs @ torch.diag(torch.clamp(evals, min=eps).rsqrt()) @ evecs.T |
|
|
| def procrustes_align(source, target, n_align=10000): |
| N = min(n_align, source.shape[0], target.shape[0]) |
| S = source[:N].float(); T = target[:N].float() |
| s_mean = S.mean(0, keepdim=True); t_mean = T.mean(0, keepdim=True) |
| Sc = S - s_mean; Tc = T - t_mean; N_s = Sc.shape[0] |
| s_cov = (Sc.T @ Sc) / max(N_s-1, 1) |
| t_cov = (Tc.T @ Tc) / max(N_s-1, 1) |
| s_whiten = symmetric_inv_sqrt(s_cov) |
| t_whiten = symmetric_inv_sqrt(t_cov) |
| Sc_w = F.normalize(Sc @ s_whiten, dim=-1) |
| Tc_w = F.normalize(Tc @ t_whiten, dim=-1) |
| U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False) |
| R = U @ Vt |
| return {"rotation": R, "source_mean": s_mean.squeeze(0), |
| "source_whitener": s_whiten, |
| "target_unwhitener": torch.linalg.pinv(t_whiten)} |
|
|
| def apply_align(emb, a): |
| x = emb.float() - a["source_mean"] |
| x = x @ a["source_whitener"]; x = x @ a["rotation"].T |
| x = x @ a["target_unwhitener"]; return x |
|
|
|
|
| |
| |
| |
|
|
| class VisionAlignmentBank(nn.Module): |
| """ |
| Exact CaptionBERT AlignmentBank architecture scaled to 34 vision experts. |
| |
| Per-expert: rotation (DΓD), whitener (DΓD), mean (D,) |
| Anchors: (N_ANCHORS, D) on hypersphere |
| geo_proj: expert_cos + expert_mse + cross_cos + disagreement + norms + anchor_cos β d_bank |
| """ |
| def __init__(self, d_embed=D_SHARED, n_experts=34, n_anchors=N_ANCHORS, d_bank=D_BANK): |
| super().__init__() |
| self.d_embed = d_embed |
| self.n_experts = n_experts |
| self.n_anchors = n_anchors |
| self.d_bank = d_bank |
|
|
| |
| self.expert_rotations = nn.ParameterList([ |
| nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)]) |
| self.expert_whiteners = nn.ParameterList([ |
| nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)]) |
| self.expert_means = nn.ParameterList([ |
| nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts)]) |
|
|
| |
| self.anchors = nn.Parameter(F.normalize(torch.randn(n_anchors, d_embed), dim=-1)) |
|
|
| |
| n_cross = n_experts * (n_experts - 1) // 2 |
| geo_dim = n_experts + n_experts + n_cross + 1 + n_experts + n_anchors |
| self.geo_proj = nn.Sequential( |
| nn.Linear(geo_dim, d_bank * 4), nn.GELU(), nn.LayerNorm(d_bank * 4), |
| nn.Linear(d_bank * 4, d_bank * 2), nn.GELU(), nn.LayerNorm(d_bank * 2), |
| nn.Linear(d_bank * 2, d_bank), nn.LayerNorm(d_bank)) |
|
|
| |
| self.register_buffer("target_cv", torch.tensor(0.2)) |
| self.register_buffer("target_cross_cos_mean", torch.tensor(0.0)) |
| self.register_buffer("target_cross_cos_std", torch.tensor(0.0)) |
| self.register_buffer("target_disagreement_ratio", torch.tensor(0.0)) |
|
|
| def forward(self, embedding): |
| B = embedding.shape[0] |
| emb = embedding.float() |
|
|
| expert_consistency = [] |
| expert_recon = [] |
| expert_projected = [] |
| for i in range(self.n_experts): |
| R = self.expert_rotations[i] |
| W = self.expert_whiteners[i] |
| mu = self.expert_means[i] |
| centered = emb - mu |
| whitened = centered @ W |
| whitened_n = F.normalize(whitened, dim=-1) |
| in_expert = whitened_n @ R.T |
| back = in_expert @ R |
| cos = F.cosine_similarity(whitened_n, back, dim=-1) |
| recon = (whitened_n - back).pow(2).mean(dim=-1) |
| expert_consistency.append(cos) |
| expert_recon.append(recon) |
| expert_projected.append(in_expert) |
|
|
| expert_cos = torch.stack(expert_consistency, dim=-1) |
| expert_mse = torch.stack(expert_recon, dim=-1) |
|
|
| |
| cross_cos = [] |
| for i in range(self.n_experts): |
| for j in range(i + 1, self.n_experts): |
| cc = F.cosine_similarity(expert_projected[i], expert_projected[j], dim=-1) |
| cross_cos.append(cc) |
| cross_features = torch.stack(cross_cos, dim=-1) |
|
|
| per_sample_agreement = expert_cos.mean(dim=-1) |
| per_sample_disagreement = expert_cos.std(dim=-1) |
| disagreement_ratio = per_sample_disagreement / (per_sample_agreement + 1e-8) |
|
|
| expert_norms = [] |
| for i in range(self.n_experts): |
| W = self.expert_whiteners[i]; mu = self.expert_means[i] |
| whitened = (emb - mu) @ W |
| expert_norms.append(whitened.norm(dim=-1)) |
| norm_ratio = torch.stack(expert_norms, dim=-1) |
| norm_ratio = norm_ratio / (norm_ratio.mean(dim=-1, keepdim=True) + 1e-8) |
|
|
| anchors_n = F.normalize(self.anchors, dim=-1) |
| anchor_cos = emb @ anchors_n.T |
|
|
| geo_input = torch.cat([ |
| expert_cos, expert_mse, cross_features, |
| disagreement_ratio.unsqueeze(-1), norm_ratio, anchor_cos |
| ], dim=-1) |
| geo_context = self.geo_proj(geo_input) |
| enriched = torch.cat([embedding, geo_context], dim=-1) |
|
|
| diagnostics = { |
| "expert_cos_mean": expert_cos.mean().item(), |
| "expert_cos_std": expert_cos.std().item(), |
| "cross_expert_cos": cross_features.mean().item(), |
| "anchor_max_cos": anchor_cos.max(dim=-1).values.mean().item(), |
| "disagreement_ratio": disagreement_ratio.mean().item(), |
| } |
| return enriched, geo_context, diagnostics |
|
|
|
|
| |
| |
| |
|
|
| class Constellation(nn.Module): |
| def __init__(self, n_anchors, d): |
| super().__init__() |
| self.n_anchors = n_anchors |
| self.anchors = nn.Parameter(F.normalize(torch.randn(n_anchors, d), dim=-1)) |
| def triangulate(self, emb): |
| a = F.normalize(self.anchors, dim=-1) |
| return 1.0 - emb @ a.T, (emb @ a.T).argmax(dim=-1) |
|
|
| class Patchwork(nn.Module): |
| def __init__(self, n_anchors, n_comp, d_comp): |
| super().__init__() |
| self.n_comp = n_comp |
| asgn = torch.arange(n_anchors) % n_comp |
| self.register_buffer("asgn", asgn) |
| self.comps = nn.ModuleList([nn.Sequential( |
| nn.Linear((asgn==k).sum().item(), d_comp*2), nn.GELU(), |
| nn.Linear(d_comp*2, d_comp), nn.LayerNorm(d_comp)) |
| for k in range(n_comp)]) |
| def forward(self, tri): |
| return torch.cat([self.comps[k](tri[:, self.asgn==k]) for k in range(self.n_comp)], -1) |
|
|
|
|
| class VisionBankModel(nn.Module): |
| """ |
| 34-expert AlignmentBank + constellation + patchwork + classifier. |
| |
| Input: L2-normalized consensus embedding (1024-d) β from GPA of 34 experts. |
| Bank: annotates with 34-expert geometric context. |
| Downstream: constellation β patchwork β classifier (multi-label COCO). |
| """ |
| def __init__(self, n_experts=34, d_shared=D_SHARED, n_anchors=N_ANCHORS, |
| n_comp=N_COMP, d_comp=D_COMP, n_classes=N_CLASSES, d_bank=D_BANK): |
| super().__init__() |
| self.bank = VisionAlignmentBank(d_shared, n_experts, n_anchors, d_bank) |
| self.constellation = Constellation(n_anchors, d_shared) |
| self.patchwork = Patchwork(n_anchors, n_comp, d_comp) |
| pw_dim = n_comp * d_comp |
| self.classifier = nn.Sequential( |
| nn.Linear(pw_dim + d_shared + d_bank, d_shared), nn.GELU(), |
| nn.LayerNorm(d_shared), nn.Dropout(0.1), |
| nn.Linear(d_shared, d_shared // 2), nn.GELU(), |
| nn.Linear(d_shared // 2, n_classes)) |
|
|
| def forward(self, embedding): |
| enriched, geo_ctx, bank_diag = self.bank(embedding) |
| tri, nearest = self.constellation.triangulate(embedding) |
| pw = self.patchwork(tri) |
| logits = self.classifier(torch.cat([pw, embedding, geo_ctx], dim=-1)) |
| return logits, embedding, tri, nearest, bank_diag |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("PHASE 0: LOAD EXPERT FEATURES") |
| print(f"{'='*65}") |
|
|
| |
| ref = load_dataset("AbstractPhil/bulk-coco-features", SUBSETS[0], split="train") |
| train_ids = ref["image_id"]; N_train = len(train_ids) |
| train_id_map = {iid: i for i, iid in enumerate(train_ids)} |
| train_labels_raw = ref["labels"] |
| train_label_matrix = torch.zeros(N_train, N_CLASSES) |
| for i, labs in enumerate(train_labels_raw): |
| for l in labs: |
| if l < N_CLASSES: train_label_matrix[i, l] = 1.0 |
|
|
| ref_val = load_dataset("AbstractPhil/bulk-coco-features", SUBSETS[0], split="val") |
| val_ids = ref_val["image_id"]; N_val = len(val_ids) |
| val_id_map = {iid: i for i, iid in enumerate(val_ids)} |
| val_labels_raw = ref_val["labels"] |
| val_label_matrix = torch.zeros(N_val, N_CLASSES) |
| for i, labs in enumerate(val_labels_raw): |
| for l in labs: |
| if l < N_CLASSES: val_label_matrix[i, l] = 1.0 |
|
|
| print(f" Train: {N_train:,} Val: {N_val:,}") |
|
|
| |
| expert_dims = {} |
| train_expert_embs = {} |
| val_expert_embs = {} |
|
|
| for name in SUBSETS: |
| ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="train") |
| dim = len(ds[0]["features"]); expert_dims[name] = dim |
| feats = torch.zeros(N_train, dim) |
| for row in ds: |
| if row["image_id"] in train_id_map: |
| feats[train_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32) |
| train_expert_embs[name] = F.normalize(feats, dim=-1) |
|
|
| ds_v = load_dataset("AbstractPhil/bulk-coco-features", name, split="val") |
| feats_v = torch.zeros(N_val, dim) |
| for row in ds_v: |
| if row["image_id"] in val_id_map: |
| feats_v[val_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32) |
| val_expert_embs[name] = F.normalize(feats_v, dim=-1) |
| print(f" {name:<30} dim={dim}", flush=True) |
| del ds, ds_v; gc.collect() |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("PHASE 1: GPA ALIGNMENT + PROCRUSTES CALIBRATION") |
| print(f"{'='*65}") |
|
|
| |
| def project_to_shared(feats, d_out=D_SHARED): |
| d_in = feats.shape[1] |
| if d_in == d_out: return feats |
| if d_in < d_out: |
| return F.normalize(torch.cat([feats, torch.zeros(feats.shape[0], d_out-d_in)], -1), dim=-1) |
| feats_c = feats - feats.mean(0, keepdim=True) |
| _, _, Vt = torch.linalg.svd(feats_c, full_matrices=False) |
| return F.normalize(feats @ Vt[:d_out].T, dim=-1) |
|
|
| projected = {n: project_to_shared(train_expert_embs[n]) for n in SUBSETS} |
|
|
| |
| current = {i: projected[SUBSETS[i]].float() for i in range(len(SUBSETS))} |
| for gpa_iter in range(20): |
| mean_shape = sum(current[i] for i in range(len(SUBSETS))) / len(SUBSETS) |
| delta = 0.0 |
| new_current = {} |
| for i in range(len(SUBSETS)): |
| info = procrustes_align(current[i], mean_shape) |
| new_current[i] = apply_align(current[i], info) |
| delta += (new_current[i] - current[i]).pow(2).mean().item() |
| current = new_current |
| if gpa_iter == 0 or (gpa_iter+1) % 5 == 0: |
| print(f" GPA iter {gpa_iter+1}: delta={delta:.8f}") |
| if delta < 1e-8: |
| print(f" Converged at iteration {gpa_iter+1}"); break |
|
|
| consensus = F.normalize( |
| sum(current[i] for i in range(len(SUBSETS))) / len(SUBSETS), dim=-1) |
| consensus_cv = cv_metric(consensus[:5000].to(DEVICE)) |
| print(f" Consensus CV: {consensus_cv:.4f}") |
|
|
| |
| print(f"\n Calibrating {len(SUBSETS)} expert Procrustes...") |
| expert_calibrations = [] |
| for i, name in enumerate(SUBSETS): |
| info = procrustes_align(current[i], consensus) |
| expert_calibrations.append(info) |
| c = F.cosine_similarity( |
| apply_align(current[i][:5000], info), |
| consensus[:5000], dim=-1).mean().item() |
| if i < 5 or i >= len(SUBSETS)-3: |
| print(f" {name:<30} cos={c:.4f}") |
| elif i == 5: |
| print(f" ...") |
|
|
| |
| print(f"\n Computing bank calibration targets...") |
| with torch.no_grad(): |
| cons_dev = consensus[:10000].to(DEVICE) |
| |
| |
| tmp_expert_cos = [] |
| tmp_expert_proj = [] |
| for i in range(len(SUBSETS)): |
| R = expert_calibrations[i]["rotation"].to(DEVICE) |
| W = expert_calibrations[i]["source_whitener"].to(DEVICE) |
| mu = expert_calibrations[i]["source_mean"].to(DEVICE) |
| centered = cons_dev - mu |
| whitened_n = F.normalize(centered @ W, dim=-1) |
| in_expert = whitened_n @ R.T |
| back = in_expert @ R |
| cos = F.cosine_similarity(whitened_n, back, dim=-1) |
| tmp_expert_cos.append(cos) |
| tmp_expert_proj.append(in_expert) |
|
|
| expert_cos_stack = torch.stack(tmp_expert_cos, dim=-1) |
| target_cross_cos_vals = [] |
| for i in range(len(SUBSETS)): |
| for j in range(i+1, len(SUBSETS)): |
| cc = F.cosine_similarity(tmp_expert_proj[i], tmp_expert_proj[j], dim=-1) |
| target_cross_cos_vals.append(cc) |
| cross_stack = torch.stack(target_cross_cos_vals, dim=-1) |
|
|
| calib_cross_mean = cross_stack.mean().item() |
| calib_cross_std = cross_stack.std().item() |
| calib_agree = expert_cos_stack.mean(dim=-1) |
| calib_disagree = expert_cos_stack.std(dim=-1) |
| calib_ratio = (calib_disagree / (calib_agree + 1e-8)).mean().item() |
|
|
| print(f" target_cv: {consensus_cv:.4f}") |
| print(f" target_cross_cos_mean: {calib_cross_mean:.4f}") |
| print(f" target_cross_cos_std: {calib_cross_std:.4f}") |
| print(f" target_disagreement_ratio: {calib_ratio:.4f}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("PHASE 2: BUILD MODEL") |
| print(f"{'='*65}") |
|
|
| model = VisionBankModel( |
| n_experts=len(SUBSETS), d_shared=D_SHARED, n_anchors=N_ANCHORS, |
| n_comp=N_COMP, d_comp=D_COMP, n_classes=N_CLASSES, d_bank=D_BANK).to(DEVICE) |
|
|
| |
| with torch.no_grad(): |
| for i, name in enumerate(SUBSETS): |
| cal = expert_calibrations[i] |
| model.bank.expert_rotations[i].copy_(cal["rotation"]) |
| model.bank.expert_whiteners[i].copy_(cal["source_whitener"]) |
| model.bank.expert_means[i].copy_(cal["source_mean"]) |
|
|
| model.bank.target_cv.fill_(consensus_cv) |
| model.bank.target_cross_cos_mean.fill_(calib_cross_mean) |
| model.bank.target_cross_cos_std.fill_(calib_cross_std) |
| model.bank.target_disagreement_ratio.fill_(calib_ratio) |
| print(f" β Bank calibrated with {len(SUBSETS)} expert Procrustes") |
|
|
| |
| if os.path.exists(SOUP_PATH): |
| soup_ckpt = torch.load(SOUP_PATH, map_location="cpu", weights_only=False) |
| soup_state = soup_ckpt["state_dict"] |
| model.constellation.anchors.data.copy_(soup_state["constellation.anchors"].to(DEVICE)) |
| model.bank.anchors.data.copy_(soup_state["constellation.anchors"].to(DEVICE)) |
| pw_state = {k.replace("patchwork.", ""): v for k, v in soup_state.items() if k.startswith("patchwork.")} |
| model.patchwork.load_state_dict(pw_state) |
| print(f" β Constellation + patchwork transferred from soup") |
| del soup_ckpt, soup_state |
| else: |
| print(f" β No soup checkpoint β using random initialization") |
|
|
| n_bank_p = sum(p.numel() for p in model.bank.parameters()) |
| n_const = sum(p.numel() for p in model.constellation.parameters()) |
| n_pw = sum(p.numel() for p in model.patchwork.parameters()) |
| n_cls = sum(p.numel() for p in model.classifier.parameters()) |
| n_total = sum(p.numel() for p in model.parameters()) |
| print(f"\n Parameters:") |
| print(f" bank: {n_bank_p:>12,}") |
| print(f" constellation: {n_const:>12,}") |
| print(f" patchwork: {n_pw:>12,}") |
| print(f" classifier: {n_cls:>12,}") |
| print(f" total: {n_total:>12,}") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("PHASE 3: TRAIN") |
| print(f"{'='*65}") |
|
|
| |
| train_targets = consensus[:N_train].to(DEVICE) |
| val_targets = F.normalize( |
| sum(project_to_shared(val_expert_embs[n]) for n in SUBSETS).float() / len(SUBSETS), |
| dim=-1) |
| |
| val_current = {i: project_to_shared(val_expert_embs[SUBSETS[i]]).float() for i in range(len(SUBSETS))} |
| val_mean = sum(val_current[i] for i in range(len(SUBSETS))) / len(SUBSETS) |
| for i in range(len(SUBSETS)): |
| info = procrustes_align(val_current[i], val_mean) |
| val_current[i] = apply_align(val_current[i], info) |
| val_consensus = F.normalize(sum(val_current[i] for i in range(len(SUBSETS))) / len(SUBSETS), dim=-1).to(DEVICE) |
|
|
| train_labels_gpu = train_label_matrix.to(DEVICE) |
| val_labels_gpu = val_label_matrix.to(DEVICE) |
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=LR) |
| os.makedirs("checkpoints", exist_ok=True) |
| writer = SummaryWriter("runs/vision_alignment_bank") |
| best_mAP = 0.0; gs = 0 |
|
|
| for epoch in range(EPOCHS): |
| model.train() |
| perm = torch.randperm(N_train, device=DEVICE) |
| tl, tn, nb = 0, 0, 0 |
|
|
| for i in range(0, N_train, BATCH): |
| idx = perm[i:i+BATCH] |
| if len(idx) < 8: continue |
|
|
| emb = train_targets[idx] |
| labels = train_labels_gpu[idx] |
|
|
| logits, out_emb, tri, nearest, bank_diag = model(emb) |
|
|
| |
| l_nce, nce_acc = infonce(out_emb, emb) |
| l_mse = F.mse_loss(out_emb, emb) |
| l_cv = cv_loss(out_emb, target=consensus_cv) |
| l_cls = F.binary_cross_entropy_with_logits(logits, labels) |
|
|
| |
| l_bank, bdiag = compute_bank_loss(model.bank, out_emb) |
|
|
| loss = W_NCE*l_nce + W_MSE*l_mse + W_CV*l_cv + W_CLS*l_cls + W_BANK*l_bank |
|
|
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) |
| optimizer.step(); optimizer.zero_grad(set_to_none=True) |
|
|
| tl += loss.item(); tn += nce_acc; nb += 1; gs += 1 |
|
|
| if gs % 100 == 0: |
| writer.add_scalar("train/loss", loss.item(), gs) |
| writer.add_scalar("train/nce", l_nce.item(), gs) |
| writer.add_scalar("train/bank", l_bank.item(), gs) |
| writer.add_scalar("train/cls", l_cls.item(), gs) |
| writer.add_scalar("train/nce_acc", nce_acc, gs) |
| for k, v in bdiag.items(): |
| writer.add_scalar(f"bank/{k}", v, gs) |
|
|
| |
| model.eval() |
| with torch.no_grad(): |
| all_lo, all_em = [], [] |
| for j in range(0, N_val, BATCH): |
| end = min(j+BATCH, N_val) |
| lo, em, _, _, _ = model(val_consensus[j:end]) |
| all_lo.append(lo.cpu()); all_em.append(em.cpu()) |
| v_lo = torch.cat(all_lo); v_em = torch.cat(all_em) |
|
|
| |
| v_lab = val_label_matrix |
| ap_sum, nv = 0, 0 |
| for c in range(N_CLASSES): |
| if v_lab[:,c].sum() > 0: |
| si = v_lo[:,c].argsort(descending=True); st = v_lab[:,c][si] |
| pak = st.cumsum(0)/torch.arange(1,len(st)+1).float() |
| ap_sum += (pak*st).sum().item()/st.sum().item(); nv += 1 |
| mAP = ap_sum/max(nv,1) |
|
|
| v_cos = F.cosine_similarity(v_em, val_consensus.cpu(), dim=-1).mean().item() |
| v_cv = cv_metric(v_em[:2000].to(DEVICE)) |
|
|
| |
| sim = v_em @ val_consensus.cpu().T |
| r1 = (sim.argmax(-1) == torch.arange(N_val)).float().mean().item() |
|
|
| writer.add_scalar("val/mAP", mAP, epoch+1) |
| writer.add_scalar("val/cos", v_cos, epoch+1) |
| writer.add_scalar("val/cv", v_cv, epoch+1) |
| writer.add_scalar("val/R@1", r1, epoch+1) |
|
|
| mk = "" |
| if mAP > best_mAP: |
| best_mAP = mAP |
| torch.save({"state_dict": model.state_dict(), "mAP": mAP, "r1": r1, "cv": v_cv, |
| "config": {"n_experts": len(SUBSETS), "d_shared": D_SHARED, |
| "n_anchors": N_ANCHORS, "n_comp": N_COMP, |
| "d_comp": D_COMP, "n_classes": N_CLASSES, "d_bank": D_BANK}}, |
| "checkpoints/best.pt"); mk = " β
" |
|
|
| print(f" E{epoch+1:2d}: mAP={mAP:.3f} R@1={r1:.3f} cos={v_cos:.3f} " |
| f"cv={v_cv:.4f} nce={tn/nb:.3f} loss={tl/nb:.4f}{mk}") |
|
|
| writer.close() |
|
|
| |
| print(f"\n Best mAP: {best_mAP:.3f}") |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| if os.path.exists("checkpoints/best.pt"): |
| api.upload_file(path_or_fileobj="checkpoints/best.pt", |
| path_in_repo="vision_bank_best.pt", |
| repo_id=REPO_ID, repo_type="model") |
| print(f" β Uploaded to {REPO_ID}") |
| except Exception as e: |
| print(f" Upload: {e}") |
|
|
| print(f"\n{'='*65}\nDONE") |