AbstractPhil commited on
Commit
e255399
Β·
verified Β·
1 Parent(s): f469500

Update trainer_model.py

Browse files
Files changed (1) hide show
  1. trainer_model.py +125 -8
trainer_model.py CHANGED
@@ -223,25 +223,110 @@ class GeoLIPCore(nn.Module):
223
  ld['nce'] = l_nce
224
  ld['nce_acc'] = nce_acc
225
 
 
 
 
 
 
 
 
 
226
  # CV
227
  l_cv = self._cv_loss(emb)
228
  ld['cv'] = l_cv
229
 
230
  # Anchor spread
231
- an = F.normalize(self.constellation.anchors, dim=-1)
232
- sim_a = an @ an.T
233
- mask = ~torch.eye(an.shape[0], dtype=torch.bool, device=an.device)
234
  l_spread = F.relu(sim_a[mask]).mean()
235
  ld['spread'] = l_spread
236
 
237
  # Total
238
  loss = (l_ce
239
  + ld.get('nce', 0.0) * 1.0
 
240
  + l_cv * 0.01
241
  + l_spread * 0.001)
242
  ld['total'] = loss
243
  return loss, ld
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  def _cv_loss(self, emb, n_samples=64, n_points=5):
246
  B = emb.shape[0]
247
  if B < n_points: return torch.tensor(0.0, device=emb.device)
@@ -295,8 +380,8 @@ N_ANCHORS = 64
295
  N_COMP = 8
296
  D_COMP = 64
297
  BATCH = 256
298
- EPOCHS = 50
299
- LR = 3e-3
300
 
301
  print("=" * 60)
302
  print("GeoLIP Core β€” Conv + Constellation + Patchwork")
@@ -359,15 +444,24 @@ writer = SummaryWriter("runs/geolip_core")
359
  best_acc = 0.0
360
  gs = 0
361
 
 
 
 
 
 
 
 
 
362
  print(f"\n{'='*60}")
363
  print(f"TRAINING β€” {EPOCHS} epochs")
 
364
  print(f"{'='*60}")
365
 
366
  for epoch in range(EPOCHS):
367
  model.train()
368
  t0 = time.time()
369
  tot_loss, tot_ce, tot_nce, tot_cv = 0, 0, 0, 0
370
- tot_acc, tot_nce_acc, n = 0, 0, 0
371
  correct, total = 0, 0
372
 
373
  pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b")
@@ -389,11 +483,29 @@ for epoch in range(EPOCHS):
389
  scheduler.step()
390
  gs += 1
391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  preds = out1['logits'].argmax(-1)
393
  correct += (preds == targets).sum().item()
394
  total += targets.shape[0]
395
  tot_loss += loss.item()
396
  tot_nce_acc += ld.get('nce_acc', 0)
 
397
  n += 1
398
 
399
  if n % 10 == 0:
@@ -401,6 +513,8 @@ for epoch in range(EPOCHS):
401
  loss=f"{tot_loss/n:.4f}",
402
  acc=f"{100*correct/total:.0f}%",
403
  nce=f"{tot_nce_acc/n:.2f}",
 
 
404
  ordered=True)
405
 
406
  elapsed = time.time() - t0
@@ -449,6 +563,8 @@ for epoch in range(EPOCHS):
449
  writer.add_scalar("epoch/val_acc", val_acc, epoch+1)
450
  writer.add_scalar("epoch/val_cv", v_cv, epoch+1)
451
  writer.add_scalar("epoch/anchors", n_active, epoch+1)
 
 
452
 
453
  mk = ""
454
  if val_acc > best_acc:
@@ -462,11 +578,12 @@ for epoch in range(EPOCHS):
462
  mk = " β˜…"
463
 
464
  nce_m = tot_nce_acc / n
 
465
  cv_band = "βœ“" if 0.18 <= v_cv <= 0.25 else "βœ—"
466
  print(f" E{epoch+1:3d}: train={train_acc:.1f}% val={val_acc:.1f}% "
467
- f"loss={tot_loss/n:.4f} nce={nce_m:.2f} "
468
  f"cv={v_cv:.4f}({cv_band}) anch={n_active}/{N_ANCHORS} "
469
- f"({elapsed:.0f}s){mk}")
470
 
471
  writer.close()
472
  print(f"\n Best val accuracy: {best_acc:.1f}%")
 
223
  ld['nce'] = l_nce
224
  ld['nce_acc'] = nce_acc
225
 
