Update trainer_model.py
Browse files- trainer_model.py +125 -8
trainer_model.py
CHANGED
|
@@ -223,25 +223,110 @@ class GeoLIPCore(nn.Module):
|
|
| 223 |
ld['nce'] = l_nce
|
| 224 |
ld['nce_acc'] = nce_acc
|
| 225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
# CV
|
| 227 |
l_cv = self._cv_loss(emb)
|
| 228 |
ld['cv'] = l_cv
|
| 229 |
|
| 230 |
# Anchor spread
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
mask = ~torch.eye(an.shape[0], dtype=torch.bool, device=an.device)
|
| 234 |
l_spread = F.relu(sim_a[mask]).mean()
|
| 235 |
ld['spread'] = l_spread
|
| 236 |
|
| 237 |
# Total
|
| 238 |
loss = (l_ce
|
| 239 |
+ ld.get('nce', 0.0) * 1.0
|
|
|
|
| 240 |
+ l_cv * 0.01
|
| 241 |
+ l_spread * 0.001)
|
| 242 |
ld['total'] = loss
|
| 243 |
return loss, ld
|
| 244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
def _cv_loss(self, emb, n_samples=64, n_points=5):
|
| 246 |
B = emb.shape[0]
|
| 247 |
if B < n_points: return torch.tensor(0.0, device=emb.device)
|
|
@@ -295,8 +380,8 @@ N_ANCHORS = 64
|
|
| 295 |
N_COMP = 8
|
| 296 |
D_COMP = 64
|
| 297 |
BATCH = 256
|
| 298 |
-
EPOCHS =
|
| 299 |
-
LR = 3e-
|
| 300 |
|
| 301 |
print("=" * 60)
|
| 302 |
print("GeoLIP Core β Conv + Constellation + Patchwork")
|
|
@@ -359,15 +444,24 @@ writer = SummaryWriter("runs/geolip_core")
|
|
| 359 |
best_acc = 0.0
|
| 360 |
gs = 0
|
| 361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
print(f"\n{'='*60}")
|
| 363 |
print(f"TRAINING β {EPOCHS} epochs")
|
|
|
|
| 364 |
print(f"{'='*60}")
|
| 365 |
|
| 366 |
for epoch in range(EPOCHS):
|
| 367 |
model.train()
|
| 368 |
t0 = time.time()
|
| 369 |
tot_loss, tot_ce, tot_nce, tot_cv = 0, 0, 0, 0
|
| 370 |
-
tot_acc, tot_nce_acc, n = 0, 0, 0
|
| 371 |
correct, total = 0, 0
|
| 372 |
|
| 373 |
pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b")
|
|
@@ -389,11 +483,29 @@ for epoch in range(EPOCHS):
|
|
| 389 |
scheduler.step()
|
| 390 |
gs += 1
|
| 391 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
preds = out1['logits'].argmax(-1)
|
| 393 |
correct += (preds == targets).sum().item()
|
| 394 |
total += targets.shape[0]
|
| 395 |
tot_loss += loss.item()
|
| 396 |
tot_nce_acc += ld.get('nce_acc', 0)
|
|
|
|
| 397 |
n += 1
|
| 398 |
|
| 399 |
if n % 10 == 0:
|
|
@@ -401,6 +513,8 @@ for epoch in range(EPOCHS):
|
|
| 401 |
loss=f"{tot_loss/n:.4f}",
|
| 402 |
acc=f"{100*correct/total:.0f}%",
|
| 403 |
nce=f"{tot_nce_acc/n:.2f}",
|
|
|
|
|
|
|
| 404 |
ordered=True)
|
| 405 |
|
| 406 |
elapsed = time.time() - t0
|
|
@@ -449,6 +563,8 @@ for epoch in range(EPOCHS):
|
|
| 449 |
writer.add_scalar("epoch/val_acc", val_acc, epoch+1)
|
| 450 |
writer.add_scalar("epoch/val_cv", v_cv, epoch+1)
|
| 451 |
writer.add_scalar("epoch/anchors", n_active, epoch+1)
|
|
|
|
|
|
|
| 452 |
|
| 453 |
mk = ""
|
| 454 |
if val_acc > best_acc:
|
|
@@ -462,11 +578,12 @@ for epoch in range(EPOCHS):
|
|
| 462 |
mk = " β
"
|
| 463 |
|
| 464 |
nce_m = tot_nce_acc / n
|
|
|
|
| 465 |
cv_band = "β" if 0.18 <= v_cv <= 0.25 else "β"
|
| 466 |
print(f" E{epoch+1:3d}: train={train_acc:.1f}% val={val_acc:.1f}% "
|
| 467 |
-
f"loss={tot_loss/n:.4f} nce={nce_m:.2f} "
|
| 468 |
f"cv={v_cv:.4f}({cv_band}) anch={n_active}/{N_ANCHORS} "
|
| 469 |
-
f"({elapsed:.0f}s){mk}")
|
| 470 |
|
| 471 |
writer.close()
|
| 472 |
print(f"\n Best val accuracy: {best_acc:.1f}%")
|
|
|
|
| 223 |
ld['nce'] = l_nce
|
| 224 |
ld['nce_acc'] = nce_acc
|
| 225 |
|
| 226 |
+
# ββ Anchor attraction: pull each embedding toward its nearest anchor ββ
|
| 227 |
+
anchors_n = F.normalize(self.constellation.anchors, dim=-1)
|
| 228 |
+
cos_to_anchors = emb @ anchors_n.T # (B, n_anchors)
|
| 229 |
+
nearest_cos = cos_to_anchors.max(dim=1).values # (B,)
|
| 230 |
+
l_attract = (1.0 - nearest_cos).mean() # 0 when on top of anchor
|
| 231 |
+
ld['attract'] = l_attract
|
| 232 |
+
ld['nearest_cos'] = nearest_cos.mean().item()
|
| 233 |
+
|
| 234 |
# CV
|
| 235 |
l_cv = self._cv_loss(emb)
|
| 236 |
ld['cv'] = l_cv
|
| 237 |
|
| 238 |
# Anchor spread
|
| 239 |
+
sim_a = anchors_n @ anchors_n.T
|
| 240 |
+
mask = ~torch.eye(anchors_n.shape[0], dtype=torch.bool, device=anchors_n.device)
|
|
|
|
| 241 |
l_spread = F.relu(sim_a[mask]).mean()
|
| 242 |
ld['spread'] = l_spread
|
| 243 |
|
| 244 |
# Total
|
| 245 |
loss = (l_ce
|
| 246 |
+ ld.get('nce', 0.0) * 1.0
|
| 247 |
+
+ l_attract * 0.5
|
| 248 |
+ l_cv * 0.01
|
| 249 |
+ l_spread * 0.001)
|
| 250 |
ld['total'] = loss
|
| 251 |
return loss, ld
|
| 252 |
|
| 253 |
+
@torch.no_grad()
|
| 254 |
+
def push_anchors_to_centroids(self, emb_buffer, label_buffer, lr=0.1):
|
| 255 |
+
"""
|
| 256 |
+
Push anchors toward CLASS centroids, not nearest-anchor centroids.
|
| 257 |
+
|
| 258 |
+
Phase 1: Compute class centroids from labels
|
| 259 |
+
Phase 2: Each class owns (n_anchors / n_classes) anchors
|
| 260 |
+
Phase 3: Assigned anchors blend toward their class centroid
|
| 261 |
+
with small angular offsets so they don't all collapse
|
| 262 |
+
|
| 263 |
+
This works even when anchors start bunched at origin.
|
| 264 |
+
"""
|
| 265 |
+
anchors = self.constellation.anchors.data # (A, D)
|
| 266 |
+
n_a = anchors.shape[0]
|
| 267 |
+
emb_n = F.normalize(emb_buffer, dim=-1)
|
| 268 |
+
device = anchors.device
|
| 269 |
+
|
| 270 |
+
# Phase 1: class centroids
|
| 271 |
+
classes = label_buffer.unique()
|
| 272 |
+
n_cls = classes.shape[0]
|
| 273 |
+
centroids = []
|
| 274 |
+
for c in classes:
|
| 275 |
+
mask = label_buffer == c
|
| 276 |
+
if mask.sum() > 0:
|
| 277 |
+
centroids.append(F.normalize(emb_n[mask].mean(0, keepdim=True), dim=-1))
|
| 278 |
+
if len(centroids) == 0:
|
| 279 |
+
return 0
|
| 280 |
+
centroids = torch.cat(centroids, dim=0) # (C, D)
|
| 281 |
+
|
| 282 |
+
# Phase 2: assign anchors to classes round-robin
|
| 283 |
+
# Sort anchors by cosine to each centroid, greedily assign
|
| 284 |
+
anchors_n = F.normalize(anchors, dim=-1)
|
| 285 |
+
cos = anchors_n @ centroids.T # (A, C)
|
| 286 |
+
anchors_per_class = n_a // n_cls
|
| 287 |
+
assigned_class = torch.full((n_a,), -1, dtype=torch.long, device=device)
|
| 288 |
+
class_count = torch.zeros(n_cls, dtype=torch.long, device=device)
|
| 289 |
+
|
| 290 |
+
# Greedy: for each anchor, assign to its best class if that class has room
|
| 291 |
+
_, flat_idx = cos.flatten().sort(descending=True)
|
| 292 |
+
for idx in flat_idx:
|
| 293 |
+
a = (idx // n_cls).item()
|
| 294 |
+
c = (idx % n_cls).item()
|
| 295 |
+
if assigned_class[a] >= 0:
|
| 296 |
+
continue
|
| 297 |
+
if class_count[c] >= anchors_per_class + 1: # +1 for remainder
|
| 298 |
+
continue
|
| 299 |
+
assigned_class[a] = c
|
| 300 |
+
class_count[c] += 1
|
| 301 |
+
if (assigned_class >= 0).all():
|
| 302 |
+
break
|
| 303 |
+
|
| 304 |
+
# Unassigned leftovers β nearest centroid
|
| 305 |
+
unassigned = (assigned_class < 0).nonzero(as_tuple=True)[0]
|
| 306 |
+
if len(unassigned) > 0:
|
| 307 |
+
leftover_cos = anchors_n[unassigned] @ centroids.T
|
| 308 |
+
assigned_class[unassigned] = leftover_cos.argmax(dim=1)
|
| 309 |
+
|
| 310 |
+
# Phase 3: push each anchor toward its class centroid
|
| 311 |
+
moved = 0
|
| 312 |
+
for a in range(n_a):
|
| 313 |
+
c = assigned_class[a].item()
|
| 314 |
+
target = centroids[c]
|
| 315 |
+
# Add small angular offset so co-class anchors don't collapse
|
| 316 |
+
rank_in_class = (assigned_class[:a] == c).sum().item()
|
| 317 |
+
if anchors_per_class > 1 and rank_in_class > 0:
|
| 318 |
+
# Tiny perpendicular perturbation
|
| 319 |
+
noise = torch.randn_like(target) * 0.05
|
| 320 |
+
noise = noise - (noise * target).sum() * target # project out radial
|
| 321 |
+
target = F.normalize((target + noise).unsqueeze(0), dim=-1).squeeze(0)
|
| 322 |
+
|
| 323 |
+
anchors[a] = F.normalize(
|
| 324 |
+
(anchors_n[a] + lr * (target - anchors_n[a])).unsqueeze(0),
|
| 325 |
+
dim=-1).squeeze(0)
|
| 326 |
+
moved += 1
|
| 327 |
+
|
| 328 |
+
return moved
|
| 329 |
+
|
| 330 |
def _cv_loss(self, emb, n_samples=64, n_points=5):
|
| 331 |
B = emb.shape[0]
|
| 332 |
if B < n_points: return torch.tensor(0.0, device=emb.device)
|
|
|
|
| 380 |
N_COMP = 8
|
| 381 |
D_COMP = 64
|
| 382 |
BATCH = 256
|
| 383 |
+
EPOCHS = 100
|
| 384 |
+
LR = 3e-4
|
| 385 |
|
| 386 |
print("=" * 60)
|
| 387 |
print("GeoLIP Core β Conv + Constellation + Patchwork")
|
|
|
|
| 444 |
best_acc = 0.0
|
| 445 |
gs = 0
|
| 446 |
|
| 447 |
+
# Anchor push config
|
| 448 |
+
PUSH_INTERVAL = 50 # batches between centroid pushes
|
| 449 |
+
PUSH_LR = 0.1 # blend rate toward centroid
|
| 450 |
+
PUSH_BUFFER_SIZE = 5000
|
| 451 |
+
emb_buffer = None # (N, D) accumulated embeddings
|
| 452 |
+
lbl_buffer = None # (N,) accumulated labels
|
| 453 |
+
push_count = 0
|
| 454 |
+
|
| 455 |
print(f"\n{'='*60}")
|
| 456 |
print(f"TRAINING β {EPOCHS} epochs")
|
| 457 |
+
print(f" Anchor push: every {PUSH_INTERVAL} batches, lr={PUSH_LR}")
|
| 458 |
print(f"{'='*60}")
|
| 459 |
|
| 460 |
for epoch in range(EPOCHS):
|
| 461 |
model.train()
|
| 462 |
t0 = time.time()
|
| 463 |
tot_loss, tot_ce, tot_nce, tot_cv = 0, 0, 0, 0
|
| 464 |
+
tot_acc, tot_nce_acc, tot_nearest_cos, n = 0, 0, 0, 0
|
| 465 |
correct, total = 0, 0
|
| 466 |
|
| 467 |
pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b")
|
|
|
|
| 483 |
scheduler.step()
|
| 484 |
gs += 1
|
| 485 |
|
| 486 |
+
# ββ Accumulate embeddings for anchor push ββ
|
| 487 |
+
with torch.no_grad():
|
| 488 |
+
batch_emb = out1['embedding'].detach().float()
|
| 489 |
+
if emb_buffer is None:
|
| 490 |
+
emb_buffer = batch_emb
|
| 491 |
+
lbl_buffer = targets.detach()
|
| 492 |
+
else:
|
| 493 |
+
emb_buffer = torch.cat([emb_buffer, batch_emb])[-PUSH_BUFFER_SIZE:]
|
| 494 |
+
lbl_buffer = torch.cat([lbl_buffer, targets.detach()])[-PUSH_BUFFER_SIZE:]
|
| 495 |
+
|
| 496 |
+
# ββ Periodic anchor push toward class centroids ββ
|
| 497 |
+
if gs % PUSH_INTERVAL == 0 and emb_buffer is not None and emb_buffer.shape[0] > 500:
|
| 498 |
+
moved = model.push_anchors_to_centroids(
|
| 499 |
+
emb_buffer, lbl_buffer, lr=PUSH_LR)
|
| 500 |
+
push_count += 1
|
| 501 |
+
writer.add_scalar("step/anchors_moved", moved, gs)
|
| 502 |
+
|
| 503 |
preds = out1['logits'].argmax(-1)
|
| 504 |
correct += (preds == targets).sum().item()
|
| 505 |
total += targets.shape[0]
|
| 506 |
tot_loss += loss.item()
|
| 507 |
tot_nce_acc += ld.get('nce_acc', 0)
|
| 508 |
+
tot_nearest_cos += ld.get('nearest_cos', 0)
|
| 509 |
n += 1
|
| 510 |
|
| 511 |
if n % 10 == 0:
|
|
|
|
| 513 |
loss=f"{tot_loss/n:.4f}",
|
| 514 |
acc=f"{100*correct/total:.0f}%",
|
| 515 |
nce=f"{tot_nce_acc/n:.2f}",
|
| 516 |
+
cos=f"{ld.get('nearest_cos', 0):.3f}",
|
| 517 |
+
push=push_count,
|
| 518 |
ordered=True)
|
| 519 |
|
| 520 |
elapsed = time.time() - t0
|
|
|
|
| 563 |
writer.add_scalar("epoch/val_acc", val_acc, epoch+1)
|
| 564 |
writer.add_scalar("epoch/val_cv", v_cv, epoch+1)
|
| 565 |
writer.add_scalar("epoch/anchors", n_active, epoch+1)
|
| 566 |
+
writer.add_scalar("epoch/nearest_cos", tot_nearest_cos / n, epoch+1)
|
| 567 |
+
writer.add_scalar("epoch/push_count", push_count, epoch+1)
|
| 568 |
|
| 569 |
mk = ""
|
| 570 |
if val_acc > best_acc:
|
|
|
|
| 578 |
mk = " β
"
|
| 579 |
|
| 580 |
nce_m = tot_nce_acc / n
|
| 581 |
+
cos_m = tot_nearest_cos / n
|
| 582 |
cv_band = "β" if 0.18 <= v_cv <= 0.25 else "β"
|
| 583 |
print(f" E{epoch+1:3d}: train={train_acc:.1f}% val={val_acc:.1f}% "
|
| 584 |
+
f"loss={tot_loss/n:.4f} nce={nce_m:.2f} cos={cos_m:.3f} "
|
| 585 |
f"cv={v_cv:.4f}({cv_band}) anch={n_active}/{N_ANCHORS} "
|
| 586 |
+
f"push={push_count} ({elapsed:.0f}s){mk}")
|
| 587 |
|
| 588 |
writer.close()
|
| 589 |
print(f"\n Best val accuracy: {best_acc:.1f}%")
|