| """ |
| SpectralViT CIFAR-100 Trainer |
| ============================== |
| Pure SpectralCell transformer. No conv backbone. |
| Cayley hypersphere positional encoding. |
| |
| Same proven training recipe: |
| - Soft hand: σ=0.15 boost=1.0 cv_penalty=0.01 |
| - CV from cm_vol2 in compute graph + EMA drift tracking |
| - CutMix α=1.0 prob=0.5 |
| - Grad clip 0.5 cross-attn only |
| - Adam (no weight decay) |
| |
| SpectralCell, SpectralViT, and cv_of are in namespace from prior cells. |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import time |
| from tqdm import tqdm |
| import torchvision |
| import torchvision.transforms as T |
| import numpy as np |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
| |
|
|
| def cutmix_data(x, y, alpha=1.0): |
| B = x.shape[0] |
| lam = np.random.beta(alpha, alpha) |
| idx = torch.randperm(B, device=x.device) |
| _, _, H, W = x.shape |
| cut_rat = np.sqrt(1.0 - lam) |
| cut_h, cut_w = int(H * cut_rat), int(W * cut_rat) |
| cy, cx = np.random.randint(H), np.random.randint(W) |
| y1 = np.clip(cy - cut_h // 2, 0, H) |
| y2 = np.clip(cy + cut_h // 2, 0, H) |
| x1 = np.clip(cx - cut_w // 2, 0, W) |
| x2 = np.clip(cx + cut_w // 2, 0, W) |
| x_mixed = x.clone() |
| x_mixed[:, :, y1:y2, x1:x2] = x[idx, :, y1:y2, x1:x2] |
| lam = 1.0 - ((y2 - y1) * (x2 - x1) / (H * W)) |
| return x_mixed, y[idx], lam |
|
|
| def cutmix_criterion(criterion, logits, y_a, y_b, lam): |
| return lam * criterion(logits, y_a) + (1.0 - lam) * criterion(logits, y_b) |
|
|
|
|
| |
|
|
| @torch.no_grad() |
| def compute_target_cv(V, D, n_trials=20, n_samples=200, device='cuda'): |
| cvs = [] |
| for _ in range(n_trials): |
| pts = F.normalize(torch.randn(V, D, device=device), dim=-1) |
| cv = cv_of(pts, n_samples=n_samples) |
| if cv > 0: |
| cvs.append(cv) |
| return sum(cvs) / len(cvs) if cvs else 0.0 |
|
|
| def batch_cv_from_cm(cm_vol2): |
| valid = cm_vol2.abs() > 1e-16 |
| if valid.sum() < 10: |
| return torch.tensor(0.0, device=cm_vol2.device), False |
| vols = cm_vol2[valid].abs().sqrt() |
| cv = vols.std() / (vols.mean() + 1e-8) |
| return cv, True |
|
|
|
|
| |
|
|
| N_CLASSES = 100 |
| mean = (0.5, 0.5, 0.5) |
| std = (0.2, 0.2, 0.2) |
|
|
| train_transform = T.Compose([ |
| T.RandomCrop(32, padding=4), |
| T.RandomHorizontalFlip(), |
| T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), |
| T.RandomRotation(15), |
| T.ToTensor(), |
| T.Normalize(mean, std), |
| T.RandomErasing(p=0.25, scale=(0.02, 0.2)), |
| ]) |
| test_transform = T.Compose([T.ToTensor(), T.Normalize(mean, std)]) |
|
|
| cifar_train = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform) |
| cifar_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform) |
| train_loader = torch.utils.data.DataLoader( |
| cifar_train, batch_size=256, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) |
| test_loader = torch.utils.data.DataLoader( |
| cifar_test, batch_size=256, shuffle=False, num_workers=4, pin_memory=True) |
|
|
|
|
| |
|
|
| EPOCHS = 200 |
| LR = 1e-3 |
| LABEL_SMOOTH = 0.1 |
|
|
| SIGMA = 0.15 |
| BOOST = 1.0 |
| CV_PENALTY_W = 0.01 |
| EMA_MOMENTUM = 0.99 |
| CUTMIX_ALPHA = 1.0 |
| CUTMIX_PROB = 0.5 |
|
|
|
|
| |
|
|
| model = SpectralViT( |
| img_size=32, |
| patch_size=4, |
| in_channels=3, |
| embed_dim=256, |
| depth=6, |
| cell_V=16, |
| cell_D=16, |
| cell_hidden=256, |
| cell_depth=2, |
| n_cross=2, |
| n_heads=4, |
| n_classes=N_CLASSES, |
| dropout=0.1, |
| ).to(device) |
|
|
| cross_attn_params = model.get_cross_attn_params() |
|
|
| print(f"Computing target CV for V=16, D=16 on S^15...") |
| target_cv = compute_target_cv(16, 16, n_trials=20, device=device) |
|
|
| print(f"\n{'═' * 70}") |
| print(f" SpectralViT — Pure SpectralCell Transformer") |
| print(f"{'═' * 70}") |
| model.summary() |
| print(f" Soft hand: target_cv={target_cv:.4f} σ={SIGMA} boost={BOOST}") |
| print(f" CV penalty: {CV_PENALTY_W} (differentiable through cm_vol2)") |
| print(f" EMA momentum: {EMA_MOMENTUM}") |
| print(f" Grad clip: 0.5 cross-attn only, uncapped otherwise") |
| print(f" CutMix: α={CUTMIX_ALPHA} prob={CUTMIX_PROB}") |
| print(f" Optimizer: Adam lr={LR}") |
| print(f" CIFAR-100, {EPOCHS} epochs") |
|
|
| criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH) |
| opt = torch.optim.Adam(model.parameters(), lr=LR) |
|
|
| def cosine_lr(epoch): |
| t = epoch / float(max(1, EPOCHS - 1)) |
| min_ratio = 1e-5 / LR |
| return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * t)) |
| sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=cosine_lr) |
|
|
|
|
| |
|
|
| @torch.no_grad() |
| def profile_cell_internals(cell, tokens, device): |
| """Profile each stage inside a single SpectralCell.format() call.""" |
| B, N, _ = tokens.shape |
| timings = {} |
|
|
| |
| flat = tokens.reshape(B * N, -1) |
| torch.cuda.synchronize(); t = time.perf_counter() |
| h = F.gelu(cell.enc_in(flat)) |
| for block in cell.enc_blocks: |
| h = h + block(h) |
| M = cell.enc_out(h).reshape(B * N, cell.V, cell.D) |
| torch.cuda.synchronize() |
| timings['enc_mlp'] = (time.perf_counter() - t) * 1000 |
|
|
| |
| torch.cuda.synchronize(); t = time.perf_counter() |
| row_mag = M.norm(dim=-1) |
| M = F.normalize(M, dim=-1) |
| torch.cuda.synchronize() |
| timings['normalize'] = (time.perf_counter() - t) * 1000 |
|
|
| |
| if cell.cm_enabled and cell.cm is not None: |
| nv = cell._cm_k + 1 |
| cm_idx = torch.linspace(0, cell.V - 1, nv).long().to(device) |
| torch.cuda.synchronize(); t = time.perf_counter() |
| cm_verts = M[:, cm_idx, :] |
| cm_d2, cm_vol2 = cell.cm(cm_verts) |
| torch.cuda.synchronize() |
| timings['cm_validation'] = (time.perf_counter() - t) * 1000 |
| else: |
| timings['cm_validation'] = 0.0 |
|
|
| |
| torch.cuda.synchronize(); t = time.perf_counter() |
| gram = torch.bmm(M, M.transpose(1, 2)) |
| d2_full = 2.0 - 2.0 * gram |
| d2_pairs = d2_full[:, cell._triu_i, cell._triu_j] |
| torch.cuda.synchronize() |
| timings['pairwise_d2'] = (time.perf_counter() - t) * 1000 |
|
|
| torch.cuda.synchronize(); t = time.perf_counter() |
| pw_features = cell.patchwork(d2_pairs) |
| torch.cuda.synchronize() |
| timings['patchwork'] = (time.perf_counter() - t) * 1000 |
|
|
| |
| torch.cuda.synchronize(); t = time.perf_counter() |
| U, S, Vt = batched_svd(M) |
| torch.cuda.synchronize() |
| timings['svd_eigh'] = (time.perf_counter() - t) * 1000 |
|
|
| |
| S = S.reshape(B, N, cell.D) |
| torch.cuda.synchronize(); t = time.perf_counter() |
| for layer in cell.cross_attn: |
| S = layer(S) |
| torch.cuda.synchronize() |
| timings['cross_attn'] = (time.perf_counter() - t) * 1000 |
|
|
| |
| U_flat = U.reshape(B * N, cell.V, cell.D) |
| S_flat = S.reshape(B * N, cell.D) |
| Vt_flat = Vt.reshape(B * N, cell.D, cell.D) |
| torch.cuda.synchronize(); t = time.perf_counter() |
| M_hat = torch.bmm(U_flat * S_flat.unsqueeze(1), Vt_flat) |
| torch.cuda.synchronize() |
| timings['recompose'] = (time.perf_counter() - t) * 1000 |
|
|
| |
| out_features = torch.cat([M_hat.reshape(B * N, -1), pw_features, row_mag], dim=-1) |
| torch.cuda.synchronize(); t = time.perf_counter() |
| h = F.gelu(cell.out_in(out_features)) |
| for block in cell.out_blocks: |
| h = h + block(h) |
| output = cell.out_proj(h) |
| torch.cuda.synchronize() |
| timings['dec_mlp'] = (time.perf_counter() - t) * 1000 |
|
|
| return timings |
|
|
|
|
| @torch.no_grad() |
| def profile_forward(model, images, device): |
| """Profile each component of SpectralViT forward pass. |
| Returns dict of component → time_ms. |
| """ |
| model.eval() |
| images = images.to(device) |
| torch.cuda.synchronize() |
| timings = {} |
|
|
| |
| torch.cuda.synchronize() |
| t = time.perf_counter() |
| tokens = model.patch_embed(images) |
| torch.cuda.synchronize() |
| timings['patch_embed'] = (time.perf_counter() - t) * 1000 |
|
|
| |
| torch.cuda.synchronize() |
| t = time.perf_counter() |
| tokens = model.pos_enc(tokens) |
| torch.cuda.synchronize() |
| timings['cayley_pe'] = (time.perf_counter() - t) * 1000 |
|
|
| |
| for i in range(model.depth - 1): |
| torch.cuda.synchronize() |
| t = time.perf_counter() |
| normed = model.norms[i](tokens) |
| tokens = tokens + model.cells[i](normed) |
| torch.cuda.synchronize() |
| timings[f'cell_{i}'] = (time.perf_counter() - t) * 1000 |
|
|
| |
| torch.cuda.synchronize() |
| t = time.perf_counter() |
| normed = model.norms[-1](tokens) |
| last_out = model.cells[-1].format(normed) |
| tokens = tokens + last_out['output'] |
| torch.cuda.synchronize() |
| timings[f'cell_{model.depth-1}_format'] = (time.perf_counter() - t) * 1000 |
|
|
| |
| torch.cuda.synchronize() |
| t = time.perf_counter() |
| tokens = model.final_norm(tokens) |
| pooled = tokens.mean(dim=1) |
| logits = model.classifier(pooled) |
| torch.cuda.synchronize() |
| timings['classify'] = (time.perf_counter() - t) * 1000 |
|
|
| return timings |
|
|
|
|
| def profile_full_step(model, images, labels, criterion, opt, cross_attn_params, device): |
| """Profile full train step: forward + loss + backward + optimizer.""" |
| model.train() |
| images, labels = images.to(device), labels.to(device) |
| timings = {} |
|
|
| |
| torch.cuda.synchronize() |
| t = time.perf_counter() |
| out = model(images) |
| torch.cuda.synchronize() |
| timings['forward'] = (time.perf_counter() - t) * 1000 |
|
|
| |
| torch.cuda.synchronize() |
| t = time.perf_counter() |
| ce_loss = criterion(out['logits'], labels) |
| torch.cuda.synchronize() |
| timings['loss'] = (time.perf_counter() - t) * 1000 |
|
|
| |
| torch.cuda.synchronize() |
| t = time.perf_counter() |
| opt.zero_grad(set_to_none=True) |
| ce_loss.backward() |
| torch.cuda.synchronize() |
| timings['backward'] = (time.perf_counter() - t) * 1000 |
|
|
| |
| torch.cuda.synchronize() |
| t = time.perf_counter() |
| nn.utils.clip_grad_norm_(cross_attn_params, max_norm=0.5) |
| torch.cuda.synchronize() |
| timings['grad_clip'] = (time.perf_counter() - t) * 1000 |
|
|
| |
| torch.cuda.synchronize() |
| t = time.perf_counter() |
| opt.step() |
| torch.cuda.synchronize() |
| timings['optim_step'] = (time.perf_counter() - t) * 1000 |
|
|
| return timings |
|
|
|
|
| def print_profile(timings, label=""): |
| total = sum(timings.values()) |
| print(f"\n ┌─ PROFILE {label} {'─' * max(0, 48 - len(label))}┐") |
| for name, ms in sorted(timings.items(), key=lambda x: -x[1]): |
| pct = ms / total * 100 |
| bar = '█' * int(pct / 3) |
| print(f" │ {name:<22s} {ms:7.1f}ms {pct:5.1f}% {bar}") |
| print(f" │ {'TOTAL':<22s} {total:7.1f}ms") |
| print(f" └{'─' * 55}┘") |
|
|
|
|
| |
|
|
| |
| print("\n Initial profiling (3 warmup + 1 measured)...") |
| sample_batch = next(iter(train_loader)) |
| for _ in range(3): |
| _ = model(sample_batch[0].to(device)) |
| torch.cuda.synchronize() |
|
|
| fwd_profile = profile_forward(model, sample_batch[0], device) |
| print_profile(fwd_profile, "FORWARD COMPONENTS") |
|
|
| step_profile = profile_full_step(model, sample_batch[0], sample_batch[1], |
| criterion, opt, cross_attn_params, device) |
| print_profile(step_profile, "FULL TRAIN STEP") |
|
|
| |
| with torch.no_grad(): |
| tokens = model.pos_enc(model.patch_embed(sample_batch[0].to(device))) |
| cell_profile = profile_cell_internals(model.cells[0], tokens, device) |
| print_profile(cell_profile, "CELL INTERNALS (cell_0)") |
|
|
| best_acc = 0 |
| t0 = time.time() |
| ema_cv = target_cv |
|
|
| for epoch in range(1, EPOCHS + 1): |
| model.train() |
| correct, total = 0, 0 |
| ep_boost_sum, ep_n = 0, 0 |
| ep_cv_penalty_sum = 0 |
|
|
| for batch_idx, (images, labels) in enumerate(tqdm(train_loader, desc=f"Ep {epoch:3d}", leave=False)): |
| images, labels = images.to(device), labels.to(device) |
|
|
| use_cutmix = CUTMIX_ALPHA > 0 and np.random.random() < CUTMIX_PROB |
| if use_cutmix: |
| images, labels_b, lam = cutmix_data(images, labels, CUTMIX_ALPHA) |
| else: |
| labels_b, lam = labels, 1.0 |
|
|
| out = model(images) |
|
|
| if use_cutmix: |
| ce_loss = cutmix_criterion(criterion, out['logits'], labels, labels_b, lam) |
| else: |
| ce_loss = criterion(out['logits'], labels) |
|
|
| |
| last_cell = out['last_cell'] |
| batch_cv, cv_valid = batch_cv_from_cm(last_cell['cm_vol2']) |
|
|
| if cv_valid: |
| with torch.no_grad(): |
| ema_cv = EMA_MOMENTUM * ema_cv + (1.0 - EMA_MOMENTUM) * batch_cv.item() |
|
|
| prox = math.exp(-(ema_cv - target_cv) ** 2 / (2 * SIGMA ** 2)) |
| primary_w = 1.0 + BOOST * prox |
| cv_w = CV_PENALTY_W * (1.0 - prox) |
| diff_cv_loss = (batch_cv - target_cv).pow(2) if cv_valid else torch.tensor(0.0, device=device) |
|
|
| loss = primary_w * ce_loss + cv_w * diff_cv_loss |
|
|
| opt.zero_grad(set_to_none=True) |
| loss.backward() |
| nn.utils.clip_grad_norm_(cross_attn_params, max_norm=0.5) |
| opt.step() |
|
|
| correct += (out['logits'].argmax(-1) == labels).sum().item() |
| total += images.shape[0] |
| ep_boost_sum += primary_w |
| ep_cv_penalty_sum += cv_w * diff_cv_loss.item() if cv_valid else 0 |
| ep_n += 1 |
|
|
| sched.step() |
| train_acc = correct / total |
| avg_boost = ep_boost_sum / ep_n |
|
|
| |
| model.eval() |
| val_correct, val_total = 0, 0 |
| with torch.no_grad(): |
| for images, labels in test_loader: |
| images, labels = images.to(device), labels.to(device) |
| out = model(images) |
| val_correct += (out['logits'].argmax(-1) == labels).sum().item() |
| val_total += images.shape[0] |
|
|
| val_acc = val_correct / val_total |
| star = ' ★' if val_acc > best_acc else '' |
| if val_acc > best_acc: |
| best_acc = val_acc |
|
|
| lr = opt.param_groups[0]['lr'] |
|
|
| if epoch <= 5 or epoch % 5 == 0 or epoch == EPOCHS: |
| with torch.no_grad(): |
| last_cell = out['last_cell'] |
| S = last_cell['S_orig'].mean(dim=(0, 1)) |
| s_str = ', '.join(f'{v:.2f}' for v in S.tolist()[:4]) + f'...{S[-1]:.2f}' |
|
|
| |
| angles = model.pos_enc.angles |
| angle_mean = angles.abs().mean().item() |
| angle_max = angles.abs().max().item() |
|
|
| print(f" ep{epoch:3d} acc={val_acc:.1%}{star} train={train_acc:.1%} " |
| f"ema_cv={ema_cv:.4f} boost={avg_boost:.3f} lr={lr:.6f}") |
| print(f" S=[{s_str}] PE angles: mean={angle_mean:.4f} max={angle_max:.4f}") |
|
|
| if epoch <= 3 or epoch % 20 == 0 or epoch == EPOCHS: |
| pcc = torch.zeros(N_CLASSES) |
| pct = torch.zeros(N_CLASSES) |
| with torch.no_grad(): |
| for images, labels in test_loader: |
| images, labels = images.to(device), labels.to(device) |
| out = model(images) |
| preds = out['logits'].argmax(-1) |
| for c in range(N_CLASSES): |
| m = labels == c |
| pcc[c] += (preds[m] == labels[m]).sum().item() |
| pct[c] += m.sum().item() |
| pca = pcc / (pct + 1e-8) |
| sorted_idx = pca.argsort(descending=True) |
| print(f" Top 5: ", end='') |
| for i in range(5): |
| c = sorted_idx[i].item() |
| print(f"c{c}={pca[c]:.0%} ", end='') |
| print(f"\n Bot 5: ", end='') |
| for i in range(5): |
| c = sorted_idx[-(i+1)].item() |
| print(f"c{c}={pca[c]:.0%} ", end='') |
| print(f"\n Mean: {pca.mean():.1%} Std: {pca.std():.1%}") |
|
|
| |
| if epoch == 1 or epoch % 50 == 0: |
| sb = next(iter(train_loader)) |
| fwd_p = profile_forward(model, sb[0], device) |
| step_p = profile_full_step(model, sb[0], sb[1], criterion, opt, cross_attn_params, device) |
| print_profile(fwd_p, f"FORWARD ep{epoch}") |
| print_profile(step_p, f"STEP ep{epoch}") |
| with torch.no_grad(): |
| toks = model.pos_enc(model.patch_embed(sb[0].to(device))) |
| cell_p = profile_cell_internals(model.cells[-1], toks, device) |
| print_profile(cell_p, f"CELL INTERNALS ep{epoch}") |
|
|
| elapsed = time.time() - t0 |
|
|
| print(f"\n{'═' * 70}") |
| print(f" COMPLETE — SpectralViT") |
| print(f"{'═' * 70}") |
| print(f" Best val acc: {best_acc:.1%}") |
| print(f" Final EMA CV: {ema_cv:.4f} (target: {target_cv:.4f})") |
| print(f" Time: {elapsed:.0f}s ({elapsed/EPOCHS:.1f}s/epoch)") |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f" Total params: {n_params:,}") |
| print(f" Architecture: PatchEmbed(4×4) → CayleyPE → 6× SpectralCell → pool → classify") |
| print(f" No conv. No external attention. Pure geometric formatting.") |