AbstractPhil commited on
Commit
87a5790
Β·
verified Β·
1 Parent(s): e2643f8

Create svd_conv_cifar100_train.py

Browse files
Files changed (1) hide show
  1. svd_conv_cifar100_train.py +254 -0
svd_conv_cifar100_train.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @title Experiment 8.21 β€” Pure SVD Classification Test
2
+ #
3
+ # Question: can SVD features alone drive classification?
4
+ # No constellation, no scatter, no patchwork. Just:
5
+ # Conv β†’ project to 32ch β†’ SVD β†’ features β†’ classify
6
+ #
7
+ # SVD of (B, H*W, 32) via gram_eigh: ~0.78ms
8
+ """
9
+ Expected Output:
10
+
11
+ [DATA] CIFAR-100: 50000 train, 10000 val
12
+ [MODEL] ConvSVDTest: 3,878,820 params
13
+ SVD feature dim per tap: 66 = 66
14
+ Total SVD features: 264 = 264
15
+ Conv features: 384
16
+ Classifier input: 648 = 648
17
+
18
+ ======================================================================
19
+ [EXP] SVD Classification Test | 3,878,820 params | 100 epochs
20
+ ======================================================================
21
+ E 1 | Tr 8.3% Va 17.4% | L=4.084 gap=-9.0 | Best 17.4%@E1 | 16.5s
22
+ E 2 | Tr 18.0% Va 27.3% | L=3.425 gap=-9.2 | Best 27.3%@E2 | 16.8s
23
+ E 3 | Tr 26.3% Va 34.3% | L=2.994 gap=-8.0 | Best 34.3%@E3 | 17.3s
24
+ E 4 | Tr 32.1% Va 38.1% | L=2.690 gap=-6.0 | Best 38.1%@E4 | 17.6s
25
+ E 5 | Tr 37.0% Va 41.3% | L=2.460 gap=-4.3 | Best 41.3%@E5 | 17.3s
26
+ E 10 | Tr 50.4% Va 52.4% | L=1.835 gap=-1.9 | Best 52.4%@E10 | 16.7s
27
+ E 15 | Tr 58.1% Va 58.5% | L=1.519 gap=-0.4 | Best 58.5%@E15 | 16.2s
28
+ E 20 | Tr 63.8% Va 61.2% | L=1.281 gap=+2.6 | Best 61.2%@E20 | 16.3s
29
+ E 25 | Tr 68.1% Va 62.9% | L=1.111 gap=+5.3 | Best 62.9%@E25 | 17.5s
30
+ E 30 | Tr 71.6% Va 64.7% | L=0.977 gap=+6.9 | Best 64.7%@E30 | 16.6s
31
+ E 35 | Tr 75.5% Va 65.6% | L=0.836 gap=+9.9 | Best 65.7%@E33 | 16.2s
32
+ E 40 | Tr 78.1% Va 66.3% | L=0.740 gap=+11.7 | Best 66.5%@E39 | 16.7s
33
+ E 45 | Tr 80.4% Va 67.3% | L=0.662 gap=+13.1 | Best 67.4%@E43 | 16.8s
34
+ E 50 | Tr 83.1% Va 67.8% | L=0.564 gap=+15.3 | Best 67.8%@E50 | 16.8s
35
+ E 55 | Tr 85.2% Va 68.2% | L=0.501 gap=+16.9 | Best 68.2%@E55 | 16.3s
36
+ E 60 | Tr 86.8% Va 69.1% | L=0.443 gap=+17.7 | Best 69.3%@E56 | 16.0s
37
+ E 65 | Tr 88.3% Va 69.3% | L=0.393 gap=+18.9 | Best 69.5%@E62 | 17.3s
38
+ E 70 | Tr 89.7% Va 69.6% | L=0.350 gap=+20.1 | Best 69.7%@E67 | 16.0s
39
+ E 75 | Tr 90.7% Va 70.0% | L=0.320 gap=+20.7 | Best 70.0%@E75 | 16.3s
40
+ E 80 | Tr 91.3% Va 70.5% | L=0.295 gap=+20.9 | Best 70.5%@E80 | 16.3s
41
+ E 85 | Tr 92.0% Va 70.5% | L=0.276 gap=+21.5 | Best 70.8%@E81 | 16.3s
42
+ E 90 | Tr 92.3% Va 70.7% | L=0.264 gap=+21.6 | Best 70.9%@E88 | 16.2s
43
+ E 95 | Tr 92.7% Va 70.8% | L=0.251 gap=+21.9 | Best 70.9%@E93 | 17.0s
44
+ E100 | Tr 92.8% Va 70.7% | L=0.254 gap=+22.2 | Best 70.9%@E93 | 16.8s
45
+
46
+ [RESULT] SVD Test: Best Val = 70.92% @E93 | Params: 3,878,820
47
+ """
48
+
49
+
50
+ # ── Simple Conv + SVD Model ──────────────────────────────────────────────────
51
+
52
+ class ConvSVDTest(nn.Module):
53
+ """Minimal test: conv backbone + SVD features β†’ classify.
54
+
55
+ 4 conv stages (same as ConvScatterNet).
56
+ After each stage: project to 32ch, SVD, extract S + Vh β†’ features.
57
+ Pool all SVD features across depth β†’ classify.
58
+ """
59
+
60
+ def __init__(self, num_classes=100, svd_rank=32):
61
+ super().__init__()
62
+ self.num_classes = num_classes
63
+ self.svd_rank = svd_rank
64
+ k = svd_rank
65
+
66
+ # Conv stages
67
+ self.stages = nn.ModuleList([
68
+ nn.Sequential(
69
+ nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(),
70
+ nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU()),
71
+ nn.Sequential(
72
+ nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(),
73
+ nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU()),
74
+ nn.Sequential(
75
+ nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(),
76
+ nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU()),
77
+ nn.Sequential(
78
+ nn.Conv2d(256, 384, 3, padding=1), nn.BatchNorm2d(384), nn.GELU(),
79
+ nn.Conv2d(384, 384, 3, padding=1), nn.BatchNorm2d(384), nn.GELU()),
80
+ ])
81
+ self.pools = nn.ModuleList([nn.MaxPool2d(2) for _ in range(4)])
82
+
83
+ # SVD projections per stage
84
+ channel_sizes = [64, 128, 256, 384]
85
+ self.to_svd = nn.ModuleList([
86
+ nn.Conv2d(ch, k, 1, bias=False) for ch in channel_sizes
87
+ ])
88
+
89
+ # Per-tap SVD feature dim: S(k) + Vh_diag(k) + Vh_offdiag_norm(1) + S_entropy(1) = 2k+2
90
+ svd_feat_dim = 2 * k + 2
91
+ total_svd_feat = svd_feat_dim * 4 # 4 depths
92
+
93
+ # Also keep the conv pooled features
94
+ self.final_pool = nn.AdaptiveAvgPool2d(1)
95
+ conv_feat_dim = 384
96
+
97
+ # Classifier: SVD features + conv features β†’ classes
98
+ total_dim = total_svd_feat + conv_feat_dim
99
+ self.classifier = nn.Sequential(
100
+ nn.Linear(total_dim, 512), nn.GELU(), nn.LayerNorm(512), nn.Dropout(0.1),
101
+ nn.Linear(512, 256), nn.GELU(), nn.LayerNorm(256), nn.Dropout(0.1),
102
+ nn.Linear(256, num_classes),
103
+ )
104
+
105
+ self.n_params = sum(p.numel() for p in self.parameters())
106
+
107
+ def _extract_svd_features(self, S, Vh):
108
+ """Extract compact features from SVD output.
109
+ S: (B, k), Vh: (B, k, k) β†’ (B, 2k+2)"""
110
+ B, k = S.shape
111
+ # Singular values (energy distribution) β€” clamp before normalize
112
+ S_safe = S.clamp(min=1e-6)
113
+ s_norm = S_safe / (S_safe.sum(dim=-1, keepdim=True) + 1e-8)
114
+
115
+ # Vh diagonal (self-alignment per component)
116
+ vh_diag = Vh.diagonal(dim1=-2, dim2=-1) # (B, k)
117
+
118
+ # Vh off-diagonal energy (cross-component mixing)
119
+ vh_offdiag = (Vh.pow(2).sum((-2, -1)) - vh_diag.pow(2).sum(-1)).unsqueeze(-1).clamp(min=0)
120
+
121
+ # Spectral entropy β€” safe log
122
+ s_ent = -(s_norm * torch.log(s_norm.clamp(min=1e-8))).sum(-1, keepdim=True)
123
+
124
+ out = torch.cat([s_norm, vh_diag, vh_offdiag, s_ent], dim=-1)
125
+ # Final NaN guard
126
+ return torch.where(torch.isfinite(out), out, torch.zeros_like(out))
127
+
128
+ def forward(self, x):
129
+ B = x.shape[0]
130
+ svd_feats = []
131
+
132
+ h = x
133
+ for i, (stage, pool, proj) in enumerate(zip(self.stages, self.pools, self.to_svd)):
134
+ h = stage(h)
135
+ # SVD on projected features
136
+ h_svd = proj(h) # (B, k, H, W)
137
+ H, W = h_svd.shape[2], h_svd.shape[3]
138
+ h_flat = h_svd.permute(0, 2, 3, 1).reshape(B, H * W, self.svd_rank)
139
+ with torch.amp.autocast('cuda', enabled=False):
140
+ with torch.no_grad():
141
+ h_f = h_flat.float()
142
+ _, S, Vh = gram_eigh_svd(h_f)
143
+ S = S.clamp(min=1e-6)
144
+ S = torch.where(torch.isfinite(S), S, torch.ones_like(S))
145
+ Vh = torch.where(torch.isfinite(Vh), Vh, torch.zeros_like(Vh))
146
+ svd_feats.append(self._extract_svd_features(S, Vh))
147
+ h = pool(h)
148
+
149
+ # Conv pooled features
150
+ conv_feat = self.final_pool(h).flatten(1) # (B, 384)
151
+
152
+ # Concatenate all SVD features + conv features
153
+ all_feats = torch.cat(svd_feats + [conv_feat], dim=-1)
154
+
155
+ return self.classifier(all_feats)
156
+
157
+
158
+ # ── Training loop (simple, no paired views) ──────────────────────────────────
159
+
160
+ def train_svd_test(model, train_loader, val_loader, device, epochs=100, lr=3e-4):
161
+ model = model.to(device)
162
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
163
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
164
+
165
+ amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
166
+ best_val = 0.0
167
+ best_epoch = 0
168
+
169
+ print(f"\n{'='*70}")
170
+ print(f"[EXP] SVD Classification Test | {model.n_params:,} params | {epochs} epochs")
171
+ print(f"{'='*70}")
172
+
173
+ for epoch in range(1, epochs + 1):
174
+ model.train()
175
+ t0 = time.time()
176
+ correct = total = 0
177
+ loss_sum = 0.0
178
+
179
+ for images, labels in train_loader:
180
+ images, labels = images.to(device), labels.to(device)
181
+ optimizer.zero_grad(set_to_none=True)
182
+
183
+ with torch.amp.autocast('cuda', dtype=amp_dtype):
184
+ logits = model(images)
185
+ loss = F.cross_entropy(logits, labels)
186
+
187
+ loss.backward()
188
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
189
+ optimizer.step()
190
+
191
+ correct += (logits.argmax(-1) == labels).sum().item()
192
+ total += labels.size(0)
193
+ loss_sum += loss.item()
194
+
195
+ scheduler.step()
196
+ train_acc = 100.0 * correct / total
197
+ train_loss = loss_sum / len(train_loader)
198
+
199
+ # Validation
200
+ model.eval()
201
+ val_correct = val_total = 0
202
+ with torch.no_grad():
203
+ for images, labels in val_loader:
204
+ images, labels = images.to(device), labels.to(device)
205
+ with torch.amp.autocast('cuda', dtype=amp_dtype):
206
+ logits = model(images)
207
+ val_correct += (logits.argmax(-1) == labels).sum().item()
208
+ val_total += labels.size(0)
209
+ val_acc = 100.0 * val_correct / val_total
210
+
211
+ if val_acc > best_val:
212
+ best_val = val_acc
213
+ best_epoch = epoch
214
+
215
+ elapsed = time.time() - t0
216
+ gap = train_acc - val_acc
217
+ if epoch <= 5 or epoch % 5 == 0 or epoch == epochs:
218
+ print(f" E{epoch:>3} | Tr {train_acc:5.1f}% Va {val_acc:5.1f}%"
219
+ f" | L={train_loss:.3f} gap={gap:+.1f}"
220
+ f" | Best {best_val:.1f}%@E{best_epoch} | {elapsed:.1f}s")
221
+
222
+ print(f"\n[RESULT] SVD Test: Best Val = {best_val:.2f}% @E{best_epoch} | Params: {model.n_params:,}")
223
+ return {'experiment': 'svd_classification_test', 'best_val_acc': best_val,
224
+ 'best_epoch': best_epoch, 'params': model.n_params}
225
+
226
+
227
+ # ── Launch ─────────────────────────────────────────────────────��─────────────
228
+
229
+ # Simple augmentation β€” single view, standard training
230
+ tf_train = T.Compose([
231
+ T.RandomCrop(32, padding=4),
232
+ T.RandomHorizontalFlip(),
233
+ T.autoaugment.RandAugment(num_ops=2, magnitude=9),
234
+ T.ToTensor(),
235
+ ])
236
+ tf_val = T.Compose([T.ToTensor()])
237
+
238
+ train_ds = torchvision.datasets.CIFAR100(root="./data", train=True, download=True, transform=tf_train)
239
+ val_ds = torchvision.datasets.CIFAR100(root="./data", train=False, download=True, transform=tf_val)
240
+ train_loader = DataLoader(train_ds, batch_size=512, shuffle=True, num_workers=4,
241
+ pin_memory=True, drop_last=True, persistent_workers=True)
242
+ val_loader = DataLoader(val_ds, batch_size=512, shuffle=False, num_workers=4,
243
+ pin_memory=True, persistent_workers=True)
244
+ print(f"[DATA] CIFAR-100: {len(train_ds)} train, {len(val_ds)} val")
245
+
246
+ model_svd_test = ConvSVDTest(num_classes=100, svd_rank=32)
247
+ print(f"[MODEL] ConvSVDTest: {model_svd_test.n_params:,} params")
248
+ print(f" SVD feature dim per tap: {2*32+2} = 66")
249
+ print(f" Total SVD features: {66*4} = 264")
250
+ print(f" Conv features: 384")
251
+ print(f" Classifier input: {264+384} = 648")
252
+
253
+ result_svd = train_svd_test(model_svd_test, train_loader, val_loader, device, epochs=100)
254
+ result_svd