AbstractPhil commited on
Commit
3eebacd
Β·
verified Β·
1 Parent(s): 8d7f7cd

Update trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +62 -57
trainer.py CHANGED
@@ -1,10 +1,12 @@
1
  #!/usr/bin/env python3
2
  """
3
- CIFAR-10 β€” Tri-Stream GeoLIP ViT β€” Experiment 7
4
- ==================================================
5
- Stream A (CE), Stream B (BCE), GAL (geometric arbitration).
6
- GAL anchors updated via Procrustes every M batches.
7
- Whitened Procrustes toggleable.
 
 
8
  """
9
 
10
  import torch
@@ -29,8 +31,8 @@ STREAM_DIM = 192
29
  N_BLOCKS = 9
30
  N_HEADS = 8
31
  OUTPUT_DIM = 256
32
- N_ANCHORS = 128 # constellation (sphere space)
33
- N_GAL_ANCHORS = 64 # GAL anchors (stream space)
34
  N_COMP = 16
35
  D_COMP = 128
36
  ANCHOR_DROP = 0.10
@@ -47,11 +49,15 @@ BCE_WEIGHT = 1.0
47
  CM_WEIGHT = 0.1
48
  INFONCE_TEMP = 0.07
49
 
50
- # ── GAL ──
51
- GAL_UPDATE_INTERVAL = 50 # batches between Procrustes updates
52
- GAL_LR = 0.01 # step size for anchor rotation
 
 
 
 
53
  GAL_BUFFER_SIZE = 50000
54
- USE_WHITENED_PROCRUSTES = False # toggle for benchmarking
55
 
56
  # ── Mastery queue ──
57
  MASTERY_PATIENCE = 50
@@ -71,18 +77,17 @@ EPOCHS = 100
71
  LR = 3e-4
72
  WARMUP = 5
73
  GRAD_CLIP = 1.0
74
- V1_CKPT = ""
75
 
76
  print("=" * 60)
77
- print("CIFAR-10 β€” Tri-Stream GeoLIP ViT β€” EXP 7 (GAL)")
78
  print(f" Architecture: {N_BLOCKS}Γ— TriStreamBlock")
79
  print(f" Sphere: {OUTPUT_DIM}-d, {N_ANCHORS} anchors, {N_COMP}Γ—{D_COMP} pw")
80
  print(f" GAL: {N_GAL_ANCHORS} anchors, Procrustes every {GAL_UPDATE_INTERVAL} "
81
  f"batches (lr={GAL_LR}, whiten={USE_WHITENED_PROCRUSTES})")
82
- print(f" GAL buffer: {GAL_BUFFER_SIZE}")
83
- print(f" InfoNCE={INFONCE_WEIGHT} on emb+geo")
84
- print(f" CE(stream A) + BCE(stream B) + BCE(geo)")
85
- print(f" LS={LABEL_SMOOTHING}, CV={CV_WEIGHT}")
86
  print(f" Device: {DEVICE}")
87
  print("=" * 60)
88
 
@@ -145,12 +150,17 @@ model = create_tri_stream_vit(
145
  autograd_tang=AUTOGRAD_TANG, autograd_sep=AUTOGRAD_SEP,
146
  enable_autograd=ENABLE_AUTOGRAD,
147
  label_smoothing=LABEL_SMOOTHING,
 
 
148
  ).to(DEVICE)
149
 
150
  if V1_CKPT and os.path.exists(V1_CKPT):
151
  ckpt = torch.load(V1_CKPT, map_location="cpu", weights_only=False)
152
- model.load_state_dict(ckpt["state_dict"], strict=False)
153
- print(f" βœ“ Loaded weights: epoch {ckpt['epoch']}")
 
 
 
154
  else:
155
  print(f" Training from scratch")
156
 
@@ -165,9 +175,6 @@ print(f"\n{'='*60}")
165
  print(f"TRAINING β€” {EPOCHS} epochs, lr={LR}, batch={BATCH}")
166
  print(f" GAL Procrustes: every {GAL_UPDATE_INTERVAL} batches, "
167
  f"lr={GAL_LR}, whiten={USE_WHITENED_PROCRUSTES}")
168
- print(f" Mastery: patience={MASTERY_PATIENCE}, queue adaptive "
169
- f"[{MASTERY_MIN_SIZE}–{MASTERY_MAX_SIZE}]")
170
- print(f" Optimizer: Adam")
171
  print(f"{'='*60}")