226
+ # ── Anchor attraction: pull each embedding toward its nearest anchor ──
227
+ anchors_n = F.normalize(self.constellation.anchors, dim=-1)
228
+ cos_to_anchors = emb @ anchors_n.T # (B, n_anchors)
229
+ nearest_cos = cos_to_anchors.max(dim=1).values # (B,)
230
+ l_attract = (1.0 - nearest_cos).mean() # 0 when on top of anchor
231
+ ld['attract'] = l_attract
232
+ ld['nearest_cos'] = nearest_cos.mean().item()
233
+
234
  # CV
235
  l_cv = self._cv_loss(emb)
236
  ld['cv'] = l_cv
237
 
238
  # Anchor spread
239
+ sim_a = anchors_n @ anchors_n.T
240
+ mask = ~torch.eye(anchors_n.shape[0], dtype=torch.bool, device=anchors_n.device)
 
241
  l_spread = F.relu(sim_a[mask]).mean()
242
  ld['spread'] = l_spread
243
 
244
  # Total
245
  loss = (l_ce
246
  + ld.get('nce', 0.0) * 1.0
247
+ + l_attract * 0.5
248
  + l_cv * 0.01
249
  + l_spread * 0.001)
250
  ld['total'] = loss
251
  return loss, ld
252
 
253
+ @torch.no_grad()
254
+ def push_anchors_to_centroids(self, emb_buffer, label_buffer, lr=0.1):
255
+ """
256
+ Push anchors toward CLASS centroids, not nearest-anchor centroids.
257
+
258
+ Phase 1: Compute class centroids from labels
259
+ Phase 2: Each class owns (n_anchors / n_classes) anchors
260
+ Phase 3: Assigned anchors blend toward their class centroid
261
+ with small angular offsets so they don't all collapse
262
+
263
+ This works even when anchors start bunched at origin.
264
+ """
265
+ anchors = self.constellation.anchors.data # (A, D)
266
+ n_a = anchors.shape[0]
267
+ emb_n = F.normalize(emb_buffer, dim=-1)
268
+ device = anchors.device
269
+
270
+ # Phase 1: class centroids
271
+ classes = label_buffer.unique()
272
+ n_cls = classes.shape[0]
273
+ centroids = []
274
+ for c in classes:
275
+ mask = label_buffer == c
276
+ if mask.sum() > 0:
277
+ centroids.append(F.normalize(emb_n[mask].mean(0, keepdim=True), dim=-1))
278
+ if len(centroids) == 0:
279
+ return 0
280
+ centroids = torch.cat(centroids, dim=0) # (C, D)
281
+
282
+ # Phase 2: assign anchors to classes round-robin
283
+ # Sort anchors by cosine to each centroid, greedily assign
284
+ anchors_n = F.normalize(anchors, dim=-1)
285
+ cos = anchors_n @ centroids.T # (A, C)
286
+ anchors_per_class = n_a // n_cls
287
+ assigned_class = torch.full((n_a,), -1, dtype=torch.long, device=device)
288
+ class_count = torch.zeros(n_cls, dtype=torch.long, device=device)
289
+
290
+ # Greedy: for each anchor, assign to its best class if that class has room
291
+ _, flat_idx = cos.flatten().sort(descending=True)
292
+ for idx in flat_idx:
293
+ a = (idx // n_cls).item()
294
+ c = (idx % n_cls).item()
295
+ if assigned_class[a] >= 0:
296
+ continue
297
+ if class_count[c] >= anchors_per_class + 1: # +1 for remainder
298
+ continue
299
+ assigned_class[a] = c
300
+ class_count[c] += 1
301
+ if (assigned_class >= 0).all():
302
+ break
303
+
304
+ # Unassigned leftovers β†’ nearest centroid
305
+ unassigned = (assigned_class < 0).nonzero(as_tuple=True)[0]
306
+ if len(unassigned) > 0:
307
+ leftover_cos = anchors_n[unassigned] @ centroids.T
308
+ assigned_class[unassigned] = leftover_cos.argmax(dim=1)
309
+
310
+ # Phase 3: push each anchor toward its class centroid
311
+ moved = 0
312
+ for a in range(n_a):
313
+ c = assigned_class[a].item()
314
+ target = centroids[c]
315
+ # Add small angular offset so co-class anchors don't collapse
316
+ rank_in_class = (assigned_class[:a] == c).sum().item()
317
+ if anchors_per_class > 1 and rank_in_class > 0:
318
+ # Tiny perpendicular perturbation
319
+ noise = torch.randn_like(target) * 0.05
320
+ noise = noise - (noise * target).sum() * target # project out radial
321
+ target = F.normalize((target + noise).unsqueeze(0), dim=-1).squeeze(0)
322
+
323
+ anchors[a] = F.normalize(
324
+ (anchors_n[a] + lr * (target - anchors_n[a])).unsqueeze(0),
325
+ dim=-1).squeeze(0)
326
+ moved += 1
327
+
328
+ return moved
329
+
330
  def _cv_loss(self, emb, n_samples=64, n_points=5):
331
  B = emb.shape[0]
332
  if B < n_points: return torch.tensor(0.0, device=emb.device)
 
380
  N_COMP = 8
381
  D_COMP = 64
382
  BATCH = 256
383
+ EPOCHS = 100
384
+ LR = 3e-4
385
 
386
  print("=" * 60)
387
  print("GeoLIP Core β€” Conv + Constellation + Patchwork")
 
444
  best_acc = 0.0
445
  gs = 0
446
 
447
+ # Anchor push config
448
+ PUSH_INTERVAL = 50 # batches between centroid pushes
449
+ PUSH_LR = 0.1 # blend rate toward centroid
450
+ PUSH_BUFFER_SIZE = 5000
451
+ emb_buffer = None # (N, D) accumulated embeddings
452
+ lbl_buffer = None # (N,) accumulated labels
453
+ push_count = 0
454
+
455
  print(f"\n{'='*60}")
456
  print(f"TRAINING β€” {EPOCHS} epochs")
457
+ print(f" Anchor push: every {PUSH_INTERVAL} batches, lr={PUSH_LR}")
458
  print(f"{'='*60}")
459
 
460
  for epoch in range(EPOCHS):
461
  model.train()
462
  t0 = time.time()
463
  tot_loss, tot_ce, tot_nce, tot_cv = 0, 0, 0, 0
464
+ tot_acc, tot_nce_acc, tot_nearest_cos, n = 0, 0, 0, 0
465
  correct, total = 0, 0
466
 
467
  pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b")
 
483
  scheduler.step()
484
  gs += 1
485
 
486
+ # ── Accumulate embeddings for anchor push ──
487
+ with torch.no_grad():
488
+ batch_emb = out1['embedding'].detach().float()
489
+ if emb_buffer is None:
490
+ emb_buffer = batch_emb
491
+ lbl_buffer = targets.detach()
492
+ else:
493
+ emb_buffer = torch.cat([emb_buffer, batch_emb])[-PUSH_BUFFER_SIZE:]
494
+ lbl_buffer = torch.cat([lbl_buffer, targets.detach()])[-PUSH_BUFFER_SIZE:]
495
+
496
+ # ── Periodic anchor push toward class centroids ──
497
+ if gs % PUSH_INTERVAL == 0 and emb_buffer is not None and emb_buffer.shape[0] > 500:
498
+ moved = model.push_anchors_to_centroids(
499
+ emb_buffer, lbl_buffer, lr=PUSH_LR)
500
+ push_count += 1
501
+ writer.add_scalar("step/anchors_moved", moved, gs)
502
+
503
  preds = out1['logits'].argmax(-1)
504
  correct += (preds == targets).sum().item()
505
  total += targets.shape[0]
506
  tot_loss += loss.item()
507
  tot_nce_acc += ld.get('nce_acc', 0)
508
+ tot_nearest_cos += ld.get('nearest_cos', 0)
509
  n += 1
510
 
511
  if n % 10 == 0:
 
513
  loss=f"{tot_loss/n:.4f}",
514
  acc=f"{100*correct/total:.0f}%",
515
  nce=f"{tot_nce_acc/n:.2f}",
516
+ cos=f"{ld.get('nearest_cos', 0):.3f}",
517
+ push=push_count,
518
  ordered=True)
