| |
| """ |
| CIFAR-10 β Dual-Stream GeoLIP ViT β Experiment 7 |
| ================================================== |
| Warm start from v1 best checkpoint. |
| + CV loss on geometric route (gentle, 0.01 weight) |
| + EmbeddingAutograd on std/fused path (tangential + separation) |
| + Anchor spread loss |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import os, time |
| import numpy as np |
| from tqdm import tqdm |
| from torchvision import datasets, transforms |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| |
| NUM_CLASSES = 10 |
| IMG_SIZE = 32 |
| PATCH_SIZE = 4 |
| EMBED_DIM = 384 |
| STREAM_DIM = 192 |
| FUSED_DIM = 256 |
| N_DUAL_BLOCKS = 2 |
| N_FUSED_BLOCKS = 4 |
| N_HEADS = 8 |
| OUTPUT_DIM = 128 |
| N_ANCHORS = 64 |
| N_COMP = 8 |
| D_COMP = 64 |
| ANCHOR_DROP = 0.10 |
| CV_TARGET = 0.22 |
|
|
| |
| CV_WEIGHT = 1.0 |
| ENABLE_AUTOGRAD = True |
| AUTOGRAD_TANG = 0.5 |
| AUTOGRAD_SEP = 0.1 |
|
|
| |
| BATCH = 128 |
| EPOCHS = 100 |
| LR = 1e-4 |
| WARMUP = 2 |
| GRAD_CLIP = 1.0 |
| INFONCE_WEIGHT = 0.1 |
| BCE_WEIGHT = 1.0 |
| CM_WEIGHT = 0.1 |
| INFONCE_TEMP = 0.07 |
|
|
| |
| V1_CKPT = "" |
|
|
| print("=" * 60) |
| print("CIFAR-10 β Dual-Stream GeoLIP ViT β EXP 3 (FULL CV)") |
| print(f" Warm start from: {V1_CKPT}") |
| print(f" CV: weight={CV_WEIGHT} (FULL), target={CV_TARGET}") |
| print(f" InfoNCE: weight={INFONCE_WEIGHT} (REDUCED)") |
| print(f" Autograd: tang={AUTOGRAD_TANG}, sep={AUTOGRAD_SEP}") |
| print(f" LR: {LR}, epochs: {EPOCHS}") |
| print(f" Device: {DEVICE}") |
| print("=" * 60) |
|
|
| |
| |
| |
|
|
| CIFAR_MEAN = (0.4914, 0.4822, 0.4465) |
| CIFAR_STD = (0.2470, 0.2435, 0.2616) |
|
|
| train_transform = transforms.Compose([ |
| transforms.RandomCrop(32, padding=4), |
| transforms.RandomHorizontalFlip(), |
| transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), |
| transforms.RandomGrayscale(p=0.2), |
| transforms.ToTensor(), |
| transforms.Normalize(CIFAR_MEAN, CIFAR_STD), |
| ]) |
| val_transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(CIFAR_MEAN, CIFAR_STD), |
| ]) |
|
|
|
|
| class TwoViewDataset(torch.utils.data.Dataset): |
| def __init__(self, base_ds, transform): |
| self.base = base_ds |
| self.transform = transform |
|
|
| def __len__(self): |
| return len(self.base) |
|
|
| def __getitem__(self, idx): |
| img, label = self.base.data[idx], self.base.targets[idx] |
| from PIL import Image |
| img = Image.fromarray(img) |
| return self.transform(img), self.transform(img), label |
|
|
|
|
| raw_train = datasets.CIFAR10(root='./data', train=True, download=True) |
| train_ds = TwoViewDataset(raw_train, train_transform) |
| val_ds = datasets.CIFAR10(root='./data', train=False, |
| download=True, transform=val_transform) |
|
|
| train_loader = torch.utils.data.DataLoader( |
| train_ds, batch_size=BATCH, shuffle=True, |
| num_workers=2, pin_memory=True, drop_last=True) |
| val_loader = torch.utils.data.DataLoader( |
| val_ds, batch_size=BATCH, shuffle=False, |
| num_workers=2, pin_memory=True) |
|
|
| CIFAR_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', |
| 'dog', 'frog', 'horse', 'ship', 'truck'] |
| print(f" Train: {len(train_ds):,} (two views) Val: {len(val_ds):,}") |
|
|
| |
| |
| |
|
|
| print(f"\n Building model...") |
| model = create_dual_stream_vit( |
| num_classes=NUM_CLASSES, img_size=IMG_SIZE, patch_size=PATCH_SIZE, |
| embed_dim=EMBED_DIM, stream_dim=STREAM_DIM, fused_dim=FUSED_DIM, |
| n_dual_blocks=N_DUAL_BLOCKS, n_fused_blocks=N_FUSED_BLOCKS, |
| n_heads=N_HEADS, output_dim=OUTPUT_DIM, |
| n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP, |
| anchor_drop=ANCHOR_DROP, cv_target=CV_TARGET, |
| dropout=0.1, infonce_temp=INFONCE_TEMP, |
| infonce_weight=INFONCE_WEIGHT, bce_weight=BCE_WEIGHT, |
| cm_weight=CM_WEIGHT, cv_weight=CV_WEIGHT, |
| autograd_tang=AUTOGRAD_TANG, autograd_sep=AUTOGRAD_SEP, |
| enable_autograd=ENABLE_AUTOGRAD, |
| ).to(DEVICE) |
|
|
| |
| if os.path.exists(V1_CKPT): |
| ckpt = torch.load(V1_CKPT, map_location="cpu", weights_only=False) |
| model.load_state_dict(ckpt["state_dict"]) |
| print(f" β Loaded v1 weights: epoch {ckpt['epoch']}, " |
| f"val_acc {ckpt['val_acc']:.1f}%") |
| else: |
| print(f" β No v1 checkpoint found at {V1_CKPT}, training from scratch") |
|
|
| n_params = sum(p.numel() for p in model.parameters()) |
|
|
| |
| geo_names = {'geo_proj', 'dual_blocks', 'constellation', 'patchwork'} |
| geo_params, std_params = [], [] |
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if any(gn in name for gn in geo_names): |
| geo_params.append(param) |
| else: |
| std_params.append(param) |
|
|
| n_geo = sum(p.numel() for p in geo_params) |
| n_std = sum(p.numel() for p in std_params) |
| print(f" Parameters: {n_params:,}") |
| print(f" Geo route: {n_geo:,} ({100*n_geo/n_params:.1f}%)") |
| print(f" Std route: {n_std:,} ({100*n_std/n_params:.1f}%)") |
|
|
| |
| |
| |
|
|
| print(f"\n{'='*60}") |
| print(f"TRAINING β {EPOCHS} epochs, lr={LR}") |
| print(f" CV={CV_WEIGHT}, autograd={'ON' if ENABLE_AUTOGRAD else 'OFF'}") |
| print(f" Mastery: patience=50 batches, queue=4096") |
| print(f"{'='*60}") |
|
|
| optimizer = torch.optim.Adam([ |
| {'params': geo_params, 'lr': LR}, |
| {'params': std_params, 'lr': LR}, |
| ], lr=LR) |
|
|
| total_steps = len(train_loader) * EPOCHS |
| warmup_steps = len(train_loader) * WARMUP |
| scheduler = torch.optim.lr_scheduler.SequentialLR( |
| optimizer, |
| [torch.optim.lr_scheduler.LinearLR( |
| optimizer, start_factor=0.1, total_iters=warmup_steps), |
| torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=max(total_steps - warmup_steps, 1), eta_min=1e-6)], |
| milestones=[warmup_steps]) |
|
|
| scaler = torch.amp.GradScaler("cuda") |
| os.makedirs("checkpoints", exist_ok=True) |
| writer = SummaryWriter("runs/cifar10_dual_stream_v3_mastery") |
| best_acc = 0.0 |
| gs = 0 |
|
|
| |
| mastery = MasteryQueue(dim=OUTPUT_DIM, max_size=4096, patience=50, device=DEVICE) |
|
|
| for epoch in range(EPOCHS): |
| model.train() |
| t0 = time.time() |
|
|
| acc_dict = {"loss": 0, "bce": 0, "nce": 0, "nce_acc": 0, |
| "cm": 0, "cm_valid": 0, "cv": 0, "cv_fused": 0, "cv_geo": 0, |
| "spread": 0, "mastery": 0, "hard_neg": 0, "hard_pos": 0, |
| "correct": 0, "total": 0, "n": 0} |
|
|
| pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="batch") |
| for v1, v2, labels in pbar: |
| v1 = v1.to(DEVICE, non_blocking=True) |
| v2 = v2.to(DEVICE, non_blocking=True) |
| labels = labels.to(DEVICE, non_blocking=True) |
|
|
| with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| out1 = model(v1, targets=labels) |
| out2 = model(v2, targets=labels) |
| loss, ld = model.compute_loss( |
| out1, labels, output_aug=out2, mastery_queue=mastery) |
|
|
| |
| mastery.check_activation(ld.get('nce_acc', 0)) |
|
|
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad(set_to_none=True) |
| scheduler.step() |
|
|
| with torch.no_grad(): |
| preds = out1['logits'].argmax(dim=-1) |
| acc_dict["correct"] += (preds == labels).sum().item() |
| acc_dict["total"] += labels.shape[0] |
|
|
| acc_dict["loss"] += loss.item() |
| for k in ["bce", "nce", "cm", "cv", "spread", "mastery"]: |
| v = ld.get(k, 0) |
| acc_dict[k] += v.item() if torch.is_tensor(v) else v |
| acc_dict["nce_acc"] += ld.get("nce_acc", 0) |
| acc_dict["cm_valid"] += ld.get("cm_valid", 0) |
| acc_dict["hard_neg"] += ld.get("hard_neg_cos", 0) |
| acc_dict["hard_pos"] += ld.get("hard_pos_cos", 0) |
| acc_dict["cv_fused"] += ld.get("cv_fused", 0) |
| acc_dict["cv_geo"] += ld.get("cv_geo", 0) |
| acc_dict["n"] += 1; gs += 1 |
|
|
| if gs % 20 == 0: |
| writer.add_scalar("step/loss", loss.item(), gs) |
| writer.add_scalar("step/cv", ld.get("cv", torch.tensor(0)).item() |
| if torch.is_tensor(ld.get("cv", 0)) |
| else ld.get("cv", 0), gs) |
| if mastery.active: |
| writer.add_scalar("step/mastery", |
| ld.get("mastery", torch.tensor(0)).item() |
| if torch.is_tensor(ld.get("mastery", 0)) |
| else ld.get("mastery", 0), gs) |
|
|
| if acc_dict["n"] % 10 == 0: |
| d = acc_dict["n"] |
| train_acc = 100 * acc_dict["correct"] / acc_dict["total"] |
| cv_mean = acc_dict["cv"] / d |
| cvf = acc_dict["cv_fused"] / d |
| cvg = acc_dict["cv_geo"] / d |
| cmv = acc_dict["cm_valid"] / d |
| mst = acc_dict["mastery"] / d |
| stage = "M" if mastery.active else "S1" |
| pbar.set_postfix( |
| loss=f"{acc_dict['loss']/d:.4f}", |
| acc=f"{train_acc:.1f}%", |
| cvf=f"{cvf:.4f}", |
| cvg=f"{cvg:.4f}", |
| cm=f"{cmv:.0%}", |
| mst=f"{mst:.4f}", |
| stg=stage, |
| ordered=True) |
|
|
| elapsed = time.time() - t0 |
| d = max(acc_dict["n"], 1) |
| train_acc = 100 * acc_dict["correct"] / acc_dict["total"] |
|
|
| writer.add_scalar("epoch/train_loss", acc_dict["loss"] / d, epoch + 1) |
| writer.add_scalar("epoch/train_acc", train_acc, epoch + 1) |
| writer.add_scalar("epoch/nce_acc", acc_dict["nce_acc"] / d, epoch + 1) |
| writer.add_scalar("epoch/cv_loss", acc_dict["cv"] / d, epoch + 1) |
| writer.add_scalar("epoch/cv_fused", acc_dict["cv_fused"] / d, epoch + 1) |
| writer.add_scalar("epoch/cv_geo", acc_dict["cv_geo"] / d, epoch + 1) |
| writer.add_scalar("epoch/cm_valid", acc_dict["cm_valid"] / d, epoch + 1) |
|
|
| |
| model.eval() |
| val_correct, val_total, val_loss_sum, val_n = 0, 0, 0, 0 |
| all_embs = [] |
|
|
| with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| for images, labels_v in val_loader: |
| images = images.to(DEVICE, non_blocking=True) |
| labels_v = labels_v.to(DEVICE, non_blocking=True) |
| out = model(images, apply_autograd=False) |
| preds = out['logits'].argmax(dim=-1) |
| val_correct += (preds == labels_v).sum().item() |
| val_total += labels_v.shape[0] |
| one_hot = F.one_hot(labels_v, NUM_CLASSES).float() |
| loss_v = F.binary_cross_entropy_with_logits(out['logits'], one_hot) |
| val_loss_sum += loss_v.item() |
| val_n += 1 |
| all_embs.append(out['embedding'].float().cpu()) |
|
|
| val_acc = 100 * val_correct / val_total |
| val_loss = val_loss_sum / max(val_n, 1) |
|
|
| |
| embs = torch.cat(all_embs) |
| with torch.no_grad(): |
| sample = embs[:2000].to(DEVICE) |
| vols = [] |
| for _ in range(200): |
| idx = torch.randperm(2000)[:5] |
| pts = sample[idx].unsqueeze(0).float() |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram |
| d2 = F.relu(d2) |
| cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| v2 = -torch.linalg.det(cm) / 9216 |
| if v2[0].item() > 1e-20: |
| vols.append(v2[0].sqrt()) |
| if len(vols) > 10: |
| vols_t = torch.stack(vols) |
| v_cv = (vols_t.std() / (vols_t.mean() + 1e-8)).item() |
| else: |
| v_cv = 0.0 |
|
|
| |
| with torch.no_grad(): |
| _, v_np = model.constellation.triangulate( |
| embs[:2000].to(DEVICE), training=False) |
| n_active = v_np.cpu().unique().numel() |
|
|
| writer.add_scalar("epoch/val_acc", val_acc, epoch + 1) |
| writer.add_scalar("epoch/val_cv", v_cv, epoch + 1) |
| writer.add_scalar("epoch/val_anchors", n_active, epoch + 1) |
|
|
| mk = "" |
| if val_acc > best_acc: |
| best_acc = val_acc |
| torch.save({ |
| "state_dict": model.state_dict(), |
| "config": model.config, |
| "epoch": epoch + 1, |
| "val_acc": val_acc, |
| "val_loss": val_loss, |
| "val_cv": v_cv, |
| "mastery": mastery.state_dict(), |
| }, "checkpoints/dual_stream_v3_best.pt") |
| mk = " β
" |
|
|
| if (epoch + 1) % 10 == 0: |
| torch.save({ |
| "state_dict": model.state_dict(), |
| "config": model.config, |
| "epoch": epoch + 1, |
| "val_acc": val_acc, |
| "optimizer": optimizer.state_dict(), |
| }, f"checkpoints/dual_stream_v3_e{epoch+1:03d}.pt") |
|
|
| cv_m = acc_dict["cv"] / d |
| cvf = acc_dict["cv_fused"] / d |
| cvg = acc_dict["cv_geo"] / d |
| nce_a = acc_dict["nce_acc"] / d |
| cmv = acc_dict["cm_valid"] / d |
| mst_m = acc_dict["mastery"] / d |
| hn = acc_dict["hard_neg"] / d if mastery.active else 0 |
| hp = acc_dict["hard_pos"] / d if mastery.active else 0 |
| stage = "MASTERY" if mastery.active else "stage1" |
| print(f" E{epoch+1:3d}: train={train_acc:.1f}% val={val_acc:.1f}% " |
| f"loss={acc_dict['loss']/d:.4f}/{val_loss:.4f} " |
| f"cv={v_cv:.4f}(f={cvf:.5f} g={cvg:.5f}) " |
| f"nce={nce_a:.2f} cm={cmv:.0%} anch={n_active}/{N_ANCHORS} " |
| f"[{stage}] mst={mst_m:.4f} hn={hn:.3f} hp={hp:.3f} " |
| f"q={mastery.size} ({elapsed:.0f}s){mk}") |
|
|
| writer.close() |
| print(f"\n Best val accuracy: {best_acc:.1f}%") |
| print(f"\n{'='*60}") |
| print("DONE") |
| print(f"{'='*60}") |