AbstractPhil's picture
Rename trainer.py to run7/trainer.py
7c43d2b verified
#!/usr/bin/env python3
"""
CIFAR-10 β€” Dual-Stream GeoLIP ViT β€” Experiment 7
==================================================
Warm start from v1 best checkpoint.
+ CV loss on geometric route (gentle, 0.01 weight)
+ EmbeddingAutograd on std/fused path (tangential + separation)
+ Anchor spread loss
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os, time
import numpy as np
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Architecture (must match v1)
NUM_CLASSES = 10
IMG_SIZE = 32
PATCH_SIZE = 4
EMBED_DIM = 384
STREAM_DIM = 192
FUSED_DIM = 256
N_DUAL_BLOCKS = 2
N_FUSED_BLOCKS = 4
N_HEADS = 8
OUTPUT_DIM = 128
N_ANCHORS = 64
N_COMP = 8
D_COMP = 64
ANCHOR_DROP = 0.10
CV_TARGET = 0.22
# NEW for v2
CV_WEIGHT = 1.0 # FULL CV β€” direct contest with InfoNCE
ENABLE_AUTOGRAD = True
AUTOGRAD_TANG = 0.5
AUTOGRAD_SEP = 0.1
# Training
BATCH = 128
EPOCHS = 100 # short run β€” see what CV does
LR = 1e-4
WARMUP = 2
GRAD_CLIP = 1.0
INFONCE_WEIGHT = 0.1 # reduced β€” let the sphere relax
BCE_WEIGHT = 1.0
CM_WEIGHT = 0.1
INFONCE_TEMP = 0.07
# Warm start
V1_CKPT = ""
print("=" * 60)
print("CIFAR-10 β€” Dual-Stream GeoLIP ViT β€” EXP 3 (FULL CV)")
print(f" Warm start from: {V1_CKPT}")
print(f" CV: weight={CV_WEIGHT} (FULL), target={CV_TARGET}")
print(f" InfoNCE: weight={INFONCE_WEIGHT} (REDUCED)")
print(f" Autograd: tang={AUTOGRAD_TANG}, sep={AUTOGRAD_SEP}")
print(f" LR: {LR}, epochs: {EPOCHS}")
print(f" Device: {DEVICE}")
print("=" * 60)
# ══════════════════════════════════════════════════════════════════
# DATA
# ══════════════════════════════════════════════════════════════════
CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR_STD = (0.2470, 0.2435, 0.2616)
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
class TwoViewDataset(torch.utils.data.Dataset):
def __init__(self, base_ds, transform):
self.base = base_ds
self.transform = transform
def __len__(self):
return len(self.base)
def __getitem__(self, idx):
img, label = self.base.data[idx], self.base.targets[idx]
from PIL import Image
img = Image.fromarray(img)
return self.transform(img), self.transform(img), label
raw_train = datasets.CIFAR10(root='./data', train=True, download=True)
train_ds = TwoViewDataset(raw_train, train_transform)
val_ds = datasets.CIFAR10(root='./data', train=False,
download=True, transform=val_transform)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=BATCH, shuffle=True,
num_workers=2, pin_memory=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(
val_ds, batch_size=BATCH, shuffle=False,
num_workers=2, pin_memory=True)
CIFAR_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
print(f" Train: {len(train_ds):,} (two views) Val: {len(val_ds):,}")
# ══════════════════════════════════════════════════════════════════
# BUILD MODEL + WARM START
# ══════════════════════════════════════════════════════════════════
print(f"\n Building model...")
model = create_dual_stream_vit(
num_classes=NUM_CLASSES, img_size=IMG_SIZE, patch_size=PATCH_SIZE,
embed_dim=EMBED_DIM, stream_dim=STREAM_DIM, fused_dim=FUSED_DIM,
n_dual_blocks=N_DUAL_BLOCKS, n_fused_blocks=N_FUSED_BLOCKS,
n_heads=N_HEADS, output_dim=OUTPUT_DIM,
n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP,
anchor_drop=ANCHOR_DROP, cv_target=CV_TARGET,
dropout=0.1, infonce_temp=INFONCE_TEMP,
infonce_weight=INFONCE_WEIGHT, bce_weight=BCE_WEIGHT,
cm_weight=CM_WEIGHT, cv_weight=CV_WEIGHT,
autograd_tang=AUTOGRAD_TANG, autograd_sep=AUTOGRAD_SEP,
enable_autograd=ENABLE_AUTOGRAD,
).to(DEVICE)
# Load v1 weights
if os.path.exists(V1_CKPT):
ckpt = torch.load(V1_CKPT, map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["state_dict"])
print(f" βœ“ Loaded v1 weights: epoch {ckpt['epoch']}, "
f"val_acc {ckpt['val_acc']:.1f}%")
else:
print(f" ⚠ No v1 checkpoint found at {V1_CKPT}, training from scratch")
n_params = sum(p.numel() for p in model.parameters())
# Param groups: geo params get separate tracking
geo_names = {'geo_proj', 'dual_blocks', 'constellation', 'patchwork'}
geo_params, std_params = [], []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if any(gn in name for gn in geo_names):
geo_params.append(param)
else:
std_params.append(param)
n_geo = sum(p.numel() for p in geo_params)
n_std = sum(p.numel() for p in std_params)
print(f" Parameters: {n_params:,}")
print(f" Geo route: {n_geo:,} ({100*n_geo/n_params:.1f}%)")
print(f" Std route: {n_std:,} ({100*n_std/n_params:.1f}%)")
# ══════════════════════════════════════════════════════════════════
# TRAINING
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*60}")
print(f"TRAINING β€” {EPOCHS} epochs, lr={LR}")
print(f" CV={CV_WEIGHT}, autograd={'ON' if ENABLE_AUTOGRAD else 'OFF'}")
print(f" Mastery: patience=50 batches, queue=4096")
print(f"{'='*60}")
optimizer = torch.optim.Adam([
{'params': geo_params, 'lr': LR},
{'params': std_params, 'lr': LR},
], lr=LR)
total_steps = len(train_loader) * EPOCHS
warmup_steps = len(train_loader) * WARMUP
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
[torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.1, total_iters=warmup_steps),
torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=max(total_steps - warmup_steps, 1), eta_min=1e-6)],
milestones=[warmup_steps])
scaler = torch.amp.GradScaler("cuda")
os.makedirs("checkpoints", exist_ok=True)
writer = SummaryWriter("runs/cifar10_dual_stream_v3_mastery")
best_acc = 0.0
gs = 0
# Mastery queue β€” activates after 50 consecutive perfect nce_acc batches
mastery = MasteryQueue(dim=OUTPUT_DIM, max_size=4096, patience=50, device=DEVICE)
for epoch in range(EPOCHS):
model.train()
t0 = time.time()
acc_dict = {"loss": 0, "bce": 0, "nce": 0, "nce_acc": 0,
"cm": 0, "cm_valid": 0, "cv": 0, "cv_fused": 0, "cv_geo": 0,
"spread": 0, "mastery": 0, "hard_neg": 0, "hard_pos": 0,
"correct": 0, "total": 0, "n": 0}
pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="batch")
for v1, v2, labels in pbar:
v1 = v1.to(DEVICE, non_blocking=True)
v2 = v2.to(DEVICE, non_blocking=True)
labels = labels.to(DEVICE, non_blocking=True)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
out1 = model(v1, targets=labels)
out2 = model(v2, targets=labels)
loss, ld = model.compute_loss(
out1, labels, output_aug=out2, mastery_queue=mastery)
# Check mastery activation
mastery.check_activation(ld.get('nce_acc', 0))
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
scheduler.step()
with torch.no_grad():
preds = out1['logits'].argmax(dim=-1)
acc_dict["correct"] += (preds == labels).sum().item()
acc_dict["total"] += labels.shape[0]
acc_dict["loss"] += loss.item()
for k in ["bce", "nce", "cm", "cv", "spread", "mastery"]:
v = ld.get(k, 0)
acc_dict[k] += v.item() if torch.is_tensor(v) else v
acc_dict["nce_acc"] += ld.get("nce_acc", 0)
acc_dict["cm_valid"] += ld.get("cm_valid", 0)
acc_dict["hard_neg"] += ld.get("hard_neg_cos", 0)
acc_dict["hard_pos"] += ld.get("hard_pos_cos", 0)
acc_dict["cv_fused"] += ld.get("cv_fused", 0)
acc_dict["cv_geo"] += ld.get("cv_geo", 0)
acc_dict["n"] += 1; gs += 1
if gs % 20 == 0:
writer.add_scalar("step/loss", loss.item(), gs)
writer.add_scalar("step/cv", ld.get("cv", torch.tensor(0)).item()
if torch.is_tensor(ld.get("cv", 0))
else ld.get("cv", 0), gs)
if mastery.active:
writer.add_scalar("step/mastery",
ld.get("mastery", torch.tensor(0)).item()
if torch.is_tensor(ld.get("mastery", 0))
else ld.get("mastery", 0), gs)
if acc_dict["n"] % 10 == 0:
d = acc_dict["n"]
train_acc = 100 * acc_dict["correct"] / acc_dict["total"]
cv_mean = acc_dict["cv"] / d
cvf = acc_dict["cv_fused"] / d
cvg = acc_dict["cv_geo"] / d
cmv = acc_dict["cm_valid"] / d
mst = acc_dict["mastery"] / d
stage = "M" if mastery.active else "S1"
pbar.set_postfix(
loss=f"{acc_dict['loss']/d:.4f}",
acc=f"{train_acc:.1f}%",
cvf=f"{cvf:.4f}",
cvg=f"{cvg:.4f}",
cm=f"{cmv:.0%}",
mst=f"{mst:.4f}",
stg=stage,
ordered=True)
elapsed = time.time() - t0
d = max(acc_dict["n"], 1)
train_acc = 100 * acc_dict["correct"] / acc_dict["total"]
writer.add_scalar("epoch/train_loss", acc_dict["loss"] / d, epoch + 1)
writer.add_scalar("epoch/train_acc", train_acc, epoch + 1)
writer.add_scalar("epoch/nce_acc", acc_dict["nce_acc"] / d, epoch + 1)
writer.add_scalar("epoch/cv_loss", acc_dict["cv"] / d, epoch + 1)
writer.add_scalar("epoch/cv_fused", acc_dict["cv_fused"] / d, epoch + 1)
writer.add_scalar("epoch/cv_geo", acc_dict["cv_geo"] / d, epoch + 1)
writer.add_scalar("epoch/cm_valid", acc_dict["cm_valid"] / d, epoch + 1)
# ── Validation ──
model.eval()
val_correct, val_total, val_loss_sum, val_n = 0, 0, 0, 0
all_embs = []
with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
for images, labels_v in val_loader:
images = images.to(DEVICE, non_blocking=True)
labels_v = labels_v.to(DEVICE, non_blocking=True)
out = model(images, apply_autograd=False)
preds = out['logits'].argmax(dim=-1)
val_correct += (preds == labels_v).sum().item()
val_total += labels_v.shape[0]
one_hot = F.one_hot(labels_v, NUM_CLASSES).float()
loss_v = F.binary_cross_entropy_with_logits(out['logits'], one_hot)
val_loss_sum += loss_v.item()
val_n += 1
all_embs.append(out['embedding'].float().cpu())
val_acc = 100 * val_correct / val_total
val_loss = val_loss_sum / max(val_n, 1)
# Quick CV check on val embeddings
embs = torch.cat(all_embs)
with torch.no_grad():
sample = embs[:2000].to(DEVICE)
vols = []
for _ in range(200):
idx = torch.randperm(2000)[:5]
pts = sample[idx].unsqueeze(0).float()
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
d2 = F.relu(d2)
cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
v2 = -torch.linalg.det(cm) / 9216
if v2[0].item() > 1e-20:
vols.append(v2[0].sqrt())
if len(vols) > 10:
vols_t = torch.stack(vols)
v_cv = (vols_t.std() / (vols_t.mean() + 1e-8)).item()
else:
v_cv = 0.0
# Anchor utilization
with torch.no_grad():
_, v_np = model.constellation.triangulate(
embs[:2000].to(DEVICE), training=False)
n_active = v_np.cpu().unique().numel()
writer.add_scalar("epoch/val_acc", val_acc, epoch + 1)
writer.add_scalar("epoch/val_cv", v_cv, epoch + 1)
writer.add_scalar("epoch/val_anchors", n_active, epoch + 1)
mk = ""
if val_acc > best_acc:
best_acc = val_acc
torch.save({
"state_dict": model.state_dict(),
"config": model.config,
"epoch": epoch + 1,
"val_acc": val_acc,
"val_loss": val_loss,
"val_cv": v_cv,
"mastery": mastery.state_dict(),
}, "checkpoints/dual_stream_v3_best.pt")
mk = " β˜…"
if (epoch + 1) % 10 == 0:
torch.save({
"state_dict": model.state_dict(),
"config": model.config,
"epoch": epoch + 1,
"val_acc": val_acc,
"optimizer": optimizer.state_dict(),
}, f"checkpoints/dual_stream_v3_e{epoch+1:03d}.pt")
cv_m = acc_dict["cv"] / d
cvf = acc_dict["cv_fused"] / d
cvg = acc_dict["cv_geo"] / d
nce_a = acc_dict["nce_acc"] / d
cmv = acc_dict["cm_valid"] / d
mst_m = acc_dict["mastery"] / d
hn = acc_dict["hard_neg"] / d if mastery.active else 0
hp = acc_dict["hard_pos"] / d if mastery.active else 0
stage = "MASTERY" if mastery.active else "stage1"
print(f" E{epoch+1:3d}: train={train_acc:.1f}% val={val_acc:.1f}% "
f"loss={acc_dict['loss']/d:.4f}/{val_loss:.4f} "
f"cv={v_cv:.4f}(f={cvf:.5f} g={cvg:.5f}) "
f"nce={nce_a:.2f} cm={cmv:.0%} anch={n_active}/{N_ANCHORS} "
f"[{stage}] mst={mst_m:.4f} hn={hn:.3f} hp={hp:.3f} "
f"q={mastery.size} ({elapsed:.0f}s){mk}")
writer.close()
print(f"\n Best val accuracy: {best_acc:.1f}%")
print(f"\n{'='*60}")
print("DONE")
print(f"{'='*60}")