""" SpectralViT CIFAR-100 Trainer ============================== Pure SpectralCell transformer. No conv backbone. Cayley hypersphere positional encoding. Same proven training recipe: - Soft hand: σ=0.15 boost=1.0 cv_penalty=0.01 - CV from cm_vol2 in compute graph + EMA drift tracking - CutMix α=1.0 prob=0.5 - Grad clip 0.5 cross-attn only - Adam (no weight decay) SpectralCell, SpectralViT, and cv_of are in namespace from prior cells. """ import math import torch import torch.nn as nn import torch.nn.functional as F import time from tqdm import tqdm import torchvision import torchvision.transforms as T import numpy as np device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # ── CutMix ─────────────────────────────────────────────────────── def cutmix_data(x, y, alpha=1.0): B = x.shape[0] lam = np.random.beta(alpha, alpha) idx = torch.randperm(B, device=x.device) _, _, H, W = x.shape cut_rat = np.sqrt(1.0 - lam) cut_h, cut_w = int(H * cut_rat), int(W * cut_rat) cy, cx = np.random.randint(H), np.random.randint(W) y1 = np.clip(cy - cut_h // 2, 0, H) y2 = np.clip(cy + cut_h // 2, 0, H) x1 = np.clip(cx - cut_w // 2, 0, W) x2 = np.clip(cx + cut_w // 2, 0, W) x_mixed = x.clone() x_mixed[:, :, y1:y2, x1:x2] = x[idx, :, y1:y2, x1:x2] lam = 1.0 - ((y2 - y1) * (x2 - x1) / (H * W)) return x_mixed, y[idx], lam def cutmix_criterion(criterion, logits, y_a, y_b, lam): return lam * criterion(logits, y_a) + (1.0 - lam) * criterion(logits, y_b) # ── CV from cm_vol2 ───────────────────────────────────────────── @torch.no_grad() def compute_target_cv(V, D, n_trials=20, n_samples=200, device='cuda'): cvs = [] for _ in range(n_trials): pts = F.normalize(torch.randn(V, D, device=device), dim=-1) cv = cv_of(pts, n_samples=n_samples) if cv > 0: cvs.append(cv) return sum(cvs) / len(cvs) if cvs else 0.0 def batch_cv_from_cm(cm_vol2): valid = cm_vol2.abs() > 1e-16 if valid.sum() < 10: return torch.tensor(0.0, device=cm_vol2.device), False vols = cm_vol2[valid].abs().sqrt() cv = vols.std() / (vols.mean() + 1e-8) return cv, True # ── Data ───────────────────────────────────────────────────────── N_CLASSES = 100 mean = (0.5, 0.5, 0.5) std = (0.2, 0.2, 0.2) train_transform = T.Compose([ T.RandomCrop(32, padding=4), T.RandomHorizontalFlip(), T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), T.RandomRotation(15), T.ToTensor(), T.Normalize(mean, std), T.RandomErasing(p=0.25, scale=(0.02, 0.2)), ]) test_transform = T.Compose([T.ToTensor(), T.Normalize(mean, std)]) cifar_train = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform) cifar_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform) train_loader = torch.utils.data.DataLoader( cifar_train, batch_size=256, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) test_loader = torch.utils.data.DataLoader( cifar_test, batch_size=256, shuffle=False, num_workers=4, pin_memory=True) # ── Config ─────────────────────────────────────────────────────── EPOCHS = 200 LR = 1e-3 LABEL_SMOOTH = 0.1 SIGMA = 0.15 BOOST = 1.0 CV_PENALTY_W = 0.01 EMA_MOMENTUM = 0.99 CUTMIX_ALPHA = 1.0 CUTMIX_PROB = 0.5 # ── Build ──────────────────────────────────────────────────────── model = SpectralViT( img_size=32, patch_size=4, in_channels=3, embed_dim=256, depth=6, cell_V=16, cell_D=16, cell_hidden=256, cell_depth=2, n_cross=2, n_heads=4, n_classes=N_CLASSES, dropout=0.1, ).to(device) cross_attn_params = model.get_cross_attn_params() print(f"Computing target CV for V=16, D=16 on S^15...") target_cv = compute_target_cv(16, 16, n_trials=20, device=device) print(f"\n{'═' * 70}") print(f" SpectralViT — Pure SpectralCell Transformer") print(f"{'═' * 70}") model.summary() print(f" Soft hand: target_cv={target_cv:.4f} σ={SIGMA} boost={BOOST}") print(f" CV penalty: {CV_PENALTY_W} (differentiable through cm_vol2)") print(f" EMA momentum: {EMA_MOMENTUM}") print(f" Grad clip: 0.5 cross-attn only, uncapped otherwise") print(f" CutMix: α={CUTMIX_ALPHA} prob={CUTMIX_PROB}") print(f" Optimizer: Adam lr={LR}") print(f" CIFAR-100, {EPOCHS} epochs") criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH) opt = torch.optim.Adam(model.parameters(), lr=LR) def cosine_lr(epoch): t = epoch / float(max(1, EPOCHS - 1)) min_ratio = 1e-5 / LR return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * t)) sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=cosine_lr) # ── Component Profiler ──────────────────────────────────────────── @torch.no_grad() def profile_cell_internals(cell, tokens, device): """Profile each stage inside a single SpectralCell.format() call.""" B, N, _ = tokens.shape timings = {} # Encoder MLP flat = tokens.reshape(B * N, -1) torch.cuda.synchronize(); t = time.perf_counter() h = F.gelu(cell.enc_in(flat)) for block in cell.enc_blocks: h = h + block(h) M = cell.enc_out(h).reshape(B * N, cell.V, cell.D) torch.cuda.synchronize() timings['enc_mlp'] = (time.perf_counter() - t) * 1000 # Magnitude capture + normalize torch.cuda.synchronize(); t = time.perf_counter() row_mag = M.norm(dim=-1) M = F.normalize(M, dim=-1) torch.cuda.synchronize() timings['normalize'] = (time.perf_counter() - t) * 1000 # CM validation (if enabled) if cell.cm_enabled and cell.cm is not None: nv = cell._cm_k + 1 cm_idx = torch.linspace(0, cell.V - 1, nv).long().to(device) torch.cuda.synchronize(); t = time.perf_counter() cm_verts = M[:, cm_idx, :] cm_d2, cm_vol2 = cell.cm(cm_verts) torch.cuda.synchronize() timings['cm_validation'] = (time.perf_counter() - t) * 1000 else: timings['cm_validation'] = 0.0 # Full pairwise d² + patchwork torch.cuda.synchronize(); t = time.perf_counter() gram = torch.bmm(M, M.transpose(1, 2)) d2_full = 2.0 - 2.0 * gram d2_pairs = d2_full[:, cell._triu_i, cell._triu_j] torch.cuda.synchronize() timings['pairwise_d2'] = (time.perf_counter() - t) * 1000 torch.cuda.synchronize(); t = time.perf_counter() pw_features = cell.patchwork(d2_pairs) torch.cuda.synchronize() timings['patchwork'] = (time.perf_counter() - t) * 1000 # SVD torch.cuda.synchronize(); t = time.perf_counter() U, S, Vt = batched_svd(M) torch.cuda.synchronize() timings['svd_eigh'] = (time.perf_counter() - t) * 1000 # Cross-attention S = S.reshape(B, N, cell.D) torch.cuda.synchronize(); t = time.perf_counter() for layer in cell.cross_attn: S = layer(S) torch.cuda.synchronize() timings['cross_attn'] = (time.perf_counter() - t) * 1000 # Recompose U_flat = U.reshape(B * N, cell.V, cell.D) S_flat = S.reshape(B * N, cell.D) Vt_flat = Vt.reshape(B * N, cell.D, cell.D) torch.cuda.synchronize(); t = time.perf_counter() M_hat = torch.bmm(U_flat * S_flat.unsqueeze(1), Vt_flat) torch.cuda.synchronize() timings['recompose'] = (time.perf_counter() - t) * 1000 # Output MLP out_features = torch.cat([M_hat.reshape(B * N, -1), pw_features, row_mag], dim=-1) torch.cuda.synchronize(); t = time.perf_counter() h = F.gelu(cell.out_in(out_features)) for block in cell.out_blocks: h = h + block(h) output = cell.out_proj(h) torch.cuda.synchronize() timings['dec_mlp'] = (time.perf_counter() - t) * 1000 return timings @torch.no_grad() def profile_forward(model, images, device): """Profile each component of SpectralViT forward pass. Returns dict of component → time_ms. """ model.eval() images = images.to(device) torch.cuda.synchronize() timings = {} # Patch embed torch.cuda.synchronize() t = time.perf_counter() tokens = model.patch_embed(images) torch.cuda.synchronize() timings['patch_embed'] = (time.perf_counter() - t) * 1000 # Positional encoding torch.cuda.synchronize() t = time.perf_counter() tokens = model.pos_enc(tokens) torch.cuda.synchronize() timings['cayley_pe'] = (time.perf_counter() - t) * 1000 # Each cell for i in range(model.depth - 1): torch.cuda.synchronize() t = time.perf_counter() normed = model.norms[i](tokens) tokens = tokens + model.cells[i](normed) torch.cuda.synchronize() timings[f'cell_{i}'] = (time.perf_counter() - t) * 1000 # Last cell (full .format) torch.cuda.synchronize() t = time.perf_counter() normed = model.norms[-1](tokens) last_out = model.cells[-1].format(normed) tokens = tokens + last_out['output'] torch.cuda.synchronize() timings[f'cell_{model.depth-1}_format'] = (time.perf_counter() - t) * 1000 # Final norm + pool + classifier torch.cuda.synchronize() t = time.perf_counter() tokens = model.final_norm(tokens) pooled = tokens.mean(dim=1) logits = model.classifier(pooled) torch.cuda.synchronize() timings['classify'] = (time.perf_counter() - t) * 1000 return timings def profile_full_step(model, images, labels, criterion, opt, cross_attn_params, device): """Profile full train step: forward + loss + backward + optimizer.""" model.train() images, labels = images.to(device), labels.to(device) timings = {} # Forward torch.cuda.synchronize() t = time.perf_counter() out = model(images) torch.cuda.synchronize() timings['forward'] = (time.perf_counter() - t) * 1000 # Loss torch.cuda.synchronize() t = time.perf_counter() ce_loss = criterion(out['logits'], labels) torch.cuda.synchronize() timings['loss'] = (time.perf_counter() - t) * 1000 # Backward torch.cuda.synchronize() t = time.perf_counter() opt.zero_grad(set_to_none=True) ce_loss.backward() torch.cuda.synchronize() timings['backward'] = (time.perf_counter() - t) * 1000 # Grad clip torch.cuda.synchronize() t = time.perf_counter() nn.utils.clip_grad_norm_(cross_attn_params, max_norm=0.5) torch.cuda.synchronize() timings['grad_clip'] = (time.perf_counter() - t) * 1000 # Optimizer step torch.cuda.synchronize() t = time.perf_counter() opt.step() torch.cuda.synchronize() timings['optim_step'] = (time.perf_counter() - t) * 1000 return timings def print_profile(timings, label=""): total = sum(timings.values()) print(f"\n ┌─ PROFILE {label} {'─' * max(0, 48 - len(label))}┐") for name, ms in sorted(timings.items(), key=lambda x: -x[1]): pct = ms / total * 100 bar = '█' * int(pct / 3) print(f" │ {name:<22s} {ms:7.1f}ms {pct:5.1f}% {bar}") print(f" │ {'TOTAL':<22s} {total:7.1f}ms") print(f" └{'─' * 55}┘") # ── Training ───────────────────────────────────────────────────── # Run initial profile before training print("\n Initial profiling (3 warmup + 1 measured)...") sample_batch = next(iter(train_loader)) for _ in range(3): # warmup _ = model(sample_batch[0].to(device)) torch.cuda.synchronize() fwd_profile = profile_forward(model, sample_batch[0], device) print_profile(fwd_profile, "FORWARD COMPONENTS") step_profile = profile_full_step(model, sample_batch[0], sample_batch[1], criterion, opt, cross_attn_params, device) print_profile(step_profile, "FULL TRAIN STEP") # Cell internals for one cell with torch.no_grad(): tokens = model.pos_enc(model.patch_embed(sample_batch[0].to(device))) cell_profile = profile_cell_internals(model.cells[0], tokens, device) print_profile(cell_profile, "CELL INTERNALS (cell_0)") best_acc = 0 t0 = time.time() ema_cv = target_cv for epoch in range(1, EPOCHS + 1): model.train() correct, total = 0, 0 ep_boost_sum, ep_n = 0, 0 ep_cv_penalty_sum = 0 for batch_idx, (images, labels) in enumerate(tqdm(train_loader, desc=f"Ep {epoch:3d}", leave=False)): images, labels = images.to(device), labels.to(device) use_cutmix = CUTMIX_ALPHA > 0 and np.random.random() < CUTMIX_PROB if use_cutmix: images, labels_b, lam = cutmix_data(images, labels, CUTMIX_ALPHA) else: labels_b, lam = labels, 1.0 out = model(images) if use_cutmix: ce_loss = cutmix_criterion(criterion, out['logits'], labels, labels_b, lam) else: ce_loss = criterion(out['logits'], labels) # CV from deepest cell's cm_vol2 last_cell = out['last_cell'] batch_cv, cv_valid = batch_cv_from_cm(last_cell['cm_vol2']) if cv_valid: with torch.no_grad(): ema_cv = EMA_MOMENTUM * ema_cv + (1.0 - EMA_MOMENTUM) * batch_cv.item() prox = math.exp(-(ema_cv - target_cv) ** 2 / (2 * SIGMA ** 2)) primary_w = 1.0 + BOOST * prox cv_w = CV_PENALTY_W * (1.0 - prox) diff_cv_loss = (batch_cv - target_cv).pow(2) if cv_valid else torch.tensor(0.0, device=device) loss = primary_w * ce_loss + cv_w * diff_cv_loss opt.zero_grad(set_to_none=True) loss.backward() nn.utils.clip_grad_norm_(cross_attn_params, max_norm=0.5) opt.step() correct += (out['logits'].argmax(-1) == labels).sum().item() total += images.shape[0] ep_boost_sum += primary_w ep_cv_penalty_sum += cv_w * diff_cv_loss.item() if cv_valid else 0 ep_n += 1 sched.step() train_acc = correct / total avg_boost = ep_boost_sum / ep_n # Val model.eval() val_correct, val_total = 0, 0 with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(device), labels.to(device) out = model(images) val_correct += (out['logits'].argmax(-1) == labels).sum().item() val_total += images.shape[0] val_acc = val_correct / val_total star = ' ★' if val_acc > best_acc else '' if val_acc > best_acc: best_acc = val_acc lr = opt.param_groups[0]['lr'] if epoch <= 5 or epoch % 5 == 0 or epoch == EPOCHS: with torch.no_grad(): last_cell = out['last_cell'] S = last_cell['S_orig'].mean(dim=(0, 1)) s_str = ', '.join(f'{v:.2f}' for v in S.tolist()[:4]) + f'...{S[-1]:.2f}' # Rotation angle magnitudes from PE angles = model.pos_enc.angles angle_mean = angles.abs().mean().item() angle_max = angles.abs().max().item() print(f" ep{epoch:3d} acc={val_acc:.1%}{star} train={train_acc:.1%} " f"ema_cv={ema_cv:.4f} boost={avg_boost:.3f} lr={lr:.6f}") print(f" S=[{s_str}] PE angles: mean={angle_mean:.4f} max={angle_max:.4f}") if epoch <= 3 or epoch % 20 == 0 or epoch == EPOCHS: pcc = torch.zeros(N_CLASSES) pct = torch.zeros(N_CLASSES) with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(device), labels.to(device) out = model(images) preds = out['logits'].argmax(-1) for c in range(N_CLASSES): m = labels == c pcc[c] += (preds[m] == labels[m]).sum().item() pct[c] += m.sum().item() pca = pcc / (pct + 1e-8) sorted_idx = pca.argsort(descending=True) print(f" Top 5: ", end='') for i in range(5): c = sorted_idx[i].item() print(f"c{c}={pca[c]:.0%} ", end='') print(f"\n Bot 5: ", end='') for i in range(5): c = sorted_idx[-(i+1)].item() print(f"c{c}={pca[c]:.0%} ", end='') print(f"\n Mean: {pca.mean():.1%} Std: {pca.std():.1%}") # Periodic profiling if epoch == 1 or epoch % 50 == 0: sb = next(iter(train_loader)) fwd_p = profile_forward(model, sb[0], device) step_p = profile_full_step(model, sb[0], sb[1], criterion, opt, cross_attn_params, device) print_profile(fwd_p, f"FORWARD ep{epoch}") print_profile(step_p, f"STEP ep{epoch}") with torch.no_grad(): toks = model.pos_enc(model.patch_embed(sb[0].to(device))) cell_p = profile_cell_internals(model.cells[-1], toks, device) print_profile(cell_p, f"CELL INTERNALS ep{epoch}") elapsed = time.time() - t0 print(f"\n{'═' * 70}") print(f" COMPLETE — SpectralViT") print(f"{'═' * 70}") print(f" Best val acc: {best_acc:.1%}") print(f" Final EMA CV: {ema_cv:.4f} (target: {target_cv:.4f})") print(f" Time: {elapsed:.0f}s ({elapsed/EPOCHS:.1f}s/epoch)") n_params = sum(p.numel() for p in model.parameters()) print(f" Total params: {n_params:,}") print(f" Architecture: PatchEmbed(4×4) → CayleyPE → 6× SpectralCell → pool → classify") print(f" No conv. No external attention. Pure geometric formatting.")