geolip-cv-experiments / reusable_losses.py
AbstractPhil's picture
Create reusable_losses.py
73183fe verified
"""
Batched Pentachoron CV — Fast Geometric Volume Measurement
============================================================
Replaces the sequential Python loop with fully batched operations.
One torch.linalg.det call on (n_samples, 6, 6) tensor.
Usage:
from cv_batch import cv_metric, cv_loss
# Non-differentiable monitoring (fast)
cv_value = cv_metric(embeddings, n_samples=200)
# Differentiable loss (fast, for training)
loss = cv_loss(embeddings, target=0.22, n_samples=64)
"""
import torch
import torch.nn.functional as F
import math
def _batch_pentachoron_volumes(emb, n_samples=200, n_points=5):
"""Compute pentachoron volumes in one batched operation.
Args:
emb: (N, D) L2-normalized embeddings on S^(d-1)
n_samples: number of random pentachora to sample
n_points: points per simplex (5 = pentachoron)
Returns:
volumes: (n_valid,) tensor of simplex volumes (may be < n_samples if some degenerate)
"""
N, D = emb.shape
device = emb.device
dtype = emb.dtype
# Sample all pentachora indices at once: (n_samples, n_points)
# Batched randperm via argsort on random values
pool = min(N, 512)
rand_keys = torch.rand(n_samples, pool, device=device)
indices = rand_keys.argsort(dim=1)[:, :n_points] # (n_samples, n_points)
# Gather points: (n_samples, n_points, D)
pts = emb[:pool][indices] # advanced indexing
# Gram matrices: (n_samples, n_points, n_points)
gram = torch.bmm(pts, pts.transpose(1, 2))
# Squared distance matrices: d2[i,j] = ||p_i - p_j||^2
norms = torch.diagonal(gram, dim1=1, dim2=2) # (n_samples, n_points)
d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram # (n_samples, n_points, n_points)
d2 = F.relu(d2) # numerical safety
# Build Cayley-Menger matrices: (n_samples, n_points+1, n_points+1)
M = n_points + 1
cm = torch.zeros(n_samples, M, M, device=device, dtype=dtype)
cm[:, 0, 1:] = 1.0
cm[:, 1:, 0] = 1.0
cm[:, 1:, 1:] = d2
# Prefactor for volume from CM determinant
k = n_points - 1 # dimension of simplex
pf = ((-1.0) ** (k + 1)) / ((2.0 ** k) * (math.factorial(k) ** 2))
# Batched determinant — the one expensive call, fully parallel
dets = pf * torch.linalg.det(cm.float()) # (n_samples,)
# Filter valid (positive volume squared) and take sqrt
valid_mask = dets > 1e-20
volumes = dets[valid_mask].to(dtype).sqrt()
return volumes
def cv_metric(emb, n_samples=200, n_points=5):
"""Non-differentiable CV for monitoring. Target band: 0.20–0.23.
Args:
emb: (N, D) embeddings (will be L2-normalized internally)
n_samples: pentachora to sample (200 is robust, 100 is fast)
n_points: points per simplex (5 = pentachoron)
Returns:
float: coefficient of variation of simplex volumes
"""
with torch.no_grad():
vols = _batch_pentachoron_volumes(emb, n_samples=n_samples, n_points=n_points)
if vols.shape[0] < 10:
return 0.0
return (vols.std() / (vols.mean() + 1e-8)).item()
def cv_loss(emb, target=0.22, n_samples=64, n_points=5):
"""Differentiable CV loss for training. Weight: 0.01 or below.
Args:
emb: (N, D) L2-normalized embeddings
target: CV target value
n_samples: pentachora to sample (32-64 for training)
n_points: points per simplex
Returns:
scalar tensor: (CV - target)^2, differentiable w.r.t. emb
"""
vols = _batch_pentachoron_volumes(emb, n_samples=n_samples, n_points=n_points)
if vols.shape[0] < 5:
return torch.tensor(0.0, device=emb.device, requires_grad=True)
cv = vols.std() / (vols.mean() + 1e-8)
return (cv - target).pow(2)
def cv_multi_scale(emb, scales=(3, 4, 5, 6, 7, 8), n_samples=100):
"""CV at multiple simplex sizes. Returns dict: {n_points: cv_value}.
Useful for diagnosing whether geometry is scale-invariant.
Target: all scales in [0.18, 0.25] for healthy geometry.
"""
results = {}
with torch.no_grad():
for n_pts in scales:
vols = _batch_pentachoron_volumes(emb, n_samples=n_samples, n_points=n_pts)
if vols.shape[0] >= 10:
results[n_pts] = round((vols.std() / (vols.mean() + 1e-8)).item(), 4)
else:
results[n_pts] = None
return results