172
 
173
  optimizer = torch.optim.Adam(model.parameters(), lr=LR)
@@ -184,7 +191,7 @@ scheduler = torch.optim.lr_scheduler.SequentialLR(
184
 
185
  scaler = torch.amp.GradScaler("cuda")
186
  os.makedirs("checkpoints", exist_ok=True)
187
- writer = SummaryWriter("runs/cifar10_tri_stream_v7_gal")
188
  best_acc = 0.0
189
  gs = 0
190
 
@@ -201,7 +208,7 @@ mastery = MasteryQueue(
201
  simplex_buf = SimplexBuffer(
202
  dim=STREAM_DIM, max_size=GAL_BUFFER_SIZE, device=DEVICE)
203
 
204
- gal_update_count = 0 # track Procrustes updates
205
 
206
  # ══════════════════════════════════════════════════════════════════
207
  # TRAINING LOOP
@@ -214,7 +221,9 @@ for epoch in range(EPOCHS):
214
  acc_dict = {
215
  "loss": 0, "ce": 0, "bce": 0, "geo_bce": 0,
216
  "acc_a": 0, "acc_b": 0, "geo_acc": 0,
217
- "nce": 0, "nce_acc": 0, "geo_nce": 0, "geo_nce_acc": 0,
 
 
218
  "cm": 0, "cm_valid": 0, "cv": 0, "cv_main": 0, "cv_geo": 0,
219
  "spread": 0, "mastery": 0, "hard_neg": 0, "hard_pos": 0,
220
  "correct": 0, "total": 0, "n": 0}
@@ -240,31 +249,28 @@ for epoch in range(EPOCHS):
240
  scaler.step(optimizer); scaler.update()
241
  scheduler.step()
242
 
243
- # ── Mastery activation check ──
244
  mastery.check_activation(ld.get('nce_acc', 0))
245
 
246
- # ── Accumulate geo features into simplex buffer ──
247
  pool_geo = out1.get('pool_geo')
248
  if pool_geo is not None:
249
  simplex_buf.push(pool_geo.float(), targets)
250
 
251
- # ── Periodic GAL Procrustes update ──
252
  gs += 1
253
- if gs % GAL_UPDATE_INTERVAL == 0 and simplex_buf.size > 1000:
254
  score = model.update_gal_anchors(
255
  simplex_buf, lr=GAL_LR, whiten=USE_WHITENED_PROCRUSTES)
256
  if score is not None:
257
  gal_update_count += 1
258
  writer.add_scalar("step/procrustes_score", score, gs)
259
 
260
- # ── Track metrics ──
261
  preds = out1['logits_a'].argmax(-1)
262
  correct = (preds == targets).sum().item()
263
  acc_dict["correct"] += correct
264
  acc_dict["total"] += targets.shape[0]
265
  acc_dict["loss"] += loss.item()
266
 
267
- for k in ["ce", "bce", "geo_bce", "nce", "geo_nce",
268
  "cm", "cv", "spread", "mastery"]:
269
  v = ld.get(k, 0)
270
  acc_dict[k] += v.item() if torch.is_tensor(v) else v
@@ -273,6 +279,7 @@ for epoch in range(EPOCHS):
273
  acc_dict["acc_b"] += ld.get("acc_b", 0)
274
  acc_dict["geo_acc"] += ld.get("geo_acc", 0)
275
  acc_dict["nce_acc"] += ld.get("nce_acc", 0)
 
276
  acc_dict["geo_nce_acc"] += ld.get("geo_nce_acc", 0)
277
  acc_dict["cm_valid"] += ld.get("cm_valid", 0)
278
  acc_dict["cv_main"] += ld.get("cv_main", 0)
@@ -285,22 +292,22 @@ for epoch in range(EPOCHS):
285
  d = acc_dict["n"]
286
  ta = 100 * acc_dict["correct"] / acc_dict["total"]
287
  ga = 100 * acc_dict["geo_acc"] / d
288
- mst = acc_dict["mastery"] / d
289
  stg = "M" if mastery.active else "S1"
290
  pbar.set_postfix(
291
  loss=f"{acc_dict['loss']/d:.4f}",
292
  a=f"{ta:.0f}%",
293
  ga=f"{ga:.0f}%",
294
- mst=f"{mst:.3f}",
295
  stg=stg,
296
  gal=gal_update_count,
297
  ordered=True)
298
 
299
- # Step-level TB
300
  if gs % 20 == 0:
301
  writer.add_scalar("step/loss", loss.item(), gs)
302
  writer.add_scalar("step/geo_acc", ld.get("geo_acc", 0), gs)
303
- # Log gate values
 
304
  gates_a = out1.get('gates_a', [])
305
  if gates_a:
306
  writer.add_scalar("step/gate_a_mean",
@@ -319,19 +326,18 @@ for epoch in range(EPOCHS):
319
  writer.add_scalar("epoch/acc_b", 100 * acc_dict["acc_b"] / d, epoch + 1)
320
  writer.add_scalar("epoch/geo_acc", 100 * acc_dict["geo_acc"] / d, epoch + 1)
321
  writer.add_scalar("epoch/nce_acc", acc_dict["nce_acc"] / d, epoch + 1)
 
322
  writer.add_scalar("epoch/geo_nce_acc", acc_dict["geo_nce_acc"] / d, epoch + 1)
323
  writer.add_scalar("epoch/cv_main", acc_dict["cv_main"] / d, epoch + 1)
324
  writer.add_scalar("epoch/cv_geo", acc_dict["cv_geo"] / d, epoch + 1)
325
  writer.add_scalar("epoch/cm_valid", acc_dict["cm_valid"] / d, epoch + 1)
326
- writer.add_scalar("epoch/margin", mastery.current_margin, epoch + 1)
327
- writer.add_scalar("epoch/queue_max", mastery._current_max, epoch + 1)
328
- writer.add_scalar("epoch/simplex_buf", simplex_buf.size, epoch + 1)
329
  writer.add_scalar("epoch/gal_updates", gal_update_count, epoch + 1)
330
 
331
  # ── Validation ──
332
  model.eval()
333
  val_correct, val_total, val_loss_sum, val_n = 0, 0, 0, 0
334
  val_geo_correct = 0
 
335
  all_embs = []
336
 
337
  with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
@@ -341,8 +347,8 @@ for epoch in range(EPOCHS):
341
  out = model(images, apply_autograd=False)
342
  preds = out['logits_a'].argmax(dim=-1)
343
  val_correct += (preds == labels_v).sum().item()
344
- geo_preds = out['geo_logits'].argmax(dim=-1)
345
- val_geo_correct += (geo_preds == labels_v).sum().item()
346
  val_total += labels_v.shape[0]
347
  loss_v = F.cross_entropy(out['logits_a'], labels_v)
348
  val_loss_sum += loss_v.item()
@@ -350,10 +356,11 @@ for epoch in range(EPOCHS):
350
  all_embs.append(out['embedding'].float().cpu())
351
 
352
  val_acc = 100 * val_correct / val_total
 
353
  val_geo_acc = 100 * val_geo_correct / val_total
354
  val_loss = val_loss_sum / max(val_n, 1)
355
 
356
- # CV on val embeddings
357
  embs = torch.cat(all_embs)
358
  with torch.no_grad():
359
  sample = embs[:2000].to(DEVICE)
@@ -372,18 +379,17 @@ for epoch in range(EPOCHS):
372
  vols.append(v2[0].sqrt())
373
  v_cv = (torch.stack(vols).std() / (torch.stack(vols).mean() + 1e-8)).item() if len(vols) > 10 else 0.0
374
 
375
- # Anchor utilization
376
  with torch.no_grad():
377
  _, v_np = model.constellation.triangulate(
378
  embs[:2000].to(DEVICE), training=False)
379
  n_active = v_np.cpu().unique().numel()
380
 
381
  writer.add_scalar("epoch/val_acc", val_acc, epoch + 1)
 
382
  writer.add_scalar("epoch/val_geo_acc", val_geo_acc, epoch + 1)
383
  writer.add_scalar("epoch/val_cv", v_cv, epoch + 1)
384
  writer.add_scalar("epoch/val_anchors", n_active, epoch + 1)
385
 
386
- # ── Adaptive queue resize ──
387
  mastery.update_size(train_acc, val_acc, epoch + 1)
388
 
389
  # ── Checkpoint ──
@@ -395,10 +401,11 @@ for epoch in range(EPOCHS):
395
  "config": model.config,
396
  "epoch": epoch + 1,
397
  "val_acc": val_acc,
 
398
  "val_geo_acc": val_geo_acc,
399
  "mastery": mastery.state_dict(),
400
  "gal_updates": gal_update_count,
401
- }, "checkpoints/tri_stream_v7_best.pt")
402
  mk = " β˜…"
403
 
404
  if (epoch + 1) % 10 == 0:
@@ -408,39 +415,37 @@ for epoch in range(EPOCHS):
408
  "epoch": epoch + 1,
409
  "val_acc": val_acc,
410
  "optimizer": optimizer.state_dict(),
411
- }, f"checkpoints/tri_stream_v7_e{epoch+1:03d}.pt")
412
 
