File size: 8,904 Bytes
3b4941f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | """preprocess a raw perturb-seq .h5ad (scperturb format) into a compact,
reusable cache.
does normalization, hvg selection, and the pca cell-state encoder (the default
invertible phi used for gene-space metrics).
outputs, under data/processed/<name>/:
meta.json dataset summary (counts, operation, control label, ...)
genes_hvg.txt hvg gene symbols (defines gene-space for de metrics)
Xhvg.npz scipy csr, log1p(cp10k) on hvgs (n_cells x n_hvg)
obs.parquet per-cell metadata + parsed perturbation
pca_emb.npy pca cell-state embedding (n_cells x d)
pca_components.npy pca basis (d x n_hvg)
pca_mean.npy feature means (n_hvg,)
pseudobulk.npz per-perturbation mean hvg-log vectors + control mean
the pca is fit on log1p(cp10k) hvg features (zero-centered, unscaled) so that
x_hat = emb @ components + mean reconstructs gene-space expression - this is what
lets us evaluate de-gene correlation on predictions made in embedding space.
"""
from __future__ import annotations
import argparse
import os
import sys
import numpy as np
import pandas as pd
import scipy.sparse as sp
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from src.utils.common import save_json, set_seed # noqa: E402
def parse_perturbation(label: str, control_label: str, sep: str = "_"):
"""return list of perturbed genes for a perturbation label ([] for control)."""
if str(label) == control_label:
return []
return [g for g in str(label).split(sep) if g and g != control_label]
def preprocess(
raw_path: str,
out_dir: str,
name: str,
operation: str,
control_label: str = "control",
pert_col: str = "perturbation",
batch_col: str | None = "gemgroup",
celltype_col: str | None = "celltype",
sep: str = "_",
n_hvg: int = 2000,
n_pca: int = 50,
min_cells_per_pert: int = 20,
max_cells: int | None = None,
seed: int = 0,
):
"""preprocess one dataset. operation is the crispr modality, e.g.
'activation' (crispra, norman) or 'interference' (crispri, replogle)."""
import scanpy as sc
set_seed(seed)
os.makedirs(out_dir, exist_ok=True)
print(f"[{name}] reading {raw_path}")
adata = sc.read_h5ad(raw_path)
# optional subsample for tractability (deterministic)
if max_cells is not None and adata.n_obs > max_cells:
rng = np.random.default_rng(seed)
idx = np.sort(rng.choice(adata.n_obs, size=max_cells, replace=False))
adata = adata[idx].copy()
print(f"[{name}] subsampled to {adata.n_obs} cells")
# resolve columns with graceful fallback
if pert_col not in adata.obs:
raise KeyError(f"pert_col '{pert_col}' not in obs: {list(adata.obs.columns)}")
if batch_col is not None and batch_col not in adata.obs:
print(f"[{name}] batch_col '{batch_col}' missing -> single batch")
batch_col = None
if celltype_col is not None and celltype_col not in adata.obs:
celltype_col = None
# --- normalization: cp10k + log1p ---
if not sp.issparse(adata.X):
adata.X = sp.csr_matrix(adata.X)
# guard: ensure raw-ish counts (integer). if already normalized, skip.
x0 = adata.X[:200].toarray()
looks_counts = np.allclose(x0, np.round(x0)) and x0.max() > 30
if looks_counts:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
print(f"[{name}] applied CP10k + log1p")
else:
print(f"[{name}] data appears pre-normalized (max={x0.max():.2f}); skipping norm")
# --- hvg selection (on all cells) ---
sc.pp.highly_variable_genes(adata, n_top_genes=n_hvg, flavor="seurat")
adata = adata[:, adata.var["highly_variable"]].copy()
genes = list(map(str, adata.var_names))
Xhvg = adata.X.tocsr().astype(np.float32)
print(f"[{name}] HVG matrix: {Xhvg.shape}")
# --- pca (centered, invertible) ---
from sklearn.decomposition import PCA
Xd = Xhvg.toarray()
mean = Xd.mean(axis=0).astype(np.float32)
pca = PCA(n_components=n_pca, random_state=seed)
emb = pca.fit_transform(Xd - mean).astype(np.float32)
components = pca.components_.astype(np.float32) # (d, h)
evr = float(pca.explained_variance_ratio_.sum())
print(f"[{name}] PCA d={n_pca} explains {evr:.3f} variance")
# --- parse perturbations ---
pert = adata.obs[pert_col].astype(str).values
genes_per_cell = [parse_perturbation(p, control_label, sep) for p in pert]
nperts = np.array([len(g) for g in genes_per_cell])
is_control = nperts == 0
obs = pd.DataFrame(
{
"perturbation": pert,
"n_pert_genes": nperts,
"is_control": is_control,
"pert_genes": [";".join(g) for g in genes_per_cell],
"batch": (adata.obs[batch_col].astype(str).values if batch_col else "0"),
"celltype": (
adata.obs[celltype_col].astype(str).values if celltype_col else name
),
}
)
# drop perturbations (non-control) with too few cells
vc = obs.loc[~obs.is_control, "perturbation"].value_counts()
keep_perts = set(vc[vc >= min_cells_per_pert].index)
keep_mask = obs.is_control.values | obs.perturbation.isin(keep_perts).values
if keep_mask.sum() < len(obs):
obs = obs.loc[keep_mask].reset_index(drop=True)
Xhvg = Xhvg[keep_mask]
emb = emb[keep_mask]
print(f"[{name}] dropped low-count perts -> {keep_mask.sum()} cells, "
f"{len(keep_perts)} perturbations")
# --- pseudobulk per perturbation (gene space, hvg-log) ---
control_mean = np.asarray(Xhvg[obs.is_control.values].mean(axis=0)).ravel().astype(np.float32)
pb_labels, pb_vecs = [], []
for p, sub in obs.groupby("perturbation"):
if p == control_label:
continue
idx = sub.index.values
pb_labels.append(p)
pb_vecs.append(np.asarray(Xhvg[idx].mean(axis=0)).ravel())
pb_vecs = np.asarray(pb_vecs, dtype=np.float32)
# --- write cache ---
sp.save_npz(os.path.join(out_dir, "Xhvg.npz"), Xhvg)
np.save(os.path.join(out_dir, "pca_emb.npy"), emb)
np.save(os.path.join(out_dir, "pca_components.npy"), components)
np.save(os.path.join(out_dir, "pca_mean.npy"), mean)
np.savez(
os.path.join(out_dir, "pseudobulk.npz"),
labels=np.array(pb_labels),
vecs=pb_vecs,
control_mean=control_mean,
)
obs.to_parquet(os.path.join(out_dir, "obs.parquet"))
with open(os.path.join(out_dir, "genes_hvg.txt"), "w") as f:
f.write("\n".join(genes))
singles = [p for p in pb_labels if len(parse_perturbation(p, control_label, sep)) == 1]
combos = [p for p in pb_labels if len(parse_perturbation(p, control_label, sep)) >= 2]
all_genes = sorted({g for p in pb_labels for g in parse_perturbation(p, control_label, sep)})
meta = {
"name": name,
"operation": operation,
"control_label": control_label,
"sep": sep,
"n_cells": int(obs.shape[0]),
"n_control": int(obs.is_control.sum()),
"n_hvg": len(genes),
"n_pca": n_pca,
"pca_explained_var": evr,
"n_perturbations": len(pb_labels),
"n_singles": len(singles),
"n_combos": len(combos),
"n_unique_target_genes": len(all_genes),
"n_batches": int(obs.batch.nunique()),
"n_celltypes": int(obs.celltype.nunique()),
}
save_json(meta, os.path.join(out_dir, "meta.json"))
print(f"[{name}] DONE -> {out_dir}")
print(meta)
return meta
DATASETS = {
"norman": dict(
raw="data/raw/NormanWeissman2019_filtered.h5ad",
name="norman", operation="activation", control_label="control",
batch_col="gemgroup", celltype_col="celltype",
),
"replogle_k562": dict(
raw="data/raw/ReplogleWeissman2022_K562_essential.h5ad",
name="replogle_k562", operation="interference", control_label="control",
batch_col="gemgroup", celltype_col="celltype", max_cells=120000,
),
"replogle_rpe1": dict(
raw="data/raw/ReplogleWeissman2022_rpe1.h5ad",
name="replogle_rpe1", operation="interference", control_label="control",
batch_col="gemgroup", celltype_col="celltype", max_cells=120000,
),
}
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("dataset", choices=list(DATASETS.keys()))
ap.add_argument("--out-root", default="data/processed")
ap.add_argument("--n-hvg", type=int, default=2000)
ap.add_argument("--n-pca", type=int, default=50)
args = ap.parse_args()
cfg = dict(DATASETS[args.dataset])
raw = cfg.pop("raw")
out = os.path.join(args.out_root, cfg["name"])
preprocess(raw, out, n_hvg=args.n_hvg, n_pca=args.n_pca, **cfg)
|