AbstractPhil commited on
Commit
73183fe
·
verified ·
1 Parent(s): 7e1d149

Create reusable_losses.py

Browse files
Files changed (1) hide show
  1. reusable_losses.py +126 -0
reusable_losses.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Batched Pentachoron CV — Fast Geometric Volume Measurement
3
+ ============================================================
4
+ Replaces the sequential Python loop with fully batched operations.
5
+ One torch.linalg.det call on (n_samples, 6, 6) tensor.
6
+
7
+ Usage:
8
+ from cv_batch import cv_metric, cv_loss
9
+
10
+ # Non-differentiable monitoring (fast)
11
+ cv_value = cv_metric(embeddings, n_samples=200)
12
+
13
+ # Differentiable loss (fast, for training)
14
+ loss = cv_loss(embeddings, target=0.22, n_samples=64)
15
+ """
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import math
20
+
21
+
22
+ def _batch_pentachoron_volumes(emb, n_samples=200, n_points=5):
23
+ """Compute pentachoron volumes in one batched operation.
24
+
25
+ Args:
26
+ emb: (N, D) L2-normalized embeddings on S^(d-1)
27
+ n_samples: number of random pentachora to sample
28
+ n_points: points per simplex (5 = pentachoron)
29
+
30
+ Returns:
31
+ volumes: (n_valid,) tensor of simplex volumes (may be < n_samples if some degenerate)
32
+ """
33
+ N, D = emb.shape
34
+ device = emb.device
35
+ dtype = emb.dtype
36
+
37
+ # Sample all pentachora indices at once: (n_samples, n_points)
38
+ # Batched randperm via argsort on random values
39
+ pool = min(N, 512)
40
+ rand_keys = torch.rand(n_samples, pool, device=device)
41
+ indices = rand_keys.argsort(dim=1)[:, :n_points] # (n_samples, n_points)
42
+
43
+ # Gather points: (n_samples, n_points, D)
44
+ pts = emb[:pool][indices] # advanced indexing
45
+
46
+ # Gram matrices: (n_samples, n_points, n_points)
47
+ gram = torch.bmm(pts, pts.transpose(1, 2))
48
+
49
+ # Squared distance matrices: d2[i,j] = ||p_i - p_j||^2
50
+ norms = torch.diagonal(gram, dim1=1, dim2=2) # (n_samples, n_points)
51
+ d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram # (n_samples, n_points, n_points)
52
+ d2 = F.relu(d2) # numerical safety
53
+
54
+ # Build Cayley-Menger matrices: (n_samples, n_points+1, n_points+1)
55
+ M = n_points + 1
56
+ cm = torch.zeros(n_samples, M, M, device=device, dtype=dtype)
57
+ cm[:, 0, 1:] = 1.0
58
+ cm[:, 1:, 0] = 1.0
59
+ cm[:, 1:, 1:] = d2
60
+
61
+ # Prefactor for volume from CM determinant
62
+ k = n_points - 1 # dimension of simplex
63
+ pf = ((-1.0) ** (k + 1)) / ((2.0 ** k) * (math.factorial(k) ** 2))
64
+
65
+ # Batched determinant — the one expensive call, fully parallel
66
+ dets = pf * torch.linalg.det(cm.float()) # (n_samples,)
67
+
68
+ # Filter valid (positive volume squared) and take sqrt
69
+ valid_mask = dets > 1e-20
70
+ volumes = dets[valid_mask].to(dtype).sqrt()
71
+
72
+ return volumes
73
+
74
+
75
+ def cv_metric(emb, n_samples=200, n_points=5):
76
+ """Non-differentiable CV for monitoring. Target band: 0.20–0.23.
77
+
78
+ Args:
79
+ emb: (N, D) embeddings (will be L2-normalized internally)
80
+ n_samples: pentachora to sample (200 is robust, 100 is fast)
81
+ n_points: points per simplex (5 = pentachoron)
82
+
83
+ Returns:
84
+ float: coefficient of variation of simplex volumes
85
+ """
86
+ with torch.no_grad():
87
+ vols = _batch_pentachoron_volumes(emb, n_samples=n_samples, n_points=n_points)
88
+ if vols.shape[0] < 10:
89
+ return 0.0
90
+ return (vols.std() / (vols.mean() + 1e-8)).item()
91
+
92
+
93
+ def cv_loss(emb, target=0.22, n_samples=64, n_points=5):
94
+ """Differentiable CV loss for training. Weight: 0.01 or below.
95
+
96
+ Args:
97
+ emb: (N, D) L2-normalized embeddings
98
+ target: CV target value
99
+ n_samples: pentachora to sample (32-64 for training)
100
+ n_points: points per simplex
101
+
102
+ Returns:
103
+ scalar tensor: (CV - target)^2, differentiable w.r.t. emb
104
+ """
105
+ vols = _batch_pentachoron_volumes(emb, n_samples=n_samples, n_points=n_points)
106
+ if vols.shape[0] < 5:
107
+ return torch.tensor(0.0, device=emb.device, requires_grad=True)
108
+ cv = vols.std() / (vols.mean() + 1e-8)
109
+ return (cv - target).pow(2)
110
+
111
+
112
+ def cv_multi_scale(emb, scales=(3, 4, 5, 6, 7, 8), n_samples=100):
113
+ """CV at multiple simplex sizes. Returns dict: {n_points: cv_value}.
114
+
115
+ Useful for diagnosing whether geometry is scale-invariant.
116
+ Target: all scales in [0.18, 0.25] for healthy geometry.
117
+ """
118
+ results = {}
119
+ with torch.no_grad():
120
+ for n_pts in scales:
121
+ vols = _batch_pentachoron_volumes(emb, n_samples=n_samples, n_points=n_pts)
122
+ if vols.shape[0] >= 10:
123
+ results[n_pts] = round((vols.std() / (vols.mean() + 1e-8)).item(), 4)
124
+ else:
125
+ results[n_pts] = None
126
+ return results