413
- # ── Epoch print ──
414
  ga = 100 * acc_dict["geo_acc"] / d
415
  ab = 100 * acc_dict["acc_b"] / d
416
- mst_m = acc_dict["mastery"] / d
417
- hn = acc_dict["hard_neg"] / d if mastery.active else 0
418
- hp = acc_dict["hard_pos"] / d if mastery.active else 0
419
  cvf = acc_dict["cv_main"] / d
420
  cvg = acc_dict["cv_geo"] / d
421
  cmv = acc_dict["cm_valid"] / d
422
  stage = "MASTERY" if mastery.active else "stage1"
423
 
424
- # Mean gate values
425
- last_out_gates = []
426
  try:
427
  model.eval()
428
  with torch.no_grad():
429
  sample_imgs = next(iter(val_loader))[0][:4].to(DEVICE)
430
  sample_out = model(sample_imgs, apply_autograd=False)
431
- last_out_gates = sample_out.get('gates_a', [])
432
  except:
433
  pass
434
- gate_str = f"g={np.mean(last_out_gates):.3f}" if last_out_gates else "g=?"
435
 
436
  print(f" E{epoch+1:3d}: A={train_acc:.1f}% B={ab:.0f}% "
437
- f"val={val_acc:.1f}% geo={ga:.0f}/{val_geo_acc:.0f}% "
438
  f"loss={acc_dict['loss']/d:.4f}/{val_loss:.4f} "
 
439
  f"cv={v_cv:.4f}(m={cvf:.5f} g={cvg:.5f}) "
440
  f"cm={cmv:.0%} anch={n_active}/{N_ANCHORS} "
441
- f"[{stage}] mst={mst_m:.3f} {gate_str} "
442
- f"hn={hn:.3f} hp={hp:.3f} "
443
- f"q={mastery.size}/{mastery._current_max} "
444
  f"gal={gal_update_count} ({elapsed:.0f}s){mk}")
