| |
| """ |
| CIFAR-10 β Tri-Stream GeoLIP ViT v8 |
| ===================================== |
| v7βv8 changes: |
| 1. GAL_UPDATE_INTERVAL: 50 β 25 (2Γ more frequent) |
| 2. GAL_LR: 0.01 β 0.015 (+50% response) |
| 3. Tracks nce_b and geo_nce_acc separately |
| 4. stream_b_nce_weight=0.5, geo_nce_weight=0.5 |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import os, time |
| import numpy as np |
| from tqdm import tqdm |
| from torchvision import datasets, transforms |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| |
| NUM_CLASSES = 10 |
| IMG_SIZE = 32 |
| PATCH_SIZE = 4 |
| EMBED_DIM = 384 |
| STREAM_DIM = 192 |
| N_BLOCKS = 9 |
| N_HEADS = 8 |
| OUTPUT_DIM = 256 |
| N_ANCHORS = 128 |
| N_GAL_ANCHORS = 64 |
| N_COMP = 16 |
| D_COMP = 128 |
| ANCHOR_DROP = 0.10 |
| CV_TARGET = 0.22 |
|
|
| |
| CV_WEIGHT = 0.1 |
| ENABLE_AUTOGRAD = True |
| AUTOGRAD_TANG = 1.0 |
| AUTOGRAD_SEP = 0.1 |
| LABEL_SMOOTHING = 0.1 |
| INFONCE_WEIGHT = 0.1 |
| BCE_WEIGHT = 1.0 |
| CM_WEIGHT = 0.1 |
| INFONCE_TEMP = 0.07 |
|
|
| |
| STREAM_B_NCE_WEIGHT = 0.5 |
| GEO_NCE_WEIGHT = 0.5 |
|
|
| |
| GAL_UPDATE_INTERVAL = 25 |
| GAL_LR = 0.015 |
| GAL_BUFFER_SIZE = 50000 |
| USE_WHITENED_PROCRUSTES = False |
|
|
| |
| MASTERY_PATIENCE = 50 |
| MASTERY_MARGIN_START = 0.1 |
| MASTERY_MARGIN_END = 0.3 |
| MASTERY_MARGIN_WARMUP = 5000 |
| MASTERY_MIN_SIZE = 1024 |
| MASTERY_MAX_SIZE = 16384 |
| MASTERY_INITIAL_SIZE = 4096 |
| MASTERY_RESIZE_STEP = 2048 |
| MASTERY_RESIZE_COOLDOWN = 5 |
| MASTERY_OVERFIT_THRESH = 3.0 |
|
|
| |
| BATCH = 256 |
| EPOCHS = 100 |
| LR = 3e-4 |
| WARMUP = 5 |
| GRAD_CLIP = 1.0 |
| V1_CKPT = "" |
|
|
| print("=" * 60) |
| print("CIFAR-10 β Tri-Stream GeoLIP ViT v8") |
| print(f" Architecture: {N_BLOCKS}Γ TriStreamBlock") |
| print(f" Sphere: {OUTPUT_DIM}-d, {N_ANCHORS} anchors, {N_COMP}Γ{D_COMP} pw") |
| print(f" GAL: {N_GAL_ANCHORS} anchors, Procrustes every {GAL_UPDATE_INTERVAL} " |
| f"batches (lr={GAL_LR}, whiten={USE_WHITENED_PROCRUSTES})") |
| print(f" v8 fixes: uniform hypersphere init, gate_init=1/(2Γ{N_BLOCKS})") |
| print(f" v8 fixes: InfoNCE on emb_b (w={STREAM_B_NCE_WEIGHT}) " |
| f"+ geo_emb (w={GEO_NCE_WEIGHT})") |
| print(f" Device: {DEVICE}") |
| print("=" * 60) |
|
|
| |
| |
| |
|
|
| CIFAR_MEAN = (0.4914, 0.4822, 0.4465) |
| CIFAR_STD = (0.2470, 0.2435, 0.2616) |
|
|
| class DualAugDataset(torch.utils.data.Dataset): |
| def __init__(self, base_ds, transform): |
| self.base = base_ds; self.transform = transform |
| def __len__(self): return len(self.base) |
| def __getitem__(self, i): |
| img, label = self.base[i] |
| return self.transform(img), self.transform(img), label |
|
|
| aug_transform = transforms.Compose([ |
| transforms.RandomCrop(32, padding=4), |
| transforms.RandomHorizontalFlip(), |
| transforms.ColorJitter(0.2, 0.2, 0.2, 0.05), |
| transforms.ToTensor(), |
| transforms.Normalize(CIFAR_MEAN, CIFAR_STD), |
| ]) |
| val_transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(CIFAR_MEAN, CIFAR_STD), |
| ]) |
|
|
| raw_train = datasets.CIFAR10(root='./data', train=True, download=True) |
| train_ds = DualAugDataset(raw_train, aug_transform) |
| val_ds = datasets.CIFAR10(root='./data', train=False, |
| download=True, transform=val_transform) |
|
|
| train_loader = torch.utils.data.DataLoader( |
| train_ds, batch_size=BATCH, shuffle=True, |
| num_workers=2, pin_memory=True, drop_last=True) |
| val_loader = torch.utils.data.DataLoader( |
| val_ds, batch_size=BATCH, shuffle=False, |
| num_workers=2, pin_memory=True) |
|
|
| print(f" Train: {len(train_ds):,} (two views) Val: {len(val_ds):,}") |
|
|
| |
| |
| |
|
|
| print(f"\n Building model...") |
| model = create_tri_stream_vit( |
| num_classes=NUM_CLASSES, img_size=IMG_SIZE, patch_size=PATCH_SIZE, |
| embed_dim=EMBED_DIM, stream_dim=STREAM_DIM, n_blocks=N_BLOCKS, |
| n_heads=N_HEADS, output_dim=OUTPUT_DIM, |
| n_anchors=N_ANCHORS, n_gal_anchors=N_GAL_ANCHORS, |
| n_comp=N_COMP, d_comp=D_COMP, |
| anchor_drop=ANCHOR_DROP, cv_target=CV_TARGET, |
| dropout=0.1, infonce_temp=INFONCE_TEMP, |
| infonce_weight=INFONCE_WEIGHT, bce_weight=BCE_WEIGHT, |
| cm_weight=CM_WEIGHT, cv_weight=CV_WEIGHT, |
| autograd_tang=AUTOGRAD_TANG, autograd_sep=AUTOGRAD_SEP, |
| enable_autograd=ENABLE_AUTOGRAD, |
| label_smoothing=LABEL_SMOOTHING, |
| stream_b_nce_weight=STREAM_B_NCE_WEIGHT, |
| geo_nce_weight=GEO_NCE_WEIGHT, |
| ).to(DEVICE) |
|
|
| if V1_CKPT and os.path.exists(V1_CKPT): |
| ckpt = torch.load(V1_CKPT, map_location="cpu", weights_only=False) |
| missing, unexpected = model.load_state_dict( |
| ckpt["state_dict"], strict=False) |
| print(f" β Loaded weights: epoch {ckpt.get('epoch', '?')}") |
| if missing: |
| print(f" New params (expected): {len(missing)}") |
| else: |
| print(f" Training from scratch") |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f" Parameters: {total_params:,}") |
|
|
| |
| |
| |
|
|
| print(f"\n{'='*60}") |
| print(f"TRAINING β {EPOCHS} epochs, lr={LR}, batch={BATCH}") |
| print(f" GAL Procrustes: every {GAL_UPDATE_INTERVAL} batches, " |
| f"lr={GAL_LR}, whiten={USE_WHITENED_PROCRUSTES}") |
| print(f"{'='*60}") |
|
|
| optimizer = torch.optim.Adam(model.parameters(), lr=LR) |
|
|
| total_steps = len(train_loader) * EPOCHS |
| warmup_steps = len(train_loader) * WARMUP |
| scheduler = torch.optim.lr_scheduler.SequentialLR( |
| optimizer, |
| [torch.optim.lr_scheduler.LinearLR( |
| optimizer, start_factor=0.01, total_iters=warmup_steps), |
| torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=max(total_steps - warmup_steps, 1), eta_min=1e-6)], |
| milestones=[warmup_steps]) |
|
|
| scaler = torch.amp.GradScaler("cuda") |
| os.makedirs("checkpoints", exist_ok=True) |
| writer = SummaryWriter("runs/cifar10_tri_stream_v8") |
| best_acc = 0.0 |
| gs = 0 |
|
|
| |
| mastery = MasteryQueue( |
| dim=OUTPUT_DIM, min_size=MASTERY_MIN_SIZE, max_size=MASTERY_MAX_SIZE, |
| initial_size=MASTERY_INITIAL_SIZE, patience=MASTERY_PATIENCE, |
| device=DEVICE, margin_start=MASTERY_MARGIN_START, |
| margin_end=MASTERY_MARGIN_END, margin_warmup=MASTERY_MARGIN_WARMUP, |
| resize_step=MASTERY_RESIZE_STEP, resize_cooldown=MASTERY_RESIZE_COOLDOWN, |
| overfit_threshold=MASTERY_OVERFIT_THRESH) |
|
|
| |
| simplex_buf = SimplexBuffer( |
| dim=STREAM_DIM, max_size=GAL_BUFFER_SIZE, device=DEVICE) |
|
|
| gal_update_count = 0 |
|
|
| |
| |
| |
|
|
| for epoch in range(EPOCHS): |
| model.train() |
| t0 = time.time() |
|
|
| acc_dict = { |
| "loss": 0, "ce": 0, "bce": 0, "geo_bce": 0, |
| "acc_a": 0, "acc_b": 0, "geo_acc": 0, |
| "nce": 0, "nce_acc": 0, |
| "nce_b": 0, "nce_b_acc": 0, |
| "geo_nce": 0, "geo_nce_acc": 0, |
| "cm": 0, "cm_valid": 0, "cv": 0, "cv_main": 0, "cv_geo": 0, |
| "spread": 0, "mastery": 0, "hard_neg": 0, "hard_pos": 0, |
| "correct": 0, "total": 0, "n": 0} |
|
|
| pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", |
| unit="batch") |
|
|
| for v1, v2, targets in pbar: |
| v1 = v1.to(DEVICE, non_blocking=True) |
| v2 = v2.to(DEVICE, non_blocking=True) |
| targets = targets.to(DEVICE, non_blocking=True) |
|
|
| with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| out1 = model(v1, apply_autograd=True) |
| out2 = model(v2, apply_autograd=True) |
| loss, ld = model.compute_loss( |
| out1, targets, output_aug=out2, mastery_queue=mastery) |
|
|
| optimizer.zero_grad(set_to_none=True) |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) |
| scaler.step(optimizer); scaler.update() |
| scheduler.step() |
|
|
| mastery.check_activation(ld.get('nce_acc', 0)) |
|
|
| pool_geo = out1.get('pool_geo') |
| if pool_geo is not None: |
| simplex_buf.push(pool_geo.float(), targets) |
|
|
| gs += 1 |
| if gs % GAL_UPDATE_INTERVAL == 0 and simplex_buf.size > 500: |
| score = model.update_gal_anchors( |
| simplex_buf, lr=GAL_LR, whiten=USE_WHITENED_PROCRUSTES) |
| if score is not None: |
| gal_update_count += 1 |
| writer.add_scalar("step/procrustes_score", score, gs) |
|
|
| |
| preds = out1['logits_a'].argmax(-1) |
| correct = (preds == targets).sum().item() |
| acc_dict["correct"] += correct |
| acc_dict["total"] += targets.shape[0] |
| acc_dict["loss"] += loss.item() |
|
|
| for k in ["ce", "bce", "geo_bce", "nce", "nce_b", "geo_nce", |
| "cm", "cv", "spread", "mastery"]: |
| v = ld.get(k, 0) |
| acc_dict[k] += v.item() if torch.is_tensor(v) else v |
|
|
| acc_dict["acc_a"] += ld.get("acc_a", 0) |
| acc_dict["acc_b"] += ld.get("acc_b", 0) |
| acc_dict["geo_acc"] += ld.get("geo_acc", 0) |
| acc_dict["nce_acc"] += ld.get("nce_acc", 0) |
| acc_dict["nce_b_acc"] += ld.get("nce_b_acc", 0) |
| acc_dict["geo_nce_acc"] += ld.get("geo_nce_acc", 0) |
| acc_dict["cm_valid"] += ld.get("cm_valid", 0) |
| acc_dict["cv_main"] += ld.get("cv_main", 0) |
| acc_dict["cv_geo"] += ld.get("cv_geo", 0) |
| acc_dict["hard_neg"] += ld.get("hard_neg_cos", 0) |
| acc_dict["hard_pos"] += ld.get("hard_pos_cos", 0) |
| acc_dict["n"] += 1 |
|
|
| if acc_dict["n"] % 10 == 0: |
| d = acc_dict["n"] |
| ta = 100 * acc_dict["correct"] / acc_dict["total"] |
| ga = 100 * acc_dict["geo_acc"] / d |
| nb = acc_dict["nce_b_acc"] / d |
| stg = "M" if mastery.active else "S1" |
| pbar.set_postfix( |
| loss=f"{acc_dict['loss']/d:.4f}", |
| a=f"{ta:.0f}%", |
| ga=f"{ga:.0f}%", |
| nb=f"{nb:.2f}", |
| stg=stg, |
| gal=gal_update_count, |
| ordered=True) |
|
|
| if gs % 20 == 0: |
| writer.add_scalar("step/loss", loss.item(), gs) |
| writer.add_scalar("step/geo_acc", ld.get("geo_acc", 0), gs) |
| writer.add_scalar("step/nce_b_acc", ld.get("nce_b_acc", 0), gs) |
| writer.add_scalar("step/geo_nce_acc", ld.get("geo_nce_acc", 0), gs) |
| gates_a = out1.get('gates_a', []) |
| if gates_a: |
| writer.add_scalar("step/gate_a_mean", |
| sum(gates_a) / len(gates_a), gs) |
| writer.add_scalar("step/gate_b_mean", |
| sum(out1.get('gates_b', [0])) / max(len(gates_a), 1), gs) |
|
|
| |
| elapsed = time.time() - t0 |
| d = acc_dict["n"] |
| train_acc = 100 * acc_dict["correct"] / acc_dict["total"] |
|
|
| writer.add_scalar("epoch/train_loss", acc_dict["loss"] / d, epoch + 1) |
| writer.add_scalar("epoch/train_acc", train_acc, epoch + 1) |
| writer.add_scalar("epoch/acc_a", 100 * acc_dict["acc_a"] / d, epoch + 1) |
| writer.add_scalar("epoch/acc_b", 100 * acc_dict["acc_b"] / d, epoch + 1) |
| writer.add_scalar("epoch/geo_acc", 100 * acc_dict["geo_acc"] / d, epoch + 1) |
| writer.add_scalar("epoch/nce_acc", acc_dict["nce_acc"] / d, epoch + 1) |
| writer.add_scalar("epoch/nce_b_acc", acc_dict["nce_b_acc"] / d, epoch + 1) |
| writer.add_scalar("epoch/geo_nce_acc", acc_dict["geo_nce_acc"] / d, epoch + 1) |
| writer.add_scalar("epoch/cv_main", acc_dict["cv_main"] / d, epoch + 1) |
| writer.add_scalar("epoch/cv_geo", acc_dict["cv_geo"] / d, epoch + 1) |
| writer.add_scalar("epoch/cm_valid", acc_dict["cm_valid"] / d, epoch + 1) |
| writer.add_scalar("epoch/gal_updates", gal_update_count, epoch + 1) |
|
|
| |
| model.eval() |
| val_correct, val_total, val_loss_sum, val_n = 0, 0, 0, 0 |
| val_geo_correct = 0 |
| val_b_correct = 0 |
| all_embs = [] |
|
|
| with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| for images, labels_v in val_loader: |
| images = images.to(DEVICE, non_blocking=True) |
| labels_v = labels_v.to(DEVICE, non_blocking=True) |
| out = model(images, apply_autograd=False) |
| preds = out['logits_a'].argmax(dim=-1) |
| val_correct += (preds == labels_v).sum().item() |
| val_b_correct += (out['logits_b'].argmax(-1) == labels_v).sum().item() |
| val_geo_correct += (out['geo_logits'].argmax(-1) == labels_v).sum().item() |
| val_total += labels_v.shape[0] |
| loss_v = F.cross_entropy(out['logits_a'], labels_v) |
| val_loss_sum += loss_v.item() |
| val_n += 1 |
| all_embs.append(out['embedding'].float().cpu()) |
|
|
| val_acc = 100 * val_correct / val_total |
| val_b_acc = 100 * val_b_correct / val_total |
| val_geo_acc = 100 * val_geo_correct / val_total |
| val_loss = val_loss_sum / max(val_n, 1) |
|
|
| |
| embs = torch.cat(all_embs) |
| with torch.no_grad(): |
| sample = embs[:2000].to(DEVICE) |
| vols = [] |
| for _ in range(200): |
| idx = torch.randperm(2000)[:5] |
| pts = sample[idx].unsqueeze(0).float() |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram |
| d2 = F.relu(d2) |
| cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| v2 = -torch.linalg.det(cm) / 9216 |
| if v2[0].item() > 1e-20: |
| vols.append(v2[0].sqrt()) |
| v_cv = (torch.stack(vols).std() / (torch.stack(vols).mean() + 1e-8)).item() if len(vols) > 10 else 0.0 |
|
|
| with torch.no_grad(): |
| _, v_np = model.constellation.triangulate( |
| embs[:2000].to(DEVICE), training=False) |
| n_active = v_np.cpu().unique().numel() |
|
|
| writer.add_scalar("epoch/val_acc", val_acc, epoch + 1) |
| writer.add_scalar("epoch/val_b_acc", val_b_acc, epoch + 1) |
| writer.add_scalar("epoch/val_geo_acc", val_geo_acc, epoch + 1) |
| writer.add_scalar("epoch/val_cv", v_cv, epoch + 1) |
| writer.add_scalar("epoch/val_anchors", n_active, epoch + 1) |
|
|
| mastery.update_size(train_acc, val_acc, epoch + 1) |
|
|
| |
| mk = "" |
| if val_acc > best_acc: |
| best_acc = val_acc |
| torch.save({ |
| "state_dict": model.state_dict(), |
| "config": model.config, |
| "epoch": epoch + 1, |
| "val_acc": val_acc, |
| "val_b_acc": val_b_acc, |
| "val_geo_acc": val_geo_acc, |
| "mastery": mastery.state_dict(), |
| "gal_updates": gal_update_count, |
| }, "checkpoints/tri_stream_v8_best.pt") |
| mk = " β
" |
|
|
| if (epoch + 1) % 10 == 0: |
| torch.save({ |
| "state_dict": model.state_dict(), |
| "config": model.config, |
| "epoch": epoch + 1, |
| "val_acc": val_acc, |
| "optimizer": optimizer.state_dict(), |
| }, f"checkpoints/tri_stream_v8_e{epoch+1:03d}.pt") |
|
|
| |
| ga = 100 * acc_dict["geo_acc"] / d |
| ab = 100 * acc_dict["acc_b"] / d |
| nb_acc = acc_dict["nce_b_acc"] / d |
| gn_acc = acc_dict["geo_nce_acc"] / d |
| cvf = acc_dict["cv_main"] / d |
| cvg = acc_dict["cv_geo"] / d |
| cmv = acc_dict["cm_valid"] / d |
| stage = "MASTERY" if mastery.active else "stage1" |
|
|
| |
| last_gates = [] |
| try: |
| model.eval() |
| with torch.no_grad(): |
| sample_imgs = next(iter(val_loader))[0][:4].to(DEVICE) |
| sample_out = model(sample_imgs, apply_autograd=False) |
| last_gates = sample_out.get('gates_a', []) |
| except: |
| pass |
| gate_str = f"g={np.mean(last_gates):.4f}" if last_gates else "g=?" |
|
|
| print(f" E{epoch+1:3d}: A={train_acc:.1f}% B={ab:.0f}% " |
| f"val={val_acc:.1f}%/{val_b_acc:.1f}%/{val_geo_acc:.1f}% " |
| f"loss={acc_dict['loss']/d:.4f}/{val_loss:.4f} " |
| f"nb={nb_acc:.2f} gn={gn_acc:.2f} " |
| f"cv={v_cv:.4f}(m={cvf:.5f} g={cvg:.5f}) " |
| f"cm={cmv:.0%} anch={n_active}/{N_ANCHORS} " |
| f"[{stage}] {gate_str} " |
| f"gal={gal_update_count} ({elapsed:.0f}s){mk}") |
|
|
| writer.close() |
| print(f"\n Best val accuracy: {best_acc:.1f}%") |
| print(f"\n{'='*60}") |
| print("DONE") |
| print(f"{'='*60}") |