519
 
520
  elapsed = time.time() - t0
 
563
  writer.add_scalar("epoch/val_acc", val_acc, epoch+1)
564
  writer.add_scalar("epoch/val_cv", v_cv, epoch+1)
565
  writer.add_scalar("epoch/anchors", n_active, epoch+1)
566
+ writer.add_scalar("epoch/nearest_cos", tot_nearest_cos / n, epoch+1)
567
+ writer.add_scalar("epoch/push_count", push_count, epoch+1)
568
 
569
  mk = ""
570
  if val_acc > best_acc:
 
578
  mk = " β˜…"
579
 
580
  nce_m = tot_nce_acc / n
581
+ cos_m = tot_nearest_cos / n
582
  cv_band = "βœ“" if 0.18 <= v_cv <= 0.25 else "βœ—"
583
  print(f" E{epoch+1:3d}: train={train_acc:.1f}% val={val_acc:.1f}% "
584
+ f"loss={tot_loss/n:.4f} nce={nce_m:.2f} cos={cos_m:.3f} "
585
  f"cv={v_cv:.4f}({cv_band}) anch={n_active}/{N_ANCHORS} "
586
+ f"push={push_count} ({elapsed:.0f}s){mk}")
587
 
588
  writer.close()
589
  print(f"\n Best val accuracy: {best_acc:.1f}%")