File size: 10,437 Bytes
9fc25e6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 | #!/usr/bin/env python3
"""
VICReg Loss Function for Joint Embedding Learning.
Implements the Variance-Invariance-Covariance Regularization loss from:
Bardes, Ponce & LeCun, "VICReg: Variance-Invariance-Covariance
Regularization for Self-Supervised Learning", ICLR 2022.
Three terms:
1. Invariance: MSE between paired embeddings (push co-located pairs together)
2. Variance: Hinge loss on per-dimension std dev (prevent collapse)
3. Covariance: Penalize off-diagonal covariance (decorrelate dimensions)
Usage:
loss_fn = VICRegLoss(lambda_inv=25.0, lambda_var=25.0, lambda_cov=1.0)
total_loss, components = loss_fn(z_a, z_b)
"""
import torch
import torch.nn as nn
class VICRegLoss(nn.Module):
"""VICReg: Variance-Invariance-Covariance Regularization Loss.
Parameters
----------
lambda_inv : float
Weight for invariance term (MSE between paired embeddings).
lambda_var : float
Weight for variance term (hinge loss on per-dimension std dev).
lambda_cov : float
Weight for covariance term (off-diagonal covariance penalty).
gamma : float
Target standard deviation for variance hinge (default 1.0).
"""
def __init__(self, lambda_inv=25.0, lambda_var=25.0, lambda_cov=1.0,
gamma=1.0):
super().__init__()
self.lambda_inv = lambda_inv
self.lambda_var = lambda_var
self.lambda_cov = lambda_cov
self.gamma = gamma
def invariance_loss(self, z_a, z_b):
"""MSE between paired embeddings.
Parameters
----------
z_a, z_b : torch.Tensor, shape (N, D)
Paired embedding vectors.
Returns
-------
torch.Tensor
Scalar invariance loss.
"""
return torch.nn.functional.mse_loss(z_a, z_b)
def variance_loss(self, z):
"""Hinge loss on per-dimension standard deviation.
Encourages each dimension to have std >= gamma, preventing
embedding collapse where all points map to the same vector.
Parameters
----------
z : torch.Tensor, shape (N, D)
Embedding matrix (single modality).
Returns
-------
torch.Tensor
Scalar variance loss.
"""
# Per-dimension std with epsilon for numerical stability
std_z = torch.sqrt(z.var(dim=0) + 1e-4)
# Hinge: penalize dimensions with std below gamma
return torch.mean(torch.relu(self.gamma - std_z))
def covariance_loss(self, z):
"""Off-diagonal covariance penalty.
Decorrelates embedding dimensions by penalizing off-diagonal
elements of the covariance matrix.
Parameters
----------
z : torch.Tensor, shape (N, D)
Embedding matrix (single modality).
Returns
-------
torch.Tensor
Scalar covariance loss.
"""
N, D = z.shape
# Center the embeddings
z_centered = z - z.mean(dim=0)
# Compute covariance matrix
cov = (z_centered.T @ z_centered) / (N - 1)
# Zero out diagonal (we only penalize off-diagonal)
cov_offdiag = cov - torch.diag(cov.diag())
# Sum of squared off-diagonal elements, normalized by D
return (cov_offdiag ** 2).sum() / D
def forward(self, z_a, z_b):
"""Compute total VICReg loss.
Parameters
----------
z_a : torch.Tensor, shape (N, D)
Embeddings from modality A (e.g., environment encoder).
z_b : torch.Tensor, shape (N, D)
Embeddings from modality B (e.g., PFAM module encoder).
Returns
-------
total_loss : torch.Tensor
Weighted sum of invariance, variance, and covariance terms.
components : dict
Individual loss components for logging:
- 'invariance': float
- 'variance_a': float (variance loss for z_a)
- 'variance_b': float (variance loss for z_b)
- 'covariance_a': float (covariance loss for z_a)
- 'covariance_b': float (covariance loss for z_b)
- 'total': float
"""
# Input validation
if z_a.shape != z_b.shape:
raise ValueError(
f"Shape mismatch: z_a {z_a.shape} vs z_b {z_b.shape}"
)
if z_a.shape[0] < 2:
raise ValueError(
f"Batch size must be >= 2, got {z_a.shape[0]}"
)
# Compute individual terms
inv_loss = self.invariance_loss(z_a, z_b)
var_loss_a = self.variance_loss(z_a)
var_loss_b = self.variance_loss(z_b)
cov_loss_a = self.covariance_loss(z_a)
cov_loss_b = self.covariance_loss(z_b)
# Combine: variance and covariance applied to BOTH modalities
total = (self.lambda_inv * inv_loss
+ self.lambda_var * (var_loss_a + var_loss_b)
+ self.lambda_cov * (cov_loss_a + cov_loss_b))
components = {
'invariance': inv_loss.item(),
'variance_a': var_loss_a.item(),
'variance_b': var_loss_b.item(),
'covariance_a': cov_loss_a.item(),
'covariance_b': cov_loss_b.item(),
'total': total.item(),
}
return total, components
def self_test():
"""Run self-tests for VICReg loss module. Returns True if all pass."""
import sys
tests_passed = 0
tests_total = 0
def check(name, condition):
nonlocal tests_passed, tests_total
tests_total += 1
if condition:
tests_passed += 1
print(f" PASS: {name}")
else:
print(f" FAIL: {name}")
print("=" * 60)
print("VICReg Loss Self-Tests")
print("=" * 60)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}\n")
loss_fn = VICRegLoss(lambda_inv=25.0, lambda_var=25.0, lambda_cov=1.0)
# Test 1: Gradient flow
print("Test 1: Gradient flow")
z_a = torch.randn(64, 16, device=device, requires_grad=True)
z_b = torch.randn(64, 16, device=device, requires_grad=True)
total, comp = loss_fn(z_a, z_b)
total.backward()
check("gradients computed for z_a", z_a.grad is not None)
check("gradients computed for z_b", z_b.grad is not None)
check("no NaN in z_a grad", not torch.isnan(z_a.grad).any())
check("no NaN in z_b grad", not torch.isnan(z_b.grad).any())
check("all components present",
all(k in comp for k in ['invariance', 'variance_a', 'variance_b',
'covariance_a', 'covariance_b', 'total']))
# Test 2: Invariance = 0 for identical embeddings
print("\nTest 2: Invariance = 0 for identical embeddings")
z_same = torch.randn(32, 16, device=device)
inv = loss_fn.invariance_loss(z_same, z_same)
check("invariance is zero", inv.item() < 1e-7)
# Test 3: Variance = 0 when std >= gamma
print("\nTest 3: Variance = 0 when std >= gamma")
z_spread = torch.randn(1000, 16, device=device) * 2.0 # std ~2.0 >> gamma=1.0
var_loss = loss_fn.variance_loss(z_spread)
check("variance is zero for high-spread embeddings", var_loss.item() < 1e-4)
# Test 4: Variance > 0 for collapsed embeddings
print("\nTest 4: Variance > 0 for collapsed embeddings")
z_collapsed = torch.ones(32, 16, device=device) * 0.5 # constant -> std=0
# Add tiny noise so std is very small but not exactly zero
z_collapsed = z_collapsed + torch.randn_like(z_collapsed) * 1e-6
var_loss_collapsed = loss_fn.variance_loss(z_collapsed)
check("variance penalizes collapsed embeddings",
var_loss_collapsed.item() > 0.5)
# Test 5: Covariance ~ 0 for i.i.d. Gaussian
print("\nTest 5: Covariance ~ 0 for i.i.d. Gaussian")
z_iid = torch.randn(1000, 16, device=device)
cov_loss_iid = loss_fn.covariance_loss(z_iid)
check("covariance low for i.i.d. Gaussian (< 0.1)",
cov_loss_iid.item() < 0.1)
# Test 6: Covariance high for correlated dimensions
print("\nTest 6: Covariance high for correlated dimensions")
z_base = torch.randn(1000, 1, device=device)
z_corr = z_base.repeat(1, 16) + torch.randn(1000, 16, device=device) * 0.01
cov_loss_corr = loss_fn.covariance_loss(z_corr)
check("covariance penalizes correlated dimensions (> 1.0)",
cov_loss_corr.item() > 1.0)
# Test 7: Three lambda configurations
print("\nTest 7: Three lambda configurations")
configs = {
'default': VICRegLoss(25.0, 25.0, 1.0),
'high_variance': VICRegLoss(10.0, 50.0, 1.0),
'high_covariance': VICRegLoss(25.0, 25.0, 10.0),
}
z_a_test = torch.randn(64, 16, device=device)
z_b_test = torch.randn(64, 16, device=device)
for name, cfg in configs.items():
total_loss, _ = cfg(z_a_test, z_b_test)
check(f"{name} produces valid loss (> 0)",
total_loss.item() > 0 and not torch.isnan(total_loss))
# Test 8: Shape validation
print("\nTest 8: Shape validation")
try:
loss_fn(torch.randn(10, 16, device=device),
torch.randn(10, 32, device=device))
check("shape mismatch caught", False)
except ValueError:
check("shape mismatch caught", True)
try:
loss_fn(torch.randn(1, 16, device=device),
torch.randn(1, 16, device=device))
check("batch size < 2 caught", False)
except ValueError:
check("batch size < 2 caught", True)
# Test 9: GPU computation (if available)
print("\nTest 9: GPU computation")
if torch.cuda.is_available():
z_gpu_a = torch.randn(64, 16, device='cuda', requires_grad=True)
z_gpu_b = torch.randn(64, 16, device='cuda', requires_grad=True)
total_gpu, comp_gpu = loss_fn.to('cuda')(z_gpu_a, z_gpu_b)
total_gpu.backward()
check("GPU forward + backward succeeded",
z_gpu_a.grad is not None and not torch.isnan(z_gpu_a.grad).any())
else:
print(" SKIP: CUDA not available")
tests_total += 1
tests_passed += 1 # Skip counts as pass
print(f"\n{'=' * 60}")
print(f"Results: {tests_passed}/{tests_total} tests passed")
print(f"{'=' * 60}")
return tests_passed == tests_total
if __name__ == '__main__':
success = self_test()
import sys
sys.exit(0 if success else 1)
|