geolip-spectral-vit / trainer.py
AbstractPhil's picture
Update trainer.py
3e3ca72 verified
"""
SpectralViT CIFAR-100 Trainer
==============================
Pure SpectralCell transformer. No conv backbone.
Cayley hypersphere positional encoding.
Same proven training recipe:
- Soft hand: σ=0.15 boost=1.0 cv_penalty=0.01
- CV from cm_vol2 in compute graph + EMA drift tracking
- CutMix α=1.0 prob=0.5
- Grad clip 0.5 cross-attn only
- Adam (no weight decay)
SpectralCell, SpectralViT, and cv_of are in namespace from prior cells.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from tqdm import tqdm
import torchvision
import torchvision.transforms as T
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# ── CutMix ───────────────────────────────────────────────────────
def cutmix_data(x, y, alpha=1.0):
B = x.shape[0]
lam = np.random.beta(alpha, alpha)
idx = torch.randperm(B, device=x.device)
_, _, H, W = x.shape
cut_rat = np.sqrt(1.0 - lam)
cut_h, cut_w = int(H * cut_rat), int(W * cut_rat)
cy, cx = np.random.randint(H), np.random.randint(W)
y1 = np.clip(cy - cut_h // 2, 0, H)
y2 = np.clip(cy + cut_h // 2, 0, H)
x1 = np.clip(cx - cut_w // 2, 0, W)
x2 = np.clip(cx + cut_w // 2, 0, W)
x_mixed = x.clone()
x_mixed[:, :, y1:y2, x1:x2] = x[idx, :, y1:y2, x1:x2]
lam = 1.0 - ((y2 - y1) * (x2 - x1) / (H * W))
return x_mixed, y[idx], lam
def cutmix_criterion(criterion, logits, y_a, y_b, lam):
return lam * criterion(logits, y_a) + (1.0 - lam) * criterion(logits, y_b)
# ── CV from cm_vol2 ─────────────────────────────────────────────
@torch.no_grad()
def compute_target_cv(V, D, n_trials=20, n_samples=200, device='cuda'):
cvs = []
for _ in range(n_trials):
pts = F.normalize(torch.randn(V, D, device=device), dim=-1)
cv = cv_of(pts, n_samples=n_samples)
if cv > 0:
cvs.append(cv)
return sum(cvs) / len(cvs) if cvs else 0.0
def batch_cv_from_cm(cm_vol2):
valid = cm_vol2.abs() > 1e-16
if valid.sum() < 10:
return torch.tensor(0.0, device=cm_vol2.device), False
vols = cm_vol2[valid].abs().sqrt()
cv = vols.std() / (vols.mean() + 1e-8)
return cv, True
# ── Data ─────────────────────────────────────────────────────────
N_CLASSES = 100
mean = (0.5, 0.5, 0.5)
std = (0.2, 0.2, 0.2)
train_transform = T.Compose([
T.RandomCrop(32, padding=4),
T.RandomHorizontalFlip(),
T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
T.RandomRotation(15),
T.ToTensor(),
T.Normalize(mean, std),
T.RandomErasing(p=0.25, scale=(0.02, 0.2)),
])
test_transform = T.Compose([T.ToTensor(), T.Normalize(mean, std)])
cifar_train = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
cifar_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)
train_loader = torch.utils.data.DataLoader(
cifar_train, batch_size=256, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(
cifar_test, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)
# ── Config ───────────────────────────────────────────────────────
EPOCHS = 200
LR = 1e-3
LABEL_SMOOTH = 0.1
SIGMA = 0.15
BOOST = 1.0
CV_PENALTY_W = 0.01
EMA_MOMENTUM = 0.99
CUTMIX_ALPHA = 1.0
CUTMIX_PROB = 0.5
# ── Build ────────────────────────────────────────────────────────
model = SpectralViT(
img_size=32,
patch_size=4,
in_channels=3,
embed_dim=256,
depth=6,
cell_V=16,
cell_D=16,
cell_hidden=256,
cell_depth=2,
n_cross=2,
n_heads=4,
n_classes=N_CLASSES,
dropout=0.1,
).to(device)
cross_attn_params = model.get_cross_attn_params()
print(f"Computing target CV for V=16, D=16 on S^15...")
target_cv = compute_target_cv(16, 16, n_trials=20, device=device)
print(f"\n{'═' * 70}")
print(f" SpectralViT — Pure SpectralCell Transformer")
print(f"{'═' * 70}")
model.summary()
print(f" Soft hand: target_cv={target_cv:.4f} σ={SIGMA} boost={BOOST}")
print(f" CV penalty: {CV_PENALTY_W} (differentiable through cm_vol2)")
print(f" EMA momentum: {EMA_MOMENTUM}")
print(f" Grad clip: 0.5 cross-attn only, uncapped otherwise")
print(f" CutMix: α={CUTMIX_ALPHA} prob={CUTMIX_PROB}")
print(f" Optimizer: Adam lr={LR}")
print(f" CIFAR-100, {EPOCHS} epochs")
criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH)
opt = torch.optim.Adam(model.parameters(), lr=LR)
def cosine_lr(epoch):
t = epoch / float(max(1, EPOCHS - 1))
min_ratio = 1e-5 / LR
return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * t))
sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=cosine_lr)
# ── Component Profiler ────────────────────────────────────────────
@torch.no_grad()
def profile_cell_internals(cell, tokens, device):
"""Profile each stage inside a single SpectralCell.format() call."""
B, N, _ = tokens.shape
timings = {}
# Encoder MLP
flat = tokens.reshape(B * N, -1)
torch.cuda.synchronize(); t = time.perf_counter()
h = F.gelu(cell.enc_in(flat))
for block in cell.enc_blocks:
h = h + block(h)
M = cell.enc_out(h).reshape(B * N, cell.V, cell.D)
torch.cuda.synchronize()
timings['enc_mlp'] = (time.perf_counter() - t) * 1000
# Magnitude capture + normalize
torch.cuda.synchronize(); t = time.perf_counter()
row_mag = M.norm(dim=-1)
M = F.normalize(M, dim=-1)
torch.cuda.synchronize()
timings['normalize'] = (time.perf_counter() - t) * 1000
# CM validation (if enabled)
if cell.cm_enabled and cell.cm is not None:
nv = cell._cm_k + 1
cm_idx = torch.linspace(0, cell.V - 1, nv).long().to(device)
torch.cuda.synchronize(); t = time.perf_counter()
cm_verts = M[:, cm_idx, :]
cm_d2, cm_vol2 = cell.cm(cm_verts)
torch.cuda.synchronize()
timings['cm_validation'] = (time.perf_counter() - t) * 1000
else:
timings['cm_validation'] = 0.0
# Full pairwise d² + patchwork
torch.cuda.synchronize(); t = time.perf_counter()
gram = torch.bmm(M, M.transpose(1, 2))
d2_full = 2.0 - 2.0 * gram
d2_pairs = d2_full[:, cell._triu_i, cell._triu_j]
torch.cuda.synchronize()
timings['pairwise_d2'] = (time.perf_counter() - t) * 1000
torch.cuda.synchronize(); t = time.perf_counter()
pw_features = cell.patchwork(d2_pairs)
torch.cuda.synchronize()
timings['patchwork'] = (time.perf_counter() - t) * 1000
# SVD
torch.cuda.synchronize(); t = time.perf_counter()
U, S, Vt = batched_svd(M)
torch.cuda.synchronize()
timings['svd_eigh'] = (time.perf_counter() - t) * 1000
# Cross-attention
S = S.reshape(B, N, cell.D)
torch.cuda.synchronize(); t = time.perf_counter()
for layer in cell.cross_attn:
S = layer(S)
torch.cuda.synchronize()
timings['cross_attn'] = (time.perf_counter() - t) * 1000
# Recompose
U_flat = U.reshape(B * N, cell.V, cell.D)
S_flat = S.reshape(B * N, cell.D)
Vt_flat = Vt.reshape(B * N, cell.D, cell.D)
torch.cuda.synchronize(); t = time.perf_counter()
M_hat = torch.bmm(U_flat * S_flat.unsqueeze(1), Vt_flat)
torch.cuda.synchronize()
timings['recompose'] = (time.perf_counter() - t) * 1000
# Output MLP
out_features = torch.cat([M_hat.reshape(B * N, -1), pw_features, row_mag], dim=-1)
torch.cuda.synchronize(); t = time.perf_counter()
h = F.gelu(cell.out_in(out_features))
for block in cell.out_blocks:
h = h + block(h)
output = cell.out_proj(h)
torch.cuda.synchronize()
timings['dec_mlp'] = (time.perf_counter() - t) * 1000
return timings
@torch.no_grad()
def profile_forward(model, images, device):
"""Profile each component of SpectralViT forward pass.
Returns dict of component → time_ms.
"""
model.eval()
images = images.to(device)
torch.cuda.synchronize()
timings = {}
# Patch embed
torch.cuda.synchronize()
t = time.perf_counter()
tokens = model.patch_embed(images)
torch.cuda.synchronize()
timings['patch_embed'] = (time.perf_counter() - t) * 1000
# Positional encoding
torch.cuda.synchronize()
t = time.perf_counter()
tokens = model.pos_enc(tokens)
torch.cuda.synchronize()
timings['cayley_pe'] = (time.perf_counter() - t) * 1000
# Each cell
for i in range(model.depth - 1):
torch.cuda.synchronize()
t = time.perf_counter()
normed = model.norms[i](tokens)
tokens = tokens + model.cells[i](normed)
torch.cuda.synchronize()
timings[f'cell_{i}'] = (time.perf_counter() - t) * 1000
# Last cell (full .format)
torch.cuda.synchronize()
t = time.perf_counter()
normed = model.norms[-1](tokens)
last_out = model.cells[-1].format(normed)
tokens = tokens + last_out['output']
torch.cuda.synchronize()
timings[f'cell_{model.depth-1}_format'] = (time.perf_counter() - t) * 1000
# Final norm + pool + classifier
torch.cuda.synchronize()
t = time.perf_counter()
tokens = model.final_norm(tokens)
pooled = tokens.mean(dim=1)
logits = model.classifier(pooled)
torch.cuda.synchronize()
timings['classify'] = (time.perf_counter() - t) * 1000
return timings
def profile_full_step(model, images, labels, criterion, opt, cross_attn_params, device):
"""Profile full train step: forward + loss + backward + optimizer."""
model.train()
images, labels = images.to(device), labels.to(device)
timings = {}
# Forward
torch.cuda.synchronize()
t = time.perf_counter()
out = model(images)
torch.cuda.synchronize()
timings['forward'] = (time.perf_counter() - t) * 1000
# Loss
torch.cuda.synchronize()
t = time.perf_counter()
ce_loss = criterion(out['logits'], labels)
torch.cuda.synchronize()
timings['loss'] = (time.perf_counter() - t) * 1000
# Backward
torch.cuda.synchronize()
t = time.perf_counter()
opt.zero_grad(set_to_none=True)
ce_loss.backward()
torch.cuda.synchronize()
timings['backward'] = (time.perf_counter() - t) * 1000
# Grad clip
torch.cuda.synchronize()
t = time.perf_counter()
nn.utils.clip_grad_norm_(cross_attn_params, max_norm=0.5)
torch.cuda.synchronize()
timings['grad_clip'] = (time.perf_counter() - t) * 1000
# Optimizer step
torch.cuda.synchronize()
t = time.perf_counter()
opt.step()
torch.cuda.synchronize()
timings['optim_step'] = (time.perf_counter() - t) * 1000
return timings
def print_profile(timings, label=""):
total = sum(timings.values())
print(f"\n ┌─ PROFILE {label} {'─' * max(0, 48 - len(label))}┐")
for name, ms in sorted(timings.items(), key=lambda x: -x[1]):
pct = ms / total * 100
bar = '█' * int(pct / 3)
print(f" │ {name:<22s} {ms:7.1f}ms {pct:5.1f}% {bar}")
print(f" │ {'TOTAL':<22s} {total:7.1f}ms")
print(f" └{'─' * 55}┘")
# ── Training ─────────────────────────────────────────────────────
# Run initial profile before training
print("\n Initial profiling (3 warmup + 1 measured)...")
sample_batch = next(iter(train_loader))
for _ in range(3): # warmup
_ = model(sample_batch[0].to(device))
torch.cuda.synchronize()
fwd_profile = profile_forward(model, sample_batch[0], device)
print_profile(fwd_profile, "FORWARD COMPONENTS")
step_profile = profile_full_step(model, sample_batch[0], sample_batch[1],
criterion, opt, cross_attn_params, device)
print_profile(step_profile, "FULL TRAIN STEP")
# Cell internals for one cell
with torch.no_grad():
tokens = model.pos_enc(model.patch_embed(sample_batch[0].to(device)))
cell_profile = profile_cell_internals(model.cells[0], tokens, device)
print_profile(cell_profile, "CELL INTERNALS (cell_0)")
best_acc = 0
t0 = time.time()
ema_cv = target_cv
for epoch in range(1, EPOCHS + 1):
model.train()
correct, total = 0, 0
ep_boost_sum, ep_n = 0, 0
ep_cv_penalty_sum = 0
for batch_idx, (images, labels) in enumerate(tqdm(train_loader, desc=f"Ep {epoch:3d}", leave=False)):
images, labels = images.to(device), labels.to(device)
use_cutmix = CUTMIX_ALPHA > 0 and np.random.random() < CUTMIX_PROB
if use_cutmix:
images, labels_b, lam = cutmix_data(images, labels, CUTMIX_ALPHA)
else:
labels_b, lam = labels, 1.0
out = model(images)
if use_cutmix:
ce_loss = cutmix_criterion(criterion, out['logits'], labels, labels_b, lam)
else:
ce_loss = criterion(out['logits'], labels)
# CV from deepest cell's cm_vol2
last_cell = out['last_cell']
batch_cv, cv_valid = batch_cv_from_cm(last_cell['cm_vol2'])
if cv_valid:
with torch.no_grad():
ema_cv = EMA_MOMENTUM * ema_cv + (1.0 - EMA_MOMENTUM) * batch_cv.item()
prox = math.exp(-(ema_cv - target_cv) ** 2 / (2 * SIGMA ** 2))
primary_w = 1.0 + BOOST * prox
cv_w = CV_PENALTY_W * (1.0 - prox)
diff_cv_loss = (batch_cv - target_cv).pow(2) if cv_valid else torch.tensor(0.0, device=device)
loss = primary_w * ce_loss + cv_w * diff_cv_loss
opt.zero_grad(set_to_none=True)
loss.backward()
nn.utils.clip_grad_norm_(cross_attn_params, max_norm=0.5)
opt.step()
correct += (out['logits'].argmax(-1) == labels).sum().item()
total += images.shape[0]
ep_boost_sum += primary_w
ep_cv_penalty_sum += cv_w * diff_cv_loss.item() if cv_valid else 0
ep_n += 1
sched.step()
train_acc = correct / total
avg_boost = ep_boost_sum / ep_n
# Val
model.eval()
val_correct, val_total = 0, 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
out = model(images)
val_correct += (out['logits'].argmax(-1) == labels).sum().item()
val_total += images.shape[0]
val_acc = val_correct / val_total
star = ' ★' if val_acc > best_acc else ''
if val_acc > best_acc:
best_acc = val_acc
lr = opt.param_groups[0]['lr']
if epoch <= 5 or epoch % 5 == 0 or epoch == EPOCHS:
with torch.no_grad():
last_cell = out['last_cell']
S = last_cell['S_orig'].mean(dim=(0, 1))
s_str = ', '.join(f'{v:.2f}' for v in S.tolist()[:4]) + f'...{S[-1]:.2f}'
# Rotation angle magnitudes from PE
angles = model.pos_enc.angles
angle_mean = angles.abs().mean().item()
angle_max = angles.abs().max().item()
print(f" ep{epoch:3d} acc={val_acc:.1%}{star} train={train_acc:.1%} "
f"ema_cv={ema_cv:.4f} boost={avg_boost:.3f} lr={lr:.6f}")
print(f" S=[{s_str}] PE angles: mean={angle_mean:.4f} max={angle_max:.4f}")
if epoch <= 3 or epoch % 20 == 0 or epoch == EPOCHS:
pcc = torch.zeros(N_CLASSES)
pct = torch.zeros(N_CLASSES)
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
out = model(images)
preds = out['logits'].argmax(-1)
for c in range(N_CLASSES):
m = labels == c
pcc[c] += (preds[m] == labels[m]).sum().item()
pct[c] += m.sum().item()
pca = pcc / (pct + 1e-8)
sorted_idx = pca.argsort(descending=True)
print(f" Top 5: ", end='')
for i in range(5):
c = sorted_idx[i].item()
print(f"c{c}={pca[c]:.0%} ", end='')
print(f"\n Bot 5: ", end='')
for i in range(5):
c = sorted_idx[-(i+1)].item()
print(f"c{c}={pca[c]:.0%} ", end='')
print(f"\n Mean: {pca.mean():.1%} Std: {pca.std():.1%}")
# Periodic profiling
if epoch == 1 or epoch % 50 == 0:
sb = next(iter(train_loader))
fwd_p = profile_forward(model, sb[0], device)
step_p = profile_full_step(model, sb[0], sb[1], criterion, opt, cross_attn_params, device)
print_profile(fwd_p, f"FORWARD ep{epoch}")
print_profile(step_p, f"STEP ep{epoch}")
with torch.no_grad():
toks = model.pos_enc(model.patch_embed(sb[0].to(device)))
cell_p = profile_cell_internals(model.cells[-1], toks, device)
print_profile(cell_p, f"CELL INTERNALS ep{epoch}")
elapsed = time.time() - t0
print(f"\n{'═' * 70}")
print(f" COMPLETE — SpectralViT")
print(f"{'═' * 70}")
print(f" Best val acc: {best_acc:.1%}")
print(f" Final EMA CV: {ema_cv:.4f} (target: {target_cv:.4f})")
print(f" Time: {elapsed:.0f}s ({elapsed/EPOCHS:.1f}s/epoch)")
n_params = sum(p.numel() for p in model.parameters())
print(f" Total params: {n_params:,}")
print(f" Architecture: PatchEmbed(4×4) → CayleyPE → 6× SpectralCell → pool → classify")
print(f" No conv. No external attention. Pure geometric formatting.")