Update modeling_core.py
Browse files- modeling_core.py +89 -4
modeling_core.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
"""
|
| 2 |
GeoLIP Core — Back to Basics
|
| 3 |
==============================
|
|
@@ -222,25 +223,110 @@ class GeoLIPCore(nn.Module):
|
|
| 222 |
ld['nce'] = l_nce
|
| 223 |
ld['nce_acc'] = nce_acc
|
| 224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
# CV
|
| 226 |
l_cv = self._cv_loss(emb)
|
| 227 |
ld['cv'] = l_cv
|
| 228 |
|
| 229 |
# Anchor spread
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
mask = ~torch.eye(an.shape[0], dtype=torch.bool, device=an.device)
|
| 233 |
l_spread = F.relu(sim_a[mask]).mean()
|
| 234 |
ld['spread'] = l_spread
|
| 235 |
|
| 236 |
# Total
|
| 237 |
loss = (l_ce
|
| 238 |
+ ld.get('nce', 0.0) * 1.0
|
|
|
|
| 239 |
+ l_cv * 0.01
|
| 240 |
+ l_spread * 0.001)
|
| 241 |
ld['total'] = loss
|
| 242 |
return loss, ld
|
| 243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
def _cv_loss(self, emb, n_samples=64, n_points=5):
|
| 245 |
B = emb.shape[0]
|
| 246 |
if B < n_points: return torch.tensor(0.0, device=emb.device)
|
|
@@ -265,4 +351,3 @@ class GeoLIPCore(nn.Module):
|
|
| 265 |
vt = torch.stack(vols)
|
| 266 |
cv = vt.std() / (vt.mean() + 1e-8)
|
| 267 |
return (cv - self.cv_target).pow(2)
|
| 268 |
-
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
GeoLIP Core — Back to Basics
|
| 4 |
==============================
|
|
|
|
| 223 |
ld['nce'] = l_nce
|
| 224 |
ld['nce_acc'] = nce_acc
|
| 225 |
|
| 226 |
+
# ── Anchor attraction: pull each embedding toward its nearest anchor ──
|
| 227 |
+
anchors_n = F.normalize(self.constellation.anchors, dim=-1)
|
| 228 |
+
cos_to_anchors = emb @ anchors_n.T # (B, n_anchors)
|
| 229 |
+
nearest_cos = cos_to_anchors.max(dim=1).values # (B,)
|
| 230 |
+
l_attract = (1.0 - nearest_cos).mean() # 0 when on top of anchor
|
| 231 |
+
ld['attract'] = l_attract
|
| 232 |
+
ld['nearest_cos'] = nearest_cos.mean().item()
|
| 233 |
+
|
| 234 |
# CV
|
| 235 |
l_cv = self._cv_loss(emb)
|
| 236 |
ld['cv'] = l_cv
|
| 237 |
|
| 238 |
# Anchor spread
|
| 239 |
+
sim_a = anchors_n @ anchors_n.T
|
| 240 |
+
mask = ~torch.eye(anchors_n.shape[0], dtype=torch.bool, device=anchors_n.device)
|
|
|
|
| 241 |
l_spread = F.relu(sim_a[mask]).mean()
|
| 242 |
ld['spread'] = l_spread
|
| 243 |
|
| 244 |
# Total
|
| 245 |
loss = (l_ce
|
| 246 |
+ ld.get('nce', 0.0) * 1.0
|
| 247 |
+
+ l_attract * 0.5
|
| 248 |
+ l_cv * 0.01
|
| 249 |
+ l_spread * 0.001)
|
| 250 |
ld['total'] = loss
|
| 251 |
return loss, ld
|
| 252 |
|
| 253 |
+
@torch.no_grad()
|
| 254 |
+
def push_anchors_to_centroids(self, emb_buffer, label_buffer, lr=0.1):
|
| 255 |
+
"""
|
| 256 |
+
Push anchors toward CLASS centroids, not nearest-anchor centroids.
|
| 257 |
+
|
| 258 |
+
Phase 1: Compute class centroids from labels
|
| 259 |
+
Phase 2: Each class owns (n_anchors / n_classes) anchors
|
| 260 |
+
Phase 3: Assigned anchors blend toward their class centroid
|
| 261 |
+
with small angular offsets so they don't all collapse
|
| 262 |
+
|
| 263 |
+
This works even when anchors start bunched at origin.
|
| 264 |
+
"""
|
| 265 |
+
anchors = self.constellation.anchors.data # (A, D)
|
| 266 |
+
n_a = anchors.shape[0]
|
| 267 |
+
emb_n = F.normalize(emb_buffer, dim=-1)
|
| 268 |
+
device = anchors.device
|
| 269 |
+
|
| 270 |
+
# Phase 1: class centroids
|
| 271 |
+
classes = label_buffer.unique()
|
| 272 |
+
n_cls = classes.shape[0]
|
| 273 |
+
centroids = []
|
| 274 |
+
for c in classes:
|
| 275 |
+
mask = label_buffer == c
|
| 276 |
+
if mask.sum() > 0:
|
| 277 |
+
centroids.append(F.normalize(emb_n[mask].mean(0, keepdim=True), dim=-1))
|
| 278 |
+
if len(centroids) == 0:
|
| 279 |
+
return 0
|
| 280 |
+
centroids = torch.cat(centroids, dim=0) # (C, D)
|
| 281 |
+
|
| 282 |
+
# Phase 2: assign anchors to classes round-robin
|
| 283 |
+
# Sort anchors by cosine to each centroid, greedily assign
|
| 284 |
+
anchors_n = F.normalize(anchors, dim=-1)
|
| 285 |
+
cos = anchors_n @ centroids.T # (A, C)
|
| 286 |
+
anchors_per_class = n_a // n_cls
|
| 287 |
+
assigned_class = torch.full((n_a,), -1, dtype=torch.long, device=device)
|
| 288 |
+
class_count = torch.zeros(n_cls, dtype=torch.long, device=device)
|
| 289 |
+
|
| 290 |
+
# Greedy: for each anchor, assign to its best class if that class has room
|
| 291 |
+
_, flat_idx = cos.flatten().sort(descending=True)
|
| 292 |
+
for idx in flat_idx:
|
| 293 |
+
a = (idx // n_cls).item()
|
| 294 |
+
c = (idx % n_cls).item()
|
| 295 |
+
if assigned_class[a] >= 0:
|
| 296 |
+
continue
|
| 297 |
+
if class_count[c] >= anchors_per_class + 1: # +1 for remainder
|
| 298 |
+
continue
|
| 299 |
+
assigned_class[a] = c
|
| 300 |
+
class_count[c] += 1
|
| 301 |
+
if (assigned_class >= 0).all():
|
| 302 |
+
break
|
| 303 |
+
|
| 304 |
+
# Unassigned leftovers → nearest centroid
|
| 305 |
+
unassigned = (assigned_class < 0).nonzero(as_tuple=True)[0]
|
| 306 |
+
if len(unassigned) > 0:
|
| 307 |
+
leftover_cos = anchors_n[unassigned] @ centroids.T
|
| 308 |
+
assigned_class[unassigned] = leftover_cos.argmax(dim=1)
|
| 309 |
+
|
| 310 |
+
# Phase 3: push each anchor toward its class centroid
|
| 311 |
+
moved = 0
|
| 312 |
+
for a in range(n_a):
|
| 313 |
+
c = assigned_class[a].item()
|
| 314 |
+
target = centroids[c]
|
| 315 |
+
# Add small angular offset so co-class anchors don't collapse
|
| 316 |
+
rank_in_class = (assigned_class[:a] == c).sum().item()
|
| 317 |
+
if anchors_per_class > 1 and rank_in_class > 0:
|
| 318 |
+
# Tiny perpendicular perturbation
|
| 319 |
+
noise = torch.randn_like(target) * 0.05
|
| 320 |
+
noise = noise - (noise * target).sum() * target # project out radial
|
| 321 |
+
target = F.normalize((target + noise).unsqueeze(0), dim=-1).squeeze(0)
|
| 322 |
+
|
| 323 |
+
anchors[a] = F.normalize(
|
| 324 |
+
(anchors_n[a] + lr * (target - anchors_n[a])).unsqueeze(0),
|
| 325 |
+
dim=-1).squeeze(0)
|
| 326 |
+
moved += 1
|
| 327 |
+
|
| 328 |
+
return moved
|
| 329 |
+
|
| 330 |
def _cv_loss(self, emb, n_samples=64, n_points=5):
|
| 331 |
B = emb.shape[0]
|
| 332 |
if B < n_points: return torch.tensor(0.0, device=emb.device)
|
|
|
|
| 351 |
vt = torch.stack(vols)
|
| 352 |
cv = vt.std() / (vt.mean() + 1e-8)
|
| 353 |
return (cv - self.cv_target).pow(2)
|
|
|