| """perturbdata: in-memory representation of a preprocessed perturb-seq dataset. |
| |
| holds the data model, perturbation vocabulary, and configurable control-matching |
| (Table 18). |
| """ |
| from __future__ import annotations |
|
|
| import os |
|
|
| import numpy as np |
| import pandas as pd |
| import scipy.sparse as sp |
|
|
| from src.utils.common import load_json |
|
|
| MATCH_STRATEGIES = ( |
| "random", |
| "batch", |
| "celltype", |
| "batch_celltype", |
| "nearest", |
| "ot", |
| ) |
|
|
|
|
| class PerturbData: |
| def __init__(self, cache_dir: str, embedding: str = "pca"): |
| self.dir = cache_dir |
| self.meta = load_json(os.path.join(cache_dir, "meta.json")) |
| self.obs = pd.read_parquet(os.path.join(cache_dir, "obs.parquet")) |
| self.genes = open(os.path.join(cache_dir, "genes_hvg.txt")).read().split("\n") |
| self.Xhvg = sp.load_npz(os.path.join(cache_dir, "Xhvg.npz")).tocsr() |
| self.pca_components = np.load(os.path.join(cache_dir, "pca_components.npy")) |
| self.pca_mean = np.load(os.path.join(cache_dir, "pca_mean.npy")) |
| pb = np.load(os.path.join(cache_dir, "pseudobulk.npz"), allow_pickle=True) |
| self.pb_labels = list(map(str, pb["labels"])) |
| self.pb_vecs = pb["vecs"].astype(np.float32) |
| self.control_mean = pb["control_mean"].astype(np.float32) |
|
|
| self.embedding = embedding |
| self.emb = self._load_embedding(embedding) |
| self.d = self.emb.shape[1] |
|
|
| self.sep = self.meta["sep"] |
| self.control_label = self.meta["control_label"] |
| self.operation = self.meta["operation"] |
|
|
| self.is_control = self.obs["is_control"].values |
| self.control_idx = np.where(self.is_control)[0] |
| self.batch = self.obs["batch"].values |
| self.celltype = self.obs["celltype"].values |
|
|
| |
| self.pert_to_idx: dict[str, np.ndarray] = { |
| p: sub.index.values |
| for p, sub in self.obs.groupby("perturbation") |
| if p != self.control_label |
| } |
| self.perturbations = sorted(self.pert_to_idx.keys()) |
|
|
| |
| self.genes_vocab = sorted({g for p in self.perturbations for g in self.parse(p)}) |
| self.gene_to_id = {g: i for i, g in enumerate(self.genes_vocab)} |
| |
| self.op_vocab = ["none", self.operation] |
| self.op_to_id = {o: i for i, o in enumerate(self.op_vocab)} |
|
|
| self.singles = [p for p in self.perturbations if len(self.parse(p)) == 1] |
| self.combos = [p for p in self.perturbations if len(self.parse(p)) >= 2] |
|
|
| self._pb_index = {p: i for i, p in enumerate(self.pb_labels)} |
| self._nn_control_cache: dict[str, np.ndarray] = {} |
|
|
| |
| def _load_embedding(self, embedding: str) -> np.ndarray: |
| path = os.path.join(self.dir, f"emb_{embedding}.npy") |
| if embedding == "pca": |
| path = os.path.join(self.dir, "pca_emb.npy") |
| if not os.path.exists(path): |
| raise FileNotFoundError( |
| f"embedding '{embedding}' not found ({path}); build it first" |
| ) |
| return np.load(path).astype(np.float32) |
|
|
| def decode_to_genes(self, emb: np.ndarray) -> np.ndarray: |
| """decode embedding(s) back to hvg gene-space (only exact for pca).""" |
| if self.embedding != "pca": |
| raise NotImplementedError( |
| f"gene-space decode only defined for PCA, not '{self.embedding}'" |
| ) |
| return emb @ self.pca_components + self.pca_mean |
|
|
| |
| def parse(self, label: str) -> list[str]: |
| if str(label) == self.control_label: |
| return [] |
| return [g for g in str(label).split(self.sep) if g and g != self.control_label] |
|
|
| def pert_gene_op_ids(self, label: str): |
| """return (gene_ids, op_ids) arrays for a perturbation label.""" |
| genes = self.parse(label) |
| gids = np.array([self.gene_to_id[g] for g in genes if g in self.gene_to_id], dtype=np.int64) |
| oids = np.full(len(gids), self.op_to_id[self.operation], dtype=np.int64) |
| return gids, oids |
|
|
| |
| def effect_vector(self, label: str) -> np.ndarray: |
| """true perturbation effect = mean(perturbed) - mean(control), gene space.""" |
| return self.pb_vecs[self._pb_index[label]] - self.control_mean |
|
|
| def all_effects(self) -> tuple[list[str], np.ndarray]: |
| return self.pb_labels, self.pb_vecs - self.control_mean[None, :] |
|
|
| |
| def sample_controls(self, target_idx: np.ndarray, strategy: str, rng: np.random.Generator): |
| """for each perturbed cell in target_idx return a matched control row index.""" |
| if strategy not in MATCH_STRATEGIES: |
| raise ValueError(f"unknown matching strategy {strategy}") |
| cidx = self.control_idx |
| if strategy == "random": |
| return rng.choice(cidx, size=len(target_idx), replace=True) |
|
|
| if strategy == "nearest": |
| return self._nearest_controls(target_idx) |
|
|
| if strategy == "ot": |
| if not hasattr(self, "_ot_map"): |
| self.precompute_ot_matching() |
| return np.array([self._ot_map.get(int(i), self.control_idx[rng.integers(len(self.control_idx))]) |
| for i in target_idx], dtype=np.int64) |
|
|
| |
| def key(i): |
| if strategy == "batch": |
| return self.batch[i] |
| if strategy == "celltype": |
| return self.celltype[i] |
| return (self.batch[i], self.celltype[i]) |
|
|
| buckets: dict = {} |
| for i in cidx: |
| buckets.setdefault(key(i), []).append(i) |
| buckets = {k: np.asarray(v) for k, v in buckets.items()} |
| out = np.empty(len(target_idx), dtype=np.int64) |
| for j, i in enumerate(target_idx): |
| pool = buckets.get(key(i)) |
| if pool is None or len(pool) == 0: |
| pool = cidx |
| out[j] = pool[rng.integers(len(pool))] |
| return out |
|
|
| def _nearest_controls(self, target_idx: np.ndarray) -> np.ndarray: |
| from sklearn.neighbors import NearestNeighbors |
|
|
| nn = NearestNeighbors(n_neighbors=1).fit(self.emb[self.control_idx]) |
| _, j = nn.kneighbors(self.emb[target_idx]) |
| return self.control_idx[j.ravel()] |
|
|
| def precompute_ot_matching(self, max_ctrl: int = 800, max_pert: int = 1200, |
| eps: float = 0.05, iters: int = 150, seed: int = 0): |
| """for each perturbation, couple its cells to control cells via entropic ot |
| (sinkhorn) on embedding l2 cost, and assign each perturbed cell a control by |
| sampling its coupling row. distribution-preserving alternative to random |
| matching (cf. cellot / ot-cfm). caches self._ot_map (perturbed idx -> control idx).""" |
| import torch |
|
|
| rng = np.random.default_rng(seed) |
| dev = "cuda" if torch.cuda.is_available() else "cpu" |
| cidx = self.control_idx |
| csamp = cidx if len(cidx) <= max_ctrl else cidx[rng.choice(len(cidx), max_ctrl, replace=False)] |
| C = torch.as_tensor(self.emb[csamp], dtype=torch.float32, device=dev) |
| self._ot_map = {} |
| for p, idx in self.pert_to_idx.items(): |
| t_idx = idx if len(idx) <= max_pert else idx[rng.choice(len(idx), max_pert, replace=False)] |
| T = torch.as_tensor(self.emb[t_idx], dtype=torch.float32, device=dev) |
| cost = torch.cdist(T, C).pow(2) |
| cost = cost / (cost.median() + 1e-8) |
| K = torch.exp(-cost / eps) |
| n, m = K.shape |
| u = torch.ones(n, device=dev) / n |
| v = torch.ones(m, device=dev) / m |
| a = torch.full((n,), 1.0 / n, device=dev) |
| b = torch.full((m,), 1.0 / m, device=dev) |
| for _ in range(iters): |
| u = a / (K @ v + 1e-8) |
| v = b / (K.t() @ u + 1e-8) |
| P = (u.unsqueeze(1) * K) * v.unsqueeze(0) |
| P = P / (P.sum(1, keepdim=True) + 1e-12) |
| |
| choice = torch.multinomial(P, 1, generator=None).squeeze(1).cpu().numpy() |
| for c_local, cell in zip(choice, t_idx): |
| self._ot_map[int(cell)] = int(csamp[c_local]) |
| return self._ot_map |
|
|
| |
| def functional_clusters(self, n_clusters: int = 15, seed: int = 0) -> dict[str, int]: |
| """cluster single-gene perturbations by effect-vector correlation. |
| |
| data-driven proxy for 'same pathway': perturbations with similar |
| transcriptional effects get grouped. used for functional top-k and |
| pathway-ndcg (Table 6).""" |
| from sklearn.cluster import AgglomerativeClustering |
|
|
| labels = self.singles |
| E = np.stack([self.effect_vector(p) for p in labels]) |
| |
| En = E - E.mean(1, keepdims=True) |
| En = En / (np.linalg.norm(En, axis=1, keepdims=True) + 1e-8) |
| sim = np.clip(En @ En.T, -1, 1) |
| dist = 1 - sim |
| k = min(n_clusters, len(labels)) |
| cl = AgglomerativeClustering(n_clusters=k, metric="precomputed", linkage="average") |
| ids = cl.fit_predict(dist) |
| |
| out = {} |
| for p, c in zip(labels, ids): |
| g = self.parse(p)[0] |
| out[g] = int(c) |
| return out |
|
|
|
|
| def load_dataset(name: str, embedding: str = "pca", root: str = "data/processed") -> PerturbData: |
| return PerturbData(os.path.join(root, name), embedding=embedding) |
|
|