GreenGenomicsLab commited on
Commit
9fc25e6
·
verified ·
1 Parent(s): a609de0

Upload scripts/vicreg_loss.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/vicreg_loss.py +294 -0
scripts/vicreg_loss.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ VICReg Loss Function for Joint Embedding Learning.
4
+
5
+ Implements the Variance-Invariance-Covariance Regularization loss from:
6
+ Bardes, Ponce & LeCun, "VICReg: Variance-Invariance-Covariance
7
+ Regularization for Self-Supervised Learning", ICLR 2022.
8
+
9
+ Three terms:
10
+ 1. Invariance: MSE between paired embeddings (push co-located pairs together)
11
+ 2. Variance: Hinge loss on per-dimension std dev (prevent collapse)
12
+ 3. Covariance: Penalize off-diagonal covariance (decorrelate dimensions)
13
+
14
+ Usage:
15
+ loss_fn = VICRegLoss(lambda_inv=25.0, lambda_var=25.0, lambda_cov=1.0)
16
+ total_loss, components = loss_fn(z_a, z_b)
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+
23
+ class VICRegLoss(nn.Module):
24
+ """VICReg: Variance-Invariance-Covariance Regularization Loss.
25
+
26
+ Parameters
27
+ ----------
28
+ lambda_inv : float
29
+ Weight for invariance term (MSE between paired embeddings).
30
+ lambda_var : float
31
+ Weight for variance term (hinge loss on per-dimension std dev).
32
+ lambda_cov : float
33
+ Weight for covariance term (off-diagonal covariance penalty).
34
+ gamma : float
35
+ Target standard deviation for variance hinge (default 1.0).
36
+ """
37
+
38
+ def __init__(self, lambda_inv=25.0, lambda_var=25.0, lambda_cov=1.0,
39
+ gamma=1.0):
40
+ super().__init__()
41
+ self.lambda_inv = lambda_inv
42
+ self.lambda_var = lambda_var
43
+ self.lambda_cov = lambda_cov
44
+ self.gamma = gamma
45
+
46
+ def invariance_loss(self, z_a, z_b):
47
+ """MSE between paired embeddings.
48
+
49
+ Parameters
50
+ ----------
51
+ z_a, z_b : torch.Tensor, shape (N, D)
52
+ Paired embedding vectors.
53
+
54
+ Returns
55
+ -------
56
+ torch.Tensor
57
+ Scalar invariance loss.
58
+ """
59
+ return torch.nn.functional.mse_loss(z_a, z_b)
60
+
61
+ def variance_loss(self, z):
62
+ """Hinge loss on per-dimension standard deviation.
63
+
64
+ Encourages each dimension to have std >= gamma, preventing
65
+ embedding collapse where all points map to the same vector.
66
+
67
+ Parameters
68
+ ----------
69
+ z : torch.Tensor, shape (N, D)
70
+ Embedding matrix (single modality).
71
+
72
+ Returns
73
+ -------
74
+ torch.Tensor
75
+ Scalar variance loss.
76
+ """
77
+ # Per-dimension std with epsilon for numerical stability
78
+ std_z = torch.sqrt(z.var(dim=0) + 1e-4)
79
+ # Hinge: penalize dimensions with std below gamma
80
+ return torch.mean(torch.relu(self.gamma - std_z))
81
+
82
+ def covariance_loss(self, z):
83
+ """Off-diagonal covariance penalty.
84
+
85
+ Decorrelates embedding dimensions by penalizing off-diagonal
86
+ elements of the covariance matrix.
87
+
88
+ Parameters
89
+ ----------
90
+ z : torch.Tensor, shape (N, D)
91
+ Embedding matrix (single modality).
92
+
93
+ Returns
94
+ -------
95
+ torch.Tensor
96
+ Scalar covariance loss.
97
+ """
98
+ N, D = z.shape
99
+ # Center the embeddings
100
+ z_centered = z - z.mean(dim=0)
101
+ # Compute covariance matrix
102
+ cov = (z_centered.T @ z_centered) / (N - 1)
103
+ # Zero out diagonal (we only penalize off-diagonal)
104
+ cov_offdiag = cov - torch.diag(cov.diag())
105
+ # Sum of squared off-diagonal elements, normalized by D
106
+ return (cov_offdiag ** 2).sum() / D
107
+
108
+ def forward(self, z_a, z_b):
109
+ """Compute total VICReg loss.
110
+
111
+ Parameters
112
+ ----------
113
+ z_a : torch.Tensor, shape (N, D)
114
+ Embeddings from modality A (e.g., environment encoder).
115
+ z_b : torch.Tensor, shape (N, D)
116
+ Embeddings from modality B (e.g., PFAM module encoder).
117
+
118
+ Returns
119
+ -------
120
+ total_loss : torch.Tensor
121
+ Weighted sum of invariance, variance, and covariance terms.
122
+ components : dict
123
+ Individual loss components for logging:
124
+ - 'invariance': float
125
+ - 'variance_a': float (variance loss for z_a)
126
+ - 'variance_b': float (variance loss for z_b)
127
+ - 'covariance_a': float (covariance loss for z_a)
128
+ - 'covariance_b': float (covariance loss for z_b)
129
+ - 'total': float
130
+ """
131
+ # Input validation
132
+ if z_a.shape != z_b.shape:
133
+ raise ValueError(
134
+ f"Shape mismatch: z_a {z_a.shape} vs z_b {z_b.shape}"
135
+ )
136
+ if z_a.shape[0] < 2:
137
+ raise ValueError(
138
+ f"Batch size must be >= 2, got {z_a.shape[0]}"
139
+ )
140
+
141
+ # Compute individual terms
142
+ inv_loss = self.invariance_loss(z_a, z_b)
143
+ var_loss_a = self.variance_loss(z_a)
144
+ var_loss_b = self.variance_loss(z_b)
145
+ cov_loss_a = self.covariance_loss(z_a)
146
+ cov_loss_b = self.covariance_loss(z_b)
147
+
148
+ # Combine: variance and covariance applied to BOTH modalities
149
+ total = (self.lambda_inv * inv_loss
150
+ + self.lambda_var * (var_loss_a + var_loss_b)
151
+ + self.lambda_cov * (cov_loss_a + cov_loss_b))
152
+
153
+ components = {
154
+ 'invariance': inv_loss.item(),
155
+ 'variance_a': var_loss_a.item(),
156
+ 'variance_b': var_loss_b.item(),
157
+ 'covariance_a': cov_loss_a.item(),
158
+ 'covariance_b': cov_loss_b.item(),
159
+ 'total': total.item(),
160
+ }
161
+
162
+ return total, components
163
+
164
+
165
+ def self_test():
166
+ """Run self-tests for VICReg loss module. Returns True if all pass."""
167
+ import sys
168
+
169
+ tests_passed = 0
170
+ tests_total = 0
171
+
172
+ def check(name, condition):
173
+ nonlocal tests_passed, tests_total
174
+ tests_total += 1
175
+ if condition:
176
+ tests_passed += 1
177
+ print(f" PASS: {name}")
178
+ else:
179
+ print(f" FAIL: {name}")
180
+
181
+ print("=" * 60)
182
+ print("VICReg Loss Self-Tests")
183
+ print("=" * 60)
184
+
185
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
186
+ print(f"Device: {device}\n")
187
+
188
+ loss_fn = VICRegLoss(lambda_inv=25.0, lambda_var=25.0, lambda_cov=1.0)
189
+
190
+ # Test 1: Gradient flow
191
+ print("Test 1: Gradient flow")
192
+ z_a = torch.randn(64, 16, device=device, requires_grad=True)
193
+ z_b = torch.randn(64, 16, device=device, requires_grad=True)
194
+ total, comp = loss_fn(z_a, z_b)
195
+ total.backward()
196
+ check("gradients computed for z_a", z_a.grad is not None)
197
+ check("gradients computed for z_b", z_b.grad is not None)
198
+ check("no NaN in z_a grad", not torch.isnan(z_a.grad).any())
199
+ check("no NaN in z_b grad", not torch.isnan(z_b.grad).any())
200
+ check("all components present",
201
+ all(k in comp for k in ['invariance', 'variance_a', 'variance_b',
202
+ 'covariance_a', 'covariance_b', 'total']))
203
+
204
+ # Test 2: Invariance = 0 for identical embeddings
205
+ print("\nTest 2: Invariance = 0 for identical embeddings")
206
+ z_same = torch.randn(32, 16, device=device)
207
+ inv = loss_fn.invariance_loss(z_same, z_same)
208
+ check("invariance is zero", inv.item() < 1e-7)
209
+
210
+ # Test 3: Variance = 0 when std >= gamma
211
+ print("\nTest 3: Variance = 0 when std >= gamma")
212
+ z_spread = torch.randn(1000, 16, device=device) * 2.0 # std ~2.0 >> gamma=1.0
213
+ var_loss = loss_fn.variance_loss(z_spread)
214
+ check("variance is zero for high-spread embeddings", var_loss.item() < 1e-4)
215
+
216
+ # Test 4: Variance > 0 for collapsed embeddings
217
+ print("\nTest 4: Variance > 0 for collapsed embeddings")
218
+ z_collapsed = torch.ones(32, 16, device=device) * 0.5 # constant -> std=0
219
+ # Add tiny noise so std is very small but not exactly zero
220
+ z_collapsed = z_collapsed + torch.randn_like(z_collapsed) * 1e-6
221
+ var_loss_collapsed = loss_fn.variance_loss(z_collapsed)
222
+ check("variance penalizes collapsed embeddings",
223
+ var_loss_collapsed.item() > 0.5)
224
+
225
+ # Test 5: Covariance ~ 0 for i.i.d. Gaussian
226
+ print("\nTest 5: Covariance ~ 0 for i.i.d. Gaussian")
227
+ z_iid = torch.randn(1000, 16, device=device)
228
+ cov_loss_iid = loss_fn.covariance_loss(z_iid)
229
+ check("covariance low for i.i.d. Gaussian (< 0.1)",
230
+ cov_loss_iid.item() < 0.1)
231
+
232
+ # Test 6: Covariance high for correlated dimensions
233
+ print("\nTest 6: Covariance high for correlated dimensions")
234
+ z_base = torch.randn(1000, 1, device=device)
235
+ z_corr = z_base.repeat(1, 16) + torch.randn(1000, 16, device=device) * 0.01
236
+ cov_loss_corr = loss_fn.covariance_loss(z_corr)
237
+ check("covariance penalizes correlated dimensions (> 1.0)",
238
+ cov_loss_corr.item() > 1.0)
239
+
240
+ # Test 7: Three lambda configurations
241
+ print("\nTest 7: Three lambda configurations")
242
+ configs = {
243
+ 'default': VICRegLoss(25.0, 25.0, 1.0),
244
+ 'high_variance': VICRegLoss(10.0, 50.0, 1.0),
245
+ 'high_covariance': VICRegLoss(25.0, 25.0, 10.0),
246
+ }
247
+ z_a_test = torch.randn(64, 16, device=device)
248
+ z_b_test = torch.randn(64, 16, device=device)
249
+ for name, cfg in configs.items():
250
+ total_loss, _ = cfg(z_a_test, z_b_test)
251
+ check(f"{name} produces valid loss (> 0)",
252
+ total_loss.item() > 0 and not torch.isnan(total_loss))
253
+
254
+ # Test 8: Shape validation
255
+ print("\nTest 8: Shape validation")
256
+ try:
257
+ loss_fn(torch.randn(10, 16, device=device),
258
+ torch.randn(10, 32, device=device))
259
+ check("shape mismatch caught", False)
260
+ except ValueError:
261
+ check("shape mismatch caught", True)
262
+
263
+ try:
264
+ loss_fn(torch.randn(1, 16, device=device),
265
+ torch.randn(1, 16, device=device))
266
+ check("batch size < 2 caught", False)
267
+ except ValueError:
268
+ check("batch size < 2 caught", True)
269
+
270
+ # Test 9: GPU computation (if available)
271
+ print("\nTest 9: GPU computation")
272
+ if torch.cuda.is_available():
273
+ z_gpu_a = torch.randn(64, 16, device='cuda', requires_grad=True)
274
+ z_gpu_b = torch.randn(64, 16, device='cuda', requires_grad=True)
275
+ total_gpu, comp_gpu = loss_fn.to('cuda')(z_gpu_a, z_gpu_b)
276
+ total_gpu.backward()
277
+ check("GPU forward + backward succeeded",
278
+ z_gpu_a.grad is not None and not torch.isnan(z_gpu_a.grad).any())
279
+ else:
280
+ print(" SKIP: CUDA not available")
281
+ tests_total += 1
282
+ tests_passed += 1 # Skip counts as pass
283
+
284
+ print(f"\n{'=' * 60}")
285
+ print(f"Results: {tests_passed}/{tests_total} tests passed")
286
+ print(f"{'=' * 60}")
287
+
288
+ return tests_passed == tests_total
289
+
290
+
291
+ if __name__ == '__main__':
292
+ success = self_test()
293
+ import sys
294
+ sys.exit(0 if success else 1)