AbstractPhil commited on
Commit
584edb8
Β·
verified Β·
1 Parent(s): a3197bd

Create trainer_model.py

Browse files
Files changed (1) hide show
  1. trainer_model.py +476 -0
trainer_model.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GeoLIP Core β€” Back to Basics
4
+ ==============================
5
+ Conv encoder β†’ sphere β†’ constellation β†’ patchwork β†’ classifier.
6
+ No streams. No GAL. No Procrustes. No mastery queue.
7
+ Just the geometric classification pipeline.
8
+
9
+ Two augmented views β†’ InfoNCE + CE + CV.
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import math
16
+ import os, time
17
+ import numpy as np
18
+ from itertools import combinations
19
+ from tqdm import tqdm
20
+ from torchvision import datasets, transforms
21
+ from torch.utils.tensorboard import SummaryWriter
22
+
23
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
+ torch.backends.cuda.matmul.allow_tf32 = True
25
+ torch.backends.cudnn.allow_tf32 = True
26
+
27
+
28
+ # ══════════════════════════════════════════════════════════════════
29
+ # UNIFORM HYPERSPHERE INIT
30
+ # ══════════════════════════════════════════════════════════════════
31
+
32
+ def uniform_hypersphere_init(n, d):
33
+ if n <= d:
34
+ M = torch.randn(d, n)
35
+ Q, _ = torch.linalg.qr(M)
36
+ return Q.T.contiguous()
37
+ else:
38
+ M = torch.randn(d, d)
39
+ Q, _ = torch.linalg.qr(M)
40
+ basis = Q.T
41
+ extra = F.normalize(torch.randn(n - d, d), dim=-1)
42
+ vecs = torch.cat([basis, extra], dim=0)
43
+ for _ in range(200):
44
+ sim = vecs @ vecs.T
45
+ sim.fill_diagonal_(-2.0)
46
+ nn_idx = sim.argmax(dim=1)
47
+ vecs = F.normalize(vecs - 0.05 * vecs[nn_idx], dim=-1)
48
+ return vecs
49
+
50
+
51
+ # ══════════════════════════════════════════════════════════════════
52
+ # CONSTELLATION + PATCHWORK
53
+ # ══════════════════════════════════════════════════════════════════
54
+
55
+ class Constellation(nn.Module):
56
+ def __init__(self, n_anchors, dim, anchor_drop=0.0):
57
+ super().__init__()
58
+ self.anchors = nn.Parameter(uniform_hypersphere_init(n_anchors, dim))
59
+ self.anchor_drop = anchor_drop
60
+
61
+ def triangulate(self, emb, training=False):
62
+ anchors = F.normalize(self.anchors, dim=-1)
63
+ if training and self.anchor_drop > 0:
64
+ mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop
65
+ if mask.sum() < 2: mask[:2] = True
66
+ anchors = anchors[mask]
67
+ cos = emb @ anchors.T
68
+ tri = 1.0 - cos
69
+ _, nearest_local = cos.max(dim=-1)
70
+ nearest = mask.nonzero(as_tuple=True)[0][nearest_local]
71
+ else:
72
+ cos = emb @ anchors.T
73
+ tri = 1.0 - cos
74
+ _, nearest = cos.max(dim=-1)
75
+ return tri, nearest
76
+
77
+
78
+ class Patchwork(nn.Module):
79
+ def __init__(self, n_anchors, n_comp, d_comp):
80
+ super().__init__()
81
+ self.n_comp = n_comp
82
+ self.register_buffer('asgn', torch.arange(n_anchors) % n_comp)
83
+ anchors_per = n_anchors // n_comp
84
+ self.comps = nn.ModuleList([nn.Sequential(
85
+ nn.Linear(anchors_per, d_comp * 2), nn.GELU(),
86
+ nn.Linear(d_comp * 2, d_comp), nn.LayerNorm(d_comp))
87
+ for _ in range(n_comp)])
88
+
89
+ def forward(self, tri):
90
+ return torch.cat([self.comps[k](tri[:, self.asgn == k])
91
+ for k in range(self.n_comp)], -1)
92
+
93
+
94
+ # ══════════════════════════════════════════════════════════════════
95
+ # CONV ENCODER
96
+ # ══════════════════════════════════════════════════════════════════
97
+
98
+ class ConvEncoder(nn.Module):
99
+ """
100
+ Simple conv backbone. No attention, no geometric layers.
101
+ Just feature extraction into a flat vector.
102
+ """
103
+ def __init__(self, output_dim=128):
104
+ super().__init__()
105
+ self.features = nn.Sequential(
106
+ # 32Γ—32 β†’ 16Γ—16
107
+ nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(),
108
+ nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(),
109
+ nn.MaxPool2d(2),
110
+
111
+ # 16Γ—16 β†’ 8Γ—8
112
+ nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(),
113
+ nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(),
114
+ nn.MaxPool2d(2),
115
+
116
+ # 8Γ—8 β†’ 4Γ—4
117
+ nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(),
118
+ nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(),
119
+ nn.MaxPool2d(2),
120
+
121
+ # 4Γ—4 β†’ global
122
+ nn.AdaptiveAvgPool2d(1),
123
+ nn.Flatten(),
124
+ )
125
+ self.proj = nn.Sequential(
126
+ nn.Linear(256, output_dim),
127
+ nn.LayerNorm(output_dim),
128
+ )
129
+
130
+ def forward(self, x):
131
+ return self.proj(self.features(x))
132
+
133
+
134
+ # ══════════════════════════════════════════════════════════════════
135
+ # GEOLIP CORE
136
+ # ══════════════════════════════════════════════════════════════════
137
+
138
+ class GeoLIPCore(nn.Module):
139
+ def __init__(
140
+ self,
141
+ num_classes=10,
142
+ output_dim=128,
143
+ n_anchors=64,
144
+ n_comp=8,
145
+ d_comp=64,
146
+ anchor_drop=0.15,
147
+ cv_target=0.22,
148
+ infonce_temp=0.07,
149
+ ):
150
+ super().__init__()
151
+ self.num_classes = num_classes
152
+ self.output_dim = output_dim
153
+ self.cv_target = cv_target
154
+ self.infonce_temp = infonce_temp
155
+
156
+ self.config = {k: v for k, v in locals().items()
157
+ if k != 'self' and not k.startswith('_')}
158
+
159
+ self.encoder = ConvEncoder(output_dim)
160
+ self.constellation = Constellation(n_anchors, output_dim, anchor_drop)
161
+ self.patchwork = Patchwork(n_anchors, n_comp, d_comp)
162
+ pw_dim = n_comp * d_comp
163
+
164
+ self.classifier = nn.Sequential(
165
+ nn.Linear(pw_dim + output_dim, pw_dim), nn.GELU(),
166
+ nn.LayerNorm(pw_dim), nn.Dropout(0.1),
167
+ nn.Linear(pw_dim, num_classes))
168
+
169
+ self._init_weights()
170
+
171
+ def _init_weights(self):
172
+ for m in self.modules():
173
+ if isinstance(m, nn.Linear):
174
+ nn.init.trunc_normal_(m.weight, std=0.02)
175
+ if m.bias is not None:
176
+ nn.init.zeros_(m.bias)
177
+ elif isinstance(m, nn.Conv2d):
178
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
179
+ if m.bias is not None:
180
+ nn.init.zeros_(m.bias)
181
+ elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)):
182
+ nn.init.ones_(m.weight)
183
+ nn.init.zeros_(m.bias)
184
+
185
+ def forward(self, x):
186
+ feat = self.encoder(x)
187
+ emb = F.normalize(feat, dim=-1)
188
+
189
+ # Full tri for patchwork (needs all anchor columns)
190
+ tri, nearest = self.constellation.triangulate(emb, training=False)
191
+ pw = self.patchwork(tri)
192
+
193
+ # Dropout version for nearest tracking only
194
+ if self.training:
195
+ _, nearest = self.constellation.triangulate(emb, training=True)
196
+
197
+ logits = self.classifier(torch.cat([pw, emb], dim=-1))
198
+
199
+ return {
200
+ 'logits': logits,
201
+ 'embedding': emb,
202
+ 'triangulation': tri,
203
+ 'nearest': nearest,
204
+ }
205
+
206
+ def compute_loss(self, output, targets, output_aug=None):
207
+ ld = {}
208
+ emb = output['embedding']
209
+ B = emb.shape[0]
210
+
211
+ # CE
212
+ l_ce = F.cross_entropy(output['logits'], targets)
213
+ ld['ce'] = l_ce
214
+ ld['acc'] = (output['logits'].argmax(-1) == targets).float().mean().item()
215
+
216
+ # InfoNCE
217
+ if output_aug is not None:
218
+ emb_aug = output_aug['embedding']
219
+ labels_nce = torch.arange(B, device=emb.device)
220
+ sim = emb @ emb_aug.T / self.infonce_temp
221
+ l_nce = F.cross_entropy(sim, labels_nce)
222
+ nce_acc = (sim.argmax(1) == labels_nce).float().mean().item()
223
+ ld['nce'] = l_nce
224
+ ld['nce_acc'] = nce_acc
225
+
226
+ # CV
227
+ l_cv = self._cv_loss(emb)
228
+ ld['cv'] = l_cv
229
+
230
+ # Anchor spread
231
+ an = F.normalize(self.constellation.anchors, dim=-1)
232
+ sim_a = an @ an.T
233
+ mask = ~torch.eye(an.shape[0], dtype=torch.bool, device=an.device)
234
+ l_spread = F.relu(sim_a[mask]).mean()
235
+ ld['spread'] = l_spread
236
+
237
+ # Total
238
+ loss = (l_ce
239
+ + ld.get('nce', 0.0) * 1.0
240
+ + l_cv * 0.01
241
+ + l_spread * 0.001)
242
+ ld['total'] = loss
243
+ return loss, ld
244
+
245
+ def _cv_loss(self, emb, n_samples=64, n_points=5):
246
+ B = emb.shape[0]
247
+ if B < n_points: return torch.tensor(0.0, device=emb.device)
248
+ vols = []
249
+ for _ in range(n_samples):
250
+ idx = torch.randperm(min(B, 512), device=emb.device)[:n_points]
251
+ pts = emb[idx].unsqueeze(0)
252
+ gram = torch.bmm(pts, pts.transpose(1, 2))
253
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
254
+ d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
255
+ d2 = F.relu(d2)
256
+ N = n_points
257
+ cm = torch.zeros(1, N+1, N+1, device=emb.device, dtype=emb.dtype)
258
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
259
+ k = N - 1
260
+ pf = ((-1.0)**(k+1)) / ((2.0**k) * (math.factorial(k)**2))
261
+ v2 = pf * torch.linalg.det(cm.float())
262
+ if v2[0].item() > 1e-20:
263
+ vols.append(v2[0].to(emb.dtype).sqrt())
264
+ if len(vols) < 5:
265
+ return torch.tensor(0.0, device=emb.device)
266
+ vt = torch.stack(vols)
267
+ cv = vt.std() / (vt.mean() + 1e-8)
268
+ return (cv - self.cv_target).pow(2)
269
+
270
+
271
+ # ══════════════════════════════════════════════════════════════════
272
+ # DATA
273
+ # ══════════════════════════════════════════════════════════════════
274
+
275
+ CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
276
+ CIFAR_STD = (0.2470, 0.2435, 0.2616)
277
+
278
+ class TwoViewDataset(torch.utils.data.Dataset):
279
+ def __init__(self, base_ds, transform):
280
+ self.base = base_ds; self.transform = transform
281
+ def __len__(self): return len(self.base)
282
+ def __getitem__(self, i):
283
+ img, label = self.base[i]
284
+ return self.transform(img), self.transform(img), label
285
+
286
+
287
+ # ══════════════════════════════════════════════════════════════════
288
+ # TRAINING
289
+ # ══════════════════════════════════════════════════════════════════
290
+
291
+ # Config
292
+ NUM_CLASSES = 10
293
+ OUTPUT_DIM = 128
294
+ N_ANCHORS = 64
295
+ N_COMP = 8
296
+ D_COMP = 64
297
+ BATCH = 256
298
+ EPOCHS = 50
299
+ LR = 3e-3
300
+
301
+ print("=" * 60)
302
+ print("GeoLIP Core β€” Conv + Constellation + Patchwork")
303
+ print(f" Encoder: 6-layer conv β†’ {OUTPUT_DIM}-d sphere")
304
+ print(f" Constellation: {N_ANCHORS} anchors, {N_COMP}Γ—{D_COMP} patchwork")
305
+ print(f" Loss: CE + InfoNCE + CV(0.22)")
306
+ print(f" Batch: {BATCH}, LR: {LR}, Epochs: {EPOCHS}")
307
+ print(f" Device: {DEVICE}")
308
+ print("=" * 60)
309
+
310
+ aug_transform = transforms.Compose([
311
+ transforms.RandomCrop(32, padding=4),
312
+ transforms.RandomHorizontalFlip(),
313
+ transforms.ColorJitter(0.2, 0.2, 0.2, 0.05),
314
+ transforms.ToTensor(),
315
+ transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
316
+ ])
317
+ val_transform = transforms.Compose([
318
+ transforms.ToTensor(),
319
+ transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
320
+ ])
321
+
322
+ raw_train = datasets.CIFAR10(root='./data', train=True, download=True)
323
+ train_ds = TwoViewDataset(raw_train, aug_transform)
324
+ val_ds = datasets.CIFAR10(root='./data', train=False,
325
+ download=True, transform=val_transform)
326
+
327
+ train_loader = torch.utils.data.DataLoader(
328
+ train_ds, batch_size=BATCH, shuffle=True,
329
+ num_workers=8, pin_memory=True, drop_last=True)
330
+ val_loader = torch.utils.data.DataLoader(
331
+ val_ds, batch_size=BATCH, shuffle=False,
332
+ num_workers=2, pin_memory=True)
333
+
334
+ print(f" Train: {len(train_ds):,} Val: {len(val_ds):,}")
335
+
336
+ # Build
337
+ model = GeoLIPCore(
338
+ num_classes=NUM_CLASSES, output_dim=OUTPUT_DIM,
339
+ n_anchors=N_ANCHORS, n_comp=N_COMP, d_comp=D_COMP,
340
+ ).to(DEVICE)
341
+
342
+ n_params = sum(p.numel() for p in model.parameters())
343
+ print(f" Parameters: {n_params:,}")
344
+
345
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.05)
346
+ total_steps = len(train_loader) * EPOCHS
347
+ warmup_steps = len(train_loader) * 3
348
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
349
+ optimizer,
350
+ [torch.optim.lr_scheduler.LinearLR(
351
+ optimizer, start_factor=0.01, total_iters=warmup_steps),
352
+ torch.optim.lr_scheduler.CosineAnnealingLR(
353
+ optimizer, T_max=max(total_steps - warmup_steps, 1), eta_min=1e-6)],
354
+ milestones=[warmup_steps])
355
+
356
+ scaler = torch.amp.GradScaler("cuda")
357
+ os.makedirs("checkpoints", exist_ok=True)
358
+ writer = SummaryWriter("runs/geolip_core")
359
+ best_acc = 0.0
360
+ gs = 0
361
+
362
+ print(f"\n{'='*60}")
363
+ print(f"TRAINING β€” {EPOCHS} epochs")
364
+ print(f"{'='*60}")
365
+
366
+ for epoch in range(EPOCHS):
367
+ model.train()
368
+ t0 = time.time()
369
+ tot_loss, tot_ce, tot_nce, tot_cv = 0, 0, 0, 0
370
+ tot_acc, tot_nce_acc, n = 0, 0, 0
371
+ correct, total = 0, 0
372
+
373
+ pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b")
374
+ for v1, v2, targets in pbar:
375
+ v1 = v1.to(DEVICE, non_blocking=True)
376
+ v2 = v2.to(DEVICE, non_blocking=True)
377
+ targets = targets.to(DEVICE, non_blocking=True)
378
+
379
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
380
+ out1 = model(v1)
381
+ out2 = model(v2)
382
+ loss, ld = model.compute_loss(out1, targets, output_aug=out2)
383
+
384
+ optimizer.zero_grad(set_to_none=True)
385
+ scaler.scale(loss).backward()
386
+ scaler.unscale_(optimizer)
387
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
388
+ scaler.step(optimizer); scaler.update()
389
+ scheduler.step()
390
+ gs += 1
391
+
392
+ preds = out1['logits'].argmax(-1)
393
+ correct += (preds == targets).sum().item()
394
+ total += targets.shape[0]
395
+ tot_loss += loss.item()
396
+ tot_nce_acc += ld.get('nce_acc', 0)
397
+ n += 1
398
+
399
+ if n % 10 == 0:
400
+ pbar.set_postfix(
401
+ loss=f"{tot_loss/n:.4f}",
402
+ acc=f"{100*correct/total:.0f}%",
403
+ nce=f"{tot_nce_acc/n:.2f}",
404
+ ordered=True)
405
+
406
+ elapsed = time.time() - t0
407
+ train_acc = 100 * correct / total
408
+
409
+ # Val
410
+ model.eval()
411
+ vc, vt_n, vl = 0, 0, 0
412
+ all_embs = []
413
+ with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
414
+ for imgs, lbls in val_loader:
415
+ imgs = imgs.to(DEVICE)
416
+ lbls = lbls.to(DEVICE)
417
+ out = model(imgs)
418
+ vc += (out['logits'].argmax(-1) == lbls).sum().item()
419
+ vt_n += lbls.shape[0]
420
+ vl += F.cross_entropy(out['logits'], lbls).item()
421
+ all_embs.append(out['embedding'].float().cpu())
422
+
423
+ val_acc = 100 * vc / vt_n
424
+
425
+ # CV
426
+ embs = torch.cat(all_embs)[:2000].to(DEVICE)
427
+ with torch.no_grad():
428
+ vols = []
429
+ for _ in range(200):
430
+ idx = torch.randperm(2000)[:5]
431
+ pts = embs[idx].unsqueeze(0).float()
432
+ gram = torch.bmm(pts, pts.transpose(1, 2))
433
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
434
+ d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
435
+ d2 = F.relu(d2)
436
+ cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32)
437
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
438
+ v2 = -torch.linalg.det(cm) / 9216
439
+ if v2[0].item() > 1e-20:
440
+ vols.append(v2[0].sqrt())
441
+ v_cv = (torch.stack(vols).std() / (torch.stack(vols).mean() + 1e-8)).item() if len(vols) > 10 else 0
442
+
443
+ # Anchors
444
+ with torch.no_grad():
445
+ _, vnp = model.constellation.triangulate(embs, training=False)
446
+ n_active = vnp.cpu().unique().numel()
447
+
448
+ writer.add_scalar("epoch/train_acc", train_acc, epoch+1)
449
+ writer.add_scalar("epoch/val_acc", val_acc, epoch+1)
450
+ writer.add_scalar("epoch/val_cv", v_cv, epoch+1)
451
+ writer.add_scalar("epoch/anchors", n_active, epoch+1)
452
+
453
+ mk = ""
454
+ if val_acc > best_acc:
455
+ best_acc = val_acc
456
+ torch.save({
457
+ "state_dict": model.state_dict(),
458
+ "config": model.config,
459
+ "epoch": epoch + 1,
460
+ "val_acc": val_acc,
461
+ }, "checkpoints/geolip_core_best.pt")
462
+ mk = " β˜…"
463
+
464
+ nce_m = tot_nce_acc / n
465
+ cv_band = "βœ“" if 0.18 <= v_cv <= 0.25 else "βœ—"
466
+ print(f" E{epoch+1:3d}: train={train_acc:.1f}% val={val_acc:.1f}% "
467
+ f"loss={tot_loss/n:.4f} nce={nce_m:.2f} "
468
+ f"cv={v_cv:.4f}({cv_band}) anch={n_active}/{N_ANCHORS} "
469
+ f"({elapsed:.0f}s){mk}")
470
+
471
+ writer.close()
472
+ print(f"\n Best val accuracy: {best_acc:.1f}%")
473
+ print(f" Parameters: {n_params:,}")
474
+ print(f"\n{'='*60}")
475
+ print("DONE")
476
+ print(f"{'='*60}")