""" Batched Pentachoron CV — Fast Geometric Volume Measurement ============================================================ Replaces the sequential Python loop with fully batched operations. One torch.linalg.det call on (n_samples, 6, 6) tensor. Usage: from cv_batch import cv_metric, cv_loss # Non-differentiable monitoring (fast) cv_value = cv_metric(embeddings, n_samples=200) # Differentiable loss (fast, for training) loss = cv_loss(embeddings, target=0.22, n_samples=64) """ import torch import torch.nn.functional as F import math def _batch_pentachoron_volumes(emb, n_samples=200, n_points=5): """Compute pentachoron volumes in one batched operation. Args: emb: (N, D) L2-normalized embeddings on S^(d-1) n_samples: number of random pentachora to sample n_points: points per simplex (5 = pentachoron) Returns: volumes: (n_valid,) tensor of simplex volumes (may be < n_samples if some degenerate) """ N, D = emb.shape device = emb.device dtype = emb.dtype # Sample all pentachora indices at once: (n_samples, n_points) # Batched randperm via argsort on random values pool = min(N, 512) rand_keys = torch.rand(n_samples, pool, device=device) indices = rand_keys.argsort(dim=1)[:, :n_points] # (n_samples, n_points) # Gather points: (n_samples, n_points, D) pts = emb[:pool][indices] # advanced indexing # Gram matrices: (n_samples, n_points, n_points) gram = torch.bmm(pts, pts.transpose(1, 2)) # Squared distance matrices: d2[i,j] = ||p_i - p_j||^2 norms = torch.diagonal(gram, dim1=1, dim2=2) # (n_samples, n_points) d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram # (n_samples, n_points, n_points) d2 = F.relu(d2) # numerical safety # Build Cayley-Menger matrices: (n_samples, n_points+1, n_points+1) M = n_points + 1 cm = torch.zeros(n_samples, M, M, device=device, dtype=dtype) cm[:, 0, 1:] = 1.0 cm[:, 1:, 0] = 1.0 cm[:, 1:, 1:] = d2 # Prefactor for volume from CM determinant k = n_points - 1 # dimension of simplex pf = ((-1.0) ** (k + 1)) / ((2.0 ** k) * (math.factorial(k) ** 2)) # Batched determinant — the one expensive call, fully parallel dets = pf * torch.linalg.det(cm.float()) # (n_samples,) # Filter valid (positive volume squared) and take sqrt valid_mask = dets > 1e-20 volumes = dets[valid_mask].to(dtype).sqrt() return volumes def cv_metric(emb, n_samples=200, n_points=5): """Non-differentiable CV for monitoring. Target band: 0.20–0.23. Args: emb: (N, D) embeddings (will be L2-normalized internally) n_samples: pentachora to sample (200 is robust, 100 is fast) n_points: points per simplex (5 = pentachoron) Returns: float: coefficient of variation of simplex volumes """ with torch.no_grad(): vols = _batch_pentachoron_volumes(emb, n_samples=n_samples, n_points=n_points) if vols.shape[0] < 10: return 0.0 return (vols.std() / (vols.mean() + 1e-8)).item() def cv_loss(emb, target=0.22, n_samples=64, n_points=5): """Differentiable CV loss for training. Weight: 0.01 or below. Args: emb: (N, D) L2-normalized embeddings target: CV target value n_samples: pentachora to sample (32-64 for training) n_points: points per simplex Returns: scalar tensor: (CV - target)^2, differentiable w.r.t. emb """ vols = _batch_pentachoron_volumes(emb, n_samples=n_samples, n_points=n_points) if vols.shape[0] < 5: return torch.tensor(0.0, device=emb.device, requires_grad=True) cv = vols.std() / (vols.mean() + 1e-8) return (cv - target).pow(2) def cv_multi_scale(emb, scales=(3, 4, 5, 6, 7, 8), n_samples=100): """CV at multiple simplex sizes. Returns dict: {n_points: cv_value}. Useful for diagnosing whether geometry is scale-invariant. Target: all scales in [0.18, 0.25] for healthy geometry. """ results = {} with torch.no_grad(): for n_pts in scales: vols = _batch_pentachoron_volumes(emb, n_samples=n_samples, n_points=n_pts) if vols.shape[0] >= 10: results[n_pts] = round((vols.std() / (vols.mean() + 1e-8)).item(), 4) else: results[n_pts] = None return results