| """Coding-rate coverage surrogate C_cr(S;x) — avoids SVD backprop instability. |
| |
| From the formalization §3: |
| C_cr(S;x) = 1/2 * log det( I + (d / (|S| eps^2)) * P_L Z_S Z_S^T P_L ) |
| A smooth, differentiable lower-bound-style surrogate for the lesion-subspace coverage; used |
| when SVD gradients in rankme are unstable (Gate 2 fallback per IMPLEMENTATION_SPEC §Gate 2). |
| """ |
| from __future__ import annotations |
|
|
| import torch |
|
|
|
|
| def coding_rate(Z_retained: torch.Tensor, P_L: torch.Tensor | None = None, |
| eps: float = 0.5) -> torch.Tensor: |
| """C_cr(S;x) for retained token features Z_retained (k, d).""" |
| Z = Z_retained |
| if Z.ndim != 2 or Z.shape[0] == 0: |
| return torch.zeros((), dtype=Z.dtype, device=Z.device) |
| PZ = (Z @ P_L.T if P_L is not None else Z).float() |
| k, d = PZ.shape |
| cov = PZ.T @ PZ |
| scale = d / (k * eps * eps) |
| mat = torch.eye(d, device=PZ.device, dtype=PZ.dtype) + scale * cov |
| return 0.5 * torch.logdet(mat) |
|
|
|
|
| def coding_rate_drop(Z_full: torch.Tensor, Z_retained: torch.Tensor, |
| P_L: torch.Tensor | None = None, eps: float = 0.5) -> torch.Tensor: |
| return coding_rate(Z_full, P_L, eps) - coding_rate(Z_retained, P_L, eps) |
|
|