"""Construction A — density-sparse lesion subspace (DACS-style), label-free. Formalization §2: lesions are rare, so lesion-bearing tokens occupy locally sparse regions of feature space relative to abundant normal-tissue tokens. Estimate token density rho(z_i) via k-NN distance over a held-out CT token bank; the candidate lesion subspace is spanned by the low-density tokens: L(x) = span{ z_i : rho(z_i) <= q_alpha } with q_alpha the alpha-quantile of density (alpha ~ 0.1). membership_score(z) = mean k-NN distance to the bank (LOW density => HIGH lesion score), which is what Gate 1 correlates against held-out lesion masks. project(Z) = P_L Z where P_L is the orthonormal projector onto the top-`rank` principal directions of the low-density bank tokens. Strictly label-free: only pixels-derived token features + geometry are used. """ from __future__ import annotations import numpy as np import torch from sklearn.neighbors import NearestNeighbors from data.leak_guard import subspace_construction_guard class DensitySubspace: def __init__(self, alpha: float = 0.1, k: int = 10, rank: int = 64, reference_size: int = 100_000, seed: int = 0): self.alpha = alpha self.k = k self.rank = rank self.reference_size = reference_size self.seed = seed self.reference_: np.ndarray | None = None self.nn_: NearestNeighbors | None = None self.q_alpha_: float | None = None self.P_L_: torch.Tensor | None = None # (d, d) projector def fit(self, token_bank: torch.Tensor) -> "DensitySubspace": with subspace_construction_guard(): # no label/mask may be read in here X = token_bank.float().cpu().numpy() rng = np.random.default_rng(self.seed) if X.shape[0] > self.reference_size: idx = rng.choice(X.shape[0], self.reference_size, replace=False) ref = X[idx] else: ref = X self.reference_ = ref self.nn_ = NearestNeighbors(n_neighbors=self.k + 1).fit(ref) # density proxy: mean distance to k nearest neighbors within the bank d, _ = self.nn_.kneighbors(ref) dens = d[:, 1:].mean(axis=1) # exclude self self.q_alpha_ = float(np.quantile(dens, 1.0 - self.alpha)) # high-dist threshold low_density = ref[dens >= self.q_alpha_] # sparse (lesion-candidate) tokens # principal directions spanning the low-density tokens -> L(x) Xc = low_density - low_density.mean(axis=0, keepdims=True) U, S, Vt = np.linalg.svd(Xc, full_matrices=False) basis = torch.from_numpy(Vt[: self.rank]).float() # (r, d) self.P_L_ = basis.T @ basis # (d, d) projector onto span return self def membership_score(self, Z: torch.Tensor) -> torch.Tensor: """Per-token lesion score = mean k-NN distance to the bank (sparser => higher).""" assert self.nn_ is not None, "fit() first" d, _ = self.nn_.kneighbors(Z.float().cpu().numpy()) return torch.from_numpy(d.mean(axis=1)).float() def membership_score_torch(self, Z: torch.Tensor, device=None, ref_chunk: int = 20000) -> torch.Tensor: """GPU/torch equivalent of membership_score: mean of k smallest distances to the reference bank, computed with chunked torch.cdist (fast on GPU for large eval sets).""" assert self.reference_ is not None, "fit() first" device = device or Z.device ref = torch.as_tensor(self.reference_, dtype=torch.float32, device=device) q = Z.float().to(device) out = torch.full((q.shape[0],), float("inf"), device=device) # accumulate the k smallest distances across reference chunks kth = torch.empty((q.shape[0], 0), device=device) for i in range(0, ref.shape[0], ref_chunk): dchunk = torch.cdist(q, ref[i:i + ref_chunk]) # (Nq, chunk) kth = torch.cat([kth, dchunk], dim=1) if kth.shape[1] > self.k: kth, _ = torch.topk(kth, self.k, dim=1, largest=False) kmin, _ = torch.topk(kth, min(self.k, kth.shape[1]), dim=1, largest=False) return kmin.mean(dim=1).cpu() def project(self, Z: torch.Tensor) -> torch.Tensor: assert self.P_L_ is not None, "fit() first" return Z.float() @ self.P_L_.T.to(Z.device)