445
 
446
  writer.close()
 
1
  #!/usr/bin/env python3
2
  """
3
+ CIFAR-10 β€” Tri-Stream GeoLIP ViT v8
4
+ =====================================
5
+ v7β†’v8 changes:
6
+ 1. GAL_UPDATE_INTERVAL: 50 β†’ 25 (2Γ— more frequent)
7
+ 2. GAL_LR: 0.01 β†’ 0.015 (+50% response)
8
+ 3. Tracks nce_b and geo_nce_acc separately
9
+ 4. stream_b_nce_weight=0.5, geo_nce_weight=0.5
10
  """
11
 
12
  import torch
 
31
  N_BLOCKS = 9
32
  N_HEADS = 8
33
  OUTPUT_DIM = 256
34
+ N_ANCHORS = 128
35
+ N_GAL_ANCHORS = 64
36
  N_COMP = 16
37
  D_COMP = 128
38
  ANCHOR_DROP = 0.10
 
49
  CM_WEIGHT = 0.1
50
  INFONCE_TEMP = 0.07
51
 
52
+ # ── v8: Stream B + Geo NCE weights ──
53
+ STREAM_B_NCE_WEIGHT = 0.5
54
+ GEO_NCE_WEIGHT = 0.5
55
+
56
+ # ── v8: GAL β€” faster updates, stronger response ──
57
+ GAL_UPDATE_INTERVAL = 25 # was 50
58
+ GAL_LR = 0.015 # was 0.01 (+50%)
59
  GAL_BUFFER_SIZE = 50000
60
+ USE_WHITENED_PROCRUSTES = False
61
 
62
  # ── Mastery queue ──
63
  MASTERY_PATIENCE = 50
 
77
  LR = 3e-4
78
  WARMUP = 5
