AbstractPhil commited on
Commit
e93e5ed
Β·
verified Β·
1 Parent(s): 2fc8ef5

Create geolip_loss.py

Browse files
Files changed (1) hide show
  1. geolip_loss.py +481 -0
geolip_loss.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GeoLIP Losses & Regularization
3
+ =================================
4
+ Every loss and metric in the GeoLIP pipeline, with uniform interfaces.
5
+
6
+ All loss functions: (inputs) β†’ scalar tensor (differentiable)
7
+ All metrics: (inputs) β†’ float (non-differentiable, for monitoring)
8
+
9
+ CV functions default to batched computation (141x speedup).
10
+ Set batched=False for sequential fallback.
11
+
12
+ Loss Spectrum (3 domains):
13
+ EXTERNAL: ce_loss, nce_loss (embedding-level)
14
+ GEOMETRIC: nce_loss (patchwork), bridge_loss
15
+ INTERNAL: assign_bce, assign_nce, nce_loss (triangulation),
16
+ attraction_loss, cv_loss, spread_loss
17
+
18
+ Metrics:
19
+ cv_metric, cv_multi_scale, cayley_menger_vol2
20
+
21
+ Compound:
22
+ three_domain_loss β€” the full cooperative loss from InternalConstellationCore
23
+
24
+ Usage:
25
+ from geolip_losses import cv_loss, cv_metric, nce_loss, three_domain_loss
26
+ """
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ import math
32
+
33
+
34
+ # ══════════════════════════════════════════════════════════════════
35
+ # CV β€” Coefficient of Variation of Pentachoron Volumes
36
+ # ══════════════════════════════════════════════════════════════════
37
+
38
+ def _batch_pentachoron_volumes(emb, n_samples=200, n_points=5):
39
+ """Compute pentachoron volumes in one batched operation. Zero Python loops.
40
+
41
+ Args:
42
+ emb: (N, D) embeddings on S^(d-1)
43
+ n_samples: random pentachora to sample
44
+ n_points: points per simplex (5 = pentachoron)
45
+
46
+ Returns:
47
+ (n_valid,) tensor of simplex volumes
48
+ """
49
+ N, D = emb.shape
50
+ device, dtype = emb.device, emb.dtype
51
+ pool = min(N, 512)
52
+
53
+ # Batched randperm via argsort on random values
54
+ indices = torch.rand(n_samples, pool, device=device).argsort(dim=1)[:, :n_points]
55
+ pts = emb[:pool][indices] # (n_samples, n_points, D)
56
+
57
+ gram = torch.bmm(pts, pts.transpose(1, 2))
58
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
59
+ d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
60
+
61
+ M = n_points + 1
62
+ cm = torch.zeros(n_samples, M, M, device=device, dtype=dtype)
63
+ cm[:, 0, 1:] = 1.0
64
+ cm[:, 1:, 0] = 1.0
65
+ cm[:, 1:, 1:] = d2
66
+
67
+ k = n_points - 1
68
+ pf = ((-1.0) ** (k + 1)) / ((2.0 ** k) * (math.factorial(k) ** 2))
69
+ dets = pf * torch.linalg.det(cm.float())
70
+
71
+ valid = dets > 1e-20
72
+ return dets[valid].to(dtype).sqrt()
73
+
74
+
75
+ def _sequential_pentachoron_volumes(emb, n_samples=200, n_points=5):
76
+ """Sequential fallback. One det call per sample."""
77
+ N = emb.shape[0]
78
+ device, dtype = emb.device, emb.dtype
79
+ vols = []
80
+ for _ in range(n_samples):
81
+ idx = torch.randperm(min(N, 512), device=device)[:n_points]
82
+ pts = emb[idx].unsqueeze(0)
83
+ gram = torch.bmm(pts, pts.transpose(1, 2))
84
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
85
+ d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
86
+ M = n_points + 1
87
+ cm = torch.zeros(1, M, M, device=device, dtype=dtype)
88
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
89
+ k = n_points - 1
90
+ pf = ((-1.0) ** (k + 1)) / ((2.0 ** k) * (math.factorial(k) ** 2))
91
+ v2 = pf * torch.linalg.det(cm.float())
92
+ if v2[0].item() > 1e-20:
93
+ vols.append(v2[0].to(dtype).sqrt())
94
+ if len(vols) < 5:
95
+ return torch.tensor([], device=device, dtype=dtype)
96
+ return torch.stack(vols)
97
+
98
+
99
+ def cv_loss(emb, target=0.22, n_samples=64, n_points=5, batched=True):
100
+ """Differentiable CV loss. Returns (CV - target)Β².
101
+
102
+ Args:
103
+ emb: (N, D) L2-normalized embeddings
104
+ target: CV target (0.22 = natural basin of S^(d-1) at eff_dim ~16)
105
+ n_samples: pentachora to sample (32-64 for training)
106
+ n_points: points per simplex
107
+ batched: use batched computation (141x faster, default True)
108
+
109
+ Returns:
110
+ scalar tensor, differentiable w.r.t. emb
111
+ """
112
+ if emb.shape[0] < n_points:
113
+ return torch.tensor(0.0, device=emb.device, requires_grad=True)
114
+
115
+ if batched:
116
+ vols = _batch_pentachoron_volumes(emb, n_samples, n_points)
117
+ else:
118
+ vols = _sequential_pentachoron_volumes(emb, n_samples, n_points)
119
+
120
+ if vols.shape[0] < 5:
121
+ return torch.tensor(0.0, device=emb.device, requires_grad=True)
122
+ cv = vols.std() / (vols.mean() + 1e-8)
123
+ return (cv - target).pow(2)
124
+
125
+
126
+ def cv_metric(emb, n_samples=200, n_points=5, batched=True):
127
+ """Non-differentiable CV for monitoring. Target band: 0.20–0.23.
128
+
129
+ Returns:
130
+ float: coefficient of variation of simplex volumes
131
+ """
132
+ with torch.no_grad():
133
+ if batched:
134
+ vols = _batch_pentachoron_volumes(emb, n_samples, n_points)
135
+ else:
136
+ vols = _sequential_pentachoron_volumes(emb, n_samples, n_points)
137
+ if vols.shape[0] < 10:
138
+ return 0.0
139
+ return (vols.std() / (vols.mean() + 1e-8)).item()
140
+
141
+
142
+ def cv_multi_scale(emb, scales=(3, 4, 5, 6, 7, 8), n_samples=100, batched=True):
143
+ """CV at multiple simplex sizes. Returns dict: {n_points: cv_value}.
144
+
145
+ Healthy geometry: all scales in [0.18, 0.25].
146
+ """
147
+ results = {}
148
+ with torch.no_grad():
149
+ for n_pts in scales:
150
+ if batched:
151
+ vols = _batch_pentachoron_volumes(emb, n_samples, n_pts)
152
+ else:
153
+ vols = _sequential_pentachoron_volumes(emb, n_samples, n_pts)
154
+ if vols.shape[0] >= 10:
155
+ results[n_pts] = round((vols.std() / (vols.mean() + 1e-8)).item(), 4)
156
+ else:
157
+ results[n_pts] = None
158
+ return results
159
+
160
+
161
+ def cayley_menger_vol2(points):
162
+ """Squared simplex volume. points: (B, N, D) β†’ (B,)."""
163
+ B, N, D = points.shape
164
+ gram = torch.bmm(points, points.transpose(1, 2))
165
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
166
+ d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
167
+ cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype)
168
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
169
+ k = N - 1
170
+ sign = (-1.0) ** (k + 1)
171
+ fact = math.factorial(k)
172
+ return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2))
173
+
174
+
175
+ # ══════════════════════════════════════════════════════════════════
176
+ # NCE β€” InfoNCE contrastive loss
177
+ # ══════════════════════════════════════════════════════════════════
178
+
179
+ def nce_loss(z1, z2, temperature=0.07, normalize=True):
180
+ """Symmetric InfoNCE between two views.
181
+
182
+ Args:
183
+ z1, z2: (B, D) embeddings from two augmented views
184
+ temperature: softmax temperature (lower = sharper)
185
+ normalize: L2-normalize before computing similarity
186
+
187
+ Returns:
188
+ scalar loss, float accuracy
189
+ """
190
+ if normalize:
191
+ z1 = F.normalize(z1, dim=-1)
192
+ z2 = F.normalize(z2, dim=-1)
193
+ B = z1.shape[0]
194
+ labels = torch.arange(B, device=z1.device)
195
+ sim = z1 @ z2.T / temperature
196
+ loss = F.cross_entropy(sim, labels)
197
+ acc = (sim.argmax(1) == labels).float().mean().item()
198
+ return loss, acc
199
+
200
+
201
+ # ══════════════════════════════════════════════════════════════════
202
+ # CLASSIFICATION
203
+ # ══════════════════════════════════════════════════════════════════
204
+
205
+ def ce_loss(logits, targets):
206
+ """Cross-entropy classification loss.
207
+
208
+ Args:
209
+ logits: (B, C) raw logits
210
+ targets: (B,) class indices
211
+
212
+ Returns:
213
+ scalar loss, float accuracy
214
+ """
215
+ loss = F.cross_entropy(logits, targets)
216
+ acc = (logits.argmax(-1) == targets).float().mean().item()
217
+ return loss, acc
218
+
219
+
220
+ def ce_loss_paired(logits1, logits2, targets):
221
+ """Averaged CE over two views.
222
+
223
+ Returns:
224
+ scalar loss, float accuracy (from view 1)
225
+ """
226
+ l1 = F.cross_entropy(logits1, targets)
227
+ l2 = F.cross_entropy(logits2, targets)
228
+ acc = (logits1.argmax(-1) == targets).float().mean().item()
229
+ return (l1 + l2) / 2, acc
230
+
231
+
232
+ # ══════════════════════════════════════════════════════════════════
233
+ # BRIDGE β€” patchwork predicts constellation's assignment
234
+ # ══════════════════════════════════════════════════════════════════
235
+
236
+ def bridge_loss(bridge_logits, assign_targets, detach_targets=True):
237
+ """Soft cross-entropy: patchwork predicts constellation's soft assignment.
238
+
239
+ One-way teaching: constellation β†’ patchwork.
240
+ Targets are detached so constellation is shaped only by internal losses.
241
+
242
+ Args:
243
+ bridge_logits: (B, A) raw logits from bridge head
244
+ assign_targets: (B, A) soft assignment from constellation
245
+ detach_targets: detach targets from graph (default True)
246
+
247
+ Returns:
248
+ scalar loss, float accuracy (hard agreement)
249
+ """
250
+ if detach_targets:
251
+ assign_targets = assign_targets.detach()
252
+ loss = -(assign_targets * F.log_softmax(bridge_logits, dim=-1)).sum(-1).mean()
253
+ acc = (bridge_logits.argmax(-1) == assign_targets.argmax(-1)).float().mean().item()
254
+ return loss, acc
255
+
256
+
257
+ def bridge_loss_paired(bridge1, bridge2, assign1, assign2, detach_targets=True):
258
+ """Bridge loss averaged over two views.
259
+
260
+ Returns:
261
+ scalar loss, float accuracy (from view 1)
262
+ """
263
+ l1, acc = bridge_loss(bridge1, assign1, detach_targets)
264
+ l2, _ = bridge_loss(bridge2, assign2, detach_targets)
265
+ return (l1 + l2) / 2, acc
266
+
267
+
268
+ # ══════════════════════════════════════════════════════════════════
269
+ # ASSIGNMENT β€” internal constellation self-organization
270
+ # ══════════════════════════════════════════════════════════════════
271
+
272
+ def assign_bce_loss(soft_assign, cos_to_anchors):
273
+ """Assignment crispness: BCE toward hard nearest-anchor target.
274
+
275
+ Args:
276
+ soft_assign: (B, A) softmax assignment
277
+ cos_to_anchors: (B, A) cosine similarities to anchors
278
+
279
+ Returns:
280
+ scalar loss, float entropy
281
+ """
282
+ nearest = cos_to_anchors.argmax(dim=-1)
283
+ hard = torch.zeros_like(soft_assign)
284
+ hard.scatter_(1, nearest.unsqueeze(1), 1.0)
285
+
286
+ with torch.amp.autocast("cuda", enabled=False):
287
+ loss = F.binary_cross_entropy(
288
+ soft_assign.float().clamp(1e-7, 1 - 1e-7),
289
+ hard.float(), reduction='mean')
290
+
291
+ entropy = -(soft_assign * soft_assign.clamp(min=1e-8).log()).sum(-1).mean().item()
292
+ return loss, entropy
293
+
294
+
295
+ def assign_nce_loss(assign1, assign2, temperature=0.1):
296
+ """Assignment consistency: NCE across two views.
297
+
298
+ Args:
299
+ assign1, assign2: (B, A) soft assignments from two views
300
+ temperature: softmax temperature
301
+
302
+ Returns:
303
+ scalar loss, float accuracy
304
+ """
305
+ B = assign1.shape[0]
306
+ labels = torch.arange(B, device=assign1.device)
307
+ sim = assign1 @ assign2.T / temperature
308
+ loss = F.cross_entropy(sim, labels)
309
+ acc = (sim.argmax(1) == labels).float().mean().item()
310
+ return loss, acc
311
+
312
+
313
+ # ══════════════════════════════════════════════════════════════════
314
+ # ATTRACTION β€” embeddings near their anchors
315
+ # ══════════════════════════════════════════════════════════════════
316
+
317
+ def attraction_loss(cos_to_anchors):
318
+ """Pull embeddings toward nearest anchor. Higher cos = closer.
319
+
320
+ Args:
321
+ cos_to_anchors: (B, A) cosine similarities
322
+
323
+ Returns:
324
+ scalar loss, float mean nearest cosine
325
+ """
326
+ nearest_cos = cos_to_anchors.max(dim=1).values
327
+ loss = (1.0 - nearest_cos).mean()
328
+ return loss, nearest_cos.mean().item()
329
+
330
+
331
+ # ══════════════════════════════════════════════════════════════════
332
+ # SPREAD β€” anchor repulsion
333
+ # ══════════════════════════════════════════════════════════════════
334
+
335
+ def spread_loss(anchors, target_cos=0.0):
336
+ """Repulsion loss keeping anchors spread on S^(d-1).
337
+
338
+ Args:
339
+ anchors: (A, D) anchor parameters
340
+ target_cos: cosine threshold (0.0 = orthogonal target)
341
+
342
+ Returns:
343
+ scalar loss
344
+ """
345
+ a = F.normalize(anchors, dim=-1)
346
+ sim = a @ a.T
347
+ mask = ~torch.eye(a.shape[0], dtype=torch.bool, device=a.device)
348
+ return F.relu(sim[mask] - target_cos).mean()
349
+
350
+
351
+ # ══════════════════════════════════════════════════════════════════
352
+ # kNN β€” non-differentiable validation metric
353
+ # ══════════════════════════════════════════════════════════════════
354
+
355
+ @torch.no_grad()
356
+ def knn_accuracy(embeddings, targets, k=1):
357
+ """k-NN classification accuracy in embedding space.
358
+
359
+ Args:
360
+ embeddings: (N, D) L2-normalized
361
+ targets: (N,) class labels
362
+ k: number of neighbors (1 for simple NN)
363
+
364
+ Returns:
365
+ float accuracy
366
+ """
367
+ sim = embeddings @ embeddings.T
368
+ sim.fill_diagonal_(-1)
369
+ if k == 1:
370
+ nn_idx = sim.argmax(dim=1)
371
+ return (targets[nn_idx] == targets).float().mean().item()
372
+ else:
373
+ _, topk_idx = sim.topk(k, dim=1)
374
+ nn_labels = targets[topk_idx] # (N, k)
375
+ # Majority vote
376
+ pred = nn_labels.mode(dim=1).values
377
+ return (pred == targets).float().mean().item()
378
+
379
+
380
+ # ══════════════════════════════════════════════════════════════════
381
+ # THREE-DOMAIN COMPOUND LOSS
382
+ # ════��═════════════════════════════════════════════════════════════
383
+
384
+ def three_domain_loss(output, targets, constellation, cv_target=0.22,
385
+ infonce_temp=0.07, assign_temp=0.1,
386
+ w_ce=1.0, w_nce_emb=0.5,
387
+ w_nce_pw=1.0, w_bridge=1.0,
388
+ w_assign=0.5, w_assign_nce=0.25,
389
+ w_nce_tri=0.5, w_attract=0.25,
390
+ w_cv=0.01, w_spread=0.01,
391
+ cv_batched=True):
392
+ """Full three-domain cooperative loss.
393
+
394
+ EXTERNAL: CE + embedding NCE
395
+ GEOMETRIC: patchwork NCE + bridge
396
+ INTERNAL: assign BCE + assign NCE + tri NCE + attraction + CV + spread
397
+
398
+ Args:
399
+ output: dict from InternalConstellationCore.forward_paired()
400
+ targets: (B,) class labels
401
+ constellation: Constellation module (for anchors)
402
+ cv_target: CV loss target
403
+ infonce_temp: embedding NCE temperature
404
+ assign_temp: assignment NCE / patchwork NCE temperature
405
+ w_*: per-term weights
406
+ cv_batched: use batched CV (default True)
407
+
408
+ Returns:
409
+ total_loss: scalar tensor
410
+ ld: dict with all per-term values and diagnostics
411
+ """
412
+ ld = {}
413
+ emb1, emb2 = output['embedding'], output['embedding_aug']
414
+ B = emb1.shape[0]
415
+ device = emb1.device
416
+
417
+ # ── EXTERNAL ──
418
+ l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)
419
+ ld['ce'], ld['acc'] = l_ce, acc
420
+
421
+ l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, infonce_temp, normalize=False)
422
+ ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc
423
+
424
+ # ── GEOMETRIC ──
425
+ l_nce_pw, nce_pw_acc = nce_loss(output['patchwork1'], output['patchwork1_aug'],
426
+ assign_temp, normalize=True)
427
+ ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc
428
+
429
+ l_bridge, bridge_acc = bridge_loss_paired(
430
+ output['bridge1'], output['bridge2'],
431
+ output['assign1'], output['assign2'])
432
+ ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc
433
+
434
+ # ── INTERNAL ──
435
+ l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])
436
+ ld['assign'], ld['assign_entropy'] = l_assign, assign_ent
437
+
438
+ l_assign_nce, assign_nce_acc = assign_nce_loss(
439
+ output['assign1'], output['assign2'], assign_temp)
440
+ ld['assign_nce'], ld['assign_nce_acc'] = l_assign_nce, assign_nce_acc
441
+
442
+ l_nce_tri, nce_tri_acc = nce_loss(output['tri1'], output['tri2'], 0.1, normalize=True)
443
+ ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc
444
+
445
+ l_attract, nearest_cos = attraction_loss(output['cos1'])
446
+ ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos
447
+
448
+ l_cv = cv_loss(emb1, target=cv_target, batched=cv_batched)
449
+ ld['cv'] = l_cv
450
+
451
+ l_spread = spread_loss(constellation.anchors)
452
+ ld['spread'] = l_spread
453
+
454
+ # ── kNN (non-differentiable) ──
455
+ ld['knn_acc'] = knn_accuracy(emb1, targets)
456
+
457
+ # ── TOTAL ──
458
+ loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb
459
+ loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge
460
+ loss_internal = (w_assign * l_assign + w_assign_nce * l_assign_nce
461
+ + w_nce_tri * l_nce_tri + w_attract * l_attract
462
+ + w_cv * l_cv + w_spread * l_spread)
463
+
464
+ loss = loss_external + loss_geometric + loss_internal
465
+
466
+ ld['loss_external'] = loss_external.item()
467
+ ld['loss_geometric'] = loss_geometric.item()
468
+ ld['loss_internal'] = loss_internal.item()
469
+ ld['total'] = loss
470
+
471
+ # Per-term raw values for analysis
472
+ ld['t_ce'] = l_ce.item()
473
+ ld['t_nce_emb'] = l_nce_emb.item()
474
+ ld['t_nce_pw'] = l_nce_pw.item()
475
+ ld['t_bridge'] = l_bridge.item()
476
+ ld['t_assign'] = l_assign.item()
477
+ ld['t_assign_nce'] = l_assign_nce.item()
478
+ ld['t_nce_tri'] = l_nce_tri.item()
479
+ ld['t_attract'] = l_attract.item()
480
+
481
+ return loss, ld