#!/usr/bin/env python3 """ CIFAR-10 — Tri-Stream GeoLIP ViT v8 ===================================== v7→v8 changes: 1. GAL_UPDATE_INTERVAL: 50 → 25 (2× more frequent) 2. GAL_LR: 0.01 → 0.015 (+50% response) 3. Tracks nce_b and geo_nce_acc separately 4. stream_b_nce_weight=0.5, geo_nce_weight=0.5 """ import torch import torch.nn as nn import torch.nn.functional as F import os, time import numpy as np from tqdm import tqdm from torchvision import datasets, transforms from torch.utils.tensorboard import SummaryWriter DEVICE = "cuda" if torch.cuda.is_available() else "cpu" torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # ── Architecture ── NUM_CLASSES = 10 IMG_SIZE = 32 PATCH_SIZE = 4 EMBED_DIM = 384 STREAM_DIM = 192 N_BLOCKS = 9 N_HEADS = 8 OUTPUT_DIM = 256 N_ANCHORS = 128 N_GAL_ANCHORS = 64 N_COMP = 16 D_COMP = 128 ANCHOR_DROP = 0.10 CV_TARGET = 0.22 # ── Loss weights ── CV_WEIGHT = 0.1 ENABLE_AUTOGRAD = True AUTOGRAD_TANG = 1.0 AUTOGRAD_SEP = 0.1 LABEL_SMOOTHING = 0.1 INFONCE_WEIGHT = 0.1 BCE_WEIGHT = 1.0 CM_WEIGHT = 0.1 INFONCE_TEMP = 0.07 # ── v8: Stream B + Geo NCE weights ── STREAM_B_NCE_WEIGHT = 0.5 GEO_NCE_WEIGHT = 0.5 # ── v8: GAL — faster updates, stronger response ── GAL_UPDATE_INTERVAL = 25 # was 50 GAL_LR = 0.015 # was 0.01 (+50%) GAL_BUFFER_SIZE = 50000 USE_WHITENED_PROCRUSTES = False # ── Mastery queue ── MASTERY_PATIENCE = 50 MASTERY_MARGIN_START = 0.1 MASTERY_MARGIN_END = 0.3 MASTERY_MARGIN_WARMUP = 5000 MASTERY_MIN_SIZE = 1024 MASTERY_MAX_SIZE = 16384 MASTERY_INITIAL_SIZE = 4096 MASTERY_RESIZE_STEP = 2048 MASTERY_RESIZE_COOLDOWN = 5 MASTERY_OVERFIT_THRESH = 3.0 # ── Training ── BATCH = 256 EPOCHS = 100 LR = 3e-4 WARMUP = 5 GRAD_CLIP = 1.0 V1_CKPT = "" # set to checkpoint path for warm start print("=" * 60) print("CIFAR-10 — Tri-Stream GeoLIP ViT v8") print(f" Architecture: {N_BLOCKS}× TriStreamBlock") print(f" Sphere: {OUTPUT_DIM}-d, {N_ANCHORS} anchors, {N_COMP}×{D_COMP} pw") print(f" GAL: {N_GAL_ANCHORS} anchors, Procrustes every {GAL_UPDATE_INTERVAL} " f"batches (lr={GAL_LR}, whiten={USE_WHITENED_PROCRUSTES})") print(f" v8 fixes: uniform hypersphere init, gate_init=1/(2×{N_BLOCKS})") print(f" v8 fixes: InfoNCE on emb_b (w={STREAM_B_NCE_WEIGHT}) " f"+ geo_emb (w={GEO_NCE_WEIGHT})") print(f" Device: {DEVICE}") print("=" * 60) # ══════════════════════════════════════════════════════════════════ # DATA # ══════════════════════════════════════════════════════════════════ CIFAR_MEAN = (0.4914, 0.4822, 0.4465) CIFAR_STD = (0.2470, 0.2435, 0.2616) class DualAugDataset(torch.utils.data.Dataset): def __init__(self, base_ds, transform): self.base = base_ds; self.transform = transform def __len__(self): return len(self.base) def __getitem__(self, i): img, label = self.base[i] return self.transform(img), self.transform(img), label aug_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ColorJitter(0.2, 0.2, 0.2, 0.05), transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD), ]) val_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD), ]) raw_train = datasets.CIFAR10(root='./data', train=True, download=True) train_ds = DualAugDataset(raw_train, aug_transform) val_ds = datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transform) train_loader = torch.utils.data.DataLoader( train_ds, batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader( val_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True) print(f" Train: {len(train_ds):,} (two views) Val: {len(val_ds):,}") # ══════════════════════════════════════════════════════════════════ # BUILD MODEL # ══════════════════════════════════════════════════════════════════ print(f"\n Building model...") model = create_tri_stream_vit( num_classes=NUM_CLASSES, img_size=IMG_SIZE, patch_size=PATCH_SIZE, embed_dim=EMBED_DIM, stream_dim=STREAM_DIM, n_blocks=N_BLOCKS, n_heads=N_HEADS, output_dim=OUTPUT_DIM, n_anchors=N_ANCHORS, n_gal_anchors=N_GAL_ANCHORS, n_comp=N_COMP, d_comp=D_COMP, anchor_drop=ANCHOR_DROP, cv_target=CV_TARGET, dropout=0.1, infonce_temp=INFONCE_TEMP, infonce_weight=INFONCE_WEIGHT, bce_weight=BCE_WEIGHT, cm_weight=CM_WEIGHT, cv_weight=CV_WEIGHT, autograd_tang=AUTOGRAD_TANG, autograd_sep=AUTOGRAD_SEP, enable_autograd=ENABLE_AUTOGRAD, label_smoothing=LABEL_SMOOTHING, stream_b_nce_weight=STREAM_B_NCE_WEIGHT, geo_nce_weight=GEO_NCE_WEIGHT, ).to(DEVICE) if V1_CKPT and os.path.exists(V1_CKPT): ckpt = torch.load(V1_CKPT, map_location="cpu", weights_only=False) missing, unexpected = model.load_state_dict( ckpt["state_dict"], strict=False) print(f" ✓ Loaded weights: epoch {ckpt.get('epoch', '?')}") if missing: print(f" New params (expected): {len(missing)}") else: print(f" Training from scratch") total_params = sum(p.numel() for p in model.parameters()) print(f" Parameters: {total_params:,}") # ══════════════════════════════════════════════════════════════════ # OPTIMIZER + SCHEDULER # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*60}") print(f"TRAINING — {EPOCHS} epochs, lr={LR}, batch={BATCH}") print(f" GAL Procrustes: every {GAL_UPDATE_INTERVAL} batches, " f"lr={GAL_LR}, whiten={USE_WHITENED_PROCRUSTES}") print(f"{'='*60}") optimizer = torch.optim.Adam(model.parameters(), lr=LR) total_steps = len(train_loader) * EPOCHS warmup_steps = len(train_loader) * WARMUP scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, [torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=0.01, total_iters=warmup_steps), torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=max(total_steps - warmup_steps, 1), eta_min=1e-6)], milestones=[warmup_steps]) scaler = torch.amp.GradScaler("cuda") os.makedirs("checkpoints", exist_ok=True) writer = SummaryWriter("runs/cifar10_tri_stream_v8") best_acc = 0.0 gs = 0 # Mastery queue mastery = MasteryQueue( dim=OUTPUT_DIM, min_size=MASTERY_MIN_SIZE, max_size=MASTERY_MAX_SIZE, initial_size=MASTERY_INITIAL_SIZE, patience=MASTERY_PATIENCE, device=DEVICE, margin_start=MASTERY_MARGIN_START, margin_end=MASTERY_MARGIN_END, margin_warmup=MASTERY_MARGIN_WARMUP, resize_step=MASTERY_RESIZE_STEP, resize_cooldown=MASTERY_RESIZE_COOLDOWN, overfit_threshold=MASTERY_OVERFIT_THRESH) # GAL simplex buffer simplex_buf = SimplexBuffer( dim=STREAM_DIM, max_size=GAL_BUFFER_SIZE, device=DEVICE) gal_update_count = 0 # ══════════════════════════════════════════════════════════════════ # TRAINING LOOP # ══════════════════════════════════════════════════════════════════ for epoch in range(EPOCHS): model.train() t0 = time.time() acc_dict = { "loss": 0, "ce": 0, "bce": 0, "geo_bce": 0, "acc_a": 0, "acc_b": 0, "geo_acc": 0, "nce": 0, "nce_acc": 0, "nce_b": 0, "nce_b_acc": 0, "geo_nce": 0, "geo_nce_acc": 0, "cm": 0, "cm_valid": 0, "cv": 0, "cv_main": 0, "cv_geo": 0, "spread": 0, "mastery": 0, "hard_neg": 0, "hard_pos": 0, "correct": 0, "total": 0, "n": 0} pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="batch") for v1, v2, targets in pbar: v1 = v1.to(DEVICE, non_blocking=True) v2 = v2.to(DEVICE, non_blocking=True) targets = targets.to(DEVICE, non_blocking=True) with torch.amp.autocast("cuda", dtype=torch.bfloat16): out1 = model(v1, apply_autograd=True) out2 = model(v2, apply_autograd=True) loss, ld = model.compute_loss( out1, targets, output_aug=out2, mastery_queue=mastery) optimizer.zero_grad(set_to_none=True) scaler.scale(loss).backward() scaler.unscale_(optimizer) nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) scaler.step(optimizer); scaler.update() scheduler.step() mastery.check_activation(ld.get('nce_acc', 0)) pool_geo = out1.get('pool_geo') if pool_geo is not None: simplex_buf.push(pool_geo.float(), targets) gs += 1 if gs % GAL_UPDATE_INTERVAL == 0 and simplex_buf.size > 500: score = model.update_gal_anchors( simplex_buf, lr=GAL_LR, whiten=USE_WHITENED_PROCRUSTES) if score is not None: gal_update_count += 1 writer.add_scalar("step/procrustes_score", score, gs) # Track preds = out1['logits_a'].argmax(-1) correct = (preds == targets).sum().item() acc_dict["correct"] += correct acc_dict["total"] += targets.shape[0] acc_dict["loss"] += loss.item() for k in ["ce", "bce", "geo_bce", "nce", "nce_b", "geo_nce", "cm", "cv", "spread", "mastery"]: v = ld.get(k, 0) acc_dict[k] += v.item() if torch.is_tensor(v) else v acc_dict["acc_a"] += ld.get("acc_a", 0) acc_dict["acc_b"] += ld.get("acc_b", 0) acc_dict["geo_acc"] += ld.get("geo_acc", 0) acc_dict["nce_acc"] += ld.get("nce_acc", 0) acc_dict["nce_b_acc"] += ld.get("nce_b_acc", 0) acc_dict["geo_nce_acc"] += ld.get("geo_nce_acc", 0) acc_dict["cm_valid"] += ld.get("cm_valid", 0) acc_dict["cv_main"] += ld.get("cv_main", 0) acc_dict["cv_geo"] += ld.get("cv_geo", 0) acc_dict["hard_neg"] += ld.get("hard_neg_cos", 0) acc_dict["hard_pos"] += ld.get("hard_pos_cos", 0) acc_dict["n"] += 1 if acc_dict["n"] % 10 == 0: d = acc_dict["n"] ta = 100 * acc_dict["correct"] / acc_dict["total"] ga = 100 * acc_dict["geo_acc"] / d nb = acc_dict["nce_b_acc"] / d stg = "M" if mastery.active else "S1" pbar.set_postfix( loss=f"{acc_dict['loss']/d:.4f}", a=f"{ta:.0f}%", ga=f"{ga:.0f}%", nb=f"{nb:.2f}", stg=stg, gal=gal_update_count, ordered=True) if gs % 20 == 0: writer.add_scalar("step/loss", loss.item(), gs) writer.add_scalar("step/geo_acc", ld.get("geo_acc", 0), gs) writer.add_scalar("step/nce_b_acc", ld.get("nce_b_acc", 0), gs) writer.add_scalar("step/geo_nce_acc", ld.get("geo_nce_acc", 0), gs) gates_a = out1.get('gates_a', []) if gates_a: writer.add_scalar("step/gate_a_mean", sum(gates_a) / len(gates_a), gs) writer.add_scalar("step/gate_b_mean", sum(out1.get('gates_b', [0])) / max(len(gates_a), 1), gs) # ── Epoch stats ── elapsed = time.time() - t0 d = acc_dict["n"] train_acc = 100 * acc_dict["correct"] / acc_dict["total"] writer.add_scalar("epoch/train_loss", acc_dict["loss"] / d, epoch + 1) writer.add_scalar("epoch/train_acc", train_acc, epoch + 1) writer.add_scalar("epoch/acc_a", 100 * acc_dict["acc_a"] / d, epoch + 1) writer.add_scalar("epoch/acc_b", 100 * acc_dict["acc_b"] / d, epoch + 1) writer.add_scalar("epoch/geo_acc", 100 * acc_dict["geo_acc"] / d, epoch + 1) writer.add_scalar("epoch/nce_acc", acc_dict["nce_acc"] / d, epoch + 1) writer.add_scalar("epoch/nce_b_acc", acc_dict["nce_b_acc"] / d, epoch + 1) writer.add_scalar("epoch/geo_nce_acc", acc_dict["geo_nce_acc"] / d, epoch + 1) writer.add_scalar("epoch/cv_main", acc_dict["cv_main"] / d, epoch + 1) writer.add_scalar("epoch/cv_geo", acc_dict["cv_geo"] / d, epoch + 1) writer.add_scalar("epoch/cm_valid", acc_dict["cm_valid"] / d, epoch + 1) writer.add_scalar("epoch/gal_updates", gal_update_count, epoch + 1) # ── Validation ── model.eval() val_correct, val_total, val_loss_sum, val_n = 0, 0, 0, 0 val_geo_correct = 0 val_b_correct = 0 all_embs = [] with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16): for images, labels_v in val_loader: images = images.to(DEVICE, non_blocking=True) labels_v = labels_v.to(DEVICE, non_blocking=True) out = model(images, apply_autograd=False) preds = out['logits_a'].argmax(dim=-1) val_correct += (preds == labels_v).sum().item() val_b_correct += (out['logits_b'].argmax(-1) == labels_v).sum().item() val_geo_correct += (out['geo_logits'].argmax(-1) == labels_v).sum().item() val_total += labels_v.shape[0] loss_v = F.cross_entropy(out['logits_a'], labels_v) val_loss_sum += loss_v.item() val_n += 1 all_embs.append(out['embedding'].float().cpu()) val_acc = 100 * val_correct / val_total val_b_acc = 100 * val_b_correct / val_total val_geo_acc = 100 * val_geo_correct / val_total val_loss = val_loss_sum / max(val_n, 1) # ── Val embedding diagnostics ── embs = torch.cat(all_embs) with torch.no_grad(): sample = embs[:2000].to(DEVICE) vols = [] for _ in range(200): idx = torch.randperm(2000)[:5] pts = sample[idx].unsqueeze(0).float() gram = torch.bmm(pts, pts.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram d2 = F.relu(d2) cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 v2 = -torch.linalg.det(cm) / 9216 if v2[0].item() > 1e-20: vols.append(v2[0].sqrt()) v_cv = (torch.stack(vols).std() / (torch.stack(vols).mean() + 1e-8)).item() if len(vols) > 10 else 0.0 with torch.no_grad(): _, v_np = model.constellation.triangulate( embs[:2000].to(DEVICE), training=False) n_active = v_np.cpu().unique().numel() writer.add_scalar("epoch/val_acc", val_acc, epoch + 1) writer.add_scalar("epoch/val_b_acc", val_b_acc, epoch + 1) writer.add_scalar("epoch/val_geo_acc", val_geo_acc, epoch + 1) writer.add_scalar("epoch/val_cv", v_cv, epoch + 1) writer.add_scalar("epoch/val_anchors", n_active, epoch + 1) mastery.update_size(train_acc, val_acc, epoch + 1) # ── Checkpoint ── mk = "" if val_acc > best_acc: best_acc = val_acc torch.save({ "state_dict": model.state_dict(), "config": model.config, "epoch": epoch + 1, "val_acc": val_acc, "val_b_acc": val_b_acc, "val_geo_acc": val_geo_acc, "mastery": mastery.state_dict(), "gal_updates": gal_update_count, }, "checkpoints/tri_stream_v8_best.pt") mk = " ★" if (epoch + 1) % 10 == 0: torch.save({ "state_dict": model.state_dict(), "config": model.config, "epoch": epoch + 1, "val_acc": val_acc, "optimizer": optimizer.state_dict(), }, f"checkpoints/tri_stream_v8_e{epoch+1:03d}.pt") # ── Epoch print — v8: shows B acc + nce_b + geo_nce ── ga = 100 * acc_dict["geo_acc"] / d ab = 100 * acc_dict["acc_b"] / d nb_acc = acc_dict["nce_b_acc"] / d gn_acc = acc_dict["geo_nce_acc"] / d cvf = acc_dict["cv_main"] / d cvg = acc_dict["cv_geo"] / d cmv = acc_dict["cm_valid"] / d stage = "MASTERY" if mastery.active else "stage1" # Gate check last_gates = [] try: model.eval() with torch.no_grad(): sample_imgs = next(iter(val_loader))[0][:4].to(DEVICE) sample_out = model(sample_imgs, apply_autograd=False) last_gates = sample_out.get('gates_a', []) except: pass gate_str = f"g={np.mean(last_gates):.4f}" if last_gates else "g=?" print(f" E{epoch+1:3d}: A={train_acc:.1f}% B={ab:.0f}% " f"val={val_acc:.1f}%/{val_b_acc:.1f}%/{val_geo_acc:.1f}% " f"loss={acc_dict['loss']/d:.4f}/{val_loss:.4f} " f"nb={nb_acc:.2f} gn={gn_acc:.2f} " f"cv={v_cv:.4f}(m={cvf:.5f} g={cvg:.5f}) " f"cm={cmv:.0%} anch={n_active}/{N_ANCHORS} " f"[{stage}] {gate_str} " f"gal={gal_update_count} ({elapsed:.0f}s){mk}") writer.close() print(f"\n Best val accuracy: {best_acc:.1f}%") print(f"\n{'='*60}") print("DONE") print(f"{'='*60}")