| |
| """ |
| GEOLIP VISION ENCODER β FROM SCRATCH |
| ====================================== |
| From-scratch ViT trained against frozen soup consensus targets. |
| |
| Phase 0: Pre-compute consensus targets from frozen soup |
| Phase 1: Pre-cache all COCO images as tensors (once, then reuse) |
| Phase 2: Train from-scratch ViT with full GeoLIP loss stack |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import os |
| import gc |
| import time |
| import math |
| import numpy as np |
| from tqdm import tqdm |
|
|
| DEVICE = "cuda" |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| |
| D_MODEL = 384 |
| N_HEADS = 6 |
| N_LAYERS = 6 |
| D_FF = 1536 |
| PATCH_SIZE = 16 |
| IMAGE_SIZE = 224 |
| D_ANCHOR = 128 |
| N_ANCHORS = 256 |
| N_CLASSES = 80 |
| N_COMP = 8 |
| D_COMP = 64 |
| DROPOUT = 0.1 |
|
|
| |
| BATCH = 48 |
| EPOCHS = 20 |
| LR = 3e-4 |
| WARMUP_STEPS = 500 |
| GRAD_CLIP = 1.0 |
|
|
| EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"] |
| N_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2 |
|
|
| print("=" * 65) |
| print("GEOLIP VISION ENCODER β FROM SCRATCH") |
| print(f" ViT: {N_LAYERS}L/{D_MODEL}d/{N_HEADS}h, patch{PATCH_SIZE}") |
| print(f" {N_PATCHES} patches + CLS β {D_ANCHOR}-d output") |
| print(f" Device: {DEVICE}") |
| print("=" * 65) |
|
|
|
|
| |
| |
| |
|
|
| def cayley_menger_vol2(pts): |
| pts = pts.float() |
| diff = pts.unsqueeze(-2) - pts.unsqueeze(-3) |
| d2 = (diff * diff).sum(-1) |
| B, V, _ = d2.shape |
| cm = torch.zeros(B, V+1, V+1, device=d2.device, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| s = (-1.0)**V; f = math.factorial(V-1) |
| return s / ((2.0**(V-1)) * f*f) * torch.linalg.det(cm) |
|
|
| def cv_loss(emb, target=0.2, n_samples=16): |
| B = emb.shape[0] |
| if B < 5: return torch.tensor(0.0, device=emb.device) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(B, device=emb.device)[:5] |
| v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| vols.append(torch.sqrt(F.relu(v2[0]) + 1e-12)) |
| stacked = torch.stack(vols) |
| return (stacked.std() / (stacked.mean() + 1e-8) - target).abs() |
|
|
| @torch.no_grad() |
| def cv_metric(emb, n_samples=200): |
| B = emb.shape[0] |
| if B < 5: return 0.0 |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(B, device=emb.device)[:5] |
| v2 = cayley_menger_vol2(emb[idx].unsqueeze(0)) |
| v = torch.sqrt(F.relu(v2[0]) + 1e-12).item() |
| if v > 0: vols.append(v) |
| if len(vols) < 10: return 0.0 |
| a = np.array(vols) |
| return float(a.std() / (a.mean() + 1e-8)) |
|
|
| def infonce(a, b, temperature=0.07): |
| a = F.normalize(a, dim=-1); b = F.normalize(b, dim=-1) |
| logits = (a @ b.T) / temperature |
| labels = torch.arange(logits.shape[0], device=logits.device) |
| loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2 |
| with torch.no_grad(): |
| acc = (logits.argmax(-1) == labels).float().mean().item() |
| return loss, acc |
|
|
| def whitened_procrustes_loss(emb, targets): |
| B = emb.shape[0] |
| if B < 10: return torch.tensor(0.0, device=emb.device) |
| em = emb.float().mean(0, keepdim=True) |
| tm = targets.float().mean(0, keepdim=True) |
| cos = F.cosine_similarity(emb.float() - em, targets.float() - tm, dim=-1) |
| return 1.0 - cos.mean() |
|
|
| class EmbeddingAutograd(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, embedding, anchors, tang, sep): |
| ctx.save_for_backward(embedding, anchors) |
| ctx.tang = tang; ctx.sep = sep |
| return x |
| @staticmethod |
| def backward(ctx, grad_output): |
| embedding, anchors = ctx.saved_tensors |
| emb_n = F.normalize(embedding.detach().float(), dim=-1) |
| anchors_n = F.normalize(anchors.detach().float(), dim=-1) |
| grad_f = grad_output.float() |
| radial = (grad_f * emb_n).sum(-1, keepdim=True) * emb_n |
| corrected = (grad_f - radial) + (1.0 - ctx.tang) * radial |
| if ctx.sep > 0: |
| cos_to = emb_n @ anchors_n.T |
| nearest = anchors_n[cos_to.argmax(dim=-1)] |
| toward = (corrected * nearest).sum(-1, keepdim=True) |
| corrected = corrected - ctx.sep * (toward > 0).float() * toward * nearest |
| return corrected.to(grad_output.dtype), None, None, None, None |
|
|
|
|
| |
| |
| |
|
|
| class Constellation(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.anchors = nn.Parameter(F.normalize(torch.randn(N_ANCHORS, D_ANCHOR), dim=-1)) |
| def triangulate(self, emb): |
| a = F.normalize(self.anchors, dim=-1) |
| cos = emb @ a.T |
| return 1.0 - cos, cos.argmax(dim=-1) |
|
|
| class Patchwork(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.n_comp = N_COMP |
| asgn = torch.arange(N_ANCHORS) % N_COMP |
| self.register_buffer("asgn", asgn) |
| self.comps = nn.ModuleList([nn.Sequential( |
| nn.Linear((asgn == k).sum().item(), D_COMP * 2), nn.GELU(), |
| nn.Linear(D_COMP * 2, D_COMP), nn.LayerNorm(D_COMP)) |
| for k in range(N_COMP)]) |
| def forward(self, tri): |
| return torch.cat([self.comps[k](tri[:, self.asgn == k]) |
| for k in range(self.n_comp)], -1) |
|
|
| class FrozenSoup(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.constellation = Constellation() |
| self.patchwork = Patchwork() |
| pw_dim = N_COMP * D_COMP |
| self.classifier = nn.Sequential( |
| nn.Linear(pw_dim + D_ANCHOR, pw_dim), nn.GELU(), |
| nn.LayerNorm(pw_dim), nn.Dropout(0.0), |
| nn.Linear(pw_dim, N_CLASSES)) |
| def forward(self, emb_128): |
| tri, nearest = self.constellation.triangulate(emb_128) |
| pw = self.patchwork(tri) |
| logits = self.classifier(torch.cat([pw, emb_128], -1)) |
| return logits, tri, nearest |
|
|
|
|
| |
| |
| |
|
|
| class GeoLIPViTEncoder(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.patch_embed = nn.Conv2d(3, D_MODEL, kernel_size=PATCH_SIZE, |
| stride=PATCH_SIZE) |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, D_MODEL)) |
| self.pos_embed = nn.Parameter(torch.zeros(1, N_PATCHES + 1, D_MODEL)) |
| self.embed_norm = nn.LayerNorm(D_MODEL) |
| self.embed_drop = nn.Dropout(DROPOUT) |
|
|
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=D_MODEL, nhead=N_HEADS, dim_feedforward=D_FF, |
| dropout=DROPOUT, activation="gelu", batch_first=True, |
| norm_first=True) |
| self.encoder = nn.TransformerEncoder( |
| encoder_layer, num_layers=N_LAYERS, enable_nested_tensor=False) |
|
|
| self.output_proj = nn.Sequential( |
| nn.Linear(D_MODEL, D_MODEL), nn.GELU(), |
| nn.LayerNorm(D_MODEL), |
| nn.Linear(D_MODEL, D_ANCHOR)) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Conv2d): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.ones_(m.weight); nn.init.zeros_(m.bias) |
| nn.init.trunc_normal_(self.pos_embed, std=0.02) |
| nn.init.trunc_normal_(self.cls_token, std=0.02) |
|
|
| def forward(self, pixel_values): |
| B = pixel_values.shape[0] |
| x = self.patch_embed(pixel_values).flatten(2).transpose(1, 2) |
| cls = self.cls_token.expand(B, -1, -1) |
| x = torch.cat([cls, x], dim=1) + self.pos_embed |
| x = self.embed_drop(self.embed_norm(x)) |
| x = self.encoder(x) |
| pooled = x[:, 1:, :].mean(dim=1) |
| return F.normalize(self.output_proj(pooled), dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n Loading soup...") |
| ckpt = torch.load("checkpoints/base_tier_best.pt", map_location="cpu", weights_only=False) |
| soup = FrozenSoup() |
| soup_sd = {k: v for k, v in ckpt["state_dict"].items() |
| if k.startswith("constellation.") or k.startswith("patchwork.") or k.startswith("classifier.")} |
| soup.load_state_dict(soup_sd, strict=True) |
| soup = soup.eval().to(DEVICE) |
| for p in soup.parameters(): |
| p.requires_grad = False |
| consensus_cv = ckpt.get("consensus_cv_128", 0.27) |
| print(f" Soup: mAP={ckpt['mAP']:.3f} CV_target={consensus_cv:.4f}") |
|
|
| |
| class ExpertProjector(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.proj = nn.Sequential(nn.Linear(768, D_ANCHOR), nn.LayerNorm(D_ANCHOR)) |
| def forward(self, x): |
| return F.normalize(self.proj(x), dim=-1) |
|
|
| from datasets import load_dataset |
|
|
| projectors = nn.ModuleList([ExpertProjector() for _ in range(3)]) |
| proj_sd = {k.replace("projectors.", ""): v for k, v in ckpt["state_dict"].items() |
| if k.startswith("projectors.")} |
| projectors.load_state_dict(proj_sd) |
| projectors = projectors.eval().to(DEVICE) |
|
|
| for split_name, split_key in [("train", "train"), ("val", "val")]: |
| cache_path = f"cached_{split_name}_targets.pt" |
| if os.path.exists(cache_path): |
| cached = torch.load(cache_path, weights_only=False) |
| if split_name == "train": |
| train_targets = cached["targets"]; train_labels = cached["labels"] |
| train_ids = cached["image_ids"]; train_id_map = {iid: i for i, iid in enumerate(train_ids)} |
| N_train = len(train_ids) |
| else: |
| val_targets = cached["targets"]; val_labels = cached["labels"] |
| val_ids = cached["image_ids"]; val_id_map = {iid: i for i, iid in enumerate(val_ids)} |
| N_val = len(val_ids) |
| print(f" {split_name}: loaded cached targets ({len(cached['targets']):,})") |
| continue |
|
|
| print(f" Computing {split_name} targets...") |
| ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split=split_key) |
| ids = ref["image_id"]; N = len(ids) |
| id_map = {iid: i for i, iid in enumerate(ids)} |
| labels = torch.zeros(N, N_CLASSES) |
| for i, labs in enumerate(ref["labels"]): |
| for l in labs: |
| if l < N_CLASSES: labels[i, l] = 1.0 |
|
|
| expert_feats = [] |
| for name in tqdm(EXPERTS, desc=f" Loading {split_name} experts"): |
| ds = load_dataset("AbstractPhil/bulk-coco-features", name, split=split_key) |
| feats = torch.zeros(N, 768) |
| for row in ds: |
| if row["image_id"] in id_map: |
| feats[id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32) |
| expert_feats.append(feats) |
| del ds |
|
|
| targets = torch.zeros(N, D_ANCHOR) |
| with torch.no_grad(): |
| for j in tqdm(range(0, N, 512), desc=f" Fusing {split_name}"): |
| end = min(j + 512, N) |
| batch = [expert_feats[e][j:end].to(DEVICE) for e in range(3)] |
| projected = [projectors[e](batch[e]) for e in range(3)] |
| fused = F.normalize(sum(projected) / 3, dim=-1) |
| targets[j:end] = fused.cpu() |
|
|
| torch.save({"targets": targets, "labels": labels, "image_ids": ids}, cache_path) |
| print(f" {split_name}: {N:,} targets computed and cached") |
|
|
| if split_name == "train": |
| train_targets = targets; train_labels = labels |
| train_ids = ids; train_id_map = id_map; N_train = N |
| else: |
| val_targets = targets; val_labels = labels |
| val_ids = ids; val_id_map = id_map; N_val = N |
| del expert_feats; gc.collect() |
|
|
| del projectors, proj_sd; gc.collect() |
|
|
| train_targets_gpu = train_targets.to(DEVICE) |
| train_labels_gpu = train_labels.to(DEVICE) |
| val_targets_gpu = val_targets.to(DEVICE) |
| anchors_frozen = soup.constellation.anchors.detach() |
|
|
| |
| from torchvision import transforms |
| img_transform = transforms.Compose([ |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
|
|
| |
| |
| |
|
|
| def cache_images(split_name, split_key, id_map, N): |
| cache_path = f"cached_{split_name}_images.pt" |
| if os.path.exists(cache_path): |
| print(f" Loading cached {split_name} images...") |
| data = torch.load(cache_path, weights_only=True) |
| print(f" {split_name}: {data.shape} ({data.shape[0] * data.element_size() * data.nelement() / data.shape[0] / 1e6:.1f} MB/img)") |
| return data |
|
|
| print(f" Caching {split_name} images ({N:,})...") |
| images = torch.zeros(N, 3, IMAGE_SIZE, IMAGE_SIZE, dtype=torch.float16) |
| stream = load_dataset("rafaelpadilla/coco2017", split=split_key, |
| revision="refs/convert/parquet", streaming=True) |
|
|
| cached = 0 |
| for row in tqdm(stream, desc=f" Caching {split_name}", total=N): |
| iid = row.get("image_id") |
| if iid not in id_map: |
| continue |
| try: |
| img = row["image"].convert("RGB") |
| tensor = img_transform(img).half() |
| images[id_map[iid]] = tensor |
| cached += 1 |
| except: |
| continue |
|
|
| print(f" Cached {cached}/{N} images") |
| torch.save(images, cache_path) |
| size_mb = os.path.getsize(cache_path) / 1e6 |
| print(f" Saved: {cache_path} ({size_mb:.0f} MB)") |
| return images |
|
|
| train_images = cache_images("train", "train", train_id_map, N_train) |
| val_images = cache_images("val", "validation", val_id_map, N_val) |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("BUILD ENCODER") |
| print(f"{'='*65}") |
|
|
| encoder = GeoLIPViTEncoder().to(DEVICE) |
| n_params = sum(p.numel() for p in encoder.parameters()) |
| print(f" Architecture: {N_LAYERS}L/{D_MODEL}d/{N_HEADS}h, patch{PATCH_SIZE}") |
| print(f" Input: {IMAGE_SIZE}Γ{IMAGE_SIZE} β {N_PATCHES} patches") |
| print(f" Output: {D_ANCHOR}-d (on hypersphere)") |
| print(f" Parameters: {n_params:,}") |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def evaluate(encoder, soup, val_images, val_targets, val_labels, desc="Val"): |
| encoder.eval() |
| N = val_images.shape[0] |
| all_logits = torch.zeros(N, N_CLASSES) |
| all_embs = torch.zeros(N, D_ANCHOR) |
| n_seen = 0 |
|
|
| for j in tqdm(range(0, N, BATCH), desc=f" {desc}", leave=False): |
| end = min(j + BATCH, N) |
| pixels = val_images[j:end].float().to(DEVICE) |
| |
| mask = pixels.abs().sum(dim=(1, 2, 3)) > 0.1 |
| if mask.sum() == 0: |
| continue |
|
|
| emb = encoder(pixels[mask]) |
| logits, _, nearest = soup(emb) |
|
|
| k = 0 |
| for idx in range(j, end): |
| if idx - j < len(mask) and mask[idx - j]: |
| all_logits[idx] = logits[k].cpu().float() |
| all_embs[idx] = emb[k].cpu().float() |
| k += 1 |
| n_seen += 1 |
|
|
| |
| v_lab = val_labels |
| ap_sum, nv = 0, 0 |
| for c in range(N_CLASSES): |
| if v_lab[:, c].sum() > 0: |
| si = all_logits[:, c].argsort(descending=True) |
| st = v_lab[:, c][si] |
| pak = st.cumsum(0) / torch.arange(1, len(st) + 1).float() |
| ap_sum += (pak * st).sum().item() / st.sum().item(); nv += 1 |
| mAP = ap_sum / max(nv, 1) |
|
|
| |
| vp = (all_logits.sigmoid() > 0.5).float() |
| tp = (vp * v_lab).sum(0); fp = (vp * (1 - v_lab)).sum(0) |
| fn = ((1 - vp) * v_lab).sum(0) |
| pr = tp / (tp + fp + 1e-8); rc = tp / (tp + fn + 1e-8) |
| f1 = 2 * pr * rc / (pr + rc + 1e-8) |
| macro_f1 = f1[f1 > 0].mean().item() |
|
|
| |
| valid = all_embs.norm(dim=-1) > 0.1 |
| v_cos = F.cosine_similarity( |
| all_embs[valid], val_targets[valid], dim=-1).mean().item() if valid.sum() > 0 else 0.0 |
|
|
| |
| if valid.sum() > 100: |
| sim = all_embs[valid] @ val_targets[valid].T |
| r1 = (sim.argmax(-1) == torch.arange(valid.sum())).float().mean().item() |
| else: |
| r1 = 0.0 |
|
|
| |
| valid_embs = all_embs[valid].to(DEVICE) |
| if valid_embs.shape[0] > 0: |
| _, v_nearest = soup.constellation.triangulate(valid_embs) |
| n_active = v_nearest.cpu().unique().numel() |
| else: |
| n_active = 0 |
|
|
| |
| v_cv = cv_metric(valid_embs[:2000]) if valid_embs.shape[0] > 100 else 0.0 |
|
|
| return { |
| "mAP": mAP, "f1": macro_f1, "r1": r1, "cos": v_cos, |
| "cv": v_cv, "n_active": n_active, "n_seen": n_seen, |
| } |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*65}") |
| print("TRAINING") |
| print(f" {EPOCHS} epochs, lr={LR}, batch={BATCH}") |
| print(f" Losses: InfoNCE + MSE + CV + BCE + Procrustes alignment") |
| print(f" CV target: {consensus_cv:.4f}") |
| print(f" Images: train={N_train:,} val={N_val:,} (cached as tensors)") |
| print(f"{'='*65}") |
|
|
| optimizer = torch.optim.Adam(encoder.parameters(), lr=LR) |
| n_batches = N_train // BATCH |
| total_steps = n_batches * EPOCHS |
| scheduler = torch.optim.lr_scheduler.SequentialLR( |
| optimizer, |
| [torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, |
| total_iters=WARMUP_STEPS), |
| torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=max(total_steps - WARMUP_STEPS, 1), eta_min=1e-6)], |
| milestones=[WARMUP_STEPS]) |
|
|
| scaler = torch.amp.GradScaler("cuda") |
| os.makedirs("checkpoints", exist_ok=True) |
|
|
| from torch.utils.tensorboard import SummaryWriter |
| writer = SummaryWriter("runs/geolip_vit_encoder") |
| best_mAP = 0.0 |
| gs = 0 |
|
|
| for epoch in range(EPOCHS): |
| encoder.train() |
| t0 = time.time() |
| perm = torch.randperm(N_train) |
|
|
| |
| acc = {"loss": 0, "nce": 0, "mse": 0, "bce": 0, "cv": 0, "align": 0, |
| "nce_acc": 0, "n": 0} |
|
|
| pbar = tqdm(range(0, N_train, BATCH), |
| desc=f"E{epoch+1:2d}/{EPOCHS} train", unit="batch") |
| for i in pbar: |
| idx = perm[i:i+BATCH] |
| if len(idx) < 4: |
| continue |
|
|
| pixels = train_images[idx].float().to(DEVICE) |
| targets = train_targets_gpu[idx] |
| labels = train_labels_gpu[idx] |
|
|
| |
| valid = pixels.abs().sum(dim=(1, 2, 3)) > 0.1 |
| if valid.sum() < 4: |
| continue |
| pixels = pixels[valid] |
| targets = targets[valid] |
| labels = labels[valid] |
|
|
| with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| emb = encoder(pixels) |
| emb = EmbeddingAutograd.apply(emb, emb, anchors_frozen, 0.01, 1.0) |
|
|
| l_nce, nce_acc = infonce(emb, targets) |
| l_mse = F.mse_loss(emb, targets) |
| l_cv = cv_loss(emb, target=consensus_cv) |
| l_align = whitened_procrustes_loss(emb, targets) |
|
|
| logits, _, _ = soup(emb) |
| l_bce = F.binary_cross_entropy_with_logits(logits, labels) |
|
|
| loss = (1.0 * l_nce + 0.5 * l_mse + 0.3 * l_bce |
| + 0.5 * l_align + 0.001 * l_cv) |
|
|
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(encoder.parameters(), GRAD_CLIP) |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad(set_to_none=True) |
| scheduler.step() |
|
|
| acc["loss"] += loss.item() |
| acc["nce"] += l_nce.item() |
| acc["mse"] += l_mse.item() |
| acc["bce"] += l_bce.item() |
| acc["cv"] += l_cv.item() |
| acc["align"] += l_align.item() |
| acc["nce_acc"] += nce_acc |
| acc["n"] += 1 |
| gs += 1 |
|
|
| |
| if gs % 50 == 0: |
| writer.add_scalar("step/loss", loss.item(), gs) |
| writer.add_scalar("step/nce", l_nce.item(), gs) |
| writer.add_scalar("step/mse", l_mse.item(), gs) |
| writer.add_scalar("step/bce", l_bce.item(), gs) |
| writer.add_scalar("step/cv", l_cv.item(), gs) |
| writer.add_scalar("step/align", l_align.item(), gs) |
| writer.add_scalar("step/nce_acc", nce_acc, gs) |
| writer.add_scalar("step/lr", scheduler.get_last_lr()[0], gs) |
|
|
| |
| if acc["n"] % 20 == 0: |
| d = acc["n"] |
| pbar.set_postfix( |
| loss=f"{acc['loss']/d:.4f}", |
| nce_acc=f"{acc['nce_acc']/d:.3f}", |
| cos=f"{1-acc['align']/d:.3f}", |
| ordered=True) |
|
|
| elapsed = time.time() - t0 |
| d = max(acc["n"], 1) |
| print(f" E{epoch+1} train: {elapsed:.0f}s " |
| f"loss={acc['loss']/d:.4f} nce={acc['nce']/d:.4f} " |
| f"mse={acc['mse']/d:.4f} bce={acc['bce']/d:.4f} " |
| f"nce_acc={acc['nce_acc']/d:.3f}") |
|
|
| |
| writer.add_scalar("epoch/train_loss", acc["loss"] / d, epoch + 1) |
| writer.add_scalar("epoch/train_nce", acc["nce"] / d, epoch + 1) |
| writer.add_scalar("epoch/train_mse", acc["mse"] / d, epoch + 1) |
| writer.add_scalar("epoch/train_bce", acc["bce"] / d, epoch + 1) |
| writer.add_scalar("epoch/train_cv", acc["cv"] / d, epoch + 1) |
| writer.add_scalar("epoch/train_align", acc["align"] / d, epoch + 1) |
| writer.add_scalar("epoch/train_nce_acc", acc["nce_acc"] / d, epoch + 1) |
|
|
| |
| m = evaluate(encoder, soup, val_images, val_targets, val_labels) |
|
|
| writer.add_scalar("epoch/val_mAP", m["mAP"], epoch + 1) |
| writer.add_scalar("epoch/val_F1", m["f1"], epoch + 1) |
| writer.add_scalar("epoch/val_R@1", m["r1"], epoch + 1) |
| writer.add_scalar("epoch/val_cos", m["cos"], epoch + 1) |
| writer.add_scalar("epoch/val_cv", m["cv"], epoch + 1) |
| writer.add_scalar("epoch/val_anchors", m["n_active"], epoch + 1) |
|
|
| mk = "" |
| if m["mAP"] > best_mAP: |
| best_mAP = m["mAP"] |
| torch.save({ |
| "encoder_state_dict": encoder.state_dict(), |
| "config": {"d_model": D_MODEL, "n_heads": N_HEADS, |
| "n_layers": N_LAYERS, "d_ff": D_FF, |
| "patch_size": PATCH_SIZE, "image_size": IMAGE_SIZE, |
| "output_dim": D_ANCHOR}, |
| "mAP": m["mAP"], "f1": m["f1"], "r1": m["r1"], |
| "cos": m["cos"], "cv": m["cv"], |
| "epoch": epoch + 1, "n_active": m["n_active"], |
| "consensus_cv": consensus_cv, |
| }, "checkpoints/geolip_vit_encoder_best.pt") |
| mk = " β
" |
|
|
| |
| torch.save({ |
| "encoder_state_dict": encoder.state_dict(), |
| "epoch": epoch + 1, "mAP": m["mAP"], |
| "optimizer": optimizer.state_dict(), |
| "scheduler": scheduler.state_dict(), |
| "scaler": scaler.state_dict(), |
| "gs": gs, |
| }, f"checkpoints/geolip_vit_e{epoch+1:02d}.pt") |
|
|
| print(f" E{epoch+1} val: mAP={m['mAP']:.3f} F1={m['f1']:.3f} " |
| f"R@1={m['r1']:.3f} cos={m['cos']:.3f} cv={m['cv']:.4f} " |
| f"anchors={m['n_active']}/256 seen={m['n_seen']}/{N_val}{mk}") |
|
|
| writer.close() |
|
|
| print(f"\n Best mAP: {best_mAP:.3f}") |
| print(f" Encoder: {n_params:,} params (from scratch)") |
| print(f" Checkpoints saved every epoch in checkpoints/") |
| print(f" Tensorboard: runs/geolip_vit_encoder") |
| print(f"\n{'='*65}\nDONE\n{'='*65}") |