Chucks90's picture
download
raw
4.45 kB
"""Construction A — density-sparse lesion subspace (DACS-style), label-free.
Formalization §2: lesions are rare, so lesion-bearing tokens occupy locally sparse regions
of feature space relative to abundant normal-tissue tokens. Estimate token density rho(z_i)
via k-NN distance over a held-out CT token bank; the candidate lesion subspace is spanned by
the low-density tokens:
L(x) = span{ z_i : rho(z_i) <= q_alpha }
with q_alpha the alpha-quantile of density (alpha ~ 0.1).
membership_score(z) = mean k-NN distance to the bank (LOW density => HIGH lesion score),
which is what Gate 1 correlates against held-out lesion masks. project(Z) = P_L Z where P_L
is the orthonormal projector onto the top-`rank` principal directions of the low-density
bank tokens. Strictly label-free: only pixels-derived token features + geometry are used.
"""
from __future__ import annotations
import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from data.leak_guard import subspace_construction_guard
class DensitySubspace:
def __init__(self, alpha: float = 0.1, k: int = 10, rank: int = 64,
reference_size: int = 100_000, seed: int = 0):
self.alpha = alpha
self.k = k
self.rank = rank
self.reference_size = reference_size
self.seed = seed
self.reference_: np.ndarray | None = None
self.nn_: NearestNeighbors | None = None
self.q_alpha_: float | None = None
self.P_L_: torch.Tensor | None = None # (d, d) projector
def fit(self, token_bank: torch.Tensor) -> "DensitySubspace":
with subspace_construction_guard(): # no label/mask may be read in here
X = token_bank.float().cpu().numpy()
rng = np.random.default_rng(self.seed)
if X.shape[0] > self.reference_size:
idx = rng.choice(X.shape[0], self.reference_size, replace=False)
ref = X[idx]
else:
ref = X
self.reference_ = ref
self.nn_ = NearestNeighbors(n_neighbors=self.k + 1).fit(ref)
# density proxy: mean distance to k nearest neighbors within the bank
d, _ = self.nn_.kneighbors(ref)
dens = d[:, 1:].mean(axis=1) # exclude self
self.q_alpha_ = float(np.quantile(dens, 1.0 - self.alpha)) # high-dist threshold
low_density = ref[dens >= self.q_alpha_] # sparse (lesion-candidate) tokens
# principal directions spanning the low-density tokens -> L(x)
Xc = low_density - low_density.mean(axis=0, keepdims=True)
U, S, Vt = np.linalg.svd(Xc, full_matrices=False)
basis = torch.from_numpy(Vt[: self.rank]).float() # (r, d)
self.P_L_ = basis.T @ basis # (d, d) projector onto span
return self
def membership_score(self, Z: torch.Tensor) -> torch.Tensor:
"""Per-token lesion score = mean k-NN distance to the bank (sparser => higher)."""
assert self.nn_ is not None, "fit() first"
d, _ = self.nn_.kneighbors(Z.float().cpu().numpy())
return torch.from_numpy(d.mean(axis=1)).float()
def membership_score_torch(self, Z: torch.Tensor, device=None,
ref_chunk: int = 20000) -> torch.Tensor:
"""GPU/torch equivalent of membership_score: mean of k smallest distances to the
reference bank, computed with chunked torch.cdist (fast on GPU for large eval sets)."""
assert self.reference_ is not None, "fit() first"
device = device or Z.device
ref = torch.as_tensor(self.reference_, dtype=torch.float32, device=device)
q = Z.float().to(device)
out = torch.full((q.shape[0],), float("inf"), device=device)
# accumulate the k smallest distances across reference chunks
kth = torch.empty((q.shape[0], 0), device=device)
for i in range(0, ref.shape[0], ref_chunk):
dchunk = torch.cdist(q, ref[i:i + ref_chunk]) # (Nq, chunk)
kth = torch.cat([kth, dchunk], dim=1)
if kth.shape[1] > self.k:
kth, _ = torch.topk(kth, self.k, dim=1, largest=False)
kmin, _ = torch.topk(kth, min(self.k, kth.shape[1]), dim=1, largest=False)
return kmin.mean(dim=1).cpu()
def project(self, Z: torch.Tensor) -> torch.Tensor:
assert self.P_L_ is not None, "fit() first"
return Z.float() @ self.P_L_.T.to(Z.device)

Xet Storage Details

Size:
4.45 kB
·
Xet hash:
76b65d9dffcbe8f5f74449a0866b749473c90df58953e8c0709387a1292a6d56

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.