geolip-vit-x34 / fault /prepare_data.py
AbstractPhil's picture
Rename failed/prepare_data.py to fault/prepare_data.py
ce6e5b5 verified
#!/usr/bin/env python3
"""
34-EXPERT PATCHWORK MODEL
==========================
Pre-extracted features from 34 vision models β†’ learned projectors β†’
cross-expert fusion β†’ constellation triangulation β†’ patchwork β†’ COCO multi-label.
Architecture:
Per-expert: Linear(d_expert β†’ d_shared) + LayerNorm
Fusion: Cross-attention over 34 expert tokens β†’ fused embedding
Geometry: Constellation(n_anchors) β†’ triangulation β†’ Patchwork β†’ MLP
Output: 80-class multi-label (BCE)
Training: Adam + geometric autograd (tang=0.01, sep=1.0, cv=0.001)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from datasets import load_dataset
import gc
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
D_SHARED = 1024
N_ANCHORS = 256
N_CLASSES = 80
N_COMP = 8
D_COMP = 128
print("=" * 65)
print("34-EXPERT PATCHWORK MODEL")
print("=" * 65)
print(f" Device: {DEVICE}")
print(f" Shared dim: {D_SHARED}, Anchors: {N_ANCHORS}, Classes: {N_CLASSES}")
# ══════════════════════════════════════════════════════════════════
# GEOMETRIC PRIMITIVES
# ══════════════════════════════════════════════════════════════════
def tangential_projection(grad, embedding):
emb_n = F.normalize(embedding.detach().float(), dim=-1)
grad_f = grad.float()
radial = (grad_f * emb_n).sum(dim=-1, keepdim=True) * emb_n
return (grad_f - radial).to(grad.dtype), radial.to(grad.dtype)
def cayley_menger_vol2(pts):
pts = pts.float()
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff * diff).sum(-1)
B, V, _ = d2.shape
cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
s = (-1.0)**V; f = math.factorial(V-1)
return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
def cv_loss(emb, target=0.2, n_samples=16):
B = emb.shape[0]
if B < 5: return torch.tensor(0.0, device=emb.device)
vols = []
for _ in range(n_samples):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
stacked = torch.stack(vols)
return (stacked.std() / (stacked.mean() + 1e-8) - target).abs()
@torch.no_grad()
def cv_metric(emb, n_samples=200):
B = emb.shape[0]
if B < 5: return 0.0
vols = []
for _ in range(n_samples):
idx = torch.randperm(B)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()
if v > 0: vols.append(v)
if len(vols) < 10: return 0.0
a = torch.tensor(vols)
return float(a.std() / (a.mean() + 1e-8))
def anchor_spread_loss(anchors):
a = F.normalize(anchors, dim=-1)
sim = a @ a.T - torch.diag(torch.ones(anchors.shape[0], device=anchors.device))
return sim.pow(2).mean()
def anchor_entropy_loss(emb, anchors, sharpness=10.0):
a = F.normalize(anchors, dim=-1)
probs = F.softmax(emb @ a.T * sharpness, dim=-1)
return -(probs * (probs + 1e-12).log()).sum(-1).mean()
class EmbeddingAutograd(torch.autograd.Function):
@staticmethod
def forward(ctx, x, embedding, anchors, tang, sep):
ctx.save_for_backward(embedding, anchors)
ctx.tang = tang; ctx.sep = sep
return x
@staticmethod
def backward(ctx, grad_output):
embedding, anchors = ctx.saved_tensors
emb_n = F.normalize(embedding.detach().float(), dim=-1)
anchors_n = F.normalize(anchors.detach().float(), dim=-1)
grad_f = grad_output.float()
tang_grad, norm_grad = tangential_projection(grad_f, emb_n)
corrected = tang_grad + (1.0 - ctx.tang) * norm_grad
if ctx.sep > 0:
cos_to = emb_n @ anchors_n.T
nearest = anchors_n[cos_to.argmax(dim=-1)]
toward = (corrected * nearest).sum(dim=-1, keepdim=True)
collapse = toward * nearest
corrected = corrected - ctx.sep * (toward > 0).float() * collapse
return corrected.to(grad_output.dtype), None, None, None, None
# ══════════════════════════════════════════════════════════════════
# MODEL COMPONENTS
# ══════════════════════════════════════════════════════════════════
class ExpertProjector(nn.Module):
"""d_expert β†’ d_shared with bottleneck."""
def __init__(self, d_in, d_out=D_SHARED):
super().__init__()
d_mid = min(d_in, d_out)
self.net = nn.Sequential(
nn.Linear(d_in, d_mid),
nn.GELU(),
nn.Linear(d_mid, d_out),
nn.LayerNorm(d_out),
)
def forward(self, x):
return self.net(x)
class ExpertFusion(nn.Module):
"""
Cross-attention fusion of N expert projections β†’ single embedding.
Uses a learned query token that attends to all expert outputs.
"""
def __init__(self, d_model=D_SHARED, n_heads=8, n_layers=2):
super().__init__()
self.query = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
self.layers = nn.ModuleList([
nn.TransformerDecoderLayer(
d_model=d_model, nhead=n_heads,
dim_feedforward=d_model * 2,
dropout=0.1, batch_first=True,
norm_first=True,
) for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
def forward(self, expert_tokens):
"""
expert_tokens: (B, N_experts, d_model)
returns: (B, d_model)
"""
B = expert_tokens.shape[0]
q = self.query.expand(B, -1, -1) # (B, 1, d_model)
for layer in self.layers:
q = layer(q, expert_tokens)
return self.norm(q.squeeze(1)) # (B, d_model)
class Constellation(nn.Module):
def __init__(self, n_anchors=N_ANCHORS, d_embed=D_SHARED, init_anchors=None):
super().__init__()
self.n_anchors = n_anchors
if init_anchors is not None:
self.anchors = nn.Parameter(init_anchors.clone())
else:
self.anchors = nn.Parameter(F.normalize(
torch.randn(n_anchors, d_embed), dim=-1))
self.register_buffer("rigidity", torch.zeros(n_anchors))
self.register_buffer("visit_count", torch.zeros(n_anchors))
def triangulate(self, emb):
a = F.normalize(self.anchors, dim=-1)
cos = emb @ a.T
return 1.0 - cos, cos.argmax(dim=-1)
@torch.no_grad()
def update_rigidity(self, tri):
nearest = tri.argmin(dim=-1)
for i in range(self.n_anchors):
m = nearest == i
if m.sum() < 5: continue
self.visit_count[i] += m.sum().float()
sp = tri[m].std(dim=0).mean()
alpha = min(0.1, 10.0 / (self.visit_count[i] + 1))
self.rigidity[i] = (1-alpha)*self.rigidity[i] + alpha/(sp+0.01)
class Patchwork(nn.Module):
def __init__(self, n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP):
super().__init__()
self.n_comp = n_comp
asgn = torch.arange(n_anchors) % n_comp
self.register_buffer("asgn", asgn)
self.comps = nn.ModuleList([nn.Sequential(
nn.Linear((asgn == k).sum().item(), d_comp * 2), nn.GELU(),
nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp),
) for k in range(n_comp)])
def forward(self, tri):
return torch.cat([
self.comps[k](tri[:, self.asgn == k])
for k in range(self.n_comp)
], dim=-1)
class SoupModel(nn.Module):
"""
34-expert β†’ projectors β†’ fusion β†’ constellation β†’ patchwork β†’ classifier.
"""
def __init__(self, expert_dims_dict, n_anchors=N_ANCHORS,
n_comp=N_COMP, d_comp=D_COMP, n_classes=N_CLASSES,
d_shared=D_SHARED, init_anchors=None):
super().__init__()
self.expert_names = sorted(expert_dims_dict.keys())
self.n_experts = len(self.expert_names)
self.d_shared = d_shared
# Per-expert projectors
self.projectors = nn.ModuleDict({
name.replace(".", "_"): ExpertProjector(dim, d_shared)
for name, dim in expert_dims_dict.items()
})
self.name_to_key = {name: name.replace(".", "_")
for name in expert_dims_dict}
# Expert identity embeddings (learned, added to projected features)
self.expert_ids = nn.Parameter(
torch.randn(self.n_experts, d_shared) * 0.02)
# Fusion: cross-attention over expert tokens
self.fusion = ExpertFusion(d_shared, n_heads=8, n_layers=2)
# Geometric pipeline
self.constellation = Constellation(n_anchors, d_shared, init_anchors)
self.patchwork = Patchwork(n_anchors, n_comp, d_comp)
# Classifier: patchwork output + fused embedding β†’ multi-label
pw_dim = n_comp * d_comp
self.classifier = nn.Sequential(
nn.Linear(pw_dim + d_shared, d_shared),
nn.GELU(),
nn.LayerNorm(d_shared),
nn.Dropout(0.1),
nn.Linear(d_shared, d_shared // 2),
nn.GELU(),
nn.Linear(d_shared // 2, n_classes),
)
def forward(self, expert_features_dict):
"""
expert_features_dict: {name: (B, d_expert)} for each expert
"""
B = next(iter(expert_features_dict.values())).shape[0]
# Project each expert
tokens = []
for i, name in enumerate(self.expert_names):
key = self.name_to_key[name]
feat = expert_features_dict[name]
proj = self.projectors[key](feat) # (B, d_shared)
proj = proj + self.expert_ids[i] # + identity
tokens.append(proj)
expert_stack = torch.stack(tokens, dim=1) # (B, N, d_shared)
# Fuse
fused = self.fusion(expert_stack) # (B, d_shared)
emb = F.normalize(fused, dim=-1) # on hypersphere
# Triangulate
tri, nearest = self.constellation.triangulate(emb)
# Patchwork
pw = self.patchwork(tri) # (B, n_comp * d_comp)
# Classify from patchwork + embedding
combined = torch.cat([pw, emb], dim=-1)
logits = self.classifier(combined) # (B, n_classes)
return logits, emb, tri, nearest
def count_params(self):
total = sum(p.numel() for p in self.parameters())
proj = sum(p.numel() for p in self.projectors.parameters())
fuse = sum(p.numel() for p in self.fusion.parameters())
geo = sum(p.numel() for p in self.constellation.parameters())
pw = sum(p.numel() for p in self.patchwork.parameters())
cls = sum(p.numel() for p in self.classifier.parameters())
return {"total": total, "projectors": proj, "fusion": fuse,
"constellation": geo, "patchwork": pw, "classifier": cls}
# ══════════════════════════════════════════════════════════════════
# DATA LOADING
# ══════════════════════════════════════════════════════════════════
SUBSETS = [
"clip_b16_laion2b", "clip_b16_openai", "clip_b32_datacomp",
"clip_b32_laion2b", "clip_b32_openai", "clip_bigg14_laion2b",
"clip_g14_laion2b", "clip_h14_laion2b", "clip_l14_336_openai",
"clip_l14_datacomp", "clip_l14_laion2b", "clip_l14_openai",
"dinov2_b14", "dinov2_b14_reg", "dinov2_g14", "dinov2_g14_reg",
"dinov2_l14", "dinov2_l14_reg", "dinov2_s14", "dinov2_s14_reg",
"mae_b16", "mae_h14", "mae_l16",
"siglip2_b16_256", "siglip2_b16_512", "siglip2_l16_384",
"siglip_b16_384", "siglip_b16_512", "siglip_l16_256",
"siglip_l16_384", "siglip_so400m_384",
"vit_b16_21k", "vit_l16_21k", "vit_s16_21k",
]
print(f"\n Loading val features...")
ref_ds = load_dataset("AbstractPhil/bulk-coco-features", SUBSETS[0], split="val")
image_ids = ref_ds["image_id"]
labels_raw = ref_ds["labels"]
N = len(image_ids)
id_to_idx = {iid: i for i, iid in enumerate(image_ids)}
# Multi-label targets
label_matrix = torch.zeros(N, N_CLASSES)
for i, labs in enumerate(labels_raw):
for l in labs:
if l < N_CLASSES:
label_matrix[i, l] = 1.0
expert_features = {}
expert_dims = {}
for name in SUBSETS:
ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="val")
dim = len(ds[0]["features"])
expert_dims[name] = dim
feats = torch.zeros(N, dim)
for row in ds:
if row["image_id"] in id_to_idx:
feats[id_to_idx[row["image_id"]]] = torch.tensor(
row["features"], dtype=torch.float32)
expert_features[name] = feats # NOT normalized β€” projector handles it
print(f" {name:<30} dim={dim}", flush=True)
print(f" Loaded {len(expert_features)} experts, N={N}")
print(f" Labels: {N_CLASSES} classes, multi-label")
print(f" Positive rate: {label_matrix.sum() / (N * N_CLASSES):.4f}")
# ══════════════════════════════════════════════════════════════════
# BUILD MODEL
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("BUILDING MODEL")
print(f"{'='*65}")
model = SoupModel(expert_dims, n_anchors=N_ANCHORS,
n_comp=N_COMP, d_comp=D_COMP,
n_classes=N_CLASSES, d_shared=D_SHARED).to(DEVICE)
params = model.count_params()
print(f" Parameters:")
for k, v in params.items():
print(f" {k:<15}: {v:>10,}")
# ══════════════════════════════════════════════════════════════════
# TRAINING
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("TRAINING")
print(f"{'='*65}")
# Split 80/20
n_train = int(N * 0.8)
train_idx = torch.arange(n_train)
val_idx = torch.arange(n_train, N)
# Pre-stack features per expert on device
train_feats = {name: expert_features[name][:n_train].to(DEVICE) for name in SUBSETS}
val_feats = {name: expert_features[name][n_train:].to(DEVICE) for name in SUBSETS}
train_labels = label_matrix[:n_train].to(DEVICE)
val_labels = label_matrix[n_train:].to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
BATCH = 128
EPOCHS = 20
TANG, SEP, CV_W = 0.01, 1.0, 0.001
for epoch in range(EPOCHS):
model.train()
perm = torch.randperm(n_train, device=DEVICE)
total_loss, total_correct, n_batches = 0, 0, 0
for i in range(0, n_train, BATCH):
idx = perm[i:i+BATCH]
if len(idx) < 4: continue
# Gather batch
batch_feats = {name: train_feats[name][idx] for name in SUBSETS}
batch_labels = train_labels[idx]
logits, emb, tri, nearest = model(batch_feats)
anchors = model.constellation.anchors
# Geometric autograd
emb_g = EmbeddingAutograd.apply(emb, emb, anchors, TANG, SEP)
tri_g, _ = model.constellation.triangulate(emb_g)
pw_g = model.patchwork(tri_g)
combined_g = torch.cat([pw_g, emb_g], dim=-1)
logits = model.classifier(combined_g)
# Multi-label BCE
l_cls = F.binary_cross_entropy_with_logits(logits, batch_labels)
# Geometric losses
l_cv = CV_W * cv_loss(emb)
l_spread = 1e-3 * anchor_spread_loss(anchors)
l_ent = 1e-4 * anchor_entropy_loss(emb, anchors)
loss = l_cls + l_cv + l_spread + l_ent
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step(); optimizer.zero_grad(set_to_none=True)
model.constellation.update_rigidity(tri.detach())
# Multi-label accuracy (threshold 0.5)
preds = (logits.detach().sigmoid() > 0.5).float()
correct = (preds == batch_labels).float().mean().item()
total_correct += correct
total_loss += loss.item()
n_batches += 1
train_acc = total_correct / n_batches
# Validation
model.eval()
with torch.no_grad():
# Process val in chunks
all_logits, all_embs = [], []
for j in range(0, len(val_idx), BATCH):
chunk_idx = torch.arange(j, min(j + BATCH, len(val_idx)))
chunk_feats = {name: val_feats[name][chunk_idx] for name in SUBSETS}
lo, em, _, _ = model(chunk_feats)
all_logits.append(lo)
all_embs.append(em)
v_logits = torch.cat(all_logits, 0)
v_embs = torch.cat(all_embs, 0)
v_preds = (v_logits.sigmoid() > 0.5).float()
v_acc = (v_preds == val_labels).float().mean().item()
v_cv = cv_metric(v_embs.cpu())
# Per-class F1 (macro)
tp = (v_preds * val_labels).sum(0)
fp = (v_preds * (1 - val_labels)).sum(0)
fn = ((1 - v_preds) * val_labels).sum(0)
precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
macro_f1 = f1[f1 > 0].mean().item()
# mAP
ap_sum = 0
n_valid = 0
for c in range(N_CLASSES):
if val_labels[:, c].sum() > 0:
scores = v_logits[:, c].cpu()
targets = val_labels[:, c].cpu()
sorted_idx = scores.argsort(descending=True)
sorted_tgt = targets[sorted_idx]
tp_cumsum = sorted_tgt.cumsum(0)
precision_at_k = tp_cumsum / torch.arange(1, len(sorted_tgt) + 1).float()
ap = (precision_at_k * sorted_tgt).sum() / sorted_tgt.sum()
ap_sum += ap.item()
n_valid += 1
mAP = ap_sum / max(n_valid, 1)
rig = model.constellation.rigidity
if (epoch + 1) % 2 == 0 or epoch == 0:
print(f" E{epoch+1:2d}: t_acc={train_acc:.3f} v_acc={v_acc:.3f} "
f"mAP={mAP:.3f} F1={macro_f1:.3f} "
f"cv={v_cv:.4f} rig={rig.mean():.1f}/{rig.max():.1f} "
f"loss={total_loss/n_batches:.4f}")
# ══════════════════════════════════════════════════════════════════
# FINAL REPORT
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("FINAL REPORT")
print(f"{'='*65}")
model.eval()
with torch.no_grad():
all_logits, all_embs = [], []
for j in range(0, len(val_idx), BATCH):
chunk_idx = torch.arange(j, min(j + BATCH, len(val_idx)))
chunk_feats = {name: val_feats[name][chunk_idx] for name in SUBSETS}
lo, em, _, _ = model(chunk_feats)
all_logits.append(lo)
all_embs.append(em)
v_logits = torch.cat(all_logits, 0)
v_embs = torch.cat(all_embs, 0)
# Top-5 and bottom-5 classes by AP
class_aps = {}
for c in range(N_CLASSES):
if val_labels[:, c].sum() > 0:
scores = v_logits[:, c].cpu()
targets = val_labels[:, c].cpu()
sorted_idx = scores.argsort(descending=True)
sorted_tgt = targets[sorted_idx]
tp_cumsum = sorted_tgt.cumsum(0)
prec_at_k = tp_cumsum / torch.arange(1, len(sorted_tgt) + 1).float()
class_aps[c] = (prec_at_k * sorted_tgt).sum().item() / sorted_tgt.sum().item()
sorted_aps = sorted(class_aps.items(), key=lambda x: -x[1])
print(f"\n Top 5 classes by AP:")
for c, ap in sorted_aps[:5]:
n = val_labels[:, c].sum().int().item()
print(f" class {c:>3}: AP={ap:.3f} (n={n})")
print(f"\n Bottom 5 classes by AP:")
for c, ap in sorted_aps[-5:]:
n = val_labels[:, c].sum().int().item()
print(f" class {c:>3}: AP={ap:.3f} (n={n})")
final_cv = cv_metric(v_embs.cpu())
print(f"\n Final mAP: {sum(class_aps.values())/len(class_aps):.3f}")
print(f" Final CV: {final_cv:.4f}")
print(f" Embedding dim: {v_embs.shape[1]}")
print(f" Anchors: {model.constellation.n_anchors}")
# Expert contribution analysis
print(f"\n Expert identity norms (learned importance):")
norms = model.expert_ids.detach().cpu().norm(dim=-1)
sorted_exp = sorted(zip(model.expert_names, norms.tolist()),
key=lambda x: -x[1])
for name, norm in sorted_exp[:5]:
print(f" {name:<30} norm={norm:.4f}")
print(f" ...")
for name, norm in sorted_exp[-3:]:
print(f" {name:<30} norm={norm:.4f}")
print(f"\n Parameters: {params['total']:,}")
print(f"\n{'='*65}")
print("DONE")
print(f"{'='*65}")