File size: 1,241 Bytes
3510f1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""Coding-rate coverage surrogate C_cr(S;x) — avoids SVD backprop instability.

From the formalization §3:
    C_cr(S;x) = 1/2 * log det( I + (d / (|S| eps^2)) * P_L Z_S Z_S^T P_L )
A smooth, differentiable lower-bound-style surrogate for the lesion-subspace coverage; used
when SVD gradients in rankme are unstable (Gate 2 fallback per IMPLEMENTATION_SPEC §Gate 2).
"""
from __future__ import annotations

import torch


def coding_rate(Z_retained: torch.Tensor, P_L: torch.Tensor | None = None,
                eps: float = 0.5) -> torch.Tensor:
    """C_cr(S;x) for retained token features Z_retained (k, d)."""
    Z = Z_retained
    if Z.ndim != 2 or Z.shape[0] == 0:
        return torch.zeros((), dtype=Z.dtype, device=Z.device)
    PZ = (Z @ P_L.T if P_L is not None else Z).float()
    k, d = PZ.shape
    cov = PZ.T @ PZ                       # (d, d)
    scale = d / (k * eps * eps)
    mat = torch.eye(d, device=PZ.device, dtype=PZ.dtype) + scale * cov
    return 0.5 * torch.logdet(mat)


def coding_rate_drop(Z_full: torch.Tensor, Z_retained: torch.Tensor,
                     P_L: torch.Tensor | None = None, eps: float = 0.5) -> torch.Tensor:
    return coding_rate(Z_full, P_L, eps) - coding_rate(Z_retained, P_L, eps)