svd-triton / svd_deit_small_cifar100_train.py
AbstractPhil's picture
Update svd_deit_small_cifar100_train.py
965c9c6 verified
# @title Experiment 8.23 β€” DeiT-Small + SVD Taps
#
# Proper transformer: 12 layers, 384-d, 6 heads, patch 4Γ—4
# SVD taps at layers 3, 6, 9, 12 (every 3 layers)
# Same detached SVD observation as 8.21 β€” proven to work
#
# DeiT-Small on CIFAR-100 typically hits ~75-78% from scratch.
# Question: does SVD structural observation push it further?
"""
Standard DeiT without SVD features at 100 epochs defeated.
Overtaken by epoch 52 of the SVD features variant.
The amount of time difference isn't insignificant, but this will be improved considerably
as the process improves.
[DATA] CIFAR-100: 50000 train, 10000 val, bs=256
======================================================================
PHASE 1: DeiT-Small BASELINE (no SVD)
======================================================================
[MODEL] DeiT-Small Baseline: 21,526,372 params
======================================================================
[EXP] deit_small_baseline | 21,526,372 params | 100 epochs
======================================================================
E 1 | Tr 1.0% Va 1.5% | L=4.636 gap=-0.5 | Best 1.5%@E1 | 5.8s
E 2 | Tr 1.8% Va 3.1% | L=4.526 gap=-1.3 | Best 3.1%@E2 | 5.4s
E 3 | Tr 2.9% Va 3.9% | L=4.389 gap=-1.0 | Best 3.9%@E3 | 5.4s
E 4 | Tr 3.6% Va 4.9% | L=4.335 gap=-1.3 | Best 4.9%@E4 | 5.4s
E 5 | Tr 4.1% Va 5.3% | L=4.282 gap=-1.2 | Best 5.3%@E5 | 5.4s
E 10 | Tr 10.5% Va 13.7% | L=3.861 gap=-3.2 | Best 13.7%@E10 | 5.4s
E 15 | Tr 15.2% Va 20.5% | L=3.585 gap=-5.2 | Best 20.5%@E15 | 5.4s
E 20 | Tr 18.8% Va 25.2% | L=3.375 gap=-6.4 | Best 25.2%@E20 | 5.4s
E 25 | Tr 21.9% Va 26.9% | L=3.218 gap=-5.1 | Best 26.9%@E25 | 5.4s
E 30 | Tr 24.4% Va 30.1% | L=3.075 gap=-5.7 | Best 30.1%@E30 | 5.4s
E 35 | Tr 27.2% Va 32.8% | L=2.941 gap=-5.6 | Best 32.8%@E35 | 5.4s
E 40 | Tr 29.5% Va 34.8% | L=2.808 gap=-5.2 | Best 34.8%@E40 | 5.4s
E 45 | Tr 31.9% Va 36.1% | L=2.685 gap=-4.2 | Best 36.1%@E45 | 5.4s
E 50 | Tr 35.1% Va 37.5% | L=2.547 gap=-2.4 | Best 37.5%@E50 | 5.4s
E 55 | Tr 37.6% Va 38.8% | L=2.428 gap=-1.2 | Best 38.8%@E55 | 5.4s
E 60 | Tr 40.1% Va 40.0% | L=2.313 gap=+0.1 | Best 40.0%@E60 | 5.4s
E 65 | Tr 42.5% Va 41.1% | L=2.202 gap=+1.4 | Best 41.1%@E65 | 5.4s
E 70 | Tr 44.8% Va 41.2% | L=2.104 gap=+3.6 | Best 41.2%@E70 | 5.4s
E 75 | Tr 46.9% Va 41.5% | L=2.023 gap=+5.3 | Best 41.6%@E72 | 5.4s
E 80 | Tr 48.5% Va 41.7% | L=1.955 gap=+6.8 | Best 41.9%@E79 | 5.4s
E 85 | Tr 49.6% Va 42.3% | L=1.914 gap=+7.2 | Best 42.5%@E83 | 5.4s
E 90 | Tr 50.4% Va 42.3% | L=1.877 gap=+8.1 | Best 42.7%@E87 | 5.4s
E 95 | Tr 50.6% Va 42.5% | L=1.860 gap=+8.1 | Best 42.7%@E87 | 5.4s
E100 | Tr 50.7% Va 42.5% | L=1.855 gap=+8.2 | Best 42.7%@E87 | 5.4s
[RESULT] deit_small_baseline: Best Val = 42.74% @E87 | Params: 21,526,372
======================================================================
PHASE 2: DeiT-Small + SVD TAPS
======================================================================
[MODEL] DeiT-Small + SVD: 21,676,900 params
SVD taps at layers: (3, 6, 9, 12)
SVD features: 264 = 4Γ—66
Classifier input: 384 + 264 = 648
======================================================================
[EXP] deit_small_svd | 21,676,900 params | 100 epochs
======================================================================
E 1 | Tr 2.3% Va 3.7% | L=4.471 gap=-1.3 | Best 3.7%@E1 | 21.4s
E 2 | Tr 4.4% Va 7.9% | L=4.261 gap=-3.5 | Best 7.9%@E2 | 22.7s
E 3 | Tr 6.4% Va 9.0% | L=4.108 gap=-2.6 | Best 9.0%@E3 | 21.1s
E 4 | Tr 8.4% Va 12.0% | L=3.972 gap=-3.6 | Best 12.0%@E4 | 21.1s
E 5 | Tr 10.1% Va 14.6% | L=3.878 gap=-4.6 | Best 14.6%@E5 | 21.1s
E 10 | Tr 15.9% Va 20.7% | L=3.533 gap=-4.8 | Best 20.7%@E10 | 21.5s
E 15 | Tr 19.7% Va 24.7% | L=3.321 gap=-5.0 | Best 24.7%@E15 | 21.4s
E 20 | Tr 22.8% Va 28.5% | L=3.170 gap=-5.7 | Best 28.5%@E20 | 22.1s
E 25 | Tr 25.3% Va 30.5% | L=3.023 gap=-5.2 | Best 30.5%@E25 | 21.8s
E 30 | Tr 28.0% Va 33.2% | L=2.882 gap=-5.2 | Best 33.2%@E30 | 22.0s
E 35 | Tr 31.0% Va 35.2% | L=2.741 gap=-4.2 | Best 35.2%@E35 | 22.0s
E 40 | Tr 34.0% Va 37.6% | L=2.588 gap=-3.6 | Best 37.6%@E40 | 21.7s
E 45 | Tr 37.5% Va 39.1% | L=2.432 gap=-1.6 | Best 39.4%@E44 | 22.0s
E 50 | Tr 40.1% Va 41.0% | L=2.301 gap=-0.9 | Best 41.2%@E49 | 21.9s
E 55 | Tr 43.4% Va 43.5% | L=2.138 gap=-0.1 | Best 43.5%@E55 | 22.2s
E 60 | Tr 46.9% Va 44.1% | L=2.007 gap=+2.8 | Best 44.1%@E60 | 22.2s
E 65 | Tr 50.0% Va 44.7% | L=1.869 gap=+5.3 | Best 45.0%@E64 | 22.2s
E 70 | Tr 52.5% Va 45.8% | L=1.768 gap=+6.7 | Best 45.9%@E69 | 23.1s
E 75 | Tr 54.7% Va 46.3% | L=1.669 gap=+8.4 | Best 46.3%@E75 | 21.8s
E 80 | Tr 57.3% Va 45.8% | L=1.574 gap=+11.5 | Best 46.3%@E79 | 22.4s
E 85 | Tr 58.4% Va 46.7% | L=1.532 gap=+11.7 | Best 46.8%@E84 | 22.4s
E 90 | Tr 59.5% Va 46.8% | L=1.490 gap=+12.7 | Best 46.9%@E87 | 22.4s
E 95 | Tr 60.0% Va 46.9% | L=1.475 gap=+13.2 | Best 46.9%@E87 | 22.2s
E100 | Tr 60.6% Va 46.8% | L=1.452 gap=+13.8 | Best 46.9%@E87 | 21.9s
[RESULT] deit_small_svd: Best Val = 46.90% @E87 | Params: 21,676,900
======================================================================
HEAD-TO-HEAD COMPARISON
======================================================================
Model Val% Params
-------------------------------------------------------
DeiT-Small baseline 42.74% 21,526,372
DeiT-Small + SVD 46.90% 21,676,900
SVD contribution: +4.16 points
================================================================================
SCOREBOARD
================================================================================
Experiment Val% Params Epoch
--------------------------------------------- ------- ---------- ------
svd_classification_test 70.92% 3,878,820 93
vit_svd_classification_test 53.57% 6,705,828 86
deit_small_baseline 42.74% 21,526,372 87
================================================================================
================================================================================
SCOREBOARD
================================================================================
Experiment Val% Params Epoch
--------------------------------------------- ------- ---------- ------
svd_classification_test 70.92% 3,878,820 93
vit_svd_classification_test 53.57% 6,705,828 86
deit_small_svd 46.90% 21,676,900 87
deit_small_baseline 42.74% 21,526,372 87
================================================================================
"""
class DeiTSmallSVD(nn.Module):
"""DeiT-Small with SVD observation taps.
Architecture:
Patch embed (4Γ—4) β†’ 64 tokens + CLS β†’ 384-d
12 transformer layers (384-d, 6 heads, MLP ratio 4)
SVD tap after layers 3, 6, 9, 12
CLS token + SVD features β†’ classify
SVD observes token-space structure at 4 depths.
Detached β€” no gradient through eigh. Transformer learns normally.
SVD provides complementary structural features to the classifier.
"""
def __init__(self, num_classes=100, img_size=32, patch_size=4,
embed_dim=384, depth=12, n_heads=6, mlp_ratio=4.0,
dropout=0.1, svd_rank=32, tap_layers=(3, 6, 9, 12)):
super().__init__()
self.embed_dim = embed_dim
self.svd_rank = svd_rank
self.tap_layers = tap_layers
self.n_taps = len(tap_layers)
k = svd_rank
# ── Patch embedding ──
self.n_patches = (img_size // patch_size) ** 2 # 64
self.patch_embed = nn.Sequential(
nn.Conv2d(3, embed_dim, patch_size, stride=patch_size),
nn.Flatten(2), # (B, embed_dim, n_patches)
)
self.patch_norm = nn.LayerNorm(embed_dim)
# CLS token + positional embedding
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02)
self.pos_embed = nn.Parameter(torch.randn(1, self.n_patches + 1, embed_dim) * 0.02)
self.pos_drop = nn.Dropout(dropout)
# ── Transformer layers ──
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=n_heads,
dim_feedforward=int(embed_dim * mlp_ratio),
dropout=dropout, activation='gelu',
batch_first=True, norm_first=True)
for _ in range(depth)])
self.norm = nn.LayerNorm(embed_dim)
# ── SVD projections at tap points ──
self.svd_projs = nn.ModuleList([
nn.Linear(embed_dim, k, bias=False) for _ in range(self.n_taps)])
# ── SVD feature extraction ──
# Per-tap: S_norm(k) + Vh_diag(k) + offdiag(1) + entropy(1) = 2k+2
svd_feat_dim = 2 * k + 2
total_svd_feat = svd_feat_dim * self.n_taps
# ── Classifier: CLS token (384-d) + SVD features (264-d) ──
total_dim = embed_dim + total_svd_feat
self.head = nn.Sequential(
nn.Linear(total_dim, embed_dim), nn.GELU(),
nn.LayerNorm(embed_dim), nn.Dropout(0.1),
nn.Linear(embed_dim, num_classes))
self.n_params = sum(p.numel() for p in self.parameters())
self._init_weights()
def _init_weights(self):
# Trunc normal for linear, ones/zeros for norms
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
def _extract_svd_features(self, S, Vh):
"""Compact SVD summary. Same as 8.21."""
B, k = S.shape
S_safe = S.clamp(min=1e-6)
s_norm = S_safe / (S_safe.sum(dim=-1, keepdim=True) + 1e-8)
vh_diag = Vh.diagonal(dim1=-2, dim2=-1)
vh_offdiag = (Vh.pow(2).sum((-2, -1)) - vh_diag.pow(2).sum(-1)).unsqueeze(-1).clamp(min=0)
s_ent = -(s_norm * torch.log(s_norm.clamp(min=1e-8))).sum(-1, keepdim=True)
out = torch.cat([s_norm, vh_diag, vh_offdiag, s_ent], dim=-1)
return torch.where(torch.isfinite(out), out, torch.zeros_like(out))
def forward(self, x):
B = x.shape[0]
# Patch embed + CLS
patches = self.patch_embed(x).transpose(1, 2) # (B, n_patches, embed_dim)
cls = self.cls_token.expand(B, -1, -1)
tokens = torch.cat([cls, patches], dim=1) # (B, n_patches+1, embed_dim)
tokens = self.pos_drop(tokens + self.pos_embed)
# Transformer layers with SVD taps
svd_feats = []
tap_idx = 0
for layer_idx, layer in enumerate(self.layers):
tokens = layer(tokens)
# SVD tap at designated layers (1-indexed: after layer 3, 6, 9, 12)
if tap_idx < self.n_taps and (layer_idx + 1) == self.tap_layers[tap_idx]:
# Project tokens (excluding CLS) to SVD space
patch_tokens = tokens[:, 1:] # (B, n_patches, embed_dim)
h_proj = self.svd_projs[tap_idx](patch_tokens) # (B, n_patches, k)
with torch.amp.autocast('cuda', enabled=False):
with torch.no_grad():
_, S, Vh = gram_eigh_svd(h_proj.float())
S = S.clamp(min=1e-6)
S = torch.where(torch.isfinite(S), S, torch.ones_like(S))
Vh = torch.where(torch.isfinite(Vh), Vh, torch.zeros_like(Vh))
svd_feats.append(self._extract_svd_features(S, Vh))
tap_idx += 1
# Final norm + CLS token
tokens = self.norm(tokens)
cls_out = tokens[:, 0] # (B, embed_dim)
# Concatenate CLS + SVD features
all_feats = torch.cat([cls_out] + svd_feats, dim=-1)
return self.head(all_feats)
# ── Also build a baseline without SVD for fair comparison ────────────────────
class DeiTSmallBaseline(nn.Module):
"""Same DeiT-Small, no SVD taps. CLS β†’ classify."""
def __init__(self, num_classes=100, img_size=32, patch_size=4,
embed_dim=384, depth=12, n_heads=6, mlp_ratio=4.0, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.n_patches = (img_size // patch_size) ** 2
self.patch_embed = nn.Sequential(
nn.Conv2d(3, embed_dim, patch_size, stride=patch_size),
nn.Flatten(2))
self.patch_norm = nn.LayerNorm(embed_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02)
self.pos_embed = nn.Parameter(torch.randn(1, self.n_patches + 1, embed_dim) * 0.02)
self.pos_drop = nn.Dropout(dropout)
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=n_heads,
dim_feedforward=int(embed_dim * mlp_ratio),
dropout=dropout, activation='gelu',
batch_first=True, norm_first=True)
for _ in range(depth)])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Sequential(
nn.Linear(embed_dim, embed_dim), nn.GELU(),
nn.LayerNorm(embed_dim), nn.Dropout(0.1),
nn.Linear(embed_dim, num_classes))
self.n_params = sum(p.numel() for p in self.parameters())
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
def forward(self, x):
B = x.shape[0]
patches = self.patch_embed(x).transpose(1, 2)
cls = self.cls_token.expand(B, -1, -1)
tokens = self.pos_drop(torch.cat([cls, patches], dim=1) + self.pos_embed)
for layer in self.layers:
tokens = layer(tokens)
cls_out = self.norm(tokens)[:, 0]
return self.head(cls_out)
# ── Training loop ────────────────────────────────────────────────────────────
def train_model(model, train_loader, val_loader, device, epochs=100, lr=3e-4, label=""):
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
best_val = 0.0; best_epoch = 0
print(f"\n{'='*70}")
print(f"[EXP] {label} | {model.n_params:,} params | {epochs} epochs")
print(f"{'='*70}")
for epoch in range(1, epochs + 1):
model.train(); t0 = time.time()
correct = total = 0; loss_sum = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad(set_to_none=True)
with torch.amp.autocast('cuda', dtype=amp_dtype):
logits = model(images)
loss = F.cross_entropy(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
correct += (logits.argmax(-1) == labels).sum().item()
total += labels.size(0); loss_sum += loss.item()
scheduler.step()
train_acc = 100.0 * correct / total
model.eval(); val_correct = val_total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
with torch.amp.autocast('cuda', dtype=amp_dtype):
logits = model(images)
val_correct += (logits.argmax(-1) == labels).sum().item()
val_total += labels.size(0)
val_acc = 100.0 * val_correct / val_total
if val_acc > best_val: best_val = val_acc; best_epoch = epoch
elapsed = time.time() - t0; gap = train_acc - val_acc
if epoch <= 5 or epoch % 5 == 0 or epoch == epochs:
print(f" E{epoch:>3} | Tr {train_acc:5.1f}% Va {val_acc:5.1f}%"
f" | L={loss_sum/len(train_loader):.3f} gap={gap:+.1f}"
f" | Best {best_val:.1f}%@E{best_epoch} | {elapsed:.1f}s")
print(f"\n[RESULT] {label}: Best Val = {best_val:.2f}% @E{best_epoch} | Params: {model.n_params:,}")
return {'experiment': label, 'best_val_acc': best_val, 'best_epoch': best_epoch, 'params': model.n_params}
# ── Launch ───────────────────────────────────────────────────────────────────
tf_train = T.Compose([
T.RandomCrop(32, padding=4), T.RandomHorizontalFlip(),
T.autoaugment.RandAugment(num_ops=2, magnitude=9), T.ToTensor()])
tf_val = T.Compose([T.ToTensor()])
train_ds = torchvision.datasets.CIFAR100(root="./data", train=True, download=True, transform=tf_train)
val_ds = torchvision.datasets.CIFAR100(root="./data", train=False, download=True, transform=tf_val)
train_loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=4,
pin_memory=True, drop_last=True, persistent_workers=True)
val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, num_workers=4,
pin_memory=True, persistent_workers=True)
print(f"[DATA] CIFAR-100: {len(train_ds)} train, {len(val_ds)} val, bs=256")
# ── Run baseline first ──
print("\n" + "="*70)
print(" PHASE 1: DeiT-Small BASELINE (no SVD)")
print("="*70)
model_baseline = DeiTSmallBaseline(num_classes=100)
print(f"[MODEL] DeiT-Small Baseline: {model_baseline.n_params:,} params")
result_baseline = train_model(model_baseline, train_loader, val_loader, device,
epochs=100, label="deit_small_baseline")
# ── Then SVD version ──
print("\n" + "="*70)
print(" PHASE 2: DeiT-Small + SVD TAPS")
print("="*70)
model_svd = DeiTSmallSVD(num_classes=100, svd_rank=32)
print(f"[MODEL] DeiT-Small + SVD: {model_svd.n_params:,} params")
print(f" SVD taps at layers: {model_svd.tap_layers}")
print(f" SVD features: {model_svd.n_taps * 66} = {model_svd.n_taps}Γ—66")
print(f" Classifier input: {model_svd.embed_dim} + {model_svd.n_taps * 66} = {model_svd.embed_dim + model_svd.n_taps * 66}")
result_svd = train_model(model_svd, train_loader, val_loader, device,
epochs=100, label="deit_small_svd")
# ── Compare ──
print(f"\n{'='*70}")
print(f" HEAD-TO-HEAD COMPARISON")
print(f"{'='*70}")
print(f" {'Model':<30} {'Val%':>7} {'Params':>12}")
print(f" {'-'*55}")
print(f" {'DeiT-Small baseline':<30} {result_baseline['best_val_acc']:>6.2f}% {result_baseline['params']:>12,}")
print(f" {'DeiT-Small + SVD':<30} {result_svd['best_val_acc']:>6.2f}% {result_svd['params']:>12,}")
delta = result_svd['best_val_acc'] - result_baseline['best_val_acc']
print(f"\n SVD contribution: {delta:+.2f} points")