geolip-vit-tri-stream / trainer.py
AbstractPhil's picture
Update trainer.py
3eebacd verified
#!/usr/bin/env python3
"""
CIFAR-10 β€” Tri-Stream GeoLIP ViT v8
=====================================
v7β†’v8 changes:
1. GAL_UPDATE_INTERVAL: 50 β†’ 25 (2Γ— more frequent)
2. GAL_LR: 0.01 β†’ 0.015 (+50% response)
3. Tracks nce_b and geo_nce_acc separately
4. stream_b_nce_weight=0.5, geo_nce_weight=0.5
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os, time
import numpy as np
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# ── Architecture ──
NUM_CLASSES = 10
IMG_SIZE = 32
PATCH_SIZE = 4
EMBED_DIM = 384
STREAM_DIM = 192
N_BLOCKS = 9
N_HEADS = 8
OUTPUT_DIM = 256
N_ANCHORS = 128
N_GAL_ANCHORS = 64
N_COMP = 16
D_COMP = 128
ANCHOR_DROP = 0.10
CV_TARGET = 0.22
# ── Loss weights ──
CV_WEIGHT = 0.1
ENABLE_AUTOGRAD = True
AUTOGRAD_TANG = 1.0
AUTOGRAD_SEP = 0.1
LABEL_SMOOTHING = 0.1
INFONCE_WEIGHT = 0.1
BCE_WEIGHT = 1.0
CM_WEIGHT = 0.1
INFONCE_TEMP = 0.07
# ── v8: Stream B + Geo NCE weights ──
STREAM_B_NCE_WEIGHT = 0.5
GEO_NCE_WEIGHT = 0.5
# ── v8: GAL β€” faster updates, stronger response ──
GAL_UPDATE_INTERVAL = 25 # was 50
GAL_LR = 0.015 # was 0.01 (+50%)
GAL_BUFFER_SIZE = 50000
USE_WHITENED_PROCRUSTES = False
# ── Mastery queue ──
MASTERY_PATIENCE = 50
MASTERY_MARGIN_START = 0.1
MASTERY_MARGIN_END = 0.3
MASTERY_MARGIN_WARMUP = 5000
MASTERY_MIN_SIZE = 1024
MASTERY_MAX_SIZE = 16384
MASTERY_INITIAL_SIZE = 4096
MASTERY_RESIZE_STEP = 2048
MASTERY_RESIZE_COOLDOWN = 5
MASTERY_OVERFIT_THRESH = 3.0
# ── Training ──
BATCH = 256
EPOCHS = 100
LR = 3e-4
WARMUP = 5
GRAD_CLIP = 1.0
V1_CKPT = "" # set to checkpoint path for warm start
print("=" * 60)
print("CIFAR-10 β€” Tri-Stream GeoLIP ViT v8")
print(f" Architecture: {N_BLOCKS}Γ— TriStreamBlock")
print(f" Sphere: {OUTPUT_DIM}-d, {N_ANCHORS} anchors, {N_COMP}Γ—{D_COMP} pw")
print(f" GAL: {N_GAL_ANCHORS} anchors, Procrustes every {GAL_UPDATE_INTERVAL} "
f"batches (lr={GAL_LR}, whiten={USE_WHITENED_PROCRUSTES})")
print(f" v8 fixes: uniform hypersphere init, gate_init=1/(2Γ—{N_BLOCKS})")
print(f" v8 fixes: InfoNCE on emb_b (w={STREAM_B_NCE_WEIGHT}) "
f"+ geo_emb (w={GEO_NCE_WEIGHT})")
print(f" Device: {DEVICE}")
print("=" * 60)
# ══════════════════════════════════════════════════════════════════
# DATA
# ══════════════════════════════════════════════════════════════════
CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR_STD = (0.2470, 0.2435, 0.2616)
class DualAugDataset(torch.utils.data.Dataset):
def __init__(self, base_ds, transform):
self.base = base_ds; self.transform = transform
def __len__(self): return len(self.base)
def __getitem__(self, i):
img, label = self.base[i]
return self.transform(img), self.transform(img), label
aug_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.2, 0.2, 0.2, 0.05),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
raw_train = datasets.CIFAR10(root='./data', train=True, download=True)
train_ds = DualAugDataset(raw_train, aug_transform)
val_ds = datasets.CIFAR10(root='./data', train=False,
download=True, transform=val_transform)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=BATCH, shuffle=True,
num_workers=2, pin_memory=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(
val_ds, batch_size=BATCH, shuffle=False,
num_workers=2, pin_memory=True)
print(f" Train: {len(train_ds):,} (two views) Val: {len(val_ds):,}")
# ══════════════════════════════════════════════════════════════════
# BUILD MODEL
# ══════════════════════════════════════════════════════════════════
print(f"\n Building model...")
model = create_tri_stream_vit(
num_classes=NUM_CLASSES, img_size=IMG_SIZE, patch_size=PATCH_SIZE,
embed_dim=EMBED_DIM, stream_dim=STREAM_DIM, n_blocks=N_BLOCKS,
n_heads=N_HEADS, output_dim=OUTPUT_DIM,
n_anchors=N_ANCHORS, n_gal_anchors=N_GAL_ANCHORS,
n_comp=N_COMP, d_comp=D_COMP,
anchor_drop=ANCHOR_DROP, cv_target=CV_TARGET,
dropout=0.1, infonce_temp=INFONCE_TEMP,
infonce_weight=INFONCE_WEIGHT, bce_weight=BCE_WEIGHT,
cm_weight=CM_WEIGHT, cv_weight=CV_WEIGHT,
autograd_tang=AUTOGRAD_TANG, autograd_sep=AUTOGRAD_SEP,
enable_autograd=ENABLE_AUTOGRAD,
label_smoothing=LABEL_SMOOTHING,
stream_b_nce_weight=STREAM_B_NCE_WEIGHT,
geo_nce_weight=GEO_NCE_WEIGHT,
).to(DEVICE)
if V1_CKPT and os.path.exists(V1_CKPT):
ckpt = torch.load(V1_CKPT, map_location="cpu", weights_only=False)
missing, unexpected = model.load_state_dict(
ckpt["state_dict"], strict=False)
print(f" βœ“ Loaded weights: epoch {ckpt.get('epoch', '?')}")
if missing:
print(f" New params (expected): {len(missing)}")
else:
print(f" Training from scratch")
total_params = sum(p.numel() for p in model.parameters())
print(f" Parameters: {total_params:,}")
# ══════════════════════════════════════════════════════════════════
# OPTIMIZER + SCHEDULER
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*60}")
print(f"TRAINING β€” {EPOCHS} epochs, lr={LR}, batch={BATCH}")
print(f" GAL Procrustes: every {GAL_UPDATE_INTERVAL} batches, "
f"lr={GAL_LR}, whiten={USE_WHITENED_PROCRUSTES}")
print(f"{'='*60}")
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
total_steps = len(train_loader) * EPOCHS
warmup_steps = len(train_loader) * WARMUP
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
[torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.01, total_iters=warmup_steps),
torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=max(total_steps - warmup_steps, 1), eta_min=1e-6)],
milestones=[warmup_steps])
scaler = torch.amp.GradScaler("cuda")
os.makedirs("checkpoints", exist_ok=True)
writer = SummaryWriter("runs/cifar10_tri_stream_v8")
best_acc = 0.0
gs = 0
# Mastery queue
mastery = MasteryQueue(
dim=OUTPUT_DIM, min_size=MASTERY_MIN_SIZE, max_size=MASTERY_MAX_SIZE,
initial_size=MASTERY_INITIAL_SIZE, patience=MASTERY_PATIENCE,
device=DEVICE, margin_start=MASTERY_MARGIN_START,
margin_end=MASTERY_MARGIN_END, margin_warmup=MASTERY_MARGIN_WARMUP,
resize_step=MASTERY_RESIZE_STEP, resize_cooldown=MASTERY_RESIZE_COOLDOWN,
overfit_threshold=MASTERY_OVERFIT_THRESH)
# GAL simplex buffer
simplex_buf = SimplexBuffer(
dim=STREAM_DIM, max_size=GAL_BUFFER_SIZE, device=DEVICE)
gal_update_count = 0
# ══════════════════════════════════════════════════════════════════
# TRAINING LOOP
# ══════════════════════════════════════════════════════════════════
for epoch in range(EPOCHS):
model.train()
t0 = time.time()
acc_dict = {
"loss": 0, "ce": 0, "bce": 0, "geo_bce": 0,
"acc_a": 0, "acc_b": 0, "geo_acc": 0,
"nce": 0, "nce_acc": 0,
"nce_b": 0, "nce_b_acc": 0,
"geo_nce": 0, "geo_nce_acc": 0,
"cm": 0, "cm_valid": 0, "cv": 0, "cv_main": 0, "cv_geo": 0,
"spread": 0, "mastery": 0, "hard_neg": 0, "hard_pos": 0,
"correct": 0, "total": 0, "n": 0}
pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}",
unit="batch")
for v1, v2, targets in pbar:
v1 = v1.to(DEVICE, non_blocking=True)
v2 = v2.to(DEVICE, non_blocking=True)
targets = targets.to(DEVICE, non_blocking=True)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
out1 = model(v1, apply_autograd=True)
out2 = model(v2, apply_autograd=True)
loss, ld = model.compute_loss(
out1, targets, output_aug=out2, mastery_queue=mastery)
optimizer.zero_grad(set_to_none=True)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
scaler.step(optimizer); scaler.update()
scheduler.step()
mastery.check_activation(ld.get('nce_acc', 0))
pool_geo = out1.get('pool_geo')
if pool_geo is not None:
simplex_buf.push(pool_geo.float(), targets)
gs += 1
if gs % GAL_UPDATE_INTERVAL == 0 and simplex_buf.size > 500:
score = model.update_gal_anchors(
simplex_buf, lr=GAL_LR, whiten=USE_WHITENED_PROCRUSTES)
if score is not None:
gal_update_count += 1
writer.add_scalar("step/procrustes_score", score, gs)
# Track
preds = out1['logits_a'].argmax(-1)
correct = (preds == targets).sum().item()
acc_dict["correct"] += correct
acc_dict["total"] += targets.shape[0]
acc_dict["loss"] += loss.item()
for k in ["ce", "bce", "geo_bce", "nce", "nce_b", "geo_nce",
"cm", "cv", "spread", "mastery"]:
v = ld.get(k, 0)
acc_dict[k] += v.item() if torch.is_tensor(v) else v
acc_dict["acc_a"] += ld.get("acc_a", 0)
acc_dict["acc_b"] += ld.get("acc_b", 0)
acc_dict["geo_acc"] += ld.get("geo_acc", 0)
acc_dict["nce_acc"] += ld.get("nce_acc", 0)
acc_dict["nce_b_acc"] += ld.get("nce_b_acc", 0)
acc_dict["geo_nce_acc"] += ld.get("geo_nce_acc", 0)
acc_dict["cm_valid"] += ld.get("cm_valid", 0)
acc_dict["cv_main"] += ld.get("cv_main", 0)
acc_dict["cv_geo"] += ld.get("cv_geo", 0)
acc_dict["hard_neg"] += ld.get("hard_neg_cos", 0)
acc_dict["hard_pos"] += ld.get("hard_pos_cos", 0)
acc_dict["n"] += 1
if acc_dict["n"] % 10 == 0:
d = acc_dict["n"]
ta = 100 * acc_dict["correct"] / acc_dict["total"]
ga = 100 * acc_dict["geo_acc"] / d
nb = acc_dict["nce_b_acc"] / d
stg = "M" if mastery.active else "S1"
pbar.set_postfix(
loss=f"{acc_dict['loss']/d:.4f}",
a=f"{ta:.0f}%",
ga=f"{ga:.0f}%",
nb=f"{nb:.2f}",
stg=stg,
gal=gal_update_count,
ordered=True)
if gs % 20 == 0:
writer.add_scalar("step/loss", loss.item(), gs)
writer.add_scalar("step/geo_acc", ld.get("geo_acc", 0), gs)
writer.add_scalar("step/nce_b_acc", ld.get("nce_b_acc", 0), gs)
writer.add_scalar("step/geo_nce_acc", ld.get("geo_nce_acc", 0), gs)
gates_a = out1.get('gates_a', [])
if gates_a:
writer.add_scalar("step/gate_a_mean",
sum(gates_a) / len(gates_a), gs)
writer.add_scalar("step/gate_b_mean",
sum(out1.get('gates_b', [0])) / max(len(gates_a), 1), gs)
# ── Epoch stats ──
elapsed = time.time() - t0
d = acc_dict["n"]
train_acc = 100 * acc_dict["correct"] / acc_dict["total"]
writer.add_scalar("epoch/train_loss", acc_dict["loss"] / d, epoch + 1)
writer.add_scalar("epoch/train_acc", train_acc, epoch + 1)
writer.add_scalar("epoch/acc_a", 100 * acc_dict["acc_a"] / d, epoch + 1)
writer.add_scalar("epoch/acc_b", 100 * acc_dict["acc_b"] / d, epoch + 1)
writer.add_scalar("epoch/geo_acc", 100 * acc_dict["geo_acc"] / d, epoch + 1)
writer.add_scalar("epoch/nce_acc", acc_dict["nce_acc"] / d, epoch + 1)
writer.add_scalar("epoch/nce_b_acc", acc_dict["nce_b_acc"] / d, epoch + 1)
writer.add_scalar("epoch/geo_nce_acc", acc_dict["geo_nce_acc"] / d, epoch + 1)
writer.add_scalar("epoch/cv_main", acc_dict["cv_main"] / d, epoch + 1)
writer.add_scalar("epoch/cv_geo", acc_dict["cv_geo"] / d, epoch + 1)
writer.add_scalar("epoch/cm_valid", acc_dict["cm_valid"] / d, epoch + 1)
writer.add_scalar("epoch/gal_updates", gal_update_count, epoch + 1)
# ── Validation ──
model.eval()
val_correct, val_total, val_loss_sum, val_n = 0, 0, 0, 0
val_geo_correct = 0
val_b_correct = 0
all_embs = []
with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
for images, labels_v in val_loader:
images = images.to(DEVICE, non_blocking=True)
labels_v = labels_v.to(DEVICE, non_blocking=True)
out = model(images, apply_autograd=False)
preds = out['logits_a'].argmax(dim=-1)
val_correct += (preds == labels_v).sum().item()
val_b_correct += (out['logits_b'].argmax(-1) == labels_v).sum().item()
val_geo_correct += (out['geo_logits'].argmax(-1) == labels_v).sum().item()
val_total += labels_v.shape[0]
loss_v = F.cross_entropy(out['logits_a'], labels_v)
val_loss_sum += loss_v.item()
val_n += 1
all_embs.append(out['embedding'].float().cpu())
val_acc = 100 * val_correct / val_total
val_b_acc = 100 * val_b_correct / val_total
val_geo_acc = 100 * val_geo_correct / val_total
val_loss = val_loss_sum / max(val_n, 1)
# ── Val embedding diagnostics ──
embs = torch.cat(all_embs)
with torch.no_grad():
sample = embs[:2000].to(DEVICE)
vols = []
for _ in range(200):
idx = torch.randperm(2000)[:5]
pts = sample[idx].unsqueeze(0).float()
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
d2 = F.relu(d2)
cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
v2 = -torch.linalg.det(cm) / 9216
if v2[0].item() > 1e-20:
vols.append(v2[0].sqrt())
v_cv = (torch.stack(vols).std() / (torch.stack(vols).mean() + 1e-8)).item() if len(vols) > 10 else 0.0
with torch.no_grad():
_, v_np = model.constellation.triangulate(
embs[:2000].to(DEVICE), training=False)
n_active = v_np.cpu().unique().numel()
writer.add_scalar("epoch/val_acc", val_acc, epoch + 1)
writer.add_scalar("epoch/val_b_acc", val_b_acc, epoch + 1)
writer.add_scalar("epoch/val_geo_acc", val_geo_acc, epoch + 1)
writer.add_scalar("epoch/val_cv", v_cv, epoch + 1)
writer.add_scalar("epoch/val_anchors", n_active, epoch + 1)
mastery.update_size(train_acc, val_acc, epoch + 1)
# ── Checkpoint ──
mk = ""
if val_acc > best_acc:
best_acc = val_acc
torch.save({
"state_dict": model.state_dict(),
"config": model.config,
"epoch": epoch + 1,
"val_acc": val_acc,
"val_b_acc": val_b_acc,
"val_geo_acc": val_geo_acc,
"mastery": mastery.state_dict(),
"gal_updates": gal_update_count,
}, "checkpoints/tri_stream_v8_best.pt")
mk = " β˜…"
if (epoch + 1) % 10 == 0:
torch.save({
"state_dict": model.state_dict(),
"config": model.config,
"epoch": epoch + 1,
"val_acc": val_acc,
"optimizer": optimizer.state_dict(),
}, f"checkpoints/tri_stream_v8_e{epoch+1:03d}.pt")
# ── Epoch print β€” v8: shows B acc + nce_b + geo_nce ──
ga = 100 * acc_dict["geo_acc"] / d
ab = 100 * acc_dict["acc_b"] / d
nb_acc = acc_dict["nce_b_acc"] / d
gn_acc = acc_dict["geo_nce_acc"] / d
cvf = acc_dict["cv_main"] / d
cvg = acc_dict["cv_geo"] / d
cmv = acc_dict["cm_valid"] / d
stage = "MASTERY" if mastery.active else "stage1"
# Gate check
last_gates = []
try:
model.eval()
with torch.no_grad():
sample_imgs = next(iter(val_loader))[0][:4].to(DEVICE)
sample_out = model(sample_imgs, apply_autograd=False)
last_gates = sample_out.get('gates_a', [])
except:
pass
gate_str = f"g={np.mean(last_gates):.4f}" if last_gates else "g=?"
print(f" E{epoch+1:3d}: A={train_acc:.1f}% B={ab:.0f}% "
f"val={val_acc:.1f}%/{val_b_acc:.1f}%/{val_geo_acc:.1f}% "
f"loss={acc_dict['loss']/d:.4f}/{val_loss:.4f} "
f"nb={nb_acc:.2f} gn={gn_acc:.2f} "
f"cv={v_cv:.4f}(m={cvf:.5f} g={cvg:.5f}) "
f"cm={cmv:.0%} anch={n_active}/{N_ANCHORS} "
f"[{stage}] {gate_str} "
f"gal={gal_update_count} ({elapsed:.0f}s){mk}")
writer.close()
print(f"\n Best val accuracy: {best_acc:.1f}%")
print(f"\n{'='*60}")
print("DONE")
print(f"{'='*60}")