Buckets:
| """Construction A — density-sparse lesion subspace (DACS-style), label-free. | |
| Formalization §2: lesions are rare, so lesion-bearing tokens occupy locally sparse regions | |
| of feature space relative to abundant normal-tissue tokens. Estimate token density rho(z_i) | |
| via k-NN distance over a held-out CT token bank; the candidate lesion subspace is spanned by | |
| the low-density tokens: | |
| L(x) = span{ z_i : rho(z_i) <= q_alpha } | |
| with q_alpha the alpha-quantile of density (alpha ~ 0.1). | |
| membership_score(z) = mean k-NN distance to the bank (LOW density => HIGH lesion score), | |
| which is what Gate 1 correlates against held-out lesion masks. project(Z) = P_L Z where P_L | |
| is the orthonormal projector onto the top-`rank` principal directions of the low-density | |
| bank tokens. Strictly label-free: only pixels-derived token features + geometry are used. | |
| """ | |
| from __future__ import annotations | |
| import numpy as np | |
| import torch | |
| from sklearn.neighbors import NearestNeighbors | |
| from data.leak_guard import subspace_construction_guard | |
| class DensitySubspace: | |
| def __init__(self, alpha: float = 0.1, k: int = 10, rank: int = 64, | |
| reference_size: int = 100_000, seed: int = 0): | |
| self.alpha = alpha | |
| self.k = k | |
| self.rank = rank | |
| self.reference_size = reference_size | |
| self.seed = seed | |
| self.reference_: np.ndarray | None = None | |
| self.nn_: NearestNeighbors | None = None | |
| self.q_alpha_: float | None = None | |
| self.P_L_: torch.Tensor | None = None # (d, d) projector | |
| def fit(self, token_bank: torch.Tensor) -> "DensitySubspace": | |
| with subspace_construction_guard(): # no label/mask may be read in here | |
| X = token_bank.float().cpu().numpy() | |
| rng = np.random.default_rng(self.seed) | |
| if X.shape[0] > self.reference_size: | |
| idx = rng.choice(X.shape[0], self.reference_size, replace=False) | |
| ref = X[idx] | |
| else: | |
| ref = X | |
| self.reference_ = ref | |
| self.nn_ = NearestNeighbors(n_neighbors=self.k + 1).fit(ref) | |
| # density proxy: mean distance to k nearest neighbors within the bank | |
| d, _ = self.nn_.kneighbors(ref) | |
| dens = d[:, 1:].mean(axis=1) # exclude self | |
| self.q_alpha_ = float(np.quantile(dens, 1.0 - self.alpha)) # high-dist threshold | |
| low_density = ref[dens >= self.q_alpha_] # sparse (lesion-candidate) tokens | |
| # principal directions spanning the low-density tokens -> L(x) | |
| Xc = low_density - low_density.mean(axis=0, keepdims=True) | |
| U, S, Vt = np.linalg.svd(Xc, full_matrices=False) | |
| basis = torch.from_numpy(Vt[: self.rank]).float() # (r, d) | |
| self.P_L_ = basis.T @ basis # (d, d) projector onto span | |
| return self | |
| def membership_score(self, Z: torch.Tensor) -> torch.Tensor: | |
| """Per-token lesion score = mean k-NN distance to the bank (sparser => higher).""" | |
| assert self.nn_ is not None, "fit() first" | |
| d, _ = self.nn_.kneighbors(Z.float().cpu().numpy()) | |
| return torch.from_numpy(d.mean(axis=1)).float() | |
| def membership_score_torch(self, Z: torch.Tensor, device=None, | |
| ref_chunk: int = 20000) -> torch.Tensor: | |
| """GPU/torch equivalent of membership_score: mean of k smallest distances to the | |
| reference bank, computed with chunked torch.cdist (fast on GPU for large eval sets).""" | |
| assert self.reference_ is not None, "fit() first" | |
| device = device or Z.device | |
| ref = torch.as_tensor(self.reference_, dtype=torch.float32, device=device) | |
| q = Z.float().to(device) | |
| out = torch.full((q.shape[0],), float("inf"), device=device) | |
| # accumulate the k smallest distances across reference chunks | |
| kth = torch.empty((q.shape[0], 0), device=device) | |
| for i in range(0, ref.shape[0], ref_chunk): | |
| dchunk = torch.cdist(q, ref[i:i + ref_chunk]) # (Nq, chunk) | |
| kth = torch.cat([kth, dchunk], dim=1) | |
| if kth.shape[1] > self.k: | |
| kth, _ = torch.topk(kth, self.k, dim=1, largest=False) | |
| kmin, _ = torch.topk(kth, min(self.k, kth.shape[1]), dim=1, largest=False) | |
| return kmin.mean(dim=1).cpu() | |
| def project(self, Z: torch.Tensor) -> torch.Tensor: | |
| assert self.P_L_ is not None, "fit() first" | |
| return Z.float() @ self.P_L_.T.to(Z.device) | |
Xet Storage Details
- Size:
- 4.45 kB
- Xet hash:
- 76b65d9dffcbe8f5f74449a0866b749473c90df58953e8c0709387a1292a6d56
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.