geolip-vit-x34 / trainer_v2.py
AbstractPhil's picture
Rename student_trainer.py to trainer_v2.py
7896108 verified
#!/usr/bin/env python3
"""
GEOLIP VISION ALIGNMENT BANK
==============================
CaptionBERT architecture applied to 34 vision experts.
CaptionBERT:
5 BERT experts β†’ GPA consensus β†’ per-expert whitened Procrustes
β†’ AlignmentBank(rotations, whiteners, means, anchors, geo_proj)
β†’ compute_bank_loss(agreement, ortho, spread, entropy, cross_var, disagree, CV)
β†’ student losses: InfoNCE + MSE against consensus
This file:
34 vision experts β†’ GPA consensus β†’ per-expert whitened Procrustes
β†’ VisionAlignmentBank(34 rotations, whiteners, means, anchors, geo_proj)
β†’ same compute_bank_loss
β†’ same student losses against consensus
β†’ classification through constellation + patchwork (transferred from soup)
Data: AbstractPhil/bulk-coco-features (118K train + 5K val, pre-extracted)
"""
import gc
import math
import os
import time
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from datasets import load_dataset
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
REPO_ID = "AbstractPhil/geolip-vit-x34"
SOUP_PATH = "soup_patchwork.pt"
# Architecture
D_SHARED = 1024
N_ANCHORS = 256
N_CLASSES = 80
N_COMP = 8
D_COMP = 128
D_BANK = 128
# Training
BATCH = 128
EPOCHS = 20
LR = 5e-4
W_NCE = 1.0
W_MSE = 0.5
W_CV = 0.001
W_BANK = 1.0
W_CLS = 0.3
GRAD_CLIP = 1.0
SUBSETS = [
"clip_b16_laion2b", "clip_b16_openai", "clip_b32_datacomp",
"clip_b32_laion2b", "clip_b32_openai", "clip_bigg14_laion2b",
"clip_g14_laion2b", "clip_h14_laion2b", "clip_l14_336_openai",
"clip_l14_datacomp", "clip_l14_laion2b", "clip_l14_openai",
"dinov2_b14", "dinov2_b14_reg", "dinov2_g14", "dinov2_g14_reg",
"dinov2_l14", "dinov2_l14_reg", "dinov2_s14", "dinov2_s14_reg",
"mae_b16", "mae_h14", "mae_l16",
"siglip2_b16_256", "siglip2_b16_512", "siglip2_l16_384",
"siglip_b16_384", "siglip_b16_512", "siglip_l16_256",
"siglip_l16_384", "siglip_so400m_384",
"vit_b16_21k", "vit_l16_21k", "vit_s16_21k",
]
print("=" * 65)
print("GEOLIP VISION ALIGNMENT BANK")
print(f" {len(SUBSETS)} experts β†’ CaptionBERT AlignmentBank")
print(f" Device: {DEVICE}")
print("=" * 65)
# ══════════════════════════════════════════════════════════════════
# GEOMETRIC PRIMITIVES (exact copy from cotrain_bank.py)
# ══════════════════════════════════════════════════════════════════
def cayley_menger_vol2(pts):
pts = pts.float()
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff * diff).sum(-1)
B, V, _ = d2.shape
cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
s = (-1.0)**V; f = math.factorial(V-1)
return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
def cv_loss(emb, target=0.2, n_samples=16):
B = emb.shape[0]
if B < 5: return torch.tensor(0.0, device=emb.device)
vols = []
for _ in range(n_samples):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
stacked = torch.stack(vols)
return (stacked.std() / (stacked.mean() + 1e-8) - target).abs()
def cv_metric(emb, n=200):
B = emb.shape[0]
if B < 5: return 0.0
vols = []
for _ in range(n):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()
if v > 0: vols.append(v)
if len(vols) < 10: return 0.0
a = np.array(vols)
return float(a.std() / (a.mean() + 1e-8))
def infonce(a, b, temperature=0.07):
a = F.normalize(a, dim=-1); b = F.normalize(b, dim=-1)
logits = (a @ b.T) / temperature
labels = torch.arange(logits.shape[0], device=logits.device)
loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
with torch.no_grad():
acc = (logits.argmax(-1) == labels).float().mean().item()
return loss, acc
# ══════════════════════════════════════════════════════════════════
# BANK LOSS (exact from cotrain_bank.py β€” compute_bank_loss)
# ══════════════════════════════════════════════════════════════════
def compute_bank_loss(bank, embedding):
B = embedding.shape[0]
emb = embedding.float()
expert_cos_list = []
expert_projected = []
for i in range(bank.n_experts):
R = bank.expert_rotations[i]
W = bank.expert_whiteners[i]
mu = bank.expert_means[i]
centered = emb - mu
whitened = centered @ W
whitened_n = F.normalize(whitened, dim=-1)
in_expert = whitened_n @ R.T
back = in_expert @ R
cos = F.cosine_similarity(whitened_n, back, dim=-1)
expert_cos_list.append(cos)
expert_projected.append(in_expert)
expert_cos = torch.stack(expert_cos_list, dim=-1)
# 1. Expert agreement
expert_mean = expert_cos.mean(dim=-1, keepdim=True)
l_agreement = (expert_cos - expert_mean).pow(2).mean()
# 2. Rotation orthogonality
l_ortho = 0.0
for i in range(bank.n_experts):
R = bank.expert_rotations[i]
l_ortho += (R @ R.T - torch.eye(bank.d_embed, device=R.device)).pow(2).mean()
l_ortho = l_ortho / bank.n_experts
# 3. Anchor spread
anchors_n = F.normalize(bank.anchors, dim=-1)
anchor_sim = anchors_n @ anchors_n.T
anchor_sim.fill_diagonal_(0)
l_spread = anchor_sim.pow(2).mean()
# 4. Anchor entropy
anchor_cos = emb @ anchors_n.T
anchor_probs = F.softmax(anchor_cos * 10, dim=-1)
l_entropy = -(anchor_probs * (anchor_probs + 1e-12).log()).sum(-1).mean()
# 5. Cross-expert differentiation
cross_cos = []
for i in range(bank.n_experts):
for j in range(i + 1, bank.n_experts):
cc = F.cosine_similarity(expert_projected[i], expert_projected[j], dim=-1)
cross_cos.append(cc)
if cross_cos:
cross_features = torch.stack(cross_cos, dim=-1)
l_cross_var = cross_features.var(dim=0).mean()
batch_cross_mean = cross_features.mean()
batch_cross_std = cross_features.std()
per_sample_agreement = expert_cos.mean(dim=-1)
per_sample_disagreement = expert_cos.std(dim=-1)
batch_disagree_ratio = (per_sample_disagreement / (per_sample_agreement + 1e-8)).mean()
l_disagree = (
(batch_cross_mean - bank.target_cross_cos_mean).pow(2) +
(batch_cross_std - bank.target_cross_cos_std).pow(2) +
(batch_disagree_ratio - bank.target_disagreement_ratio).pow(2))
else:
l_cross_var = torch.tensor(0.0, device=emb.device)
l_disagree = torch.tensor(0.0, device=emb.device)
# 7. Embedding CV
l_emb_cv = torch.tensor(0.0, device=emb.device)
if B >= 10:
emb_n = F.normalize(emb, dim=-1)
vols = []
for _ in range(16):
idx = torch.randperm(B, device=emb.device)[:5]
pts = emb_n[idx].unsqueeze(0)
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff*diff).sum(-1)
Bv, V, _ = d2.shape
cm = torch.zeros(Bv, V+1, V+1, device=d2.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
s = (-1.0)**V; f = math.factorial(V-1)
v2 = s / ((2.0**(V-1))*f*f) * torch.linalg.det(cm)
vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
stacked = torch.stack(vols)
emb_cv = stacked.std() / (stacked.mean() + 1e-8)
l_emb_cv = (emb_cv - bank.target_cv).abs()
total = (1.0*l_agreement + 1.0*l_ortho + 0.5*l_spread +
0.1*l_entropy + 0.3*l_cross_var + 0.3*l_emb_cv + 0.5*l_disagree)
diagnostics = {
"agreement": l_agreement.item(),
"ortho": l_ortho.item() if torch.is_tensor(l_ortho) else l_ortho,
"spread": l_spread.item(), "entropy": l_entropy.item(),
"cross_var": l_cross_var.item(), "disagree": l_disagree.item(),
"emb_cv": emb_cv.item() if B >= 10 else 0.0,
"expert_cos_mean": expert_cos.mean().item(),
"expert_cos_std": expert_cos.std().item(),
}
return total, diagnostics
# ══════════════════════════════════════════════════════════════════
# ALIGNMENT UTILITIES (exact from cotrain_bank.py)
# ══════════════════════════════════════════════════════════════════
def symmetric_inv_sqrt(cov, eps=1e-6):
evals, evecs = torch.linalg.eigh(cov)
return evecs @ torch.diag(torch.clamp(evals, min=eps).rsqrt()) @ evecs.T
def procrustes_align(source, target, n_align=10000):
N = min(n_align, source.shape[0], target.shape[0])
S = source[:N].float(); T = target[:N].float()
s_mean = S.mean(0, keepdim=True); t_mean = T.mean(0, keepdim=True)
Sc = S - s_mean; Tc = T - t_mean; N_s = Sc.shape[0]
s_cov = (Sc.T @ Sc) / max(N_s-1, 1)
t_cov = (Tc.T @ Tc) / max(N_s-1, 1)
s_whiten = symmetric_inv_sqrt(s_cov)
t_whiten = symmetric_inv_sqrt(t_cov)
Sc_w = F.normalize(Sc @ s_whiten, dim=-1)
Tc_w = F.normalize(Tc @ t_whiten, dim=-1)
U, _, Vt = torch.linalg.svd(Tc_w.T @ Sc_w, full_matrices=False)
R = U @ Vt
return {"rotation": R, "source_mean": s_mean.squeeze(0),
"source_whitener": s_whiten,
"target_unwhitener": torch.linalg.pinv(t_whiten)}
def apply_align(emb, a):
x = emb.float() - a["source_mean"]
x = x @ a["source_whitener"]; x = x @ a["rotation"].T
x = x @ a["target_unwhitener"]; return x
# ══════════════════════════════════════════════════════════════════
# VISION ALIGNMENT BANK (CaptionBERT AlignmentBank for 34 experts)
# ══════════════════════════════════════════════════════════════════
class VisionAlignmentBank(nn.Module):
"""
Exact CaptionBERT AlignmentBank architecture scaled to 34 vision experts.
Per-expert: rotation (DΓ—D), whitener (DΓ—D), mean (D,)
Anchors: (N_ANCHORS, D) on hypersphere
geo_proj: expert_cos + expert_mse + cross_cos + disagreement + norms + anchor_cos β†’ d_bank
"""
def __init__(self, d_embed=D_SHARED, n_experts=34, n_anchors=N_ANCHORS, d_bank=D_BANK):
super().__init__()
self.d_embed = d_embed
self.n_experts = n_experts
self.n_anchors = n_anchors
self.d_bank = d_bank
# Per-expert Procrustes (calibrated, then trainable)
self.expert_rotations = nn.ParameterList([
nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)])
self.expert_whiteners = nn.ParameterList([
nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)])
self.expert_means = nn.ParameterList([
nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts)])
# Constellation anchors
self.anchors = nn.Parameter(F.normalize(torch.randn(n_anchors, d_embed), dim=-1))
# Geometric context projection
n_cross = n_experts * (n_experts - 1) // 2
geo_dim = n_experts + n_experts + n_cross + 1 + n_experts + n_anchors
self.geo_proj = nn.Sequential(
nn.Linear(geo_dim, d_bank * 4), nn.GELU(), nn.LayerNorm(d_bank * 4),
nn.Linear(d_bank * 4, d_bank * 2), nn.GELU(), nn.LayerNorm(d_bank * 2),
nn.Linear(d_bank * 2, d_bank), nn.LayerNorm(d_bank))
# Calibrated targets
self.register_buffer("target_cv", torch.tensor(0.2))
self.register_buffer("target_cross_cos_mean", torch.tensor(0.0))
self.register_buffer("target_cross_cos_std", torch.tensor(0.0))
self.register_buffer("target_disagreement_ratio", torch.tensor(0.0))
def forward(self, embedding):
B = embedding.shape[0]
emb = embedding.float()
expert_consistency = []
expert_recon = []
expert_projected = []
for i in range(self.n_experts):
R = self.expert_rotations[i]
W = self.expert_whiteners[i]
mu = self.expert_means[i]
centered = emb - mu
whitened = centered @ W
whitened_n = F.normalize(whitened, dim=-1)
in_expert = whitened_n @ R.T
back = in_expert @ R
cos = F.cosine_similarity(whitened_n, back, dim=-1)
recon = (whitened_n - back).pow(2).mean(dim=-1)
expert_consistency.append(cos)
expert_recon.append(recon)
expert_projected.append(in_expert)
expert_cos = torch.stack(expert_consistency, dim=-1)
expert_mse = torch.stack(expert_recon, dim=-1)
# Cross-expert (all pairs β€” 34 choose 2 = 561 pairs)
cross_cos = []
for i in range(self.n_experts):
for j in range(i + 1, self.n_experts):
cc = F.cosine_similarity(expert_projected[i], expert_projected[j], dim=-1)
cross_cos.append(cc)
cross_features = torch.stack(cross_cos, dim=-1)
per_sample_agreement = expert_cos.mean(dim=-1)
per_sample_disagreement = expert_cos.std(dim=-1)
disagreement_ratio = per_sample_disagreement / (per_sample_agreement + 1e-8)
expert_norms = []
for i in range(self.n_experts):
W = self.expert_whiteners[i]; mu = self.expert_means[i]
whitened = (emb - mu) @ W
expert_norms.append(whitened.norm(dim=-1))
norm_ratio = torch.stack(expert_norms, dim=-1)
norm_ratio = norm_ratio / (norm_ratio.mean(dim=-1, keepdim=True) + 1e-8)
anchors_n = F.normalize(self.anchors, dim=-1)
anchor_cos = emb @ anchors_n.T
geo_input = torch.cat([
expert_cos, expert_mse, cross_features,
disagreement_ratio.unsqueeze(-1), norm_ratio, anchor_cos
], dim=-1)
geo_context = self.geo_proj(geo_input)
enriched = torch.cat([embedding, geo_context], dim=-1)
diagnostics = {
"expert_cos_mean": expert_cos.mean().item(),
"expert_cos_std": expert_cos.std().item(),
"cross_expert_cos": cross_features.mean().item(),
"anchor_max_cos": anchor_cos.max(dim=-1).values.mean().item(),
"disagreement_ratio": disagreement_ratio.mean().item(),
}
return enriched, geo_context, diagnostics
# ══════════════════════════════════════════════════════════════════
# FULL MODEL: bank + constellation + patchwork + classifier
# ══════════════════════════════════════════════════════════════════
class Constellation(nn.Module):
def __init__(self, n_anchors, d):
super().__init__()
self.n_anchors = n_anchors
self.anchors = nn.Parameter(F.normalize(torch.randn(n_anchors, d), dim=-1))
def triangulate(self, emb):
a = F.normalize(self.anchors, dim=-1)
return 1.0 - emb @ a.T, (emb @ a.T).argmax(dim=-1)
class Patchwork(nn.Module):
def __init__(self, n_anchors, n_comp, d_comp):
super().__init__()
self.n_comp = n_comp
asgn = torch.arange(n_anchors) % n_comp
self.register_buffer("asgn", asgn)
self.comps = nn.ModuleList([nn.Sequential(
nn.Linear((asgn==k).sum().item(), d_comp*2), nn.GELU(),
nn.Linear(d_comp*2, d_comp), nn.LayerNorm(d_comp))
for k in range(n_comp)])
def forward(self, tri):
return torch.cat([self.comps[k](tri[:, self.asgn==k]) for k in range(self.n_comp)], -1)
class VisionBankModel(nn.Module):
"""
34-expert AlignmentBank + constellation + patchwork + classifier.
Input: L2-normalized consensus embedding (1024-d) β€” from GPA of 34 experts.
Bank: annotates with 34-expert geometric context.
Downstream: constellation β†’ patchwork β†’ classifier (multi-label COCO).
"""
def __init__(self, n_experts=34, d_shared=D_SHARED, n_anchors=N_ANCHORS,
n_comp=N_COMP, d_comp=D_COMP, n_classes=N_CLASSES, d_bank=D_BANK):
super().__init__()
self.bank = VisionAlignmentBank(d_shared, n_experts, n_anchors, d_bank)
self.constellation = Constellation(n_anchors, d_shared)
self.patchwork = Patchwork(n_anchors, n_comp, d_comp)
pw_dim = n_comp * d_comp
self.classifier = nn.Sequential(
nn.Linear(pw_dim + d_shared + d_bank, d_shared), nn.GELU(),
nn.LayerNorm(d_shared), nn.Dropout(0.1),
nn.Linear(d_shared, d_shared // 2), nn.GELU(),
nn.Linear(d_shared // 2, n_classes))
def forward(self, embedding):
enriched, geo_ctx, bank_diag = self.bank(embedding)
tri, nearest = self.constellation.triangulate(embedding)
pw = self.patchwork(tri)
logits = self.classifier(torch.cat([pw, embedding, geo_ctx], dim=-1))
return logits, embedding, tri, nearest, bank_diag
# ══════════════════════════════════════════════════════════════════
# PHASE 0: LOAD ALL EXPERT FEATURES
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("PHASE 0: LOAD EXPERT FEATURES")
print(f"{'='*65}")
# Reference for image_id alignment
ref = load_dataset("AbstractPhil/bulk-coco-features", SUBSETS[0], split="train")
train_ids = ref["image_id"]; N_train = len(train_ids)
train_id_map = {iid: i for i, iid in enumerate(train_ids)}
train_labels_raw = ref["labels"]
train_label_matrix = torch.zeros(N_train, N_CLASSES)
for i, labs in enumerate(train_labels_raw):
for l in labs:
if l < N_CLASSES: train_label_matrix[i, l] = 1.0
ref_val = load_dataset("AbstractPhil/bulk-coco-features", SUBSETS[0], split="val")
val_ids = ref_val["image_id"]; N_val = len(val_ids)
val_id_map = {iid: i for i, iid in enumerate(val_ids)}
val_labels_raw = ref_val["labels"]
val_label_matrix = torch.zeros(N_val, N_CLASSES)
for i, labs in enumerate(val_labels_raw):
for l in labs:
if l < N_CLASSES: val_label_matrix[i, l] = 1.0
print(f" Train: {N_train:,} Val: {N_val:,}")
# Load all 34 experts
expert_dims = {}
train_expert_embs = {}
val_expert_embs = {}
for name in SUBSETS:
ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="train")
dim = len(ds[0]["features"]); expert_dims[name] = dim
feats = torch.zeros(N_train, dim)
for row in ds:
if row["image_id"] in train_id_map:
feats[train_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32)
train_expert_embs[name] = F.normalize(feats, dim=-1)
ds_v = load_dataset("AbstractPhil/bulk-coco-features", name, split="val")
feats_v = torch.zeros(N_val, dim)
for row in ds_v:
if row["image_id"] in val_id_map:
feats_v[val_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32)
val_expert_embs[name] = F.normalize(feats_v, dim=-1)
print(f" {name:<30} dim={dim}", flush=True)
del ds, ds_v; gc.collect()
# ══════════════════════════════════════════════════════════════════
# PHASE 1: GPA β†’ CONSENSUS + PER-EXPERT PROCRUSTES
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("PHASE 1: GPA ALIGNMENT + PROCRUSTES CALIBRATION")
print(f"{'='*65}")
# Project all to D_SHARED for GPA (PCA for d>1024, pad for d<1024)
def project_to_shared(feats, d_out=D_SHARED):
d_in = feats.shape[1]
if d_in == d_out: return feats
if d_in < d_out:
return F.normalize(torch.cat([feats, torch.zeros(feats.shape[0], d_out-d_in)], -1), dim=-1)
feats_c = feats - feats.mean(0, keepdim=True)
_, _, Vt = torch.linalg.svd(feats_c, full_matrices=False)
return F.normalize(feats @ Vt[:d_out].T, dim=-1)
projected = {n: project_to_shared(train_expert_embs[n]) for n in SUBSETS}
# GPA
current = {i: projected[SUBSETS[i]].float() for i in range(len(SUBSETS))}
for gpa_iter in range(20):
mean_shape = sum(current[i] for i in range(len(SUBSETS))) / len(SUBSETS)
delta = 0.0
new_current = {}
for i in range(len(SUBSETS)):
info = procrustes_align(current[i], mean_shape)
new_current[i] = apply_align(current[i], info)
delta += (new_current[i] - current[i]).pow(2).mean().item()
current = new_current
if gpa_iter == 0 or (gpa_iter+1) % 5 == 0:
print(f" GPA iter {gpa_iter+1}: delta={delta:.8f}")
if delta < 1e-8:
print(f" Converged at iteration {gpa_iter+1}"); break
consensus = F.normalize(
sum(current[i] for i in range(len(SUBSETS))) / len(SUBSETS), dim=-1)
consensus_cv = cv_metric(consensus[:5000].to(DEVICE))
print(f" Consensus CV: {consensus_cv:.4f}")
# Per-expert Procrustes calibration (expert β†’ consensus space)
print(f"\n Calibrating {len(SUBSETS)} expert Procrustes...")
expert_calibrations = []
for i, name in enumerate(SUBSETS):
info = procrustes_align(current[i], consensus)
expert_calibrations.append(info)
c = F.cosine_similarity(
apply_align(current[i][:5000], info),
consensus[:5000], dim=-1).mean().item()
if i < 5 or i >= len(SUBSETS)-3:
print(f" {name:<30} cos={c:.4f}")
elif i == 5:
print(f" ...")
# Compute bank calibration targets on consensus
print(f"\n Computing bank calibration targets...")
with torch.no_grad():
cons_dev = consensus[:10000].to(DEVICE)
# Simulate bank on consensus to get cross-expert targets
# We need initial expert_cos and cross_cos from calibrated Procrustes
tmp_expert_cos = []
tmp_expert_proj = []
for i in range(len(SUBSETS)):
R = expert_calibrations[i]["rotation"].to(DEVICE)
W = expert_calibrations[i]["source_whitener"].to(DEVICE)
mu = expert_calibrations[i]["source_mean"].to(DEVICE)
centered = cons_dev - mu
whitened_n = F.normalize(centered @ W, dim=-1)
in_expert = whitened_n @ R.T
back = in_expert @ R
cos = F.cosine_similarity(whitened_n, back, dim=-1)
tmp_expert_cos.append(cos)
tmp_expert_proj.append(in_expert)
expert_cos_stack = torch.stack(tmp_expert_cos, dim=-1)
target_cross_cos_vals = []
for i in range(len(SUBSETS)):
for j in range(i+1, len(SUBSETS)):
cc = F.cosine_similarity(tmp_expert_proj[i], tmp_expert_proj[j], dim=-1)
target_cross_cos_vals.append(cc)
cross_stack = torch.stack(target_cross_cos_vals, dim=-1)
calib_cross_mean = cross_stack.mean().item()
calib_cross_std = cross_stack.std().item()
calib_agree = expert_cos_stack.mean(dim=-1)
calib_disagree = expert_cos_stack.std(dim=-1)
calib_ratio = (calib_disagree / (calib_agree + 1e-8)).mean().item()
print(f" target_cv: {consensus_cv:.4f}")
print(f" target_cross_cos_mean: {calib_cross_mean:.4f}")
print(f" target_cross_cos_std: {calib_cross_std:.4f}")
print(f" target_disagreement_ratio: {calib_ratio:.4f}")
# ══════════════════════════════════════════════════════════════════
# PHASE 2: BUILD MODEL + LOAD SOUP DOWNSTREAM
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("PHASE 2: BUILD MODEL")
print(f"{'='*65}")
model = VisionBankModel(
n_experts=len(SUBSETS), d_shared=D_SHARED, n_anchors=N_ANCHORS,
n_comp=N_COMP, d_comp=D_COMP, n_classes=N_CLASSES, d_bank=D_BANK).to(DEVICE)
# Initialize bank with calibrated Procrustes
with torch.no_grad():
for i, name in enumerate(SUBSETS):
cal = expert_calibrations[i]
model.bank.expert_rotations[i].copy_(cal["rotation"])
model.bank.expert_whiteners[i].copy_(cal["source_whitener"])
model.bank.expert_means[i].copy_(cal["source_mean"])
model.bank.target_cv.fill_(consensus_cv)
model.bank.target_cross_cos_mean.fill_(calib_cross_mean)
model.bank.target_cross_cos_std.fill_(calib_cross_std)
model.bank.target_disagreement_ratio.fill_(calib_ratio)
print(f" βœ“ Bank calibrated with {len(SUBSETS)} expert Procrustes")
# Transfer soup constellation + patchwork (classifier is new due to +d_bank dim)
if os.path.exists(SOUP_PATH):
soup_ckpt = torch.load(SOUP_PATH, map_location="cpu", weights_only=False)
soup_state = soup_ckpt["state_dict"]
model.constellation.anchors.data.copy_(soup_state["constellation.anchors"].to(DEVICE))
model.bank.anchors.data.copy_(soup_state["constellation.anchors"].to(DEVICE))
pw_state = {k.replace("patchwork.", ""): v for k, v in soup_state.items() if k.startswith("patchwork.")}
model.patchwork.load_state_dict(pw_state)
print(f" βœ“ Constellation + patchwork transferred from soup")
del soup_ckpt, soup_state
else:
print(f" ⚠ No soup checkpoint β€” using random initialization")
n_bank_p = sum(p.numel() for p in model.bank.parameters())
n_const = sum(p.numel() for p in model.constellation.parameters())
n_pw = sum(p.numel() for p in model.patchwork.parameters())
n_cls = sum(p.numel() for p in model.classifier.parameters())
n_total = sum(p.numel() for p in model.parameters())
print(f"\n Parameters:")
print(f" bank: {n_bank_p:>12,}")
print(f" constellation: {n_const:>12,}")
print(f" patchwork: {n_pw:>12,}")
print(f" classifier: {n_cls:>12,}")
print(f" total: {n_total:>12,}")
# ══════════════════════════════════════════════════════════════════
# PHASE 3: TRAINING
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("PHASE 3: TRAIN")
print(f"{'='*65}")
# Consensus targets on GPU
train_targets = consensus[:N_train].to(DEVICE)
val_targets = F.normalize(
sum(project_to_shared(val_expert_embs[n]) for n in SUBSETS).float() / len(SUBSETS),
dim=-1)
# GPA-align val consensus
val_current = {i: project_to_shared(val_expert_embs[SUBSETS[i]]).float() for i in range(len(SUBSETS))}
val_mean = sum(val_current[i] for i in range(len(SUBSETS))) / len(SUBSETS)
for i in range(len(SUBSETS)):
info = procrustes_align(val_current[i], val_mean)
val_current[i] = apply_align(val_current[i], info)
val_consensus = F.normalize(sum(val_current[i] for i in range(len(SUBSETS))) / len(SUBSETS), dim=-1).to(DEVICE)
train_labels_gpu = train_label_matrix.to(DEVICE)
val_labels_gpu = val_label_matrix.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
os.makedirs("checkpoints", exist_ok=True)
writer = SummaryWriter("runs/vision_alignment_bank")
best_mAP = 0.0; gs = 0
for epoch in range(EPOCHS):
model.train()
perm = torch.randperm(N_train, device=DEVICE)
tl, tn, nb = 0, 0, 0
for i in range(0, N_train, BATCH):
idx = perm[i:i+BATCH]
if len(idx) < 8: continue
emb = train_targets[idx]
labels = train_labels_gpu[idx]
logits, out_emb, tri, nearest, bank_diag = model(emb)
# Student losses
l_nce, nce_acc = infonce(out_emb, emb)
l_mse = F.mse_loss(out_emb, emb)
l_cv = cv_loss(out_emb, target=consensus_cv)
l_cls = F.binary_cross_entropy_with_logits(logits, labels)
# Bank losses
l_bank, bdiag = compute_bank_loss(model.bank, out_emb)
loss = W_NCE*l_nce + W_MSE*l_mse + W_CV*l_cv + W_CLS*l_cls + W_BANK*l_bank
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
optimizer.step(); optimizer.zero_grad(set_to_none=True)
tl += loss.item(); tn += nce_acc; nb += 1; gs += 1
if gs % 100 == 0:
writer.add_scalar("train/loss", loss.item(), gs)
writer.add_scalar("train/nce", l_nce.item(), gs)
writer.add_scalar("train/bank", l_bank.item(), gs)
writer.add_scalar("train/cls", l_cls.item(), gs)
writer.add_scalar("train/nce_acc", nce_acc, gs)
for k, v in bdiag.items():
writer.add_scalar(f"bank/{k}", v, gs)
# Validation
model.eval()
with torch.no_grad():
all_lo, all_em = [], []
for j in range(0, N_val, BATCH):
end = min(j+BATCH, N_val)
lo, em, _, _, _ = model(val_consensus[j:end])
all_lo.append(lo.cpu()); all_em.append(em.cpu())
v_lo = torch.cat(all_lo); v_em = torch.cat(all_em)
# mAP
v_lab = val_label_matrix
ap_sum, nv = 0, 0
for c in range(N_CLASSES):
if v_lab[:,c].sum() > 0:
si = v_lo[:,c].argsort(descending=True); st = v_lab[:,c][si]
pak = st.cumsum(0)/torch.arange(1,len(st)+1).float()
ap_sum += (pak*st).sum().item()/st.sum().item(); nv += 1
mAP = ap_sum/max(nv,1)
v_cos = F.cosine_similarity(v_em, val_consensus.cpu(), dim=-1).mean().item()
v_cv = cv_metric(v_em[:2000].to(DEVICE))
# R@1
sim = v_em @ val_consensus.cpu().T
r1 = (sim.argmax(-1) == torch.arange(N_val)).float().mean().item()
writer.add_scalar("val/mAP", mAP, epoch+1)
writer.add_scalar("val/cos", v_cos, epoch+1)
writer.add_scalar("val/cv", v_cv, epoch+1)
writer.add_scalar("val/R@1", r1, epoch+1)
mk = ""
if mAP > best_mAP:
best_mAP = mAP
torch.save({"state_dict": model.state_dict(), "mAP": mAP, "r1": r1, "cv": v_cv,
"config": {"n_experts": len(SUBSETS), "d_shared": D_SHARED,
"n_anchors": N_ANCHORS, "n_comp": N_COMP,
"d_comp": D_COMP, "n_classes": N_CLASSES, "d_bank": D_BANK}},
"checkpoints/best.pt"); mk = " β˜…"
print(f" E{epoch+1:2d}: mAP={mAP:.3f} R@1={r1:.3f} cos={v_cos:.3f} "
f"cv={v_cv:.4f} nce={tn/nb:.3f} loss={tl/nb:.4f}{mk}")
writer.close()
# Upload
print(f"\n Best mAP: {best_mAP:.3f}")
try:
from huggingface_hub import HfApi
api = HfApi()
if os.path.exists("checkpoints/best.pt"):
api.upload_file(path_or_fileobj="checkpoints/best.pt",
path_in_repo="vision_bank_best.pt",
repo_id=REPO_ID, repo_type="model")
print(f" βœ“ Uploaded to {REPO_ID}")
except Exception as e:
print(f" Upload: {e}")
print(f"\n{'='*65}\nDONE")