79
  GRAD_CLIP = 1.0
80
+ V1_CKPT = "" # set to checkpoint path for warm start
81
 
82
  print("=" * 60)
83
+ print("CIFAR-10 β€” Tri-Stream GeoLIP ViT v8")
84
  print(f" Architecture: {N_BLOCKS}Γ— TriStreamBlock")
85
  print(f" Sphere: {OUTPUT_DIM}-d, {N_ANCHORS} anchors, {N_COMP}Γ—{D_COMP} pw")
86
  print(f" GAL: {N_GAL_ANCHORS} anchors, Procrustes every {GAL_UPDATE_INTERVAL} "
87
  f"batches (lr={GAL_LR}, whiten={USE_WHITENED_PROCRUSTES})")
88
+ print(f" v8 fixes: uniform hypersphere init, gate_init=1/(2Γ—{N_BLOCKS})")
89
+ print(f" v8 fixes: InfoNCE on emb_b (w={STREAM_B_NCE_WEIGHT}) "
90
+ f"+ geo_emb (w={GEO_NCE_WEIGHT})")
 
91
  print(f" Device: {DEVICE}")
92
  print("=" * 60)
93
 
 
150
  autograd_tang=AUTOGRAD_TANG, autograd_sep=AUTOGRAD_SEP,
151
  enable_autograd=ENABLE_AUTOGRAD,
152
  label_smoothing=LABEL_SMOOTHING,
153
+ stream_b_nce_weight=STREAM_B_NCE_WEIGHT,
154
+ geo_nce_weight=GEO_NCE_WEIGHT,
155
  ).to(DEVICE)
156
 
157
  if V1_CKPT and os.path.exists(V1_CKPT):
158
  ckpt = torch.load(V1_CKPT, map_location="cpu", weights_only=False)
159
+ missing, unexpected = model.load_state_dict(
160
+ ckpt["state_dict"], strict=False)
161
+ print(f" βœ“ Loaded weights: epoch {ckpt.get('epoch', '?')}")
162
+ if missing:
163
+ print(f" New params (expected): {len(missing)}")
164
  else:
165
  print(f" Training from scratch")
166
 
 
175
  print(f"TRAINING β€” {EPOCHS} epochs, lr={LR}, batch={BATCH}")
176
  print(f" GAL Procrustes: every {GAL_UPDATE_INTERVAL} batches, "
177
  f"lr={GAL_LR}, whiten={USE_WHITENED_PROCRUSTES}")
 
 
 
178
  print(f"{'='*60}")
179
 
180
  optimizer = torch.optim.Adam(model.parameters(), lr=LR)
 
191
 
192
  scaler = torch.amp.GradScaler("cuda")
193
  os.makedirs("checkpoints", exist_ok=True)
194
+ writer = SummaryWriter("runs/cifar10_tri_stream_v8")
195
  best_acc = 0.0
196
  gs = 0
197
 
 
208
  simplex_buf = SimplexBuffer(
209
  dim=STREAM_DIM, max_size=GAL_BUFFER_SIZE, device=DEVICE)
210
 
211
+ gal_update_count = 0
212
 
213
  # ══════════════════════════════════════════════════════════════════
214
  # TRAINING LOOP
 
221
  acc_dict = {
222
  "loss": 0, "ce": 0, "bce": 0, "geo_bce": 0,
223
  "acc_a": 0, "acc_b": 0, "geo_acc": 0,
224
+ "nce": 0, "nce_acc": 0,
225
+ "nce_b": 0, "nce_b_acc": 0,
226
+ "geo_nce": 0, "geo_nce_acc": 0,
227
  "cm": 0, "cm_valid": 0, "cv": 0, "cv_main": 0, "cv_geo": 0,
228
  "spread": 0, "mastery": 0, "hard_neg": 0, "hard_pos": 0,
229
  "correct": 0, "total": 0, "n": 0}
 
249
  scaler.step(optimizer); scaler.update()
250
  scheduler.step()
251
 
 
252
  mastery.check_activation(ld.get('nce_acc', 0))
253
 
 
254
  pool_geo = out1.get('pool_geo')
255
  if pool_geo is not None:
256
  simplex_buf.push(pool_geo.float(), targets)
257
 
 
258
  gs += 1
