"""preprocess a raw perturb-seq .h5ad (scperturb format) into a compact, reusable cache. does normalization, hvg selection, and the pca cell-state encoder (the default invertible phi used for gene-space metrics). outputs, under data/processed//: meta.json dataset summary (counts, operation, control label, ...) genes_hvg.txt hvg gene symbols (defines gene-space for de metrics) Xhvg.npz scipy csr, log1p(cp10k) on hvgs (n_cells x n_hvg) obs.parquet per-cell metadata + parsed perturbation pca_emb.npy pca cell-state embedding (n_cells x d) pca_components.npy pca basis (d x n_hvg) pca_mean.npy feature means (n_hvg,) pseudobulk.npz per-perturbation mean hvg-log vectors + control mean the pca is fit on log1p(cp10k) hvg features (zero-centered, unscaled) so that x_hat = emb @ components + mean reconstructs gene-space expression - this is what lets us evaluate de-gene correlation on predictions made in embedding space. """ from __future__ import annotations import argparse import os import sys import numpy as np import pandas as pd import scipy.sparse as sp sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) from src.utils.common import save_json, set_seed # noqa: E402 def parse_perturbation(label: str, control_label: str, sep: str = "_"): """return list of perturbed genes for a perturbation label ([] for control).""" if str(label) == control_label: return [] return [g for g in str(label).split(sep) if g and g != control_label] def preprocess( raw_path: str, out_dir: str, name: str, operation: str, control_label: str = "control", pert_col: str = "perturbation", batch_col: str | None = "gemgroup", celltype_col: str | None = "celltype", sep: str = "_", n_hvg: int = 2000, n_pca: int = 50, min_cells_per_pert: int = 20, max_cells: int | None = None, seed: int = 0, ): """preprocess one dataset. operation is the crispr modality, e.g. 'activation' (crispra, norman) or 'interference' (crispri, replogle).""" import scanpy as sc set_seed(seed) os.makedirs(out_dir, exist_ok=True) print(f"[{name}] reading {raw_path}") adata = sc.read_h5ad(raw_path) # optional subsample for tractability (deterministic) if max_cells is not None and adata.n_obs > max_cells: rng = np.random.default_rng(seed) idx = np.sort(rng.choice(adata.n_obs, size=max_cells, replace=False)) adata = adata[idx].copy() print(f"[{name}] subsampled to {adata.n_obs} cells") # resolve columns with graceful fallback if pert_col not in adata.obs: raise KeyError(f"pert_col '{pert_col}' not in obs: {list(adata.obs.columns)}") if batch_col is not None and batch_col not in adata.obs: print(f"[{name}] batch_col '{batch_col}' missing -> single batch") batch_col = None if celltype_col is not None and celltype_col not in adata.obs: celltype_col = None # --- normalization: cp10k + log1p --- if not sp.issparse(adata.X): adata.X = sp.csr_matrix(adata.X) # guard: ensure raw-ish counts (integer). if already normalized, skip. x0 = adata.X[:200].toarray() looks_counts = np.allclose(x0, np.round(x0)) and x0.max() > 30 if looks_counts: sc.pp.normalize_total(adata, target_sum=1e4) sc.pp.log1p(adata) print(f"[{name}] applied CP10k + log1p") else: print(f"[{name}] data appears pre-normalized (max={x0.max():.2f}); skipping norm") # --- hvg selection (on all cells) --- sc.pp.highly_variable_genes(adata, n_top_genes=n_hvg, flavor="seurat") adata = adata[:, adata.var["highly_variable"]].copy() genes = list(map(str, adata.var_names)) Xhvg = adata.X.tocsr().astype(np.float32) print(f"[{name}] HVG matrix: {Xhvg.shape}") # --- pca (centered, invertible) --- from sklearn.decomposition import PCA Xd = Xhvg.toarray() mean = Xd.mean(axis=0).astype(np.float32) pca = PCA(n_components=n_pca, random_state=seed) emb = pca.fit_transform(Xd - mean).astype(np.float32) components = pca.components_.astype(np.float32) # (d, h) evr = float(pca.explained_variance_ratio_.sum()) print(f"[{name}] PCA d={n_pca} explains {evr:.3f} variance") # --- parse perturbations --- pert = adata.obs[pert_col].astype(str).values genes_per_cell = [parse_perturbation(p, control_label, sep) for p in pert] nperts = np.array([len(g) for g in genes_per_cell]) is_control = nperts == 0 obs = pd.DataFrame( { "perturbation": pert, "n_pert_genes": nperts, "is_control": is_control, "pert_genes": [";".join(g) for g in genes_per_cell], "batch": (adata.obs[batch_col].astype(str).values if batch_col else "0"), "celltype": ( adata.obs[celltype_col].astype(str).values if celltype_col else name ), } ) # drop perturbations (non-control) with too few cells vc = obs.loc[~obs.is_control, "perturbation"].value_counts() keep_perts = set(vc[vc >= min_cells_per_pert].index) keep_mask = obs.is_control.values | obs.perturbation.isin(keep_perts).values if keep_mask.sum() < len(obs): obs = obs.loc[keep_mask].reset_index(drop=True) Xhvg = Xhvg[keep_mask] emb = emb[keep_mask] print(f"[{name}] dropped low-count perts -> {keep_mask.sum()} cells, " f"{len(keep_perts)} perturbations") # --- pseudobulk per perturbation (gene space, hvg-log) --- control_mean = np.asarray(Xhvg[obs.is_control.values].mean(axis=0)).ravel().astype(np.float32) pb_labels, pb_vecs = [], [] for p, sub in obs.groupby("perturbation"): if p == control_label: continue idx = sub.index.values pb_labels.append(p) pb_vecs.append(np.asarray(Xhvg[idx].mean(axis=0)).ravel()) pb_vecs = np.asarray(pb_vecs, dtype=np.float32) # --- write cache --- sp.save_npz(os.path.join(out_dir, "Xhvg.npz"), Xhvg) np.save(os.path.join(out_dir, "pca_emb.npy"), emb) np.save(os.path.join(out_dir, "pca_components.npy"), components) np.save(os.path.join(out_dir, "pca_mean.npy"), mean) np.savez( os.path.join(out_dir, "pseudobulk.npz"), labels=np.array(pb_labels), vecs=pb_vecs, control_mean=control_mean, ) obs.to_parquet(os.path.join(out_dir, "obs.parquet")) with open(os.path.join(out_dir, "genes_hvg.txt"), "w") as f: f.write("\n".join(genes)) singles = [p for p in pb_labels if len(parse_perturbation(p, control_label, sep)) == 1] combos = [p for p in pb_labels if len(parse_perturbation(p, control_label, sep)) >= 2] all_genes = sorted({g for p in pb_labels for g in parse_perturbation(p, control_label, sep)}) meta = { "name": name, "operation": operation, "control_label": control_label, "sep": sep, "n_cells": int(obs.shape[0]), "n_control": int(obs.is_control.sum()), "n_hvg": len(genes), "n_pca": n_pca, "pca_explained_var": evr, "n_perturbations": len(pb_labels), "n_singles": len(singles), "n_combos": len(combos), "n_unique_target_genes": len(all_genes), "n_batches": int(obs.batch.nunique()), "n_celltypes": int(obs.celltype.nunique()), } save_json(meta, os.path.join(out_dir, "meta.json")) print(f"[{name}] DONE -> {out_dir}") print(meta) return meta DATASETS = { "norman": dict( raw="data/raw/NormanWeissman2019_filtered.h5ad", name="norman", operation="activation", control_label="control", batch_col="gemgroup", celltype_col="celltype", ), "replogle_k562": dict( raw="data/raw/ReplogleWeissman2022_K562_essential.h5ad", name="replogle_k562", operation="interference", control_label="control", batch_col="gemgroup", celltype_col="celltype", max_cells=120000, ), "replogle_rpe1": dict( raw="data/raw/ReplogleWeissman2022_rpe1.h5ad", name="replogle_rpe1", operation="interference", control_label="control", batch_col="gemgroup", celltype_col="celltype", max_cells=120000, ), } if __name__ == "__main__": ap = argparse.ArgumentParser() ap.add_argument("dataset", choices=list(DATASETS.keys())) ap.add_argument("--out-root", default="data/processed") ap.add_argument("--n-hvg", type=int, default=2000) ap.add_argument("--n-pca", type=int, default=50) args = ap.parse_args() cfg = dict(DATASETS[args.dataset]) raw = cfg.pop("raw") out = os.path.join(args.out_root, cfg["name"]) preprocess(raw, out, n_hvg=args.n_hvg, n_pca=args.n_pca, **cfg)