| """Construction A — density-sparse lesion subspace (DACS-style), label-free. |
| |
| Formalization §2: lesions are rare, so lesion-bearing tokens occupy locally sparse regions |
| of feature space relative to abundant normal-tissue tokens. Estimate token density rho(z_i) |
| via k-NN distance over a held-out CT token bank; the candidate lesion subspace is spanned by |
| the low-density tokens: |
| L(x) = span{ z_i : rho(z_i) <= q_alpha } |
| with q_alpha the alpha-quantile of density (alpha ~ 0.1). |
| |
| membership_score(z) = mean k-NN distance to the bank (LOW density => HIGH lesion score), |
| which is what Gate 1 correlates against held-out lesion masks. project(Z) = P_L Z where P_L |
| is the orthonormal projector onto the top-`rank` principal directions of the low-density |
| bank tokens. Strictly label-free: only pixels-derived token features + geometry are used. |
| """ |
| from __future__ import annotations |
|
|
| import numpy as np |
| import torch |
| from sklearn.neighbors import NearestNeighbors |
|
|
| from data.leak_guard import subspace_construction_guard |
|
|
|
|
| class DensitySubspace: |
| def __init__(self, alpha: float = 0.1, k: int = 10, rank: int = 64, |
| reference_size: int = 100_000, seed: int = 0): |
| self.alpha = alpha |
| self.k = k |
| self.rank = rank |
| self.reference_size = reference_size |
| self.seed = seed |
| self.reference_: np.ndarray | None = None |
| self.nn_: NearestNeighbors | None = None |
| self.q_alpha_: float | None = None |
| self.P_L_: torch.Tensor | None = None |
|
|
| def fit(self, token_bank: torch.Tensor) -> "DensitySubspace": |
| with subspace_construction_guard(): |
| X = token_bank.float().cpu().numpy() |
| rng = np.random.default_rng(self.seed) |
| if X.shape[0] > self.reference_size: |
| idx = rng.choice(X.shape[0], self.reference_size, replace=False) |
| ref = X[idx] |
| else: |
| ref = X |
| self.reference_ = ref |
| self.nn_ = NearestNeighbors(n_neighbors=self.k + 1).fit(ref) |
| |
| d, _ = self.nn_.kneighbors(ref) |
| dens = d[:, 1:].mean(axis=1) |
| self.q_alpha_ = float(np.quantile(dens, 1.0 - self.alpha)) |
| low_density = ref[dens >= self.q_alpha_] |
| |
| Xc = low_density - low_density.mean(axis=0, keepdims=True) |
| U, S, Vt = np.linalg.svd(Xc, full_matrices=False) |
| basis = torch.from_numpy(Vt[: self.rank]).float() |
| self.P_L_ = basis.T @ basis |
| return self |
|
|
| def membership_score(self, Z: torch.Tensor) -> torch.Tensor: |
| """Per-token lesion score = mean k-NN distance to the bank (sparser => higher).""" |
| assert self.nn_ is not None, "fit() first" |
| d, _ = self.nn_.kneighbors(Z.float().cpu().numpy()) |
| return torch.from_numpy(d.mean(axis=1)).float() |
|
|
| def membership_score_torch(self, Z: torch.Tensor, device=None, |
| ref_chunk: int = 20000) -> torch.Tensor: |
| """GPU/torch equivalent of membership_score: mean of k smallest distances to the |
| reference bank, computed with chunked torch.cdist (fast on GPU for large eval sets).""" |
| assert self.reference_ is not None, "fit() first" |
| device = device or Z.device |
| ref = torch.as_tensor(self.reference_, dtype=torch.float32, device=device) |
| q = Z.float().to(device) |
| out = torch.full((q.shape[0],), float("inf"), device=device) |
| |
| kth = torch.empty((q.shape[0], 0), device=device) |
| for i in range(0, ref.shape[0], ref_chunk): |
| dchunk = torch.cdist(q, ref[i:i + ref_chunk]) |
| kth = torch.cat([kth, dchunk], dim=1) |
| if kth.shape[1] > self.k: |
| kth, _ = torch.topk(kth, self.k, dim=1, largest=False) |
| kmin, _ = torch.topk(kth, min(self.k, kth.shape[1]), dim=1, largest=False) |
| return kmin.mean(dim=1).cpu() |
|
|
| def project(self, Z: torch.Tensor) -> torch.Tensor: |
| assert self.P_L_ is not None, "fit() first" |
| return Z.float() @ self.P_L_.T.to(Z.device) |
|
|