259
+ if gs % GAL_UPDATE_INTERVAL == 0 and simplex_buf.size > 500:
260
  score = model.update_gal_anchors(
261
  simplex_buf, lr=GAL_LR, whiten=USE_WHITENED_PROCRUSTES)
262
  if score is not None:
263
  gal_update_count += 1
264
  writer.add_scalar("step/procrustes_score", score, gs)
265
 
266
+ # Track
267
  preds = out1['logits_a'].argmax(-1)
268
  correct = (preds == targets).sum().item()
269
  acc_dict["correct"] += correct
270
  acc_dict["total"] += targets.shape[0]
271
  acc_dict["loss"] += loss.item()
272
 
273
+ for k in ["ce", "bce", "geo_bce", "nce", "nce_b", "geo_nce",
274
  "cm", "cv", "spread", "mastery"]:
275
  v = ld.get(k, 0)
276
  acc_dict[k] += v.item() if torch.is_tensor(v) else v
 
279
  acc_dict["acc_b"] += ld.get("acc_b", 0)
280
  acc_dict["geo_acc"] += ld.get("geo_acc", 0)
281
  acc_dict["nce_acc"] += ld.get("nce_acc", 0)
282
+ acc_dict["nce_b_acc"] += ld.get("nce_b_acc", 0)
283
  acc_dict["geo_nce_acc"] += ld.get("geo_nce_acc", 0)
284
  acc_dict["cm_valid"] += ld.get("cm_valid", 0)
285
  acc_dict["cv_main"] += ld.get("cv_main", 0)
 
292
  d = acc_dict["n"]
293
  ta = 100 * acc_dict["correct"] / acc_dict["total"]
294
  ga = 100 * acc_dict["geo_acc"] / d
295
+ nb = acc_dict["nce_b_acc"] / d
296
  stg = "M" if mastery.active else "S1"
297
  pbar.set_postfix(
298
  loss=f"{acc_dict['loss']/d:.4f}",
299
  a=f"{ta:.0f}%",
300
  ga=f"{ga:.0f}%",
301
+ nb=f"{nb:.2f}",
302
  stg=stg,
303
  gal=gal_update_count,
304
  ordered=True)
305
 
 
306
  if gs % 20 == 0:
307
  writer.add_scalar("step/loss", loss.item(), gs)
308
  writer.add_scalar("step/geo_acc", ld.get("geo_acc", 0), gs)
309
+ writer.add_scalar("step/nce_b_acc", ld.get("nce_b_acc", 0), gs)
310
+ writer.add_scalar("step/geo_nce_acc", ld.get("geo_nce_acc", 0), gs)
311
  gates_a = out1.get('gates_a', [])
312
  if gates_a:
