Chucks90's picture
download
raw
1.76 kB
"""Coverage functional C(S;x) — effective rank (RankMe form) of projected retained tokens.
From the formalization §3:
C(S;x) = exp(-sum_j p_j log p_j), p_j = sigma_j(P_L Z_S)/sum_l sigma_l(P_L Z_S) + eps
where sigma_j are singular values of the projected retained feature matrix P_L Z_S. This is
label-free and differentiable through the SVD (or use the coding-rate surrogate to avoid SVD
backprop). It measures how much of the lesion-relevant directions the kept tokens still span.
"""
from __future__ import annotations
import torch
def effective_rank(singular_values: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
"""RankMe effective rank from a vector of singular values."""
s = singular_values
p = s / (s.sum() + eps) + eps
p = p / p.sum()
entropy = -(p * p.log()).sum()
return entropy.exp()
def coverage(Z_retained: torch.Tensor, P_L: torch.Tensor | None = None,
eps: float = 1e-7) -> torch.Tensor:
"""C(S;x) for retained token features Z_retained (k, d).
P_L: optional (d, d) projection onto the lesion subspace L(x). If None, uses raw Z.
Returns a scalar tensor (differentiable through the SVD).
"""
Z = Z_retained
if Z.ndim != 2 or Z.shape[0] == 0:
return torch.zeros((), dtype=Z.dtype, device=Z.device)
PZ = Z @ P_L.T if P_L is not None else Z
# singular values of the projected retained feature matrix
s = torch.linalg.svdvals(PZ.float())
return effective_rank(s, eps)
def coverage_drop(Z_full: torch.Tensor, Z_retained: torch.Tensor,
P_L: torch.Tensor | None = None) -> torch.Tensor:
"""delta_C = C*(x) - C(S;x): coverage lost by pruning to the retained set."""
return coverage(Z_full, P_L) - coverage(Z_retained, P_L)

Xet Storage Details

Size:
1.76 kB
·
Xet hash:
6fa84e852f765d457896c77b6b11d7a44becb96798a4ab422e5e9b98a6ca17f0

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.