PIVOT / src /data /perturb_data.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
9.88 kB
"""perturbdata: in-memory representation of a preprocessed perturb-seq dataset.
holds the data model, perturbation vocabulary, and configurable control-matching
(Table 18).
"""
from __future__ import annotations
import os
import numpy as np
import pandas as pd
import scipy.sparse as sp
from src.utils.common import load_json
MATCH_STRATEGIES = (
"random",
"batch",
"celltype",
"batch_celltype",
"nearest",
"ot", # sinkhorn ot coupling control<->perturbed (distribution-preserving)
)
class PerturbData:
def __init__(self, cache_dir: str, embedding: str = "pca"):
self.dir = cache_dir
self.meta = load_json(os.path.join(cache_dir, "meta.json"))
self.obs = pd.read_parquet(os.path.join(cache_dir, "obs.parquet"))
self.genes = open(os.path.join(cache_dir, "genes_hvg.txt")).read().split("\n")
self.Xhvg = sp.load_npz(os.path.join(cache_dir, "Xhvg.npz")).tocsr()
self.pca_components = np.load(os.path.join(cache_dir, "pca_components.npy"))
self.pca_mean = np.load(os.path.join(cache_dir, "pca_mean.npy"))
pb = np.load(os.path.join(cache_dir, "pseudobulk.npz"), allow_pickle=True)
self.pb_labels = list(map(str, pb["labels"]))
self.pb_vecs = pb["vecs"].astype(np.float32)
self.control_mean = pb["control_mean"].astype(np.float32)
self.embedding = embedding
self.emb = self._load_embedding(embedding)
self.d = self.emb.shape[1]
self.sep = self.meta["sep"]
self.control_label = self.meta["control_label"]
self.operation = self.meta["operation"]
self.is_control = self.obs["is_control"].values
self.control_idx = np.where(self.is_control)[0]
self.batch = self.obs["batch"].values
self.celltype = self.obs["celltype"].values
# perturbation -> row indices
self.pert_to_idx: dict[str, np.ndarray] = {
p: sub.index.values
for p, sub in self.obs.groupby("perturbation")
if p != self.control_label
}
self.perturbations = sorted(self.pert_to_idx.keys())
# vocabulary: genes and operations
self.genes_vocab = sorted({g for p in self.perturbations for g in self.parse(p)})
self.gene_to_id = {g: i for i, g in enumerate(self.genes_vocab)}
# operations: one modality per dataset, plus a 'control'/none slot id 0
self.op_vocab = ["none", self.operation]
self.op_to_id = {o: i for i, o in enumerate(self.op_vocab)}
self.singles = [p for p in self.perturbations if len(self.parse(p)) == 1]
self.combos = [p for p in self.perturbations if len(self.parse(p)) >= 2]
self._pb_index = {p: i for i, p in enumerate(self.pb_labels)}
self._nn_control_cache: dict[str, np.ndarray] = {}
# ---- embeddings ----
def _load_embedding(self, embedding: str) -> np.ndarray:
path = os.path.join(self.dir, f"emb_{embedding}.npy")
if embedding == "pca":
path = os.path.join(self.dir, "pca_emb.npy")
if not os.path.exists(path):
raise FileNotFoundError(
f"embedding '{embedding}' not found ({path}); build it first"
)
return np.load(path).astype(np.float32)
def decode_to_genes(self, emb: np.ndarray) -> np.ndarray:
"""decode embedding(s) back to hvg gene-space (only exact for pca)."""
if self.embedding != "pca":
raise NotImplementedError(
f"gene-space decode only defined for PCA, not '{self.embedding}'"
)
return emb @ self.pca_components + self.pca_mean
# ---- perturbation parsing / encoding ----
def parse(self, label: str) -> list[str]:
if str(label) == self.control_label:
return []
return [g for g in str(label).split(self.sep) if g and g != self.control_label]
def pert_gene_op_ids(self, label: str):
"""return (gene_ids, op_ids) arrays for a perturbation label."""
genes = self.parse(label)
gids = np.array([self.gene_to_id[g] for g in genes if g in self.gene_to_id], dtype=np.int64)
oids = np.full(len(gids), self.op_to_id[self.operation], dtype=np.int64)
return gids, oids
# ---- pseudobulk / effects (gene space, hvg-log) ----
def effect_vector(self, label: str) -> np.ndarray:
"""true perturbation effect = mean(perturbed) - mean(control), gene space."""
return self.pb_vecs[self._pb_index[label]] - self.control_mean
def all_effects(self) -> tuple[list[str], np.ndarray]:
return self.pb_labels, self.pb_vecs - self.control_mean[None, :]
# ---- control matching (Table 18) ----
def sample_controls(self, target_idx: np.ndarray, strategy: str, rng: np.random.Generator):
"""for each perturbed cell in target_idx return a matched control row index."""
if strategy not in MATCH_STRATEGIES:
raise ValueError(f"unknown matching strategy {strategy}")
cidx = self.control_idx
if strategy == "random":
return rng.choice(cidx, size=len(target_idx), replace=True)
if strategy == "nearest":
return self._nearest_controls(target_idx)
if strategy == "ot":
if not hasattr(self, "_ot_map"):
self.precompute_ot_matching()
return np.array([self._ot_map.get(int(i), self.control_idx[rng.integers(len(self.control_idx))])
for i in target_idx], dtype=np.int64)
# bucketed matching by batch / celltype / both
def key(i):
if strategy == "batch":
return self.batch[i]
if strategy == "celltype":
return self.celltype[i]
return (self.batch[i], self.celltype[i])
buckets: dict = {}
for i in cidx:
buckets.setdefault(key(i), []).append(i)
buckets = {k: np.asarray(v) for k, v in buckets.items()}
out = np.empty(len(target_idx), dtype=np.int64)
for j, i in enumerate(target_idx):
pool = buckets.get(key(i))
if pool is None or len(pool) == 0:
pool = cidx # fall back to any control
out[j] = pool[rng.integers(len(pool))]
return out
def _nearest_controls(self, target_idx: np.ndarray) -> np.ndarray:
from sklearn.neighbors import NearestNeighbors
nn = NearestNeighbors(n_neighbors=1).fit(self.emb[self.control_idx])
_, j = nn.kneighbors(self.emb[target_idx])
return self.control_idx[j.ravel()]
def precompute_ot_matching(self, max_ctrl: int = 800, max_pert: int = 1200,
eps: float = 0.05, iters: int = 150, seed: int = 0):
"""for each perturbation, couple its cells to control cells via entropic ot
(sinkhorn) on embedding l2 cost, and assign each perturbed cell a control by
sampling its coupling row. distribution-preserving alternative to random
matching (cf. cellot / ot-cfm). caches self._ot_map (perturbed idx -> control idx)."""
import torch
rng = np.random.default_rng(seed)
dev = "cuda" if torch.cuda.is_available() else "cpu"
cidx = self.control_idx
csamp = cidx if len(cidx) <= max_ctrl else cidx[rng.choice(len(cidx), max_ctrl, replace=False)]
C = torch.as_tensor(self.emb[csamp], dtype=torch.float32, device=dev)
self._ot_map = {}
for p, idx in self.pert_to_idx.items():
t_idx = idx if len(idx) <= max_pert else idx[rng.choice(len(idx), max_pert, replace=False)]
T = torch.as_tensor(self.emb[t_idx], dtype=torch.float32, device=dev)
cost = torch.cdist(T, C).pow(2)
cost = cost / (cost.median() + 1e-8)
K = torch.exp(-cost / eps)
n, m = K.shape
u = torch.ones(n, device=dev) / n
v = torch.ones(m, device=dev) / m
a = torch.full((n,), 1.0 / n, device=dev)
b = torch.full((m,), 1.0 / m, device=dev)
for _ in range(iters):
u = a / (K @ v + 1e-8)
v = b / (K.t() @ u + 1e-8)
P = (u.unsqueeze(1) * K) * v.unsqueeze(0) # coupling (n, m)
P = P / (P.sum(1, keepdim=True) + 1e-12)
# sample a control per perturbed cell from its coupling row
choice = torch.multinomial(P, 1, generator=None).squeeze(1).cpu().numpy()
for c_local, cell in zip(choice, t_idx):
self._ot_map[int(cell)] = int(csamp[c_local])
return self._ot_map
# ---- functional clusters (for pathway/functional recovery) ----
def functional_clusters(self, n_clusters: int = 15, seed: int = 0) -> dict[str, int]:
"""cluster single-gene perturbations by effect-vector correlation.
data-driven proxy for 'same pathway': perturbations with similar
transcriptional effects get grouped. used for functional top-k and
pathway-ndcg (Table 6)."""
from sklearn.cluster import AgglomerativeClustering
labels = self.singles
E = np.stack([self.effect_vector(p) for p in labels])
# correlation distance
En = E - E.mean(1, keepdims=True)
En = En / (np.linalg.norm(En, axis=1, keepdims=True) + 1e-8)
sim = np.clip(En @ En.T, -1, 1)
dist = 1 - sim
k = min(n_clusters, len(labels))
cl = AgglomerativeClustering(n_clusters=k, metric="precomputed", linkage="average")
ids = cl.fit_predict(dist)
# map by single-gene name
out = {}
for p, c in zip(labels, ids):
g = self.parse(p)[0]
out[g] = int(c)
return out
def load_dataset(name: str, embedding: str = "pca", root: str = "data/processed") -> PerturbData:
return PerturbData(os.path.join(root, name), embedding=embedding)