AbstractPhil commited on
Commit
095ed4f
·
verified ·
1 Parent(s): e255399

Update modeling_core.py

Browse files
Files changed (1) hide show
  1. modeling_core.py +89 -4
modeling_core.py CHANGED
@@ -1,3 +1,4 @@
 
1
  """
2
  GeoLIP Core — Back to Basics
3
  ==============================
@@ -222,25 +223,110 @@ class GeoLIPCore(nn.Module):
222
  ld['nce'] = l_nce
223
  ld['nce_acc'] = nce_acc
224
 
 
 
 
 
 
 
 
 
225
  # CV
226
  l_cv = self._cv_loss(emb)
227
  ld['cv'] = l_cv
228
 
229
  # Anchor spread
230
- an = F.normalize(self.constellation.anchors, dim=-1)
231
- sim_a = an @ an.T
232
- mask = ~torch.eye(an.shape[0], dtype=torch.bool, device=an.device)
233
  l_spread = F.relu(sim_a[mask]).mean()
234
  ld['spread'] = l_spread
235
 
236
  # Total
237
  loss = (l_ce
238
  + ld.get('nce', 0.0) * 1.0
 
239
  + l_cv * 0.01
240
  + l_spread * 0.001)
241
  ld['total'] = loss
242
  return loss, ld
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  def _cv_loss(self, emb, n_samples=64, n_points=5):
245
  B = emb.shape[0]
246
  if B < n_points: return torch.tensor(0.0, device=emb.device)
@@ -265,4 +351,3 @@ class GeoLIPCore(nn.Module):
265
  vt = torch.stack(vols)
266
  cv = vt.std() / (vt.mean() + 1e-8)
267
  return (cv - self.cv_target).pow(2)
268
-
 
1
+ #!/usr/bin/env python3
2
  """
3
  GeoLIP Core — Back to Basics
4
  ==============================
 
223
  ld['nce'] = l_nce
224
  ld['nce_acc'] = nce_acc
225
 
226
+ # ── Anchor attraction: pull each embedding toward its nearest anchor ──
227
+ anchors_n = F.normalize(self.constellation.anchors, dim=-1)
228
+ cos_to_anchors = emb @ anchors_n.T # (B, n_anchors)
229
+ nearest_cos = cos_to_anchors.max(dim=1).values # (B,)
230
+ l_attract = (1.0 - nearest_cos).mean() # 0 when on top of anchor
231
+ ld['attract'] = l_attract
232
+ ld['nearest_cos'] = nearest_cos.mean().item()
233
+
234
  # CV
235
  l_cv = self._cv_loss(emb)
236
  ld['cv'] = l_cv
237
 
238
  # Anchor spread
239
+ sim_a = anchors_n @ anchors_n.T
240
+ mask = ~torch.eye(anchors_n.shape[0], dtype=torch.bool, device=anchors_n.device)
 
241
  l_spread = F.relu(sim_a[mask]).mean()
242
  ld['spread'] = l_spread
243
 
244
  # Total
245
  loss = (l_ce
246
  + ld.get('nce', 0.0) * 1.0
247
+ + l_attract * 0.5
248
  + l_cv * 0.01
249
  + l_spread * 0.001)
250
  ld['total'] = loss
251
  return loss, ld
252
 
253
+ @torch.no_grad()
254
+ def push_anchors_to_centroids(self, emb_buffer, label_buffer, lr=0.1):
255
+ """
256
+ Push anchors toward CLASS centroids, not nearest-anchor centroids.
257
+
258
+ Phase 1: Compute class centroids from labels
259
+ Phase 2: Each class owns (n_anchors / n_classes) anchors
260
+ Phase 3: Assigned anchors blend toward their class centroid
261
+ with small angular offsets so they don't all collapse
262
+
263
+ This works even when anchors start bunched at origin.
264
+ """
265
+ anchors = self.constellation.anchors.data # (A, D)
266
+ n_a = anchors.shape[0]
267
+ emb_n = F.normalize(emb_buffer, dim=-1)
268
+ device = anchors.device
269
+
270
+ # Phase 1: class centroids
271
+ classes = label_buffer.unique()
272
+ n_cls = classes.shape[0]
273
+ centroids = []
274
+ for c in classes:
275
+ mask = label_buffer == c
276
+ if mask.sum() > 0:
277
+ centroids.append(F.normalize(emb_n[mask].mean(0, keepdim=True), dim=-1))
278
+ if len(centroids) == 0:
279
+ return 0
280
+ centroids = torch.cat(centroids, dim=0) # (C, D)
281
+
282
+ # Phase 2: assign anchors to classes round-robin
283
+ # Sort anchors by cosine to each centroid, greedily assign
284
+ anchors_n = F.normalize(anchors, dim=-1)
285
+ cos = anchors_n @ centroids.T # (A, C)
286
+ anchors_per_class = n_a // n_cls
287
+ assigned_class = torch.full((n_a,), -1, dtype=torch.long, device=device)
288
+ class_count = torch.zeros(n_cls, dtype=torch.long, device=device)
289
+
290
+ # Greedy: for each anchor, assign to its best class if that class has room
291
+ _, flat_idx = cos.flatten().sort(descending=True)
292
+ for idx in flat_idx:
293
+ a = (idx // n_cls).item()
294
+ c = (idx % n_cls).item()
295
+ if assigned_class[a] >= 0:
296
+ continue
297
+ if class_count[c] >= anchors_per_class + 1: # +1 for remainder
298
+ continue
299
+ assigned_class[a] = c
300
+ class_count[c] += 1
301
+ if (assigned_class >= 0).all():
302
+ break
303
+
304
+ # Unassigned leftovers → nearest centroid
305
+ unassigned = (assigned_class < 0).nonzero(as_tuple=True)[0]
306
+ if len(unassigned) > 0:
307
+ leftover_cos = anchors_n[unassigned] @ centroids.T
308
+ assigned_class[unassigned] = leftover_cos.argmax(dim=1)
309
+
310
+ # Phase 3: push each anchor toward its class centroid
311
+ moved = 0
312
+ for a in range(n_a):
313
+ c = assigned_class[a].item()
314
+ target = centroids[c]
315
+ # Add small angular offset so co-class anchors don't collapse
316
+ rank_in_class = (assigned_class[:a] == c).sum().item()
317
+ if anchors_per_class > 1 and rank_in_class > 0:
318
+ # Tiny perpendicular perturbation
319
+ noise = torch.randn_like(target) * 0.05
320
+ noise = noise - (noise * target).sum() * target # project out radial
321
+ target = F.normalize((target + noise).unsqueeze(0), dim=-1).squeeze(0)
322
+
323
+ anchors[a] = F.normalize(
324
+ (anchors_n[a] + lr * (target - anchors_n[a])).unsqueeze(0),
325
+ dim=-1).squeeze(0)
326
+ moved += 1
327
+
328
+ return moved
329
+
330
  def _cv_loss(self, emb, n_samples=64, n_points=5):
331
  B = emb.shape[0]
332
  if B < n_points: return torch.tensor(0.0, device=emb.device)
 
351
  vt = torch.stack(vols)
352
  cv = vt.std() / (vt.mean() + 1e-8)
353
  return (cv - self.cv_target).pow(2)