Chucks90's picture
download
raw
3.56 kB
"""Construction B — residual-from-normal-manifold lesion subspace, label-free.
Formalization §2: fit a low-rank normal-tissue subspace U_norm on a large normal CT bank
(PCA / coding rate). Lesion-relevant directions are the high-residual ones:
r_i = || z_i - U_norm U_norm^T z_i ||, L(x) = span{ z_i : r_i >= tau }.
Because lesions are rare, PCA over the whole bank approximates the normal manifold, so the
top-`rank` principal directions are U_norm and the lesion subspace is the residual (orthogonal
complement) where pathology concentrates.
membership_score(z) = residual norm ||(I - U U^T) z|| (HIGH residual => HIGH lesion score).
project(Z) = (I - U U^T) Z, the projection onto the lesion (residual) subspace. Label-free:
only token-feature geometry is used; tau is a quantile of residuals, not a label.
"""
from __future__ import annotations
import numpy as np
import torch
from data.leak_guard import subspace_construction_guard
class ResidualSubspace:
def __init__(self, rank: int = 64, tau_quantile: float = 0.9,
reference_size: int = 200_000, seed: int = 0):
self.rank = rank
self.tau_quantile = tau_quantile
self.reference_size = reference_size
self.seed = seed
self.U_norm_: torch.Tensor | None = None # (d, rank) normal-manifold basis
self.mean_: torch.Tensor | None = None
self.tau_: float | None = None
self.P_L_: torch.Tensor | None = None # (d, d) residual projector I - UU^T
def fit(self, token_bank: torch.Tensor) -> "ResidualSubspace":
with subspace_construction_guard(): # no label/mask may be read in here
X = token_bank.float()
rng = np.random.default_rng(self.seed)
if X.shape[0] > self.reference_size:
idx = torch.from_numpy(
rng.choice(X.shape[0], self.reference_size, replace=False))
Xs = X[idx]
else:
Xs = X
self.mean_ = Xs.mean(dim=0, keepdim=True)
Xc = Xs - self.mean_
# PCA via SVD: top-`rank` right singular vectors = normal manifold U_norm
_, _, Vt = torch.linalg.svd(Xc, full_matrices=False)
U = Vt[: self.rank].T.contiguous() # (d, rank)
self.U_norm_ = U
d = X.shape[1]
self.P_L_ = torch.eye(d) - U @ U.T # residual projector
res = self._residual(Xs)
self.tau_ = float(torch.quantile(res, self.tau_quantile))
return self
def _residual(self, Z: torch.Tensor) -> torch.Tensor:
Zc = Z.float() - self.mean_
proj = Zc @ self.U_norm_ @ self.U_norm_.T
return (Zc - proj).norm(dim=1)
def membership_score(self, Z: torch.Tensor) -> torch.Tensor:
"""Per-token lesion score = normal-manifold residual norm (higher => more lesion-like)."""
assert self.U_norm_ is not None, "fit() first"
return self._residual(Z).cpu()
def membership_score_torch(self, Z: torch.Tensor, device=None) -> torch.Tensor:
"""GPU/torch residual scoring for large eval sets."""
assert self.U_norm_ is not None, "fit() first"
device = device or Z.device
U = self.U_norm_.to(device)
mean = self.mean_.to(device)
Zc = Z.float().to(device) - mean
proj = Zc @ U @ U.T
return (Zc - proj).norm(dim=1).cpu()
def project(self, Z: torch.Tensor) -> torch.Tensor:
assert self.P_L_ is not None, "fit() first"
return Z.float() @ self.P_L_.T.to(Z.device)

Xet Storage Details

Size:
3.56 kB
·
Xet hash:
74ecc0d3ca022051a40ccbc47de21b90e3422610200ecc59a0c1697077837d43

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.