AbstractPhil commited on
Commit
3bc2d4b
·
verified ·
1 Parent(s): 0c59bd3

Create 1_1_constellation_adapted_kymatio_projected.py

Browse files
spectral/experiment_1/1_1_constellation_adapted_kymatio_projected.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GeoLIP Scattering Constellation — Autopsy-Informed Prototype
4
+ ================================================================
5
+ kymatio scattering (frozen, zero params)
6
+ → BatchNorm2d(243) — 15x dimensionality expansion (dim_90: 31→463)
7
+ → FLATTEN to 15552-d (NEVER avg pool — destroys spatial structure)
8
+ → Learned projection 15552 → 512-d (captures full dim_90=463 effective space)
9
+ → L2 normalize → S^511
10
+ → Constellation (64 anchors on S^511)
11
+ → Patchwork (8×64 = 512-d)
12
+ → Classifier (patchwork + embedding → 10 classes)
13
+
14
+ Autopsy findings applied:
15
+ - ImageNet normalization (not CIFAR stats)
16
+ - BN variance ratios: o0/o1=136x, o0/o2=27x (deterministic constants)
17
+ - BN expands eff_dim 128.8→946, dim_90 31→463
18
+ - BN pushes CV from 0.29→0.24 (toward 0.20 attractor)
19
+ - Orders are independent subspaces (Procrustes o0↔o1=0.15)
20
+ - Class separation comes from classifier, not encoder (BN: 0.66→0.64)
21
+ - Augmentation stability: cos=0.574 (InfoNCE has signal)
22
+
23
+ Losses: CE + InfoNCE + attract + CV + spread
24
+ Optimizer: SGD lr=0.05, momentum=0.9, wd=5e-4, 5x decay every 20 epochs
25
+ """
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ import math
31
+ import os, time
32
+ from tqdm import tqdm
33
+ from kymatio.torch import Scattering2D
34
+ from torchvision import datasets, transforms
35
+
36
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
37
+ torch.backends.cuda.matmul.allow_tf32 = True
38
+ torch.backends.cudnn.allow_tf32 = True
39
+
40
+
41
+ # ══════════════════════════════════════════════════════════════════
42
+ # ACTIVATION — SquaredReLU proven superior for geometric paths
43
+ # ══════════════════════════════════════════════════════════════════
44
+
45
+ class SquaredReLU(nn.Module):
46
+ def forward(self, x):
47
+ return F.relu(x) ** 2
48
+
49
+
50
+ # ══════════════════════════════════════════════════════════════════
51
+ # UNIFORM HYPERSPHERE INIT
52
+ # ══════════════════════════════════════════════════════════════════
53
+
54
+ def uniform_hypersphere_init(n, d):
55
+ if n <= d:
56
+ M = torch.randn(d, n)
57
+ Q, _ = torch.linalg.qr(M)
58
+ return Q.T.contiguous()
59
+ else:
60
+ M = torch.randn(d, d)
61
+ Q, _ = torch.linalg.qr(M)
62
+ basis = Q.T
63
+ extra = F.normalize(torch.randn(n - d, d), dim=-1)
64
+ vecs = torch.cat([basis, extra], dim=0)
65
+ for _ in range(200):
66
+ sim = vecs @ vecs.T
67
+ sim.fill_diagonal_(-2.0)
68
+ nn_idx = sim.argmax(dim=1)
69
+ vecs = F.normalize(vecs - 0.05 * vecs[nn_idx], dim=-1)
70
+ return vecs
71
+
72
+
73
+ # ══════════════════════════════════════════════════════════════════
74
+ # CONSTELLATION + PATCHWORK (proven)
75
+ # ══════════════════════════════════════════════════════════════════
76
+
77
+ class Constellation(nn.Module):
78
+ def __init__(self, n_anchors, dim, anchor_drop=0.0):
79
+ super().__init__()
80
+ self.n_anchors = n_anchors
81
+ self.anchors = nn.Parameter(uniform_hypersphere_init(n_anchors, dim))
82
+ self.anchor_drop = anchor_drop
83
+
84
+ def triangulate(self, emb, training=False):
85
+ anchors = F.normalize(self.anchors, dim=-1)
86
+ if training and self.anchor_drop > 0:
87
+ mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop
88
+ if mask.sum() < 2:
89
+ mask[:2] = True
90
+ anchors = anchors[mask]
91
+ cos = emb @ anchors.T
92
+ tri = 1.0 - cos
93
+ _, nearest_local = cos.max(dim=-1)
94
+ nearest = mask.nonzero(as_tuple=True)[0][nearest_local]
95
+ else:
96
+ cos = emb @ anchors.T
97
+ tri = 1.0 - cos
98
+ _, nearest = cos.max(dim=-1)
99
+ return tri, nearest
100
+
101
+
102
+ class Patchwork(nn.Module):
103
+ """Compartmentalized patchwork — interleaved anchor assignment."""
104
+ def __init__(self, n_anchors, n_comp, d_comp):
105
+ super().__init__()
106
+ self.n_comp = n_comp
107
+ self.register_buffer('asgn', torch.arange(n_anchors) % n_comp)
108
+ anchors_per = n_anchors // n_comp
109
+ self.comps = nn.ModuleList([nn.Sequential(
110
+ nn.Linear(anchors_per, d_comp * 2), SquaredReLU(),
111
+ nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp))
112
+ for _ in range(n_comp)])
113
+
114
+ def forward(self, tri):
115
+ return torch.cat([self.comps[k](tri[:, self.asgn == k])
116
+ for k in range(self.n_comp)], -1)
117
+
118
+
119
+ # ══════════════════════════════════════════════════════════════════
120
+ # GEOLIP SCATTERING CONSTELLATION
121
+ # ══════════════════════════════════════════════════════════════════
122
+
123
+ class GeoLIPScatteringConstellation(nn.Module):
124
+ def __init__(
125
+ self,
126
+ num_classes=10,
127
+ proj_dim=512,
128
+ n_anchors=64,
129
+ n_comp=8,
130
+ d_comp=64,
131
+ anchor_drop=0.15,
132
+ cv_target=0.22,
133
+ infonce_temp=0.07,
134
+ ):
135
+ super().__init__()
136
+ self.num_classes = num_classes
137
+ self.proj_dim = proj_dim
138
+ self.cv_target = cv_target
139
+ self.infonce_temp = infonce_temp
140
+
141
+ self.config = {k: v for k, v in locals().items()
142
+ if k != 'self' and not k.startswith('_')}
143
+
144
+ # Stage 1: kymatio scattering (frozen, zero params) — built externally
145
+ # Output: (B, 243, 8, 8)
146
+
147
+ # Stage 2: BatchNorm on scattering output
148
+ # Autopsy: expands eff_dim 128.8→946, dim_90 31→463
149
+ # Equalizes order 0/1/2 variance ratios (136x, 27x)
150
+ self.bn = nn.BatchNorm2d(243)
151
+
152
+ # Stage 3: Flatten → learned projection → S^(proj_dim-1)
153
+ # FLATTEN not avg pool (15552-d preserves spatial structure)
154
+ self.proj = nn.Sequential(
155
+ nn.Linear(15552, proj_dim * 2),
156
+ SquaredReLU(),
157
+ nn.LayerNorm(proj_dim * 2),
158
+ nn.Linear(proj_dim * 2, proj_dim),
159
+ nn.LayerNorm(proj_dim),
160
+ )
161
+
162
+ # Stage 4: Constellation on S^(proj_dim-1)
163
+ self.constellation = Constellation(n_anchors, proj_dim, anchor_drop)
164
+
165
+ # Stage 5: Patchwork
166
+ self.patchwork = Patchwork(n_anchors, n_comp, d_comp)
167
+ pw_dim = n_comp * d_comp
168
+
169
+ # Classifier reads patchwork + projected embedding
170
+ self.classifier = nn.Sequential(
171
+ nn.Linear(pw_dim + proj_dim, pw_dim), SquaredReLU(),
172
+ nn.LayerNorm(pw_dim), nn.Dropout(0.1),
173
+ nn.Linear(pw_dim, num_classes))
174
+
175
+ self._init_weights()
176
+
177
+ def _init_weights(self):
178
+ for m in self.modules():
179
+ if isinstance(m, nn.Linear):
180
+ nn.init.trunc_normal_(m.weight, std=0.02)
181
+ if m.bias is not None:
182
+ nn.init.zeros_(m.bias)
183
+ elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)):
184
+ nn.init.ones_(m.weight)
185
+ nn.init.zeros_(m.bias)
186
+
187
+ def forward(self, scat_features):
188
+ """scat_features: (B, 243, 8, 8) from kymatio scattering."""
189
+ B = scat_features.shape[0]
190
+
191
+ # BN equalizes multi-scale features
192
+ x = self.bn(scat_features)
193
+
194
+ # FLATTEN — never avg pool
195
+ x = x.flatten(1) # (B, 15552)
196
+
197
+ # Learned projection → sphere
198
+ feat = self.proj(x)
199
+ emb = F.normalize(feat, dim=-1) # → S^(proj_dim-1)
200
+
201
+ # Constellation triangulation
202
+ tri, nearest = self.constellation.triangulate(emb, training=False)
203
+ pw = self.patchwork(tri)
204
+
205
+ if self.training:
206
+ _, nearest = self.constellation.triangulate(emb, training=True)
207
+
208
+ logits = self.classifier(torch.cat([pw, emb], dim=-1))
209
+
210
+ return {
211
+ 'logits': logits,
212
+ 'embedding': emb,
213
+ 'triangulation': tri,
214
+ 'nearest': nearest,
215
+ }
216
+
217
+ def compute_loss(self, output, targets, output_aug=None):
218
+ ld = {}
219
+ emb = output['embedding']
220
+ B = emb.shape[0]
221
+
222
+ # CE
223
+ l_ce = F.cross_entropy(output['logits'], targets)
224
+ ld['ce'] = l_ce
225
+ ld['acc'] = (output['logits'].argmax(-1) == targets).float().mean().item()
226
+
227
+ # InfoNCE between two augmented views
228
+ if output_aug is not None:
229
+ emb_aug = output_aug['embedding']
230
+ labels_nce = torch.arange(B, device=emb.device)
231
+ sim = emb @ emb_aug.T / self.infonce_temp
232
+ l_nce = F.cross_entropy(sim, labels_nce)
233
+ nce_acc = (sim.argmax(1) == labels_nce).float().mean().item()
234
+ ld['nce'] = l_nce
235
+ ld['nce_acc'] = nce_acc
236
+
237
+ # Anchor attraction
238
+ anchors_n = F.normalize(self.constellation.anchors, dim=-1)
239
+ cos_to_anchors = emb @ anchors_n.T
240
+ nearest_cos = cos_to_anchors.max(dim=1).values
241
+ l_attract = (1.0 - nearest_cos).mean()
242
+ ld['attract'] = l_attract
243
+ ld['nearest_cos'] = nearest_cos.mean().item()
244
+
245
+ # CV
246
+ l_cv = self._cv_loss(emb)
247
+ ld['cv'] = l_cv
248
+
249
+ # Anchor spread
250
+ sim_a = anchors_n @ anchors_n.T
251
+ mask_a = ~torch.eye(anchors_n.shape[0], dtype=torch.bool, device=emb.device)
252
+ l_spread = F.relu(sim_a[mask_a]).mean()
253
+ ld['spread'] = l_spread
254
+
255
+ loss = (l_ce
256
+ + ld.get('nce', 0.0) * 1.0
257
+ + l_attract * 0.5
258
+ + l_cv * 0.01
259
+ + l_spread * 0.001)
260
+ ld['total'] = loss
261
+ return loss, ld
262
+
263
+ @torch.no_grad()
264
+ def push_anchors_to_centroids(self, emb_buffer, label_buffer, lr=0.1):
265
+ anchors = self.constellation.anchors.data
266
+ n_a = anchors.shape[0]
267
+ emb_n = F.normalize(emb_buffer, dim=-1)
268
+ device = anchors.device
269
+
270
+ classes = label_buffer.unique()
271
+ n_cls = classes.shape[0]
272
+ centroids = []
273
+ for c in classes:
274
+ mask = label_buffer == c
275
+ if mask.sum() > 0:
276
+ centroids.append(F.normalize(emb_n[mask].mean(0, keepdim=True), dim=-1))
277
+ if len(centroids) == 0:
278
+ return 0
279
+ centroids = torch.cat(centroids, dim=0)
280
+
281
+ anchors_n = F.normalize(anchors, dim=-1)
282
+ cos = anchors_n @ centroids.T
283
+ apc = n_a // n_cls
284
+ assigned = torch.full((n_a,), -1, dtype=torch.long, device=device)
285
+ cls_count = torch.zeros(n_cls, dtype=torch.long, device=device)
286
+
287
+ _, flat_idx = cos.flatten().sort(descending=True)
288
+ for idx in flat_idx:
289
+ a = (idx // n_cls).item()
290
+ c = (idx % n_cls).item()
291
+ if assigned[a] >= 0:
292
+ continue
293
+ if cls_count[c] >= apc + 1:
294
+ continue
295
+ assigned[a] = c
296
+ cls_count[c] += 1
297
+ if (assigned >= 0).all():
298
+ break
299
+
300
+ unassigned = (assigned < 0).nonzero(as_tuple=True)[0]
301
+ if len(unassigned) > 0:
302
+ assigned[unassigned] = (anchors_n[unassigned] @ centroids.T).argmax(dim=1)
303
+
304
+ moved = 0
305
+ for a in range(n_a):
306
+ c = assigned[a].item()
307
+ target = centroids[c]
308
+ rank = (assigned[:a] == c).sum().item()
309
+ if apc > 1 and rank > 0:
310
+ noise = torch.randn_like(target) * 0.05
311
+ noise = noise - (noise * target).sum() * target
312
+ target = F.normalize((target + noise).unsqueeze(0), dim=-1).squeeze(0)
313
+ anchors[a] = F.normalize(
314
+ (anchors_n[a] + lr * (target - anchors_n[a])).unsqueeze(0),
315
+ dim=-1).squeeze(0)
316
+ moved += 1
317
+ return moved
318
+
319
+ def _cv_loss(self, emb, n_samples=64, n_points=5):
320
+ B = emb.shape[0]
321
+ if B < n_points:
322
+ return torch.tensor(0.0, device=emb.device)
323
+ vols = []
324
+ for _ in range(n_samples):
325
+ idx = torch.randperm(min(B, 512), device=emb.device)[:n_points]
326
+ pts = emb[idx].unsqueeze(0)
327
+ gram = torch.bmm(pts, pts.transpose(1, 2))
328
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
329
+ d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
330
+ d2 = F.relu(d2)
331
+ N = n_points
332
+ cm = torch.zeros(1, N + 1, N + 1, device=emb.device, dtype=emb.dtype)
333
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
334
+ k = N - 1
335
+ pf = ((-1.0) ** (k + 1)) / ((2.0 ** k) * (math.factorial(k) ** 2))
336
+ v2 = pf * torch.linalg.det(cm.float())
337
+ if v2[0].item() > 1e-20:
338
+ vols.append(v2[0].to(emb.dtype).sqrt())
339
+ if len(vols) < 5:
340
+ return torch.tensor(0.0, device=emb.device)
341
+ vt = torch.stack(vols)
342
+ cv = vt.std() / (vt.mean() + 1e-8)
343
+ return (cv - self.cv_target).pow(2)
344
+
345
+
346
+ # ══════════════════════════════════════════════════════════════════
347
+ # DATA — ImageNet normalization (kymatio standard)
348
+ # ══════════════════════════════════════════════════════════════════
349
+
350
+ NORMALIZE = transforms.Normalize(mean=[0.485, 0.456, 0.406],
351
+ std=[0.229, 0.224, 0.225])
352
+
353
+
354
+ class TwoViewDataset(torch.utils.data.Dataset):
355
+ def __init__(self, base_ds, transform):
356
+ self.base = base_ds
357
+ self.transform = transform
358
+ def __len__(self):
359
+ return len(self.base)
360
+ def __getitem__(self, i):
361
+ img, label = self.base[i]
362
+ return self.transform(img), self.transform(img), label
363
+
364
+
365
+ # ══════════════════════════════════════════════════════════════════
366
+ # TRAINING
367
+ # ══════════════════════════════════════════════════════════════════
368
+
369
+ NUM_CLASSES = 10
370
+ PROJ_DIM = 512
371
+ N_ANCHORS = 64
372
+ N_COMP = 8
373
+ D_COMP = 64
374
+ BATCH = 128
375
+ EPOCHS = 90
376
+ K = 81 * 3 # 243 scattering channels
377
+
378
+ print("=" * 60)
379
+ print("GeoLIP Scattering Constellation — Autopsy-Informed")
380
+ print(f" Scattering: kymatio J=2, L=8, order 2 → (B, 243, 8, 8)")
381
+ print(f" BN(243) → FLATTEN(15552) → proj(512) → S^511")
382
+ print(f" Constellation: {N_ANCHORS} anchors on S^511")
383
+ print(f" Patchwork: {N_COMP}×{D_COMP} = {N_COMP*D_COMP}d")
384
+ print(f" Activation: SquaredReLU")
385
+ print(f" Loss: CE + InfoNCE + attract + CV(0.22) + spread")
386
+ print(f" Optimizer: SGD lr=0.05, momentum=0.9, wd=5e-4")
387
+ print(f" Batch: {BATCH}, Epochs: {EPOCHS}")
388
+ print(f" Device: {DEVICE}")
389
+ print("=" * 60)
390
+
391
+ aug_transform = transforms.Compose([
392
+ transforms.RandomHorizontalFlip(),
393
+ transforms.RandomCrop(32, 4),
394
+ transforms.ToTensor(),
395
+ NORMALIZE,
396
+ ])
397
+ val_transform = transforms.Compose([
398
+ transforms.ToTensor(),
399
+ NORMALIZE,
400
+ ])
401
+
402
+ raw_train = datasets.CIFAR10(root='./data', train=True, download=True)
403
+ train_ds = TwoViewDataset(raw_train, aug_transform)
404
+ val_ds = datasets.CIFAR10(root='./data', train=False,
405
+ download=True, transform=val_transform)
406
+
407
+ train_loader = torch.utils.data.DataLoader(
408
+ train_ds, batch_size=BATCH, shuffle=True,
409
+ num_workers=4, pin_memory=True, drop_last=True)
410
+ val_loader = torch.utils.data.DataLoader(
411
+ val_ds, batch_size=BATCH, shuffle=False,
412
+ num_workers=4, pin_memory=True)
413
+
414
+ print(f" Train: {len(train_ds):,} Val: {len(val_ds):,}")
415
+
416
+ # Scattering (frozen)
417
+ scat = Scattering2D(J=2, shape=(32, 32)).to(DEVICE)
418
+
419
+ # Check output format
420
+ with torch.no_grad():
421
+ _d = torch.randn(2, 3, 32, 32, device=DEVICE)
422
+ _o = scat(_d)
423
+ USE_5D = (_o.dim() == 5)
424
+ if USE_5D:
425
+ _o = _o.reshape(_o.shape[0], -1, _o.shape[-2], _o.shape[-1])
426
+ print(f" Scattering output: {_o.shape} (5D={USE_5D})")
427
+ del _d, _o
428
+
429
+ def get_scat(imgs):
430
+ o = scat(imgs)
431
+ if USE_5D:
432
+ o = o.reshape(o.shape[0], -1, o.shape[-2], o.shape[-1])
433
+ return o
434
+
435
+ # Model
436
+ model = GeoLIPScatteringConstellation(
437
+ num_classes=NUM_CLASSES, proj_dim=PROJ_DIM,
438
+ n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP,
439
+ ).to(DEVICE)
440
+
441
+ n_total = sum(p.numel() for p in model.parameters())
442
+ n_proj = sum(p.numel() for p in model.proj.parameters())
443
+ n_bn = sum(p.numel() for p in model.bn.parameters())
444
+ print(f" Total params: {n_total:,}")
445
+ print(f" BN: {n_bn:,}")
446
+ print(f" Projection: {n_proj:,}")
447
+ print(f" Constellation+PW+Clf: {n_total - n_proj - n_bn:,}")
448
+
449
+ # SGD with step decay (kymatio proven recipe)
450
+ lr = 0.05
451
+ best_acc = 0.0
452
+ gs = 0
453
+ os.makedirs("checkpoints", exist_ok=True)
454
+
455
+ PUSH_INTERVAL = 50
456
+ PUSH_LR = 0.1
457
+ PUSH_BUFFER_SIZE = 5000
458
+ emb_buffer = None
459
+ lbl_buffer = None
460
+ push_count = 0
461
+
462
+ print(f"\n{'='*60}")
463
+ print(f"TRAINING — {EPOCHS} epochs")
464
+ print(f" SGD lr={lr}, step decay 5x every 20 epochs")
465
+ print(f" Anchor push: every {PUSH_INTERVAL} batches, lr={PUSH_LR}")
466
+ print(f"{'='*60}")
467
+
468
+ for epoch in range(EPOCHS):
469
+ # Step decay
470
+ if epoch % 20 == 0:
471
+ optimizer = torch.optim.SGD(model.parameters(), lr=lr,
472
+ momentum=0.9, weight_decay=0.0005)
473
+ lr *= 0.2
474
+
475
+ model.train()
476
+ t0 = time.time()
477
+ tot_loss, tot_nce_acc, tot_nearest_cos, n = 0, 0, 0, 0
478
+ correct, total = 0, 0
479
+
480
+ pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b")
481
+ for v1, v2, targets in pbar:
482
+ v1 = v1.to(DEVICE, non_blocking=True)
483
+ v2 = v2.to(DEVICE, non_blocking=True)
484
+ targets = targets.to(DEVICE, non_blocking=True)
485
+
486
+ with torch.no_grad():
487
+ s1 = get_scat(v1)
488
+ s2 = get_scat(v2)
489
+
490
+ out1 = model(s1)
491
+ out2 = model(s2)
492
+ loss, ld = model.compute_loss(out1, targets, output_aug=out2)
493
+
494
+ optimizer.zero_grad()
495
+ loss.backward()
496
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
497
+ optimizer.step()
498
+ gs += 1
499
+
500
+ # Embedding buffer for anchor push
501
+ with torch.no_grad():
502
+ batch_emb = out1['embedding'].detach().float()
503
+ if emb_buffer is None:
504
+ emb_buffer = batch_emb
505
+ lbl_buffer = targets.detach()
506
+ else:
507
+ emb_buffer = torch.cat([emb_buffer, batch_emb])[-PUSH_BUFFER_SIZE:]
508
+ lbl_buffer = torch.cat([lbl_buffer, targets.detach()])[-PUSH_BUFFER_SIZE:]
509
+
510
+ if gs % PUSH_INTERVAL == 0 and emb_buffer is not None and emb_buffer.shape[0] > 500:
511
+ moved = model.push_anchors_to_centroids(emb_buffer, lbl_buffer, lr=PUSH_LR)
512
+ push_count += 1
513
+
514
+ preds = out1['logits'].argmax(-1)
515
+ correct += (preds == targets).sum().item()
516
+ total += targets.shape[0]
517
+ tot_loss += loss.item()
518
+ tot_nce_acc += ld.get('nce_acc', 0)
519
+ tot_nearest_cos += ld.get('nearest_cos', 0)
520
+ n += 1
521
+
522
+ if n % 10 == 0:
523
+ with torch.no_grad():
524
+ _an = F.normalize(model.constellation.anchors, dim=-1)
525
+ _cos = out1['embedding'].detach() @ _an.T
526
+ _act = _cos.argmax(-1).unique().numel()
527
+ pbar.set_postfix(
528
+ loss=f"{tot_loss/n:.4f}",
529
+ acc=f"{100*correct/total:.0f}%",
530
+ nce=f"{tot_nce_acc/n:.2f}",
531
+ cos=f"{ld.get('nearest_cos', 0):.3f}",
532
+ anch=f"{_act}/{N_ANCHORS}",
533
+ push=push_count,
534
+ ordered=True)
535
+
536
+ elapsed = time.time() - t0
537
+ train_acc = 100 * correct / total
538
+
539
+ # Val
540
+ model.eval()
541
+ vc, vt_n = 0, 0
542
+ all_embs = []
543
+ with torch.no_grad():
544
+ for imgs, lbls in val_loader:
545
+ imgs = imgs.to(DEVICE)
546
+ lbls = lbls.to(DEVICE)
547
+ out = model(get_scat(imgs))
548
+ vc += (out['logits'].argmax(-1) == lbls).sum().item()
549
+ vt_n += lbls.shape[0]
550
+ all_embs.append(out['embedding'].float().cpu())
551
+
552
+ val_acc = 100 * vc / vt_n
553
+
554
+ # CV measurement
555
+ embs = torch.cat(all_embs)[:2000].to(DEVICE)
556
+ with torch.no_grad():
557
+ vols = []
558
+ for _ in range(200):
559
+ idx = torch.randperm(2000)[:5]
560
+ pts = embs[idx].unsqueeze(0).float()
561
+ gram = torch.bmm(pts, pts.transpose(1, 2))
562
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
563
+ d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
564
+ d2 = F.relu(d2)
565
+ cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32)
566
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
567
+ v2 = -torch.linalg.det(cm) / 9216
568
+ if v2[0].item() > 1e-20:
569
+ vols.append(v2[0].sqrt())
570
+ v_cv = (torch.stack(vols).std() / (torch.stack(vols).mean() + 1e-8)).item() if len(vols) > 10 else 0
571
+
572
+ # Active anchors
573
+ with torch.no_grad():
574
+ _, vnp = model.constellation.triangulate(embs)
575
+ n_active = vnp.cpu().unique().numel()
576
+
577
+ mk = ""
578
+ if val_acc > best_acc:
579
+ best_acc = val_acc
580
+ torch.save({
581
+ "state_dict": model.state_dict(),
582
+ "config": model.config,
583
+ "epoch": epoch + 1,
584
+ "val_acc": val_acc,
585
+ }, "checkpoints/geolip_scat_constellation_best.pt")
586
+ mk = " ★"
587
+
588
+ nce_m = tot_nce_acc / n
589
+ cos_m = tot_nearest_cos / n
590
+ cv_band = "✓" if 0.18 <= v_cv <= 0.25 else "✗"
591
+ print(f" E{epoch+1:3d}: train={train_acc:.1f}% val={val_acc:.1f}% "
592
+ f"loss={tot_loss/n:.4f} nce={nce_m:.2f} cos={cos_m:.3f} "
593
+ f"cv={v_cv:.4f}({cv_band}) anch={n_active}/{N_ANCHORS} "
594
+ f"push={push_count} ({elapsed:.0f}s){mk}")
595
+
596
+ print(f"\n Best val accuracy: {best_acc:.1f}%")
597
+ print(f" Total params: {n_total:,}")
598
+ print(f" Baseline (BN+linear): 70.8%")
599
+ print(f" Target: >70.8% (constellation must add value over linear)")
600
+ print(f"\n{'='*60}")
601
+ print("DONE")
602
+ print(f"{'='*60}")