PIVOT / src /data /preprocess.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
8.9 kB
"""preprocess a raw perturb-seq .h5ad (scperturb format) into a compact,
reusable cache.
does normalization, hvg selection, and the pca cell-state encoder (the default
invertible phi used for gene-space metrics).
outputs, under data/processed/<name>/:
meta.json dataset summary (counts, operation, control label, ...)
genes_hvg.txt hvg gene symbols (defines gene-space for de metrics)
Xhvg.npz scipy csr, log1p(cp10k) on hvgs (n_cells x n_hvg)
obs.parquet per-cell metadata + parsed perturbation
pca_emb.npy pca cell-state embedding (n_cells x d)
pca_components.npy pca basis (d x n_hvg)
pca_mean.npy feature means (n_hvg,)
pseudobulk.npz per-perturbation mean hvg-log vectors + control mean
the pca is fit on log1p(cp10k) hvg features (zero-centered, unscaled) so that
x_hat = emb @ components + mean reconstructs gene-space expression - this is what
lets us evaluate de-gene correlation on predictions made in embedding space.
"""
from __future__ import annotations
import argparse
import os
import sys
import numpy as np
import pandas as pd
import scipy.sparse as sp
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from src.utils.common import save_json, set_seed # noqa: E402
def parse_perturbation(label: str, control_label: str, sep: str = "_"):
"""return list of perturbed genes for a perturbation label ([] for control)."""
if str(label) == control_label:
return []
return [g for g in str(label).split(sep) if g and g != control_label]
def preprocess(
raw_path: str,
out_dir: str,
name: str,
operation: str,
control_label: str = "control",
pert_col: str = "perturbation",
batch_col: str | None = "gemgroup",
celltype_col: str | None = "celltype",
sep: str = "_",
n_hvg: int = 2000,
n_pca: int = 50,
min_cells_per_pert: int = 20,
max_cells: int | None = None,
seed: int = 0,
):
"""preprocess one dataset. operation is the crispr modality, e.g.
'activation' (crispra, norman) or 'interference' (crispri, replogle)."""
import scanpy as sc
set_seed(seed)
os.makedirs(out_dir, exist_ok=True)
print(f"[{name}] reading {raw_path}")
adata = sc.read_h5ad(raw_path)
# optional subsample for tractability (deterministic)
if max_cells is not None and adata.n_obs > max_cells:
rng = np.random.default_rng(seed)
idx = np.sort(rng.choice(adata.n_obs, size=max_cells, replace=False))
adata = adata[idx].copy()
print(f"[{name}] subsampled to {adata.n_obs} cells")
# resolve columns with graceful fallback
if pert_col not in adata.obs:
raise KeyError(f"pert_col '{pert_col}' not in obs: {list(adata.obs.columns)}")
if batch_col is not None and batch_col not in adata.obs:
print(f"[{name}] batch_col '{batch_col}' missing -> single batch")
batch_col = None
if celltype_col is not None and celltype_col not in adata.obs:
celltype_col = None
# --- normalization: cp10k + log1p ---
if not sp.issparse(adata.X):
adata.X = sp.csr_matrix(adata.X)
# guard: ensure raw-ish counts (integer). if already normalized, skip.
x0 = adata.X[:200].toarray()
looks_counts = np.allclose(x0, np.round(x0)) and x0.max() > 30
if looks_counts:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
print(f"[{name}] applied CP10k + log1p")
else:
print(f"[{name}] data appears pre-normalized (max={x0.max():.2f}); skipping norm")
# --- hvg selection (on all cells) ---
sc.pp.highly_variable_genes(adata, n_top_genes=n_hvg, flavor="seurat")
adata = adata[:, adata.var["highly_variable"]].copy()
genes = list(map(str, adata.var_names))
Xhvg = adata.X.tocsr().astype(np.float32)
print(f"[{name}] HVG matrix: {Xhvg.shape}")
# --- pca (centered, invertible) ---
from sklearn.decomposition import PCA
Xd = Xhvg.toarray()
mean = Xd.mean(axis=0).astype(np.float32)
pca = PCA(n_components=n_pca, random_state=seed)
emb = pca.fit_transform(Xd - mean).astype(np.float32)
components = pca.components_.astype(np.float32) # (d, h)
evr = float(pca.explained_variance_ratio_.sum())
print(f"[{name}] PCA d={n_pca} explains {evr:.3f} variance")
# --- parse perturbations ---
pert = adata.obs[pert_col].astype(str).values
genes_per_cell = [parse_perturbation(p, control_label, sep) for p in pert]
nperts = np.array([len(g) for g in genes_per_cell])
is_control = nperts == 0
obs = pd.DataFrame(
{
"perturbation": pert,
"n_pert_genes": nperts,
"is_control": is_control,
"pert_genes": [";".join(g) for g in genes_per_cell],
"batch": (adata.obs[batch_col].astype(str).values if batch_col else "0"),
"celltype": (
adata.obs[celltype_col].astype(str).values if celltype_col else name
),
}
)
# drop perturbations (non-control) with too few cells
vc = obs.loc[~obs.is_control, "perturbation"].value_counts()
keep_perts = set(vc[vc >= min_cells_per_pert].index)
keep_mask = obs.is_control.values | obs.perturbation.isin(keep_perts).values
if keep_mask.sum() < len(obs):
obs = obs.loc[keep_mask].reset_index(drop=True)
Xhvg = Xhvg[keep_mask]
emb = emb[keep_mask]
print(f"[{name}] dropped low-count perts -> {keep_mask.sum()} cells, "
f"{len(keep_perts)} perturbations")
# --- pseudobulk per perturbation (gene space, hvg-log) ---
control_mean = np.asarray(Xhvg[obs.is_control.values].mean(axis=0)).ravel().astype(np.float32)
pb_labels, pb_vecs = [], []
for p, sub in obs.groupby("perturbation"):
if p == control_label:
continue
idx = sub.index.values
pb_labels.append(p)
pb_vecs.append(np.asarray(Xhvg[idx].mean(axis=0)).ravel())
pb_vecs = np.asarray(pb_vecs, dtype=np.float32)
# --- write cache ---
sp.save_npz(os.path.join(out_dir, "Xhvg.npz"), Xhvg)
np.save(os.path.join(out_dir, "pca_emb.npy"), emb)
np.save(os.path.join(out_dir, "pca_components.npy"), components)
np.save(os.path.join(out_dir, "pca_mean.npy"), mean)
np.savez(
os.path.join(out_dir, "pseudobulk.npz"),
labels=np.array(pb_labels),
vecs=pb_vecs,
control_mean=control_mean,
)
obs.to_parquet(os.path.join(out_dir, "obs.parquet"))
with open(os.path.join(out_dir, "genes_hvg.txt"), "w") as f:
f.write("\n".join(genes))
singles = [p for p in pb_labels if len(parse_perturbation(p, control_label, sep)) == 1]
combos = [p for p in pb_labels if len(parse_perturbation(p, control_label, sep)) >= 2]
all_genes = sorted({g for p in pb_labels for g in parse_perturbation(p, control_label, sep)})
meta = {
"name": name,
"operation": operation,
"control_label": control_label,
"sep": sep,
"n_cells": int(obs.shape[0]),
"n_control": int(obs.is_control.sum()),
"n_hvg": len(genes),
"n_pca": n_pca,
"pca_explained_var": evr,
"n_perturbations": len(pb_labels),
"n_singles": len(singles),
"n_combos": len(combos),
"n_unique_target_genes": len(all_genes),
"n_batches": int(obs.batch.nunique()),
"n_celltypes": int(obs.celltype.nunique()),
}
save_json(meta, os.path.join(out_dir, "meta.json"))
print(f"[{name}] DONE -> {out_dir}")
print(meta)
return meta
DATASETS = {
"norman": dict(
raw="data/raw/NormanWeissman2019_filtered.h5ad",
name="norman", operation="activation", control_label="control",
batch_col="gemgroup", celltype_col="celltype",
),
"replogle_k562": dict(
raw="data/raw/ReplogleWeissman2022_K562_essential.h5ad",
name="replogle_k562", operation="interference", control_label="control",
batch_col="gemgroup", celltype_col="celltype", max_cells=120000,
),
"replogle_rpe1": dict(
raw="data/raw/ReplogleWeissman2022_rpe1.h5ad",
name="replogle_rpe1", operation="interference", control_label="control",
batch_col="gemgroup", celltype_col="celltype", max_cells=120000,
),
}
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("dataset", choices=list(DATASETS.keys()))
ap.add_argument("--out-root", default="data/processed")
ap.add_argument("--n-hvg", type=int, default=2000)
ap.add_argument("--n-pca", type=int, default=50)
args = ap.parse_args()
cfg = dict(DATASETS[args.dataset])
raw = cfg.pop("raw")
out = os.path.join(args.out_root, cfg["name"])
preprocess(raw, out, n_hvg=args.n_hvg, n_pca=args.n_pca, **cfg)