| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Standard DeiT without SVD features at 100 epochs defeated. |
| |
| Overtaken by epoch 52 of the SVD features variant. |
| |
| The amount of time difference isn't insignificant, but this will be improved considerably |
| as the process improves. |
| |
| |
| |
| [DATA] CIFAR-100: 50000 train, 10000 val, bs=256 |
| |
| ====================================================================== |
| PHASE 1: DeiT-Small BASELINE (no SVD) |
| ====================================================================== |
| [MODEL] DeiT-Small Baseline: 21,526,372 params |
| |
| ====================================================================== |
| [EXP] deit_small_baseline | 21,526,372 params | 100 epochs |
| ====================================================================== |
| E 1 | Tr 1.0% Va 1.5% | L=4.636 gap=-0.5 | Best 1.5%@E1 | 5.8s |
| E 2 | Tr 1.8% Va 3.1% | L=4.526 gap=-1.3 | Best 3.1%@E2 | 5.4s |
| E 3 | Tr 2.9% Va 3.9% | L=4.389 gap=-1.0 | Best 3.9%@E3 | 5.4s |
| E 4 | Tr 3.6% Va 4.9% | L=4.335 gap=-1.3 | Best 4.9%@E4 | 5.4s |
| E 5 | Tr 4.1% Va 5.3% | L=4.282 gap=-1.2 | Best 5.3%@E5 | 5.4s |
| E 10 | Tr 10.5% Va 13.7% | L=3.861 gap=-3.2 | Best 13.7%@E10 | 5.4s |
| E 15 | Tr 15.2% Va 20.5% | L=3.585 gap=-5.2 | Best 20.5%@E15 | 5.4s |
| E 20 | Tr 18.8% Va 25.2% | L=3.375 gap=-6.4 | Best 25.2%@E20 | 5.4s |
| E 25 | Tr 21.9% Va 26.9% | L=3.218 gap=-5.1 | Best 26.9%@E25 | 5.4s |
| E 30 | Tr 24.4% Va 30.1% | L=3.075 gap=-5.7 | Best 30.1%@E30 | 5.4s |
| E 35 | Tr 27.2% Va 32.8% | L=2.941 gap=-5.6 | Best 32.8%@E35 | 5.4s |
| E 40 | Tr 29.5% Va 34.8% | L=2.808 gap=-5.2 | Best 34.8%@E40 | 5.4s |
| E 45 | Tr 31.9% Va 36.1% | L=2.685 gap=-4.2 | Best 36.1%@E45 | 5.4s |
| E 50 | Tr 35.1% Va 37.5% | L=2.547 gap=-2.4 | Best 37.5%@E50 | 5.4s |
| E 55 | Tr 37.6% Va 38.8% | L=2.428 gap=-1.2 | Best 38.8%@E55 | 5.4s |
| E 60 | Tr 40.1% Va 40.0% | L=2.313 gap=+0.1 | Best 40.0%@E60 | 5.4s |
| E 65 | Tr 42.5% Va 41.1% | L=2.202 gap=+1.4 | Best 41.1%@E65 | 5.4s |
| E 70 | Tr 44.8% Va 41.2% | L=2.104 gap=+3.6 | Best 41.2%@E70 | 5.4s |
| E 75 | Tr 46.9% Va 41.5% | L=2.023 gap=+5.3 | Best 41.6%@E72 | 5.4s |
| E 80 | Tr 48.5% Va 41.7% | L=1.955 gap=+6.8 | Best 41.9%@E79 | 5.4s |
| E 85 | Tr 49.6% Va 42.3% | L=1.914 gap=+7.2 | Best 42.5%@E83 | 5.4s |
| E 90 | Tr 50.4% Va 42.3% | L=1.877 gap=+8.1 | Best 42.7%@E87 | 5.4s |
| E 95 | Tr 50.6% Va 42.5% | L=1.860 gap=+8.1 | Best 42.7%@E87 | 5.4s |
| E100 | Tr 50.7% Va 42.5% | L=1.855 gap=+8.2 | Best 42.7%@E87 | 5.4s |
| |
| [RESULT] deit_small_baseline: Best Val = 42.74% @E87 | Params: 21,526,372 |
| |
| ====================================================================== |
| PHASE 2: DeiT-Small + SVD TAPS |
| ====================================================================== |
| [MODEL] DeiT-Small + SVD: 21,676,900 params |
| SVD taps at layers: (3, 6, 9, 12) |
| SVD features: 264 = 4Γ66 |
| Classifier input: 384 + 264 = 648 |
| |
| ====================================================================== |
| [EXP] deit_small_svd | 21,676,900 params | 100 epochs |
| ====================================================================== |
| E 1 | Tr 2.3% Va 3.7% | L=4.471 gap=-1.3 | Best 3.7%@E1 | 21.4s |
| E 2 | Tr 4.4% Va 7.9% | L=4.261 gap=-3.5 | Best 7.9%@E2 | 22.7s |
| E 3 | Tr 6.4% Va 9.0% | L=4.108 gap=-2.6 | Best 9.0%@E3 | 21.1s |
| E 4 | Tr 8.4% Va 12.0% | L=3.972 gap=-3.6 | Best 12.0%@E4 | 21.1s |
| E 5 | Tr 10.1% Va 14.6% | L=3.878 gap=-4.6 | Best 14.6%@E5 | 21.1s |
| E 10 | Tr 15.9% Va 20.7% | L=3.533 gap=-4.8 | Best 20.7%@E10 | 21.5s |
| E 15 | Tr 19.7% Va 24.7% | L=3.321 gap=-5.0 | Best 24.7%@E15 | 21.4s |
| E 20 | Tr 22.8% Va 28.5% | L=3.170 gap=-5.7 | Best 28.5%@E20 | 22.1s |
| E 25 | Tr 25.3% Va 30.5% | L=3.023 gap=-5.2 | Best 30.5%@E25 | 21.8s |
| E 30 | Tr 28.0% Va 33.2% | L=2.882 gap=-5.2 | Best 33.2%@E30 | 22.0s |
| E 35 | Tr 31.0% Va 35.2% | L=2.741 gap=-4.2 | Best 35.2%@E35 | 22.0s |
| E 40 | Tr 34.0% Va 37.6% | L=2.588 gap=-3.6 | Best 37.6%@E40 | 21.7s |
| E 45 | Tr 37.5% Va 39.1% | L=2.432 gap=-1.6 | Best 39.4%@E44 | 22.0s |
| E 50 | Tr 40.1% Va 41.0% | L=2.301 gap=-0.9 | Best 41.2%@E49 | 21.9s |
| E 55 | Tr 43.4% Va 43.5% | L=2.138 gap=-0.1 | Best 43.5%@E55 | 22.2s |
| E 60 | Tr 46.9% Va 44.1% | L=2.007 gap=+2.8 | Best 44.1%@E60 | 22.2s |
| E 65 | Tr 50.0% Va 44.7% | L=1.869 gap=+5.3 | Best 45.0%@E64 | 22.2s |
| E 70 | Tr 52.5% Va 45.8% | L=1.768 gap=+6.7 | Best 45.9%@E69 | 23.1s |
| E 75 | Tr 54.7% Va 46.3% | L=1.669 gap=+8.4 | Best 46.3%@E75 | 21.8s |
| E 80 | Tr 57.3% Va 45.8% | L=1.574 gap=+11.5 | Best 46.3%@E79 | 22.4s |
| E 85 | Tr 58.4% Va 46.7% | L=1.532 gap=+11.7 | Best 46.8%@E84 | 22.4s |
| E 90 | Tr 59.5% Va 46.8% | L=1.490 gap=+12.7 | Best 46.9%@E87 | 22.4s |
| E 95 | Tr 60.0% Va 46.9% | L=1.475 gap=+13.2 | Best 46.9%@E87 | 22.2s |
| E100 | Tr 60.6% Va 46.8% | L=1.452 gap=+13.8 | Best 46.9%@E87 | 21.9s |
| |
| [RESULT] deit_small_svd: Best Val = 46.90% @E87 | Params: 21,676,900 |
| |
| ====================================================================== |
| HEAD-TO-HEAD COMPARISON |
| ====================================================================== |
| Model Val% Params |
| ------------------------------------------------------- |
| DeiT-Small baseline 42.74% 21,526,372 |
| DeiT-Small + SVD 46.90% 21,676,900 |
| |
| SVD contribution: +4.16 points |
| |
| ================================================================================ |
| SCOREBOARD |
| ================================================================================ |
| Experiment Val% Params Epoch |
| --------------------------------------------- ------- ---------- ------ |
| svd_classification_test 70.92% 3,878,820 93 |
| vit_svd_classification_test 53.57% 6,705,828 86 |
| deit_small_baseline 42.74% 21,526,372 87 |
| ================================================================================ |
| |
| ================================================================================ |
| SCOREBOARD |
| ================================================================================ |
| Experiment Val% Params Epoch |
| --------------------------------------------- ------- ---------- ------ |
| svd_classification_test 70.92% 3,878,820 93 |
| vit_svd_classification_test 53.57% 6,705,828 86 |
| deit_small_svd 46.90% 21,676,900 87 |
| deit_small_baseline 42.74% 21,526,372 87 |
| ================================================================================ |
| |
| """ |
|
|
| class DeiTSmallSVD(nn.Module): |
| """DeiT-Small with SVD observation taps. |
| |
| Architecture: |
| Patch embed (4Γ4) β 64 tokens + CLS β 384-d |
| 12 transformer layers (384-d, 6 heads, MLP ratio 4) |
| SVD tap after layers 3, 6, 9, 12 |
| CLS token + SVD features β classify |
| |
| SVD observes token-space structure at 4 depths. |
| Detached β no gradient through eigh. Transformer learns normally. |
| SVD provides complementary structural features to the classifier. |
| """ |
|
|
| def __init__(self, num_classes=100, img_size=32, patch_size=4, |
| embed_dim=384, depth=12, n_heads=6, mlp_ratio=4.0, |
| dropout=0.1, svd_rank=32, tap_layers=(3, 6, 9, 12)): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.svd_rank = svd_rank |
| self.tap_layers = tap_layers |
| self.n_taps = len(tap_layers) |
| k = svd_rank |
|
|
| |
| self.n_patches = (img_size // patch_size) ** 2 |
| self.patch_embed = nn.Sequential( |
| nn.Conv2d(3, embed_dim, patch_size, stride=patch_size), |
| nn.Flatten(2), |
| ) |
| self.patch_norm = nn.LayerNorm(embed_dim) |
|
|
| |
| self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02) |
| self.pos_embed = nn.Parameter(torch.randn(1, self.n_patches + 1, embed_dim) * 0.02) |
| self.pos_drop = nn.Dropout(dropout) |
|
|
| |
| self.layers = nn.ModuleList([ |
| nn.TransformerEncoderLayer( |
| d_model=embed_dim, nhead=n_heads, |
| dim_feedforward=int(embed_dim * mlp_ratio), |
| dropout=dropout, activation='gelu', |
| batch_first=True, norm_first=True) |
| for _ in range(depth)]) |
| self.norm = nn.LayerNorm(embed_dim) |
|
|
| |
| self.svd_projs = nn.ModuleList([ |
| nn.Linear(embed_dim, k, bias=False) for _ in range(self.n_taps)]) |
|
|
| |
| |
| svd_feat_dim = 2 * k + 2 |
| total_svd_feat = svd_feat_dim * self.n_taps |
|
|
| |
| total_dim = embed_dim + total_svd_feat |
| self.head = nn.Sequential( |
| nn.Linear(total_dim, embed_dim), nn.GELU(), |
| nn.LayerNorm(embed_dim), nn.Dropout(0.1), |
| nn.Linear(embed_dim, num_classes)) |
|
|
| self.n_params = sum(p.numel() for p in self.parameters()) |
| self._init_weights() |
|
|
| def _init_weights(self): |
| |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.ones_(m.weight) |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out') |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
|
|
| def _extract_svd_features(self, S, Vh): |
| """Compact SVD summary. Same as 8.21.""" |
| B, k = S.shape |
| S_safe = S.clamp(min=1e-6) |
| s_norm = S_safe / (S_safe.sum(dim=-1, keepdim=True) + 1e-8) |
| vh_diag = Vh.diagonal(dim1=-2, dim2=-1) |
| vh_offdiag = (Vh.pow(2).sum((-2, -1)) - vh_diag.pow(2).sum(-1)).unsqueeze(-1).clamp(min=0) |
| s_ent = -(s_norm * torch.log(s_norm.clamp(min=1e-8))).sum(-1, keepdim=True) |
| out = torch.cat([s_norm, vh_diag, vh_offdiag, s_ent], dim=-1) |
| return torch.where(torch.isfinite(out), out, torch.zeros_like(out)) |
|
|
| def forward(self, x): |
| B = x.shape[0] |
|
|
| |
| patches = self.patch_embed(x).transpose(1, 2) |
| cls = self.cls_token.expand(B, -1, -1) |
| tokens = torch.cat([cls, patches], dim=1) |
| tokens = self.pos_drop(tokens + self.pos_embed) |
|
|
| |
| svd_feats = [] |
| tap_idx = 0 |
|
|
| for layer_idx, layer in enumerate(self.layers): |
| tokens = layer(tokens) |
|
|
| |
| if tap_idx < self.n_taps and (layer_idx + 1) == self.tap_layers[tap_idx]: |
| |
| patch_tokens = tokens[:, 1:] |
| h_proj = self.svd_projs[tap_idx](patch_tokens) |
|
|
| with torch.amp.autocast('cuda', enabled=False): |
| with torch.no_grad(): |
| _, S, Vh = gram_eigh_svd(h_proj.float()) |
| S = S.clamp(min=1e-6) |
| S = torch.where(torch.isfinite(S), S, torch.ones_like(S)) |
| Vh = torch.where(torch.isfinite(Vh), Vh, torch.zeros_like(Vh)) |
|
|
| svd_feats.append(self._extract_svd_features(S, Vh)) |
| tap_idx += 1 |
|
|
| |
| tokens = self.norm(tokens) |
| cls_out = tokens[:, 0] |
|
|
| |
| all_feats = torch.cat([cls_out] + svd_feats, dim=-1) |
| return self.head(all_feats) |
|
|
|
|
| |
|
|
| class DeiTSmallBaseline(nn.Module): |
| """Same DeiT-Small, no SVD taps. CLS β classify.""" |
|
|
| def __init__(self, num_classes=100, img_size=32, patch_size=4, |
| embed_dim=384, depth=12, n_heads=6, mlp_ratio=4.0, dropout=0.1): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.n_patches = (img_size // patch_size) ** 2 |
| self.patch_embed = nn.Sequential( |
| nn.Conv2d(3, embed_dim, patch_size, stride=patch_size), |
| nn.Flatten(2)) |
| self.patch_norm = nn.LayerNorm(embed_dim) |
| self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02) |
| self.pos_embed = nn.Parameter(torch.randn(1, self.n_patches + 1, embed_dim) * 0.02) |
| self.pos_drop = nn.Dropout(dropout) |
| self.layers = nn.ModuleList([ |
| nn.TransformerEncoderLayer( |
| d_model=embed_dim, nhead=n_heads, |
| dim_feedforward=int(embed_dim * mlp_ratio), |
| dropout=dropout, activation='gelu', |
| batch_first=True, norm_first=True) |
| for _ in range(depth)]) |
| self.norm = nn.LayerNorm(embed_dim) |
| self.head = nn.Sequential( |
| nn.Linear(embed_dim, embed_dim), nn.GELU(), |
| nn.LayerNorm(embed_dim), nn.Dropout(0.1), |
| nn.Linear(embed_dim, num_classes)) |
| self.n_params = sum(p.numel() for p in self.parameters()) |
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.ones_(m.weight); nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out') |
|
|
| def forward(self, x): |
| B = x.shape[0] |
| patches = self.patch_embed(x).transpose(1, 2) |
| cls = self.cls_token.expand(B, -1, -1) |
| tokens = self.pos_drop(torch.cat([cls, patches], dim=1) + self.pos_embed) |
| for layer in self.layers: |
| tokens = layer(tokens) |
| cls_out = self.norm(tokens)[:, 0] |
| return self.head(cls_out) |
|
|
|
|
| |
|
|
| def train_model(model, train_loader, val_loader, device, epochs=100, lr=3e-4, label=""): |
| model = model.to(device) |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) |
| amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 |
| best_val = 0.0; best_epoch = 0 |
| print(f"\n{'='*70}") |
| print(f"[EXP] {label} | {model.n_params:,} params | {epochs} epochs") |
| print(f"{'='*70}") |
| for epoch in range(1, epochs + 1): |
| model.train(); t0 = time.time() |
| correct = total = 0; loss_sum = 0.0 |
| for images, labels in train_loader: |
| images, labels = images.to(device), labels.to(device) |
| optimizer.zero_grad(set_to_none=True) |
| with torch.amp.autocast('cuda', dtype=amp_dtype): |
| logits = model(images) |
| loss = F.cross_entropy(logits, labels) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| correct += (logits.argmax(-1) == labels).sum().item() |
| total += labels.size(0); loss_sum += loss.item() |
| scheduler.step() |
| train_acc = 100.0 * correct / total |
| model.eval(); val_correct = val_total = 0 |
| with torch.no_grad(): |
| for images, labels in val_loader: |
| images, labels = images.to(device), labels.to(device) |
| with torch.amp.autocast('cuda', dtype=amp_dtype): |
| logits = model(images) |
| val_correct += (logits.argmax(-1) == labels).sum().item() |
| val_total += labels.size(0) |
| val_acc = 100.0 * val_correct / val_total |
| if val_acc > best_val: best_val = val_acc; best_epoch = epoch |
| elapsed = time.time() - t0; gap = train_acc - val_acc |
| if epoch <= 5 or epoch % 5 == 0 or epoch == epochs: |
| print(f" E{epoch:>3} | Tr {train_acc:5.1f}% Va {val_acc:5.1f}%" |
| f" | L={loss_sum/len(train_loader):.3f} gap={gap:+.1f}" |
| f" | Best {best_val:.1f}%@E{best_epoch} | {elapsed:.1f}s") |
| print(f"\n[RESULT] {label}: Best Val = {best_val:.2f}% @E{best_epoch} | Params: {model.n_params:,}") |
| return {'experiment': label, 'best_val_acc': best_val, 'best_epoch': best_epoch, 'params': model.n_params} |
|
|
|
|
| |
|
|
| tf_train = T.Compose([ |
| T.RandomCrop(32, padding=4), T.RandomHorizontalFlip(), |
| T.autoaugment.RandAugment(num_ops=2, magnitude=9), T.ToTensor()]) |
| tf_val = T.Compose([T.ToTensor()]) |
| train_ds = torchvision.datasets.CIFAR100(root="./data", train=True, download=True, transform=tf_train) |
| val_ds = torchvision.datasets.CIFAR100(root="./data", train=False, download=True, transform=tf_val) |
| train_loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=4, |
| pin_memory=True, drop_last=True, persistent_workers=True) |
| val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, num_workers=4, |
| pin_memory=True, persistent_workers=True) |
| print(f"[DATA] CIFAR-100: {len(train_ds)} train, {len(val_ds)} val, bs=256") |
|
|
| |
| print("\n" + "="*70) |
| print(" PHASE 1: DeiT-Small BASELINE (no SVD)") |
| print("="*70) |
| model_baseline = DeiTSmallBaseline(num_classes=100) |
| print(f"[MODEL] DeiT-Small Baseline: {model_baseline.n_params:,} params") |
| result_baseline = train_model(model_baseline, train_loader, val_loader, device, |
| epochs=100, label="deit_small_baseline") |
|
|
| |
| print("\n" + "="*70) |
| print(" PHASE 2: DeiT-Small + SVD TAPS") |
| print("="*70) |
| model_svd = DeiTSmallSVD(num_classes=100, svd_rank=32) |
| print(f"[MODEL] DeiT-Small + SVD: {model_svd.n_params:,} params") |
| print(f" SVD taps at layers: {model_svd.tap_layers}") |
| print(f" SVD features: {model_svd.n_taps * 66} = {model_svd.n_taps}Γ66") |
| print(f" Classifier input: {model_svd.embed_dim} + {model_svd.n_taps * 66} = {model_svd.embed_dim + model_svd.n_taps * 66}") |
| result_svd = train_model(model_svd, train_loader, val_loader, device, |
| epochs=100, label="deit_small_svd") |
|
|
| |
| print(f"\n{'='*70}") |
| print(f" HEAD-TO-HEAD COMPARISON") |
| print(f"{'='*70}") |
| print(f" {'Model':<30} {'Val%':>7} {'Params':>12}") |
| print(f" {'-'*55}") |
| print(f" {'DeiT-Small baseline':<30} {result_baseline['best_val_acc']:>6.2f}% {result_baseline['params']:>12,}") |
| print(f" {'DeiT-Small + SVD':<30} {result_svd['best_val_acc']:>6.2f}% {result_svd['params']:>12,}") |
| delta = result_svd['best_val_acc'] - result_baseline['best_val_acc'] |
| print(f"\n SVD contribution: {delta:+.2f} points") |