| """ |
| Batched Pentachoron CV — Fast Geometric Volume Measurement |
| ============================================================ |
| Replaces the sequential Python loop with fully batched operations. |
| One torch.linalg.det call on (n_samples, 6, 6) tensor. |
| |
| Usage: |
| from cv_batch import cv_metric, cv_loss |
| |
| # Non-differentiable monitoring (fast) |
| cv_value = cv_metric(embeddings, n_samples=200) |
| |
| # Differentiable loss (fast, for training) |
| loss = cv_loss(embeddings, target=0.22, n_samples=64) |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import math |
|
|
|
|
| def _batch_pentachoron_volumes(emb, n_samples=200, n_points=5): |
| """Compute pentachoron volumes in one batched operation. |
| |
| Args: |
| emb: (N, D) L2-normalized embeddings on S^(d-1) |
| n_samples: number of random pentachora to sample |
| n_points: points per simplex (5 = pentachoron) |
| |
| Returns: |
| volumes: (n_valid,) tensor of simplex volumes (may be < n_samples if some degenerate) |
| """ |
| N, D = emb.shape |
| device = emb.device |
| dtype = emb.dtype |
|
|
| |
| |
| pool = min(N, 512) |
| rand_keys = torch.rand(n_samples, pool, device=device) |
| indices = rand_keys.argsort(dim=1)[:, :n_points] |
|
|
| |
| pts = emb[:pool][indices] |
|
|
| |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
|
|
| |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram |
| d2 = F.relu(d2) |
|
|
| |
| M = n_points + 1 |
| cm = torch.zeros(n_samples, M, M, device=device, dtype=dtype) |
| cm[:, 0, 1:] = 1.0 |
| cm[:, 1:, 0] = 1.0 |
| cm[:, 1:, 1:] = d2 |
|
|
| |
| k = n_points - 1 |
| pf = ((-1.0) ** (k + 1)) / ((2.0 ** k) * (math.factorial(k) ** 2)) |
|
|
| |
| dets = pf * torch.linalg.det(cm.float()) |
|
|
| |
| valid_mask = dets > 1e-20 |
| volumes = dets[valid_mask].to(dtype).sqrt() |
|
|
| return volumes |
|
|
|
|
| def cv_metric(emb, n_samples=200, n_points=5): |
| """Non-differentiable CV for monitoring. Target band: 0.20–0.23. |
| |
| Args: |
| emb: (N, D) embeddings (will be L2-normalized internally) |
| n_samples: pentachora to sample (200 is robust, 100 is fast) |
| n_points: points per simplex (5 = pentachoron) |
| |
| Returns: |
| float: coefficient of variation of simplex volumes |
| """ |
| with torch.no_grad(): |
| vols = _batch_pentachoron_volumes(emb, n_samples=n_samples, n_points=n_points) |
| if vols.shape[0] < 10: |
| return 0.0 |
| return (vols.std() / (vols.mean() + 1e-8)).item() |
|
|
|
|
| def cv_loss(emb, target=0.22, n_samples=64, n_points=5): |
| """Differentiable CV loss for training. Weight: 0.01 or below. |
| |
| Args: |
| emb: (N, D) L2-normalized embeddings |
| target: CV target value |
| n_samples: pentachora to sample (32-64 for training) |
| n_points: points per simplex |
| |
| Returns: |
| scalar tensor: (CV - target)^2, differentiable w.r.t. emb |
| """ |
| vols = _batch_pentachoron_volumes(emb, n_samples=n_samples, n_points=n_points) |
| if vols.shape[0] < 5: |
| return torch.tensor(0.0, device=emb.device, requires_grad=True) |
| cv = vols.std() / (vols.mean() + 1e-8) |
| return (cv - target).pow(2) |
|
|
|
|
| def cv_multi_scale(emb, scales=(3, 4, 5, 6, 7, 8), n_samples=100): |
| """CV at multiple simplex sizes. Returns dict: {n_points: cv_value}. |
| |
| Useful for diagnosing whether geometry is scale-invariant. |
| Target: all scales in [0.18, 0.25] for healthy geometry. |
| """ |
| results = {} |
| with torch.no_grad(): |
| for n_pts in scales: |
| vols = _batch_pentachoron_volumes(emb, n_samples=n_samples, n_points=n_pts) |
| if vols.shape[0] >= 10: |
| results[n_pts] = round((vols.std() / (vols.mean() + 1e-8)).item(), 4) |
| else: |
| results[n_pts] = None |
| return results |