geolip-vit-base-x3 / run_2_train_vit_with_soup.py
AbstractPhil's picture
Update run_2_train_vit_with_soup.py
d5583cc verified
#!/usr/bin/env python3
"""
GEOLIP VISION ENCODER β€” FROM SCRATCH
======================================
From-scratch ViT trained against frozen soup consensus targets.
Phase 0: Pre-compute consensus targets from frozen soup
Phase 1: Pre-cache all COCO images as tensors (once, then reuse)
Phase 2: Train from-scratch ViT with full GeoLIP loss stack
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import gc
import time
import math
import numpy as np
from tqdm import tqdm
DEVICE = "cuda"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Architecture
D_MODEL = 384
N_HEADS = 6
N_LAYERS = 6
D_FF = 1536
PATCH_SIZE = 16
IMAGE_SIZE = 224
D_ANCHOR = 128
N_ANCHORS = 256
N_CLASSES = 80
N_COMP = 8
D_COMP = 64
DROPOUT = 0.1
# Training
BATCH = 48
EPOCHS = 20
LR = 3e-4
WARMUP_STEPS = 500
GRAD_CLIP = 1.0
EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"]
N_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
print("=" * 65)
print("GEOLIP VISION ENCODER β€” FROM SCRATCH")
print(f" ViT: {N_LAYERS}L/{D_MODEL}d/{N_HEADS}h, patch{PATCH_SIZE}")
print(f" {N_PATCHES} patches + CLS β†’ {D_ANCHOR}-d output")
print(f" Device: {DEVICE}")
print("=" * 65)
# ══════════════════════════════════════════════════════════════════
# GEOMETRIC PRIMITIVES
# ══════════════════════════════════════════════════════════════════
def cayley_menger_vol2(pts):
pts = pts.float()
diff = pts.unsqueeze(-2) - pts.unsqueeze(-3)
d2 = (diff * diff).sum(-1)
B, V, _ = d2.shape
cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
s = (-1.0)**V; f = math.factorial(V-1)
return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm)
def cv_loss(emb, target=0.2, n_samples=16):
B = emb.shape[0]
if B < 5: return torch.tensor(0.0, device=emb.device)
vols = []
for _ in range(n_samples):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12))
stacked = torch.stack(vols)
return (stacked.std() / (stacked.mean() + 1e-8) - target).abs()
@torch.no_grad()
def cv_metric(emb, n_samples=200):
B = emb.shape[0]
if B < 5: return 0.0
vols = []
for _ in range(n_samples):
idx = torch.randperm(B, device=emb.device)[:5]
v2 = cayley_menger_vol2(emb[idx].unsqueeze(0))
v = torch.sqrt(F.relu(v2[0]) + 1e-12).item()
if v > 0: vols.append(v)
if len(vols) < 10: return 0.0
a = np.array(vols)
return float(a.std() / (a.mean() + 1e-8))
def infonce(a, b, temperature=0.07):
a = F.normalize(a, dim=-1); b = F.normalize(b, dim=-1)
logits = (a @ b.T) / temperature
labels = torch.arange(logits.shape[0], device=logits.device)
loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
with torch.no_grad():
acc = (logits.argmax(-1) == labels).float().mean().item()
return loss, acc
def whitened_procrustes_loss(emb, targets):
B = emb.shape[0]
if B < 10: return torch.tensor(0.0, device=emb.device)
em = emb.float().mean(0, keepdim=True)
tm = targets.float().mean(0, keepdim=True)
cos = F.cosine_similarity(emb.float() - em, targets.float() - tm, dim=-1)
return 1.0 - cos.mean()
class EmbeddingAutograd(torch.autograd.Function):
@staticmethod
def forward(ctx, x, embedding, anchors, tang, sep):
ctx.save_for_backward(embedding, anchors)
ctx.tang = tang; ctx.sep = sep
return x
@staticmethod
def backward(ctx, grad_output):
embedding, anchors = ctx.saved_tensors
emb_n = F.normalize(embedding.detach().float(), dim=-1)
anchors_n = F.normalize(anchors.detach().float(), dim=-1)
grad_f = grad_output.float()
radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n
corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial
if ctx.sep > 0:
cos_to = emb_n @ anchors_n.T
nearest = anchors_n[cos_to.argmax(dim=-1)]
toward = (corrected * nearest).sum(-1, keepdim=True)
corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest
return corrected.to(grad_output.dtype), None, None, None, None
# ══════════════════════════════════════════════════════════════════
# FROZEN SOUP
# ══════════════════════════════════════════════════════════════════
class Constellation(nn.Module):
def __init__(self):
super().__init__()
self.anchors = nn.Parameter(F.normalize(torch.randn(N_ANCHORS, D_ANCHOR), dim=-1))
def triangulate(self, emb):
a = F.normalize(self.anchors, dim=-1)
cos = emb @ a.T
return 1.0 - cos, cos.argmax(dim=-1)
class Patchwork(nn.Module):
def __init__(self):
super().__init__()
self.n_comp = N_COMP
asgn = torch.arange(N_ANCHORS) % N_COMP
self.register_buffer("asgn", asgn)
self.comps = nn.ModuleList([nn.Sequential(
nn.Linear((asgn == k).sum().item(), D_COMP * 2), nn.GELU(),
nn.Linear(D_COMP * 2, D_COMP), nn.LayerNorm(D_COMP))
for k in range(N_COMP)])
def forward(self, tri):
return torch.cat([self.comps[k](tri[:, self.asgn == k])
for k in range(self.n_comp)], -1)
class FrozenSoup(nn.Module):
def __init__(self):
super().__init__()
self.constellation = Constellation()
self.patchwork = Patchwork()
pw_dim = N_COMP * D_COMP
self.classifier = nn.Sequential(
nn.Linear(pw_dim + D_ANCHOR, pw_dim), nn.GELU(),
nn.LayerNorm(pw_dim), nn.Dropout(0.0),
nn.Linear(pw_dim, N_CLASSES))
def forward(self, emb_128):
tri, nearest = self.constellation.triangulate(emb_128)
pw = self.patchwork(tri)
logits = self.classifier(torch.cat([pw, emb_128], -1))
return logits, tri, nearest
# ══════════════════════════════════════════════════════════════════
# FROM-SCRATCH ViT ENCODER
# ══════════════════════════════════════════════════════════════════
class GeoLIPViTEncoder(nn.Module):
def __init__(self):
super().__init__()
self.patch_embed = nn.Conv2d(3, D_MODEL, kernel_size=PATCH_SIZE,
stride=PATCH_SIZE)
self.cls_token = nn.Parameter(torch.zeros(1, 1, D_MODEL))
self.pos_embed = nn.Parameter(torch.zeros(1, N_PATCHES + 1, D_MODEL))
self.embed_norm = nn.LayerNorm(D_MODEL)
self.embed_drop = nn.Dropout(DROPOUT)
encoder_layer = nn.TransformerEncoderLayer(
d_model=D_MODEL, nhead=N_HEADS, dim_feedforward=D_FF,
dropout=DROPOUT, activation="gelu", batch_first=True,
norm_first=True)
self.encoder = nn.TransformerEncoder(
encoder_layer, num_layers=N_LAYERS, enable_nested_tensor=False)
self.output_proj = nn.Sequential(
nn.Linear(D_MODEL, D_MODEL), nn.GELU(),
nn.LayerNorm(D_MODEL),
nn.Linear(D_MODEL, D_ANCHOR))
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, pixel_values):
B = pixel_values.shape[0]
x = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
cls = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls, x], dim=1) + self.pos_embed
x = self.embed_drop(self.embed_norm(x))
x = self.encoder(x)
pooled = x[:, 1:, :].mean(dim=1)
return F.normalize(self.output_proj(pooled), dim=-1)
# ══════════════════════════════════════════════════════════════════
# LOAD SOUP + PRE-COMPUTE TARGETS
# ══════════════════════════════════════════════════════════════════
print(f"\n Loading soup...")
ckpt = torch.load("checkpoints/base_tier_best.pt", map_location="cpu", weights_only=False)
soup = FrozenSoup()
soup_sd = {k: v for k, v in ckpt["state_dict"].items()
if k.startswith("constellation.") or k.startswith("patchwork.") or k.startswith("classifier.")}
soup.load_state_dict(soup_sd, strict=True)
soup = soup.eval().to(DEVICE)
for p in soup.parameters():
p.requires_grad = False
consensus_cv = ckpt.get("consensus_cv_128", 0.27)
print(f" Soup: mAP={ckpt['mAP']:.3f} CV_target={consensus_cv:.4f}")
# Rebuild projectors for target generation
class ExpertProjector(nn.Module):
def __init__(self):
super().__init__()
self.proj = nn.Sequential(nn.Linear(768, D_ANCHOR), nn.LayerNorm(D_ANCHOR))
def forward(self, x):
return F.normalize(self.proj(x), dim=-1)
from datasets import load_dataset
projectors = nn.ModuleList([ExpertProjector() for _ in range(3)])
proj_sd = {k.replace("projectors.", ""): v for k, v in ckpt["state_dict"].items()
if k.startswith("projectors.")}
projectors.load_state_dict(proj_sd)
projectors = projectors.eval().to(DEVICE)
for split_name, split_key in [("train", "train"), ("val", "val")]:
cache_path = f"cached_{split_name}_targets.pt"
if os.path.exists(cache_path):
cached = torch.load(cache_path, weights_only=False)
if split_name == "train":
train_targets = cached["targets"]; train_labels = cached["labels"]
train_ids = cached["image_ids"]; train_id_map = {iid: i for i, iid in enumerate(train_ids)}
N_train = len(train_ids)
else:
val_targets = cached["targets"]; val_labels = cached["labels"]
val_ids = cached["image_ids"]; val_id_map = {iid: i for i, iid in enumerate(val_ids)}
N_val = len(val_ids)
print(f" {split_name}: loaded cached targets ({len(cached['targets']):,})")
continue
print(f" Computing {split_name} targets...")
ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split=split_key)
ids = ref["image_id"]; N = len(ids)
id_map = {iid: i for i, iid in enumerate(ids)}
labels = torch.zeros(N, N_CLASSES)
for i, labs in enumerate(ref["labels"]):
for l in labs:
if l < N_CLASSES: labels[i, l] = 1.0
expert_feats = []
for name in tqdm(EXPERTS, desc=f" Loading {split_name} experts"):
ds = load_dataset("AbstractPhil/bulk-coco-features", name, split=split_key)
feats = torch.zeros(N, 768)
for row in ds:
if row["image_id"] in id_map:
feats[id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32)
expert_feats.append(feats)
del ds
targets = torch.zeros(N, D_ANCHOR)
with torch.no_grad():
for j in tqdm(range(0, N, 512), desc=f" Fusing {split_name}"):
end = min(j + 512, N)
batch = [expert_feats[e][j:end].to(DEVICE) for e in range(3)]
projected = [projectors[e](batch[e]) for e in range(3)]
fused = F.normalize(sum(projected) / 3, dim=-1)
targets[j:end] = fused.cpu()
torch.save({"targets": targets, "labels": labels, "image_ids": ids}, cache_path)
print(f" {split_name}: {N:,} targets computed and cached")
if split_name == "train":
train_targets = targets; train_labels = labels
train_ids = ids; train_id_map = id_map; N_train = N
else:
val_targets = targets; val_labels = labels
val_ids = ids; val_id_map = id_map; N_val = N
del expert_feats; gc.collect()
del projectors, proj_sd; gc.collect()
train_targets_gpu = train_targets.to(DEVICE)
train_labels_gpu = train_labels.to(DEVICE)
val_targets_gpu = val_targets.to(DEVICE)
anchors_frozen = soup.constellation.anchors.detach()
# Image preprocessing
from torchvision import transforms
img_transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# ══════════════════════════════════════════════════════════════════
# PRE-CACHE IMAGES AS TENSORS
# ══════════════════════════════════════════════════════════════════
def cache_images(split_name, split_key, id_map, N):
cache_path = f"cached_{split_name}_images.pt"
if os.path.exists(cache_path):
print(f" Loading cached {split_name} images...")
data = torch.load(cache_path, weights_only=True)
print(f" {split_name}: {data.shape} ({data.shape[0] * data.element_size() * data.nelement() / data.shape[0] / 1e6:.1f} MB/img)")
return data
print(f" Caching {split_name} images ({N:,})...")
images = torch.zeros(N, 3, IMAGE_SIZE, IMAGE_SIZE, dtype=torch.float16)
stream = load_dataset("rafaelpadilla/coco2017", split=split_key,
revision="refs/convert/parquet", streaming=True)
cached = 0
for row in tqdm(stream, desc=f" Caching {split_name}", total=N):
iid = row.get("image_id")
if iid not in id_map:
continue
try:
img = row["image"].convert("RGB")
tensor = img_transform(img).half()
images[id_map[iid]] = tensor
cached += 1
except:
continue
print(f" Cached {cached}/{N} images")
torch.save(images, cache_path)
size_mb = os.path.getsize(cache_path) / 1e6
print(f" Saved: {cache_path} ({size_mb:.0f} MB)")
return images
train_images = cache_images("train", "train", train_id_map, N_train)
val_images = cache_images("val", "validation", val_id_map, N_val)
# ══════════════════════════════════════════════════════════════════
# BUILD ENCODER
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("BUILD ENCODER")
print(f"{'='*65}")
encoder = GeoLIPViTEncoder().to(DEVICE)
n_params = sum(p.numel() for p in encoder.parameters())
print(f" Architecture: {N_LAYERS}L/{D_MODEL}d/{N_HEADS}h, patch{PATCH_SIZE}")
print(f" Input: {IMAGE_SIZE}Γ—{IMAGE_SIZE} β†’ {N_PATCHES} patches")
print(f" Output: {D_ANCHOR}-d (on hypersphere)")
print(f" Parameters: {n_params:,}")
# ══════════════════════════════════════════════════════════════════
# EVALUATION
# ══════════════════════════════════════════════════════════════════
@torch.no_grad()
def evaluate(encoder, soup, val_images, val_targets, val_labels, desc="Val"):
encoder.eval()
N = val_images.shape[0]
all_logits = torch.zeros(N, N_CLASSES)
all_embs = torch.zeros(N, D_ANCHOR)
n_seen = 0
for j in tqdm(range(0, N, BATCH), desc=f" {desc}", leave=False):
end = min(j + BATCH, N)
pixels = val_images[j:end].float().to(DEVICE)
# Skip zero images (failed to cache)
mask = pixels.abs().sum(dim=(1, 2, 3)) > 0.1
if mask.sum() == 0:
continue
emb = encoder(pixels[mask])
logits, _, nearest = soup(emb)
k = 0
for idx in range(j, end):
if idx - j < len(mask) and mask[idx - j]:
all_logits[idx] = logits[k].cpu().float()
all_embs[idx] = emb[k].cpu().float()
k += 1
n_seen += 1
# mAP
v_lab = val_labels
ap_sum, nv = 0, 0
for c in range(N_CLASSES):
if v_lab[:, c].sum() > 0:
si = all_logits[:, c].argsort(descending=True)
st = v_lab[:, c][si]
pak = st.cumsum(0) / torch.arange(1, len(st) + 1).float()
ap_sum += (pak * st).sum().item() / st.sum().item(); nv += 1
mAP = ap_sum / max(nv, 1)
# F1
vp = (all_logits.sigmoid() > 0.5).float()
tp = (vp * v_lab).sum(0); fp = (vp * (1 - v_lab)).sum(0)
fn = ((1 - vp) * v_lab).sum(0)
pr = tp / (tp + fp + 1e-8); rc = tp / (tp + fn + 1e-8)
f1 = 2 * pr * rc / (pr + rc + 1e-8)
macro_f1 = f1[f1 > 0].mean().item()
# Cosine to targets
valid = all_embs.norm(dim=-1) > 0.1
v_cos = F.cosine_similarity(
all_embs[valid], val_targets[valid], dim=-1).mean().item() if valid.sum() > 0 else 0.0
# R@1
if valid.sum() > 100:
sim = all_embs[valid] @ val_targets[valid].T
r1 = (sim.argmax(-1) == torch.arange(valid.sum())).float().mean().item()
else:
r1 = 0.0
# Active anchors
valid_embs = all_embs[valid].to(DEVICE)
if valid_embs.shape[0] > 0:
_, v_nearest = soup.constellation.triangulate(valid_embs)
n_active = v_nearest.cpu().unique().numel()
else:
n_active = 0
# CV
v_cv = cv_metric(valid_embs[:2000]) if valid_embs.shape[0] > 100 else 0.0
return {
"mAP": mAP, "f1": macro_f1, "r1": r1, "cos": v_cos,
"cv": v_cv, "n_active": n_active, "n_seen": n_seen,
}
# ══════════════════════════════════════════════════════════════════
# TRAINING
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*65}")
print("TRAINING")
print(f" {EPOCHS} epochs, lr={LR}, batch={BATCH}")
print(f" Losses: InfoNCE + MSE + CV + BCE + Procrustes alignment")
print(f" CV target: {consensus_cv:.4f}")
print(f" Images: train={N_train:,} val={N_val:,} (cached as tensors)")
print(f"{'='*65}")
optimizer = torch.optim.Adam(encoder.parameters(), lr=LR)
n_batches = N_train // BATCH
total_steps = n_batches * EPOCHS
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
[torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01,
total_iters=WARMUP_STEPS),
torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=max(total_steps - WARMUP_STEPS, 1), eta_min=1e-6)],
milestones=[WARMUP_STEPS])
scaler = torch.amp.GradScaler("cuda")
os.makedirs("checkpoints", exist_ok=True)
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("runs/geolip_vit_encoder")
best_mAP = 0.0
gs = 0
for epoch in range(EPOCHS):
encoder.train()
t0 = time.time()
perm = torch.randperm(N_train)
# Accumulators
acc = {"loss": 0, "nce": 0, "mse": 0, "bce": 0, "cv": 0, "align": 0,
"nce_acc": 0, "n": 0}
pbar = tqdm(range(0, N_train, BATCH),
desc=f"E{epoch+1:2d}/{EPOCHS} train", unit="batch")
for i in pbar:
idx = perm[i:i+BATCH]
if len(idx) < 4:
continue
pixels = train_images[idx].float().to(DEVICE)
targets = train_targets_gpu[idx]
labels = train_labels_gpu[idx]
# Skip batches with too many zero images
valid = pixels.abs().sum(dim=(1, 2, 3)) > 0.1
if valid.sum() < 4:
continue
pixels = pixels[valid]
targets = targets[valid]
labels = labels[valid]
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
emb = encoder(pixels)
emb = EmbeddingAutograd.apply(emb, emb, anchors_frozen, 0.01, 1.0)
l_nce, nce_acc = infonce(emb, targets)
l_mse = F.mse_loss(emb, targets)
l_cv = cv_loss(emb, target=consensus_cv)
l_align = whitened_procrustes_loss(emb, targets)
logits, _, _ = soup(emb)
l_bce = F.binary_cross_entropy_with_logits(logits, labels)
loss = (1.0 * l_nce + 0.5 * l_mse + 0.3 * l_bce
+ 0.5 * l_align + 0.001 * l_cv)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(encoder.parameters(), GRAD_CLIP)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
scheduler.step()
acc["loss"] += loss.item()
acc["nce"] += l_nce.item()
acc["mse"] += l_mse.item()
acc["bce"] += l_bce.item()
acc["cv"] += l_cv.item()
acc["align"] += l_align.item()
acc["nce_acc"] += nce_acc
acc["n"] += 1
gs += 1
# Tensorboard step logging
if gs % 50 == 0:
writer.add_scalar("step/loss", loss.item(), gs)
writer.add_scalar("step/nce", l_nce.item(), gs)
writer.add_scalar("step/mse", l_mse.item(), gs)
writer.add_scalar("step/bce", l_bce.item(), gs)
writer.add_scalar("step/cv", l_cv.item(), gs)
writer.add_scalar("step/align", l_align.item(), gs)
writer.add_scalar("step/nce_acc", nce_acc, gs)
writer.add_scalar("step/lr", scheduler.get_last_lr()[0], gs)
# Update tqdm
if acc["n"] % 20 == 0:
d = acc["n"]
pbar.set_postfix(
loss=f"{acc['loss']/d:.4f}",
nce_acc=f"{acc['nce_acc']/d:.3f}",
cos=f"{1-acc['align']/d:.3f}",
ordered=True)
elapsed = time.time() - t0
d = max(acc["n"], 1)
print(f" E{epoch+1} train: {elapsed:.0f}s "
f"loss={acc['loss']/d:.4f} nce={acc['nce']/d:.4f} "
f"mse={acc['mse']/d:.4f} bce={acc['bce']/d:.4f} "
f"nce_acc={acc['nce_acc']/d:.3f}")
# Epoch tensorboard
writer.add_scalar("epoch/train_loss", acc["loss"] / d, epoch + 1)
writer.add_scalar("epoch/train_nce", acc["nce"] / d, epoch + 1)
writer.add_scalar("epoch/train_mse", acc["mse"] / d, epoch + 1)
writer.add_scalar("epoch/train_bce", acc["bce"] / d, epoch + 1)
writer.add_scalar("epoch/train_cv", acc["cv"] / d, epoch + 1)
writer.add_scalar("epoch/train_align", acc["align"] / d, epoch + 1)
writer.add_scalar("epoch/train_nce_acc", acc["nce_acc"] / d, epoch + 1)
# ── Validation ──
m = evaluate(encoder, soup, val_images, val_targets, val_labels)
writer.add_scalar("epoch/val_mAP", m["mAP"], epoch + 1)
writer.add_scalar("epoch/val_F1", m["f1"], epoch + 1)
writer.add_scalar("epoch/val_R@1", m["r1"], epoch + 1)
writer.add_scalar("epoch/val_cos", m["cos"], epoch + 1)
writer.add_scalar("epoch/val_cv", m["cv"], epoch + 1)
writer.add_scalar("epoch/val_anchors", m["n_active"], epoch + 1)
mk = ""
if m["mAP"] > best_mAP:
best_mAP = m["mAP"]
torch.save({
"encoder_state_dict": encoder.state_dict(),
"config": {"d_model": D_MODEL, "n_heads": N_HEADS,
"n_layers": N_LAYERS, "d_ff": D_FF,
"patch_size": PATCH_SIZE, "image_size": IMAGE_SIZE,
"output_dim": D_ANCHOR},
"mAP": m["mAP"], "f1": m["f1"], "r1": m["r1"],
"cos": m["cos"], "cv": m["cv"],
"epoch": epoch + 1, "n_active": m["n_active"],
"consensus_cv": consensus_cv,
}, "checkpoints/geolip_vit_encoder_best.pt")
mk = " β˜…"
# Save every epoch checkpoint
torch.save({
"encoder_state_dict": encoder.state_dict(),
"epoch": epoch + 1, "mAP": m["mAP"],
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"scaler": scaler.state_dict(),
"gs": gs,
}, f"checkpoints/geolip_vit_e{epoch+1:02d}.pt")
print(f" E{epoch+1} val: mAP={m['mAP']:.3f} F1={m['f1']:.3f} "
f"R@1={m['r1']:.3f} cos={m['cos']:.3f} cv={m['cv']:.4f} "
f"anchors={m['n_active']}/256 seen={m['n_seen']}/{N_val}{mk}")
writer.close()
print(f"\n Best mAP: {best_mAP:.3f}")
print(f" Encoder: {n_params:,} params (from scratch)")
print(f" Checkpoints saved every epoch in checkpoints/")
print(f" Tensorboard: runs/geolip_vit_encoder")
print(f"\n{'='*65}\nDONE\n{'='*65}")