Update trainer.py
Browse files- trainer.py +62 -57
trainer.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
CIFAR-10 β Tri-Stream GeoLIP ViT
|
| 4 |
-
=====================================
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
import torch
|
|
@@ -29,8 +31,8 @@ STREAM_DIM = 192
|
|
| 29 |
N_BLOCKS = 9
|
| 30 |
N_HEADS = 8
|
| 31 |
OUTPUT_DIM = 256
|
| 32 |
-
N_ANCHORS = 128
|
| 33 |
-
N_GAL_ANCHORS = 64
|
| 34 |
N_COMP = 16
|
| 35 |
D_COMP = 128
|
| 36 |
ANCHOR_DROP = 0.10
|
|
@@ -47,11 +49,15 @@ BCE_WEIGHT = 1.0
|
|
| 47 |
CM_WEIGHT = 0.1
|
| 48 |
INFONCE_TEMP = 0.07
|
| 49 |
|
| 50 |
-
# ββ
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
GAL_BUFFER_SIZE = 50000
|
| 54 |
-
USE_WHITENED_PROCRUSTES = False
|
| 55 |
|
| 56 |
# ββ Mastery queue ββ
|
| 57 |
MASTERY_PATIENCE = 50
|
|
@@ -71,18 +77,17 @@ EPOCHS = 100
|
|
| 71 |
LR = 3e-4
|
| 72 |
WARMUP = 5
|
| 73 |
GRAD_CLIP = 1.0
|
| 74 |
-
V1_CKPT = ""
|
| 75 |
|
| 76 |
print("=" * 60)
|
| 77 |
-
print("CIFAR-10 β Tri-Stream GeoLIP ViT
|
| 78 |
print(f" Architecture: {N_BLOCKS}Γ TriStreamBlock")
|
| 79 |
print(f" Sphere: {OUTPUT_DIM}-d, {N_ANCHORS} anchors, {N_COMP}Γ{D_COMP} pw")
|
| 80 |
print(f" GAL: {N_GAL_ANCHORS} anchors, Procrustes every {GAL_UPDATE_INTERVAL} "
|
| 81 |
f"batches (lr={GAL_LR}, whiten={USE_WHITENED_PROCRUSTES})")
|
| 82 |
-
print(f"
|
| 83 |
-
print(f" InfoNCE={
|
| 84 |
-
|
| 85 |
-
print(f" LS={LABEL_SMOOTHING}, CV={CV_WEIGHT}")
|
| 86 |
print(f" Device: {DEVICE}")
|
| 87 |
print("=" * 60)
|
| 88 |
|
|
@@ -145,12 +150,17 @@ model = create_tri_stream_vit(
|
|
| 145 |
autograd_tang=AUTOGRAD_TANG, autograd_sep=AUTOGRAD_SEP,
|
| 146 |
enable_autograd=ENABLE_AUTOGRAD,
|
| 147 |
label_smoothing=LABEL_SMOOTHING,
|
|
|
|
|
|
|
| 148 |
).to(DEVICE)
|
| 149 |
|
| 150 |
if V1_CKPT and os.path.exists(V1_CKPT):
|
| 151 |
ckpt = torch.load(V1_CKPT, map_location="cpu", weights_only=False)
|
| 152 |
-
model.load_state_dict(
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
| 154 |
else:
|
| 155 |
print(f" Training from scratch")
|
| 156 |
|
|
@@ -165,9 +175,6 @@ print(f"\n{'='*60}")
|
|
| 165 |
print(f"TRAINING β {EPOCHS} epochs, lr={LR}, batch={BATCH}")
|
| 166 |
print(f" GAL Procrustes: every {GAL_UPDATE_INTERVAL} batches, "
|
| 167 |
f"lr={GAL_LR}, whiten={USE_WHITENED_PROCRUSTES}")
|
| 168 |
-
print(f" Mastery: patience={MASTERY_PATIENCE}, queue adaptive "
|
| 169 |
-
f"[{MASTERY_MIN_SIZE}β{MASTERY_MAX_SIZE}]")
|
| 170 |
-
print(f" Optimizer: Adam")
|
| 171 |
print(f"{'='*60}")
|
| 172 |
|
| 173 |
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
|
|
@@ -184,7 +191,7 @@ scheduler = torch.optim.lr_scheduler.SequentialLR(
|
|
| 184 |
|
| 185 |
scaler = torch.amp.GradScaler("cuda")
|
| 186 |
os.makedirs("checkpoints", exist_ok=True)
|
| 187 |
-
writer = SummaryWriter("runs/
|
| 188 |
best_acc = 0.0
|
| 189 |
gs = 0
|
| 190 |
|
|
@@ -201,7 +208,7 @@ mastery = MasteryQueue(
|
|
| 201 |
simplex_buf = SimplexBuffer(
|
| 202 |
dim=STREAM_DIM, max_size=GAL_BUFFER_SIZE, device=DEVICE)
|
| 203 |
|
| 204 |
-
gal_update_count = 0
|
| 205 |
|
| 206 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 207 |
# TRAINING LOOP
|
|
@@ -214,7 +221,9 @@ for epoch in range(EPOCHS):
|
|
| 214 |
acc_dict = {
|
| 215 |
"loss": 0, "ce": 0, "bce": 0, "geo_bce": 0,
|
| 216 |
"acc_a": 0, "acc_b": 0, "geo_acc": 0,
|
| 217 |
-
"nce": 0, "nce_acc": 0,
|
|
|
|
|
|
|
| 218 |
"cm": 0, "cm_valid": 0, "cv": 0, "cv_main": 0, "cv_geo": 0,
|
| 219 |
"spread": 0, "mastery": 0, "hard_neg": 0, "hard_pos": 0,
|
| 220 |
"correct": 0, "total": 0, "n": 0}
|
|
@@ -240,31 +249,28 @@ for epoch in range(EPOCHS):
|
|
| 240 |
scaler.step(optimizer); scaler.update()
|
| 241 |
scheduler.step()
|
| 242 |
|
| 243 |
-
# ββ Mastery activation check ββ
|
| 244 |
mastery.check_activation(ld.get('nce_acc', 0))
|
| 245 |
|
| 246 |
-
# ββ Accumulate geo features into simplex buffer ββ
|
| 247 |
pool_geo = out1.get('pool_geo')
|
| 248 |
if pool_geo is not None:
|
| 249 |
simplex_buf.push(pool_geo.float(), targets)
|
| 250 |
|
| 251 |
-
# ββ Periodic GAL Procrustes update ββ
|
| 252 |
gs += 1
|
| 253 |
-
if gs % GAL_UPDATE_INTERVAL == 0 and simplex_buf.size >
|
| 254 |
score = model.update_gal_anchors(
|
| 255 |
simplex_buf, lr=GAL_LR, whiten=USE_WHITENED_PROCRUSTES)
|
| 256 |
if score is not None:
|
| 257 |
gal_update_count += 1
|
| 258 |
writer.add_scalar("step/procrustes_score", score, gs)
|
| 259 |
|
| 260 |
-
#
|
| 261 |
preds = out1['logits_a'].argmax(-1)
|
| 262 |
correct = (preds == targets).sum().item()
|
| 263 |
acc_dict["correct"] += correct
|
| 264 |
acc_dict["total"] += targets.shape[0]
|
| 265 |
acc_dict["loss"] += loss.item()
|
| 266 |
|
| 267 |
-
for k in ["ce", "bce", "geo_bce", "nce", "geo_nce",
|
| 268 |
"cm", "cv", "spread", "mastery"]:
|
| 269 |
v = ld.get(k, 0)
|
| 270 |
acc_dict[k] += v.item() if torch.is_tensor(v) else v
|
|
@@ -273,6 +279,7 @@ for epoch in range(EPOCHS):
|
|
| 273 |
acc_dict["acc_b"] += ld.get("acc_b", 0)
|
| 274 |
acc_dict["geo_acc"] += ld.get("geo_acc", 0)
|
| 275 |
acc_dict["nce_acc"] += ld.get("nce_acc", 0)
|
|
|
|
| 276 |
acc_dict["geo_nce_acc"] += ld.get("geo_nce_acc", 0)
|
| 277 |
acc_dict["cm_valid"] += ld.get("cm_valid", 0)
|
| 278 |
acc_dict["cv_main"] += ld.get("cv_main", 0)
|
|
@@ -285,22 +292,22 @@ for epoch in range(EPOCHS):
|
|
| 285 |
d = acc_dict["n"]
|
| 286 |
ta = 100 * acc_dict["correct"] / acc_dict["total"]
|
| 287 |
ga = 100 * acc_dict["geo_acc"] / d
|
| 288 |
-
|
| 289 |
stg = "M" if mastery.active else "S1"
|
| 290 |
pbar.set_postfix(
|
| 291 |
loss=f"{acc_dict['loss']/d:.4f}",
|
| 292 |
a=f"{ta:.0f}%",
|
| 293 |
ga=f"{ga:.0f}%",
|
| 294 |
-
|
| 295 |
stg=stg,
|
| 296 |
gal=gal_update_count,
|
| 297 |
ordered=True)
|
| 298 |
|
| 299 |
-
# Step-level TB
|
| 300 |
if gs % 20 == 0:
|
| 301 |
writer.add_scalar("step/loss", loss.item(), gs)
|
| 302 |
writer.add_scalar("step/geo_acc", ld.get("geo_acc", 0), gs)
|
| 303 |
-
|
|
|
|
| 304 |
gates_a = out1.get('gates_a', [])
|
| 305 |
if gates_a:
|
| 306 |
writer.add_scalar("step/gate_a_mean",
|
|
@@ -319,19 +326,18 @@ for epoch in range(EPOCHS):
|
|
| 319 |
writer.add_scalar("epoch/acc_b", 100 * acc_dict["acc_b"] / d, epoch + 1)
|
| 320 |
writer.add_scalar("epoch/geo_acc", 100 * acc_dict["geo_acc"] / d, epoch + 1)
|
| 321 |
writer.add_scalar("epoch/nce_acc", acc_dict["nce_acc"] / d, epoch + 1)
|
|
|
|
| 322 |
writer.add_scalar("epoch/geo_nce_acc", acc_dict["geo_nce_acc"] / d, epoch + 1)
|
| 323 |
writer.add_scalar("epoch/cv_main", acc_dict["cv_main"] / d, epoch + 1)
|
| 324 |
writer.add_scalar("epoch/cv_geo", acc_dict["cv_geo"] / d, epoch + 1)
|
| 325 |
writer.add_scalar("epoch/cm_valid", acc_dict["cm_valid"] / d, epoch + 1)
|
| 326 |
-
writer.add_scalar("epoch/margin", mastery.current_margin, epoch + 1)
|
| 327 |
-
writer.add_scalar("epoch/queue_max", mastery._current_max, epoch + 1)
|
| 328 |
-
writer.add_scalar("epoch/simplex_buf", simplex_buf.size, epoch + 1)
|
| 329 |
writer.add_scalar("epoch/gal_updates", gal_update_count, epoch + 1)
|
| 330 |
|
| 331 |
# ββ Validation ββ
|
| 332 |
model.eval()
|
| 333 |
val_correct, val_total, val_loss_sum, val_n = 0, 0, 0, 0
|
| 334 |
val_geo_correct = 0
|
|
|
|
| 335 |
all_embs = []
|
| 336 |
|
| 337 |
with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
|
|
@@ -341,8 +347,8 @@ for epoch in range(EPOCHS):
|
|
| 341 |
out = model(images, apply_autograd=False)
|
| 342 |
preds = out['logits_a'].argmax(dim=-1)
|
| 343 |
val_correct += (preds == labels_v).sum().item()
|
| 344 |
-
|
| 345 |
-
val_geo_correct += (
|
| 346 |
val_total += labels_v.shape[0]
|
| 347 |
loss_v = F.cross_entropy(out['logits_a'], labels_v)
|
| 348 |
val_loss_sum += loss_v.item()
|
|
@@ -350,10 +356,11 @@ for epoch in range(EPOCHS):
|
|
| 350 |
all_embs.append(out['embedding'].float().cpu())
|
| 351 |
|
| 352 |
val_acc = 100 * val_correct / val_total
|
|
|
|
| 353 |
val_geo_acc = 100 * val_geo_correct / val_total
|
| 354 |
val_loss = val_loss_sum / max(val_n, 1)
|
| 355 |
|
| 356 |
-
#
|
| 357 |
embs = torch.cat(all_embs)
|
| 358 |
with torch.no_grad():
|
| 359 |
sample = embs[:2000].to(DEVICE)
|
|
@@ -372,18 +379,17 @@ for epoch in range(EPOCHS):
|
|
| 372 |
vols.append(v2[0].sqrt())
|
| 373 |
v_cv = (torch.stack(vols).std() / (torch.stack(vols).mean() + 1e-8)).item() if len(vols) > 10 else 0.0
|
| 374 |
|
| 375 |
-
# Anchor utilization
|
| 376 |
with torch.no_grad():
|
| 377 |
_, v_np = model.constellation.triangulate(
|
| 378 |
embs[:2000].to(DEVICE), training=False)
|
| 379 |
n_active = v_np.cpu().unique().numel()
|
| 380 |
|
| 381 |
writer.add_scalar("epoch/val_acc", val_acc, epoch + 1)
|
|
|
|
| 382 |
writer.add_scalar("epoch/val_geo_acc", val_geo_acc, epoch + 1)
|
| 383 |
writer.add_scalar("epoch/val_cv", v_cv, epoch + 1)
|
| 384 |
writer.add_scalar("epoch/val_anchors", n_active, epoch + 1)
|
| 385 |
|
| 386 |
-
# ββ Adaptive queue resize ββ
|
| 387 |
mastery.update_size(train_acc, val_acc, epoch + 1)
|
| 388 |
|
| 389 |
# ββ Checkpoint ββ
|
|
@@ -395,10 +401,11 @@ for epoch in range(EPOCHS):
|
|
| 395 |
"config": model.config,
|
| 396 |
"epoch": epoch + 1,
|
| 397 |
"val_acc": val_acc,
|
|
|
|
| 398 |
"val_geo_acc": val_geo_acc,
|
| 399 |
"mastery": mastery.state_dict(),
|
| 400 |
"gal_updates": gal_update_count,
|
| 401 |
-
}, "checkpoints/
|
| 402 |
mk = " β
"
|
| 403 |
|
| 404 |
if (epoch + 1) % 10 == 0:
|
|
@@ -408,39 +415,37 @@ for epoch in range(EPOCHS):
|
|
| 408 |
"epoch": epoch + 1,
|
| 409 |
"val_acc": val_acc,
|
| 410 |
"optimizer": optimizer.state_dict(),
|
| 411 |
-
}, f"checkpoints/
|
| 412 |
|
| 413 |
-
# ββ Epoch print ββ
|
| 414 |
ga = 100 * acc_dict["geo_acc"] / d
|
| 415 |
ab = 100 * acc_dict["acc_b"] / d
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
hp = acc_dict["hard_pos"] / d if mastery.active else 0
|
| 419 |
cvf = acc_dict["cv_main"] / d
|
| 420 |
cvg = acc_dict["cv_geo"] / d
|
| 421 |
cmv = acc_dict["cm_valid"] / d
|
| 422 |
stage = "MASTERY" if mastery.active else "stage1"
|
| 423 |
|
| 424 |
-
#
|
| 425 |
-
|
| 426 |
try:
|
| 427 |
model.eval()
|
| 428 |
with torch.no_grad():
|
| 429 |
sample_imgs = next(iter(val_loader))[0][:4].to(DEVICE)
|
| 430 |
sample_out = model(sample_imgs, apply_autograd=False)
|
| 431 |
-
|
| 432 |
except:
|
| 433 |
pass
|
| 434 |
-
gate_str = f"g={np.mean(
|
| 435 |
|
| 436 |
print(f" E{epoch+1:3d}: A={train_acc:.1f}% B={ab:.0f}% "
|
| 437 |
-
f"val={val_acc:.1f}%
|
| 438 |
f"loss={acc_dict['loss']/d:.4f}/{val_loss:.4f} "
|
|
|
|
| 439 |
f"cv={v_cv:.4f}(m={cvf:.5f} g={cvg:.5f}) "
|
| 440 |
f"cm={cmv:.0%} anch={n_active}/{N_ANCHORS} "
|
| 441 |
-
f"[{stage}]
|
| 442 |
-
f"hn={hn:.3f} hp={hp:.3f} "
|
| 443 |
-
f"q={mastery.size}/{mastery._current_max} "
|
| 444 |
f"gal={gal_update_count} ({elapsed:.0f}s){mk}")
|
| 445 |
|
| 446 |
writer.close()
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
CIFAR-10 β Tri-Stream GeoLIP ViT v8
|
| 4 |
+
=====================================
|
| 5 |
+
v7βv8 changes:
|
| 6 |
+
1. GAL_UPDATE_INTERVAL: 50 β 25 (2Γ more frequent)
|
| 7 |
+
2. GAL_LR: 0.01 β 0.015 (+50% response)
|
| 8 |
+
3. Tracks nce_b and geo_nce_acc separately
|
| 9 |
+
4. stream_b_nce_weight=0.5, geo_nce_weight=0.5
|
| 10 |
"""
|
| 11 |
|
| 12 |
import torch
|
|
|
|
| 31 |
N_BLOCKS = 9
|
| 32 |
N_HEADS = 8
|
| 33 |
OUTPUT_DIM = 256
|
| 34 |
+
N_ANCHORS = 128
|
| 35 |
+
N_GAL_ANCHORS = 64
|
| 36 |
N_COMP = 16
|
| 37 |
D_COMP = 128
|
| 38 |
ANCHOR_DROP = 0.10
|
|
|
|
| 49 |
CM_WEIGHT = 0.1
|
| 50 |
INFONCE_TEMP = 0.07
|
| 51 |
|
| 52 |
+
# ββ v8: Stream B + Geo NCE weights ββ
|
| 53 |
+
STREAM_B_NCE_WEIGHT = 0.5
|
| 54 |
+
GEO_NCE_WEIGHT = 0.5
|
| 55 |
+
|
| 56 |
+
# ββ v8: GAL β faster updates, stronger response ββ
|
| 57 |
+
GAL_UPDATE_INTERVAL = 25 # was 50
|
| 58 |
+
GAL_LR = 0.015 # was 0.01 (+50%)
|
| 59 |
GAL_BUFFER_SIZE = 50000
|
| 60 |
+
USE_WHITENED_PROCRUSTES = False
|
| 61 |
|
| 62 |
# ββ Mastery queue ββ
|
| 63 |
MASTERY_PATIENCE = 50
|
|
|
|
| 77 |
LR = 3e-4
|
| 78 |
WARMUP = 5
|
| 79 |
GRAD_CLIP = 1.0
|
| 80 |
+
V1_CKPT = "" # set to checkpoint path for warm start
|
| 81 |
|
| 82 |
print("=" * 60)
|
| 83 |
+
print("CIFAR-10 β Tri-Stream GeoLIP ViT v8")
|
| 84 |
print(f" Architecture: {N_BLOCKS}Γ TriStreamBlock")
|
| 85 |
print(f" Sphere: {OUTPUT_DIM}-d, {N_ANCHORS} anchors, {N_COMP}Γ{D_COMP} pw")
|
| 86 |
print(f" GAL: {N_GAL_ANCHORS} anchors, Procrustes every {GAL_UPDATE_INTERVAL} "
|
| 87 |
f"batches (lr={GAL_LR}, whiten={USE_WHITENED_PROCRUSTES})")
|
| 88 |
+
print(f" v8 fixes: uniform hypersphere init, gate_init=1/(2Γ{N_BLOCKS})")
|
| 89 |
+
print(f" v8 fixes: InfoNCE on emb_b (w={STREAM_B_NCE_WEIGHT}) "
|
| 90 |
+
f"+ geo_emb (w={GEO_NCE_WEIGHT})")
|
|
|
|
| 91 |
print(f" Device: {DEVICE}")
|
| 92 |
print("=" * 60)
|
| 93 |
|
|
|
|
| 150 |
autograd_tang=AUTOGRAD_TANG, autograd_sep=AUTOGRAD_SEP,
|
| 151 |
enable_autograd=ENABLE_AUTOGRAD,
|
| 152 |
label_smoothing=LABEL_SMOOTHING,
|
| 153 |
+
stream_b_nce_weight=STREAM_B_NCE_WEIGHT,
|
| 154 |
+
geo_nce_weight=GEO_NCE_WEIGHT,
|
| 155 |
).to(DEVICE)
|
| 156 |
|
| 157 |
if V1_CKPT and os.path.exists(V1_CKPT):
|
| 158 |
ckpt = torch.load(V1_CKPT, map_location="cpu", weights_only=False)
|
| 159 |
+
missing, unexpected = model.load_state_dict(
|
| 160 |
+
ckpt["state_dict"], strict=False)
|
| 161 |
+
print(f" β Loaded weights: epoch {ckpt.get('epoch', '?')}")
|
| 162 |
+
if missing:
|
| 163 |
+
print(f" New params (expected): {len(missing)}")
|
| 164 |
else:
|
| 165 |
print(f" Training from scratch")
|
| 166 |
|
|
|
|
| 175 |
print(f"TRAINING β {EPOCHS} epochs, lr={LR}, batch={BATCH}")
|
| 176 |
print(f" GAL Procrustes: every {GAL_UPDATE_INTERVAL} batches, "
|
| 177 |
f"lr={GAL_LR}, whiten={USE_WHITENED_PROCRUSTES}")
|
|
|
|
|
|
|
|
|
|
| 178 |
print(f"{'='*60}")
|
| 179 |
|
| 180 |
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
|
|
|
|
| 191 |
|
| 192 |
scaler = torch.amp.GradScaler("cuda")
|
| 193 |
os.makedirs("checkpoints", exist_ok=True)
|
| 194 |
+
writer = SummaryWriter("runs/cifar10_tri_stream_v8")
|
| 195 |
best_acc = 0.0
|
| 196 |
gs = 0
|
| 197 |
|
|
|
|
| 208 |
simplex_buf = SimplexBuffer(
|
| 209 |
dim=STREAM_DIM, max_size=GAL_BUFFER_SIZE, device=DEVICE)
|
| 210 |
|
| 211 |
+
gal_update_count = 0
|
| 212 |
|
| 213 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 214 |
# TRAINING LOOP
|
|
|
|
| 221 |
acc_dict = {
|
| 222 |
"loss": 0, "ce": 0, "bce": 0, "geo_bce": 0,
|
| 223 |
"acc_a": 0, "acc_b": 0, "geo_acc": 0,
|
| 224 |
+
"nce": 0, "nce_acc": 0,
|
| 225 |
+
"nce_b": 0, "nce_b_acc": 0,
|
| 226 |
+
"geo_nce": 0, "geo_nce_acc": 0,
|
| 227 |
"cm": 0, "cm_valid": 0, "cv": 0, "cv_main": 0, "cv_geo": 0,
|
| 228 |
"spread": 0, "mastery": 0, "hard_neg": 0, "hard_pos": 0,
|
| 229 |
"correct": 0, "total": 0, "n": 0}
|
|
|
|
| 249 |
scaler.step(optimizer); scaler.update()
|
| 250 |
scheduler.step()
|
| 251 |
|
|
|
|
| 252 |
mastery.check_activation(ld.get('nce_acc', 0))
|
| 253 |
|
|
|
|
| 254 |
pool_geo = out1.get('pool_geo')
|
| 255 |
if pool_geo is not None:
|
| 256 |
simplex_buf.push(pool_geo.float(), targets)
|
| 257 |
|
|
|
|
| 258 |
gs += 1
|
| 259 |
+
if gs % GAL_UPDATE_INTERVAL == 0 and simplex_buf.size > 500:
|
| 260 |
score = model.update_gal_anchors(
|
| 261 |
simplex_buf, lr=GAL_LR, whiten=USE_WHITENED_PROCRUSTES)
|
| 262 |
if score is not None:
|
| 263 |
gal_update_count += 1
|
| 264 |
writer.add_scalar("step/procrustes_score", score, gs)
|
| 265 |
|
| 266 |
+
# Track
|
| 267 |
preds = out1['logits_a'].argmax(-1)
|
| 268 |
correct = (preds == targets).sum().item()
|
| 269 |
acc_dict["correct"] += correct
|
| 270 |
acc_dict["total"] += targets.shape[0]
|
| 271 |
acc_dict["loss"] += loss.item()
|
| 272 |
|
| 273 |
+
for k in ["ce", "bce", "geo_bce", "nce", "nce_b", "geo_nce",
|
| 274 |
"cm", "cv", "spread", "mastery"]:
|
| 275 |
v = ld.get(k, 0)
|
| 276 |
acc_dict[k] += v.item() if torch.is_tensor(v) else v
|
|
|
|
| 279 |
acc_dict["acc_b"] += ld.get("acc_b", 0)
|
| 280 |
acc_dict["geo_acc"] += ld.get("geo_acc", 0)
|
| 281 |
acc_dict["nce_acc"] += ld.get("nce_acc", 0)
|
| 282 |
+
acc_dict["nce_b_acc"] += ld.get("nce_b_acc", 0)
|
| 283 |
acc_dict["geo_nce_acc"] += ld.get("geo_nce_acc", 0)
|
| 284 |
acc_dict["cm_valid"] += ld.get("cm_valid", 0)
|
| 285 |
acc_dict["cv_main"] += ld.get("cv_main", 0)
|
|
|
|
| 292 |
d = acc_dict["n"]
|
| 293 |
ta = 100 * acc_dict["correct"] / acc_dict["total"]
|
| 294 |
ga = 100 * acc_dict["geo_acc"] / d
|
| 295 |
+
nb = acc_dict["nce_b_acc"] / d
|
| 296 |
stg = "M" if mastery.active else "S1"
|
| 297 |
pbar.set_postfix(
|
| 298 |
loss=f"{acc_dict['loss']/d:.4f}",
|
| 299 |
a=f"{ta:.0f}%",
|
| 300 |
ga=f"{ga:.0f}%",
|
| 301 |
+
nb=f"{nb:.2f}",
|
| 302 |
stg=stg,
|
| 303 |
gal=gal_update_count,
|
| 304 |
ordered=True)
|
| 305 |
|
|
|
|
| 306 |
if gs % 20 == 0:
|
| 307 |
writer.add_scalar("step/loss", loss.item(), gs)
|
| 308 |
writer.add_scalar("step/geo_acc", ld.get("geo_acc", 0), gs)
|
| 309 |
+
writer.add_scalar("step/nce_b_acc", ld.get("nce_b_acc", 0), gs)
|
| 310 |
+
writer.add_scalar("step/geo_nce_acc", ld.get("geo_nce_acc", 0), gs)
|
| 311 |
gates_a = out1.get('gates_a', [])
|
| 312 |
if gates_a:
|
| 313 |
writer.add_scalar("step/gate_a_mean",
|
|
|
|
| 326 |
writer.add_scalar("epoch/acc_b", 100 * acc_dict["acc_b"] / d, epoch + 1)
|
| 327 |
writer.add_scalar("epoch/geo_acc", 100 * acc_dict["geo_acc"] / d, epoch + 1)
|
| 328 |
writer.add_scalar("epoch/nce_acc", acc_dict["nce_acc"] / d, epoch + 1)
|
| 329 |
+
writer.add_scalar("epoch/nce_b_acc", acc_dict["nce_b_acc"] / d, epoch + 1)
|
| 330 |
writer.add_scalar("epoch/geo_nce_acc", acc_dict["geo_nce_acc"] / d, epoch + 1)
|
| 331 |
writer.add_scalar("epoch/cv_main", acc_dict["cv_main"] / d, epoch + 1)
|
| 332 |
writer.add_scalar("epoch/cv_geo", acc_dict["cv_geo"] / d, epoch + 1)
|
| 333 |
writer.add_scalar("epoch/cm_valid", acc_dict["cm_valid"] / d, epoch + 1)
|
|
|
|
|
|
|
|
|
|
| 334 |
writer.add_scalar("epoch/gal_updates", gal_update_count, epoch + 1)
|
| 335 |
|
| 336 |
# ββ Validation ββ
|
| 337 |
model.eval()
|
| 338 |
val_correct, val_total, val_loss_sum, val_n = 0, 0, 0, 0
|
| 339 |
val_geo_correct = 0
|
| 340 |
+
val_b_correct = 0
|
| 341 |
all_embs = []
|
| 342 |
|
| 343 |
with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
|
|
|
|
| 347 |
out = model(images, apply_autograd=False)
|
| 348 |
preds = out['logits_a'].argmax(dim=-1)
|
| 349 |
val_correct += (preds == labels_v).sum().item()
|
| 350 |
+
val_b_correct += (out['logits_b'].argmax(-1) == labels_v).sum().item()
|
| 351 |
+
val_geo_correct += (out['geo_logits'].argmax(-1) == labels_v).sum().item()
|
| 352 |
val_total += labels_v.shape[0]
|
| 353 |
loss_v = F.cross_entropy(out['logits_a'], labels_v)
|
| 354 |
val_loss_sum += loss_v.item()
|
|
|
|
| 356 |
all_embs.append(out['embedding'].float().cpu())
|
| 357 |
|
| 358 |
val_acc = 100 * val_correct / val_total
|
| 359 |
+
val_b_acc = 100 * val_b_correct / val_total
|
| 360 |
val_geo_acc = 100 * val_geo_correct / val_total
|
| 361 |
val_loss = val_loss_sum / max(val_n, 1)
|
| 362 |
|
| 363 |
+
# ββ Val embedding diagnostics ββ
|
| 364 |
embs = torch.cat(all_embs)
|
| 365 |
with torch.no_grad():
|
| 366 |
sample = embs[:2000].to(DEVICE)
|
|
|
|
| 379 |
vols.append(v2[0].sqrt())
|
| 380 |
v_cv = (torch.stack(vols).std() / (torch.stack(vols).mean() + 1e-8)).item() if len(vols) > 10 else 0.0
|
| 381 |
|
|
|
|
| 382 |
with torch.no_grad():
|
| 383 |
_, v_np = model.constellation.triangulate(
|
| 384 |
embs[:2000].to(DEVICE), training=False)
|
| 385 |
n_active = v_np.cpu().unique().numel()
|
| 386 |
|
| 387 |
writer.add_scalar("epoch/val_acc", val_acc, epoch + 1)
|
| 388 |
+
writer.add_scalar("epoch/val_b_acc", val_b_acc, epoch + 1)
|
| 389 |
writer.add_scalar("epoch/val_geo_acc", val_geo_acc, epoch + 1)
|
| 390 |
writer.add_scalar("epoch/val_cv", v_cv, epoch + 1)
|
| 391 |
writer.add_scalar("epoch/val_anchors", n_active, epoch + 1)
|
| 392 |
|
|
|
|
| 393 |
mastery.update_size(train_acc, val_acc, epoch + 1)
|
| 394 |
|
| 395 |
# ββ Checkpoint ββ
|
|
|
|
| 401 |
"config": model.config,
|
| 402 |
"epoch": epoch + 1,
|
| 403 |
"val_acc": val_acc,
|
| 404 |
+
"val_b_acc": val_b_acc,
|
| 405 |
"val_geo_acc": val_geo_acc,
|
| 406 |
"mastery": mastery.state_dict(),
|
| 407 |
"gal_updates": gal_update_count,
|
| 408 |
+
}, "checkpoints/tri_stream_v8_best.pt")
|
| 409 |
mk = " β
"
|
| 410 |
|
| 411 |
if (epoch + 1) % 10 == 0:
|
|
|
|
| 415 |
"epoch": epoch + 1,
|
| 416 |
"val_acc": val_acc,
|
| 417 |
"optimizer": optimizer.state_dict(),
|
| 418 |
+
}, f"checkpoints/tri_stream_v8_e{epoch+1:03d}.pt")
|
| 419 |
|
| 420 |
+
# ββ Epoch print β v8: shows B acc + nce_b + geo_nce ββ
|
| 421 |
ga = 100 * acc_dict["geo_acc"] / d
|
| 422 |
ab = 100 * acc_dict["acc_b"] / d
|
| 423 |
+
nb_acc = acc_dict["nce_b_acc"] / d
|
| 424 |
+
gn_acc = acc_dict["geo_nce_acc"] / d
|
|
|
|
| 425 |
cvf = acc_dict["cv_main"] / d
|
| 426 |
cvg = acc_dict["cv_geo"] / d
|
| 427 |
cmv = acc_dict["cm_valid"] / d
|
| 428 |
stage = "MASTERY" if mastery.active else "stage1"
|
| 429 |
|
| 430 |
+
# Gate check
|
| 431 |
+
last_gates = []
|
| 432 |
try:
|
| 433 |
model.eval()
|
| 434 |
with torch.no_grad():
|
| 435 |
sample_imgs = next(iter(val_loader))[0][:4].to(DEVICE)
|
| 436 |
sample_out = model(sample_imgs, apply_autograd=False)
|
| 437 |
+
last_gates = sample_out.get('gates_a', [])
|
| 438 |
except:
|
| 439 |
pass
|
| 440 |
+
gate_str = f"g={np.mean(last_gates):.4f}" if last_gates else "g=?"
|
| 441 |
|
| 442 |
print(f" E{epoch+1:3d}: A={train_acc:.1f}% B={ab:.0f}% "
|
| 443 |
+
f"val={val_acc:.1f}%/{val_b_acc:.1f}%/{val_geo_acc:.1f}% "
|
| 444 |
f"loss={acc_dict['loss']/d:.4f}/{val_loss:.4f} "
|
| 445 |
+
f"nb={nb_acc:.2f} gn={gn_acc:.2f} "
|
| 446 |
f"cv={v_cv:.4f}(m={cvf:.5f} g={cvg:.5f}) "
|
| 447 |
f"cm={cmv:.0%} anch={n_active}/{N_ANCHORS} "
|
| 448 |
+
f"[{stage}] {gate_str} "
|
|
|
|
|
|
|
| 449 |
f"gal={gal_update_count} ({elapsed:.0f}s){mk}")
|
| 450 |
|
| 451 |
writer.close()
|