313
  writer.add_scalar("step/gate_a_mean",
 
326
  writer.add_scalar("epoch/acc_b", 100 * acc_dict["acc_b"] / d, epoch + 1)
327
  writer.add_scalar("epoch/geo_acc", 100 * acc_dict["geo_acc"] / d, epoch + 1)
328
  writer.add_scalar("epoch/nce_acc", acc_dict["nce_acc"] / d, epoch + 1)
329
+ writer.add_scalar("epoch/nce_b_acc", acc_dict["nce_b_acc"] / d, epoch + 1)
330
  writer.add_scalar("epoch/geo_nce_acc", acc_dict["geo_nce_acc"] / d, epoch + 1)
331
  writer.add_scalar("epoch/cv_main", acc_dict["cv_main"] / d, epoch + 1)
332
  writer.add_scalar("epoch/cv_geo", acc_dict["cv_geo"] / d, epoch + 1)
333
  writer.add_scalar("epoch/cm_valid", acc_dict["cm_valid"] / d, epoch + 1)
 
 
 
334
  writer.add_scalar("epoch/gal_updates", gal_update_count, epoch + 1)
335
 
336
  # ── Validation ──
337
  model.eval()
338
  val_correct, val_total, val_loss_sum, val_n = 0, 0, 0, 0
339
  val_geo_correct = 0
340
+ val_b_correct = 0
341
  all_embs = []
342
 
343
  with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
 
347
  out = model(images, apply_autograd=False)
348
  preds = out['logits_a'].argmax(dim=-1)
349
  val_correct += (preds == labels_v).sum().item()
350
+ val_b_correct += (out['logits_b'].argmax(-1) == labels_v).sum().item()
351
+ val_geo_correct += (out['geo_logits'].argmax(-1) == labels_v).sum().item()
352
  val_total += labels_v.shape[0]
353
  loss_v = F.cross_entropy(out['logits_a'], labels_v)
354
  val_loss_sum += loss_v.item()
 
356
  all_embs.append(out['embedding'].float().cpu())
357
 
358
  val_acc = 100 * val_correct / val_total
359
+ val_b_acc = 100 * val_b_correct / val_total
360
  val_geo_acc = 100 * val_geo_correct / val_total
361
  val_loss = val_loss_sum / max(val_n, 1)
362
 
363
+ # ── Val embedding diagnostics ──
364
  embs = torch.cat(all_embs)
365
  with torch.no_grad():
366
  sample = embs[:2000].to(DEVICE)
 
379
  vols.append(v2[0].sqrt())
380
  v_cv = (torch.stack(vols).std() / (torch.stack(vols).mean() + 1e-8)).item() if len(vols) > 10 else 0.0
381
 
 
382
  with torch.no_grad():
383
  _, v_np = model.constellation.triangulate(
384
  embs[:2000].to(DEVICE), training=False)
385
  n_active = v_np.cpu().unique().numel()
386
 
387
  writer.add_scalar("epoch/val_acc", val_acc, epoch + 1)
388
+ writer.add_scalar("epoch/val_b_acc", val_b_acc, epoch + 1)
389
  writer.add_scalar("epoch/val_geo_acc", val_geo_acc, epoch + 1)
390
  writer.add_scalar("epoch/val_cv", v_cv, epoch + 1)
391
  writer.add_scalar("epoch/val_anchors", n_active, epoch + 1)
392
 
 
393
  mastery.update_size(train_acc, val_acc, epoch + 1)
394
 
395
  # ── Checkpoint ──
 
401
  "config": model.config,
402
  "epoch": epoch + 1,
403
  "val_acc": val_acc,
404
+ "val_b_acc": val_b_acc,
405
  "val_geo_acc": val_geo_acc,
406
  "mastery": mastery.state_dict(),
407
  "gal_updates": gal_update_count,
408
+ }, "checkpoints/tri_stream_v8_best.pt")
409
  mk = " β˜…"
410
 
411
  if (epoch + 1) % 10 == 0:
 
415
  "epoch": epoch + 1,
416
  "val_acc": val_acc,
417
  "optimizer": optimizer.state_dict(),
418
+ }, f"checkpoints/tri_stream_v8_e{epoch+1:03d}.pt")
419
 
420
+ # ── Epoch print β€” v8: shows B acc + nce_b + geo_nce ──
421
  ga = 100 * acc_dict["geo_acc"] / d
422
  ab = 100 * acc_dict["acc_b"] / d
423
+ nb_acc = acc_dict["nce_b_acc"] / d
424
+ gn_acc = acc_dict["geo_nce_acc"] / d
 
425
  cvf = acc_dict["cv_main"] / d
426
  cvg = acc_dict["cv_geo"] / d
427
  cmv = acc_dict["cm_valid"] / d
428
  stage = "MASTERY" if mastery.active else "stage1"
429
 
430
+ # Gate check
431
+ last_gates = []
432
  try:
433
  model.eval()
434
  with torch.no_grad():
435
  sample_imgs = next(iter(val_loader))[0][:4].to(DEVICE)
436
  sample_out = model(sample_imgs, apply_autograd=False)
437
+ last_gates = sample_out.get('gates_a', [])
438
  except:
439
  pass
440
+ gate_str = f"g={np.mean(last_gates):.4f}" if last_gates else "g=?"
441
 
442
  print(f" E{epoch+1:3d}: A={train_acc:.1f}% B={ab:.0f}% "
443
+ f"val={val_acc:.1f}%/{val_b_acc:.1f}%/{val_geo_acc:.1f}% "
444
  f"loss={acc_dict['loss']/d:.4f}/{val_loss:.4f} "
445
+ f"nb={nb_acc:.2f} gn={gn_acc:.2f} "
446
  f"cv={v_cv:.4f}(m={cvf:.5f} g={cvg:.5f}) "
447
  f"cm={cmv:.0%} anch={n_active}/{N_ANCHORS} "
448
+ f"[{stage}] {gate_str} "
 
 
449
  f"gal={gal_update_count} ({elapsed:.0f}s){mk}")
450
 
451
  writer.close()