File size: 1,241 Bytes
3510f1d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | """Coding-rate coverage surrogate C_cr(S;x) — avoids SVD backprop instability.
From the formalization §3:
C_cr(S;x) = 1/2 * log det( I + (d / (|S| eps^2)) * P_L Z_S Z_S^T P_L )
A smooth, differentiable lower-bound-style surrogate for the lesion-subspace coverage; used
when SVD gradients in rankme are unstable (Gate 2 fallback per IMPLEMENTATION_SPEC §Gate 2).
"""
from __future__ import annotations
import torch
def coding_rate(Z_retained: torch.Tensor, P_L: torch.Tensor | None = None,
eps: float = 0.5) -> torch.Tensor:
"""C_cr(S;x) for retained token features Z_retained (k, d)."""
Z = Z_retained
if Z.ndim != 2 or Z.shape[0] == 0:
return torch.zeros((), dtype=Z.dtype, device=Z.device)
PZ = (Z @ P_L.T if P_L is not None else Z).float()
k, d = PZ.shape
cov = PZ.T @ PZ # (d, d)
scale = d / (k * eps * eps)
mat = torch.eye(d, device=PZ.device, dtype=PZ.dtype) + scale * cov
return 0.5 * torch.logdet(mat)
def coding_rate_drop(Z_full: torch.Tensor, Z_retained: torch.Tensor,
P_L: torch.Tensor | None = None, eps: float = 0.5) -> torch.Tensor:
return coding_rate(Z_full, P_L, eps) - coding_rate(Z_retained, P_L, eps)
|