| """preprocess a raw perturb-seq .h5ad (scperturb format) into a compact, |
| reusable cache. |
| |
| does normalization, hvg selection, and the pca cell-state encoder (the default |
| invertible phi used for gene-space metrics). |
| |
| outputs, under data/processed/<name>/: |
| meta.json dataset summary (counts, operation, control label, ...) |
| genes_hvg.txt hvg gene symbols (defines gene-space for de metrics) |
| Xhvg.npz scipy csr, log1p(cp10k) on hvgs (n_cells x n_hvg) |
| obs.parquet per-cell metadata + parsed perturbation |
| pca_emb.npy pca cell-state embedding (n_cells x d) |
| pca_components.npy pca basis (d x n_hvg) |
| pca_mean.npy feature means (n_hvg,) |
| pseudobulk.npz per-perturbation mean hvg-log vectors + control mean |
| |
| the pca is fit on log1p(cp10k) hvg features (zero-centered, unscaled) so that |
| x_hat = emb @ components + mean reconstructs gene-space expression - this is what |
| lets us evaluate de-gene correlation on predictions made in embedding space. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import os |
| import sys |
|
|
| import numpy as np |
| import pandas as pd |
| import scipy.sparse as sp |
|
|
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) |
| from src.utils.common import save_json, set_seed |
|
|
|
|
| def parse_perturbation(label: str, control_label: str, sep: str = "_"): |
| """return list of perturbed genes for a perturbation label ([] for control).""" |
| if str(label) == control_label: |
| return [] |
| return [g for g in str(label).split(sep) if g and g != control_label] |
|
|
|
|
| def preprocess( |
| raw_path: str, |
| out_dir: str, |
| name: str, |
| operation: str, |
| control_label: str = "control", |
| pert_col: str = "perturbation", |
| batch_col: str | None = "gemgroup", |
| celltype_col: str | None = "celltype", |
| sep: str = "_", |
| n_hvg: int = 2000, |
| n_pca: int = 50, |
| min_cells_per_pert: int = 20, |
| max_cells: int | None = None, |
| seed: int = 0, |
| ): |
| """preprocess one dataset. operation is the crispr modality, e.g. |
| 'activation' (crispra, norman) or 'interference' (crispri, replogle).""" |
| import scanpy as sc |
|
|
| set_seed(seed) |
| os.makedirs(out_dir, exist_ok=True) |
| print(f"[{name}] reading {raw_path}") |
| adata = sc.read_h5ad(raw_path) |
|
|
| |
| if max_cells is not None and adata.n_obs > max_cells: |
| rng = np.random.default_rng(seed) |
| idx = np.sort(rng.choice(adata.n_obs, size=max_cells, replace=False)) |
| adata = adata[idx].copy() |
| print(f"[{name}] subsampled to {adata.n_obs} cells") |
|
|
| |
| if pert_col not in adata.obs: |
| raise KeyError(f"pert_col '{pert_col}' not in obs: {list(adata.obs.columns)}") |
| if batch_col is not None and batch_col not in adata.obs: |
| print(f"[{name}] batch_col '{batch_col}' missing -> single batch") |
| batch_col = None |
| if celltype_col is not None and celltype_col not in adata.obs: |
| celltype_col = None |
|
|
| |
| if not sp.issparse(adata.X): |
| adata.X = sp.csr_matrix(adata.X) |
| |
| x0 = adata.X[:200].toarray() |
| looks_counts = np.allclose(x0, np.round(x0)) and x0.max() > 30 |
| if looks_counts: |
| sc.pp.normalize_total(adata, target_sum=1e4) |
| sc.pp.log1p(adata) |
| print(f"[{name}] applied CP10k + log1p") |
| else: |
| print(f"[{name}] data appears pre-normalized (max={x0.max():.2f}); skipping norm") |
|
|
| |
| sc.pp.highly_variable_genes(adata, n_top_genes=n_hvg, flavor="seurat") |
| adata = adata[:, adata.var["highly_variable"]].copy() |
| genes = list(map(str, adata.var_names)) |
| Xhvg = adata.X.tocsr().astype(np.float32) |
| print(f"[{name}] HVG matrix: {Xhvg.shape}") |
|
|
| |
| from sklearn.decomposition import PCA |
|
|
| Xd = Xhvg.toarray() |
| mean = Xd.mean(axis=0).astype(np.float32) |
| pca = PCA(n_components=n_pca, random_state=seed) |
| emb = pca.fit_transform(Xd - mean).astype(np.float32) |
| components = pca.components_.astype(np.float32) |
| evr = float(pca.explained_variance_ratio_.sum()) |
| print(f"[{name}] PCA d={n_pca} explains {evr:.3f} variance") |
|
|
| |
| pert = adata.obs[pert_col].astype(str).values |
| genes_per_cell = [parse_perturbation(p, control_label, sep) for p in pert] |
| nperts = np.array([len(g) for g in genes_per_cell]) |
| is_control = nperts == 0 |
|
|
| obs = pd.DataFrame( |
| { |
| "perturbation": pert, |
| "n_pert_genes": nperts, |
| "is_control": is_control, |
| "pert_genes": [";".join(g) for g in genes_per_cell], |
| "batch": (adata.obs[batch_col].astype(str).values if batch_col else "0"), |
| "celltype": ( |
| adata.obs[celltype_col].astype(str).values if celltype_col else name |
| ), |
| } |
| ) |
|
|
| |
| vc = obs.loc[~obs.is_control, "perturbation"].value_counts() |
| keep_perts = set(vc[vc >= min_cells_per_pert].index) |
| keep_mask = obs.is_control.values | obs.perturbation.isin(keep_perts).values |
| if keep_mask.sum() < len(obs): |
| obs = obs.loc[keep_mask].reset_index(drop=True) |
| Xhvg = Xhvg[keep_mask] |
| emb = emb[keep_mask] |
| print(f"[{name}] dropped low-count perts -> {keep_mask.sum()} cells, " |
| f"{len(keep_perts)} perturbations") |
|
|
| |
| control_mean = np.asarray(Xhvg[obs.is_control.values].mean(axis=0)).ravel().astype(np.float32) |
| pb_labels, pb_vecs = [], [] |
| for p, sub in obs.groupby("perturbation"): |
| if p == control_label: |
| continue |
| idx = sub.index.values |
| pb_labels.append(p) |
| pb_vecs.append(np.asarray(Xhvg[idx].mean(axis=0)).ravel()) |
| pb_vecs = np.asarray(pb_vecs, dtype=np.float32) |
|
|
| |
| sp.save_npz(os.path.join(out_dir, "Xhvg.npz"), Xhvg) |
| np.save(os.path.join(out_dir, "pca_emb.npy"), emb) |
| np.save(os.path.join(out_dir, "pca_components.npy"), components) |
| np.save(os.path.join(out_dir, "pca_mean.npy"), mean) |
| np.savez( |
| os.path.join(out_dir, "pseudobulk.npz"), |
| labels=np.array(pb_labels), |
| vecs=pb_vecs, |
| control_mean=control_mean, |
| ) |
| obs.to_parquet(os.path.join(out_dir, "obs.parquet")) |
| with open(os.path.join(out_dir, "genes_hvg.txt"), "w") as f: |
| f.write("\n".join(genes)) |
|
|
| singles = [p for p in pb_labels if len(parse_perturbation(p, control_label, sep)) == 1] |
| combos = [p for p in pb_labels if len(parse_perturbation(p, control_label, sep)) >= 2] |
| all_genes = sorted({g for p in pb_labels for g in parse_perturbation(p, control_label, sep)}) |
| meta = { |
| "name": name, |
| "operation": operation, |
| "control_label": control_label, |
| "sep": sep, |
| "n_cells": int(obs.shape[0]), |
| "n_control": int(obs.is_control.sum()), |
| "n_hvg": len(genes), |
| "n_pca": n_pca, |
| "pca_explained_var": evr, |
| "n_perturbations": len(pb_labels), |
| "n_singles": len(singles), |
| "n_combos": len(combos), |
| "n_unique_target_genes": len(all_genes), |
| "n_batches": int(obs.batch.nunique()), |
| "n_celltypes": int(obs.celltype.nunique()), |
| } |
| save_json(meta, os.path.join(out_dir, "meta.json")) |
| print(f"[{name}] DONE -> {out_dir}") |
| print(meta) |
| return meta |
|
|
|
|
| DATASETS = { |
| "norman": dict( |
| raw="data/raw/NormanWeissman2019_filtered.h5ad", |
| name="norman", operation="activation", control_label="control", |
| batch_col="gemgroup", celltype_col="celltype", |
| ), |
| "replogle_k562": dict( |
| raw="data/raw/ReplogleWeissman2022_K562_essential.h5ad", |
| name="replogle_k562", operation="interference", control_label="control", |
| batch_col="gemgroup", celltype_col="celltype", max_cells=120000, |
| ), |
| "replogle_rpe1": dict( |
| raw="data/raw/ReplogleWeissman2022_rpe1.h5ad", |
| name="replogle_rpe1", operation="interference", control_label="control", |
| batch_col="gemgroup", celltype_col="celltype", max_cells=120000, |
| ), |
| } |
|
|
|
|
| if __name__ == "__main__": |
| ap = argparse.ArgumentParser() |
| ap.add_argument("dataset", choices=list(DATASETS.keys())) |
| ap.add_argument("--out-root", default="data/processed") |
| ap.add_argument("--n-hvg", type=int, default=2000) |
| ap.add_argument("--n-pca", type=int, default=50) |
| args = ap.parse_args() |
| cfg = dict(DATASETS[args.dataset]) |
| raw = cfg.pop("raw") |
| out = os.path.join(args.out_root, cfg["name"]) |
| preprocess(raw, out, n_hvg=args.n_hvg, n_pca=args.n_pca, **cfg) |
|
|