"""perturbdata: in-memory representation of a preprocessed perturb-seq dataset. holds the data model, perturbation vocabulary, and configurable control-matching (Table 18). """ from __future__ import annotations import os import numpy as np import pandas as pd import scipy.sparse as sp from src.utils.common import load_json MATCH_STRATEGIES = ( "random", "batch", "celltype", "batch_celltype", "nearest", "ot", # sinkhorn ot coupling control<->perturbed (distribution-preserving) ) class PerturbData: def __init__(self, cache_dir: str, embedding: str = "pca"): self.dir = cache_dir self.meta = load_json(os.path.join(cache_dir, "meta.json")) self.obs = pd.read_parquet(os.path.join(cache_dir, "obs.parquet")) self.genes = open(os.path.join(cache_dir, "genes_hvg.txt")).read().split("\n") self.Xhvg = sp.load_npz(os.path.join(cache_dir, "Xhvg.npz")).tocsr() self.pca_components = np.load(os.path.join(cache_dir, "pca_components.npy")) self.pca_mean = np.load(os.path.join(cache_dir, "pca_mean.npy")) pb = np.load(os.path.join(cache_dir, "pseudobulk.npz"), allow_pickle=True) self.pb_labels = list(map(str, pb["labels"])) self.pb_vecs = pb["vecs"].astype(np.float32) self.control_mean = pb["control_mean"].astype(np.float32) self.embedding = embedding self.emb = self._load_embedding(embedding) self.d = self.emb.shape[1] self.sep = self.meta["sep"] self.control_label = self.meta["control_label"] self.operation = self.meta["operation"] self.is_control = self.obs["is_control"].values self.control_idx = np.where(self.is_control)[0] self.batch = self.obs["batch"].values self.celltype = self.obs["celltype"].values # perturbation -> row indices self.pert_to_idx: dict[str, np.ndarray] = { p: sub.index.values for p, sub in self.obs.groupby("perturbation") if p != self.control_label } self.perturbations = sorted(self.pert_to_idx.keys()) # vocabulary: genes and operations self.genes_vocab = sorted({g for p in self.perturbations for g in self.parse(p)}) self.gene_to_id = {g: i for i, g in enumerate(self.genes_vocab)} # operations: one modality per dataset, plus a 'control'/none slot id 0 self.op_vocab = ["none", self.operation] self.op_to_id = {o: i for i, o in enumerate(self.op_vocab)} self.singles = [p for p in self.perturbations if len(self.parse(p)) == 1] self.combos = [p for p in self.perturbations if len(self.parse(p)) >= 2] self._pb_index = {p: i for i, p in enumerate(self.pb_labels)} self._nn_control_cache: dict[str, np.ndarray] = {} # ---- embeddings ---- def _load_embedding(self, embedding: str) -> np.ndarray: path = os.path.join(self.dir, f"emb_{embedding}.npy") if embedding == "pca": path = os.path.join(self.dir, "pca_emb.npy") if not os.path.exists(path): raise FileNotFoundError( f"embedding '{embedding}' not found ({path}); build it first" ) return np.load(path).astype(np.float32) def decode_to_genes(self, emb: np.ndarray) -> np.ndarray: """decode embedding(s) back to hvg gene-space (only exact for pca).""" if self.embedding != "pca": raise NotImplementedError( f"gene-space decode only defined for PCA, not '{self.embedding}'" ) return emb @ self.pca_components + self.pca_mean # ---- perturbation parsing / encoding ---- def parse(self, label: str) -> list[str]: if str(label) == self.control_label: return [] return [g for g in str(label).split(self.sep) if g and g != self.control_label] def pert_gene_op_ids(self, label: str): """return (gene_ids, op_ids) arrays for a perturbation label.""" genes = self.parse(label) gids = np.array([self.gene_to_id[g] for g in genes if g in self.gene_to_id], dtype=np.int64) oids = np.full(len(gids), self.op_to_id[self.operation], dtype=np.int64) return gids, oids # ---- pseudobulk / effects (gene space, hvg-log) ---- def effect_vector(self, label: str) -> np.ndarray: """true perturbation effect = mean(perturbed) - mean(control), gene space.""" return self.pb_vecs[self._pb_index[label]] - self.control_mean def all_effects(self) -> tuple[list[str], np.ndarray]: return self.pb_labels, self.pb_vecs - self.control_mean[None, :] # ---- control matching (Table 18) ---- def sample_controls(self, target_idx: np.ndarray, strategy: str, rng: np.random.Generator): """for each perturbed cell in target_idx return a matched control row index.""" if strategy not in MATCH_STRATEGIES: raise ValueError(f"unknown matching strategy {strategy}") cidx = self.control_idx if strategy == "random": return rng.choice(cidx, size=len(target_idx), replace=True) if strategy == "nearest": return self._nearest_controls(target_idx) if strategy == "ot": if not hasattr(self, "_ot_map"): self.precompute_ot_matching() return np.array([self._ot_map.get(int(i), self.control_idx[rng.integers(len(self.control_idx))]) for i in target_idx], dtype=np.int64) # bucketed matching by batch / celltype / both def key(i): if strategy == "batch": return self.batch[i] if strategy == "celltype": return self.celltype[i] return (self.batch[i], self.celltype[i]) buckets: dict = {} for i in cidx: buckets.setdefault(key(i), []).append(i) buckets = {k: np.asarray(v) for k, v in buckets.items()} out = np.empty(len(target_idx), dtype=np.int64) for j, i in enumerate(target_idx): pool = buckets.get(key(i)) if pool is None or len(pool) == 0: pool = cidx # fall back to any control out[j] = pool[rng.integers(len(pool))] return out def _nearest_controls(self, target_idx: np.ndarray) -> np.ndarray: from sklearn.neighbors import NearestNeighbors nn = NearestNeighbors(n_neighbors=1).fit(self.emb[self.control_idx]) _, j = nn.kneighbors(self.emb[target_idx]) return self.control_idx[j.ravel()] def precompute_ot_matching(self, max_ctrl: int = 800, max_pert: int = 1200, eps: float = 0.05, iters: int = 150, seed: int = 0): """for each perturbation, couple its cells to control cells via entropic ot (sinkhorn) on embedding l2 cost, and assign each perturbed cell a control by sampling its coupling row. distribution-preserving alternative to random matching (cf. cellot / ot-cfm). caches self._ot_map (perturbed idx -> control idx).""" import torch rng = np.random.default_rng(seed) dev = "cuda" if torch.cuda.is_available() else "cpu" cidx = self.control_idx csamp = cidx if len(cidx) <= max_ctrl else cidx[rng.choice(len(cidx), max_ctrl, replace=False)] C = torch.as_tensor(self.emb[csamp], dtype=torch.float32, device=dev) self._ot_map = {} for p, idx in self.pert_to_idx.items(): t_idx = idx if len(idx) <= max_pert else idx[rng.choice(len(idx), max_pert, replace=False)] T = torch.as_tensor(self.emb[t_idx], dtype=torch.float32, device=dev) cost = torch.cdist(T, C).pow(2) cost = cost / (cost.median() + 1e-8) K = torch.exp(-cost / eps) n, m = K.shape u = torch.ones(n, device=dev) / n v = torch.ones(m, device=dev) / m a = torch.full((n,), 1.0 / n, device=dev) b = torch.full((m,), 1.0 / m, device=dev) for _ in range(iters): u = a / (K @ v + 1e-8) v = b / (K.t() @ u + 1e-8) P = (u.unsqueeze(1) * K) * v.unsqueeze(0) # coupling (n, m) P = P / (P.sum(1, keepdim=True) + 1e-12) # sample a control per perturbed cell from its coupling row choice = torch.multinomial(P, 1, generator=None).squeeze(1).cpu().numpy() for c_local, cell in zip(choice, t_idx): self._ot_map[int(cell)] = int(csamp[c_local]) return self._ot_map # ---- functional clusters (for pathway/functional recovery) ---- def functional_clusters(self, n_clusters: int = 15, seed: int = 0) -> dict[str, int]: """cluster single-gene perturbations by effect-vector correlation. data-driven proxy for 'same pathway': perturbations with similar transcriptional effects get grouped. used for functional top-k and pathway-ndcg (Table 6).""" from sklearn.cluster import AgglomerativeClustering labels = self.singles E = np.stack([self.effect_vector(p) for p in labels]) # correlation distance En = E - E.mean(1, keepdims=True) En = En / (np.linalg.norm(En, axis=1, keepdims=True) + 1e-8) sim = np.clip(En @ En.T, -1, 1) dist = 1 - sim k = min(n_clusters, len(labels)) cl = AgglomerativeClustering(n_clusters=k, metric="precomputed", linkage="average") ids = cl.fit_predict(dist) # map by single-gene name out = {} for p, c in zip(labels, ids): g = self.parse(p)[0] out[g] = int(c) return out def load_dataset(name: str, embedding: str = "pca", root: str = "data/processed") -> PerturbData: return PerturbData(os.path.join(root, name), embedding=embedding)