File size: 8,904 Bytes
3b4941f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""preprocess a raw perturb-seq .h5ad (scperturb format) into a compact,
reusable cache.

does normalization, hvg selection, and the pca cell-state encoder (the default
invertible phi used for gene-space metrics).

outputs, under data/processed/<name>/:
  meta.json            dataset summary (counts, operation, control label, ...)
  genes_hvg.txt        hvg gene symbols (defines gene-space for de metrics)
  Xhvg.npz             scipy csr, log1p(cp10k) on hvgs  (n_cells x n_hvg)
  obs.parquet          per-cell metadata + parsed perturbation
  pca_emb.npy          pca cell-state embedding          (n_cells x d)
  pca_components.npy   pca basis                          (d x n_hvg)
  pca_mean.npy         feature means                      (n_hvg,)
  pseudobulk.npz       per-perturbation mean hvg-log vectors + control mean

the pca is fit on log1p(cp10k) hvg features (zero-centered, unscaled) so that
x_hat = emb @ components + mean reconstructs gene-space expression - this is what
lets us evaluate de-gene correlation on predictions made in embedding space.
"""
from __future__ import annotations

import argparse
import os
import sys

import numpy as np
import pandas as pd
import scipy.sparse as sp

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from src.utils.common import save_json, set_seed  # noqa: E402


def parse_perturbation(label: str, control_label: str, sep: str = "_"):
    """return list of perturbed genes for a perturbation label ([] for control)."""
    if str(label) == control_label:
        return []
    return [g for g in str(label).split(sep) if g and g != control_label]


def preprocess(
    raw_path: str,
    out_dir: str,
    name: str,
    operation: str,
    control_label: str = "control",
    pert_col: str = "perturbation",
    batch_col: str | None = "gemgroup",
    celltype_col: str | None = "celltype",
    sep: str = "_",
    n_hvg: int = 2000,
    n_pca: int = 50,
    min_cells_per_pert: int = 20,
    max_cells: int | None = None,
    seed: int = 0,
):
    """preprocess one dataset. operation is the crispr modality, e.g.
    'activation' (crispra, norman) or 'interference' (crispri, replogle)."""
    import scanpy as sc

    set_seed(seed)
    os.makedirs(out_dir, exist_ok=True)
    print(f"[{name}] reading {raw_path}")
    adata = sc.read_h5ad(raw_path)

    # optional subsample for tractability (deterministic)
    if max_cells is not None and adata.n_obs > max_cells:
        rng = np.random.default_rng(seed)
        idx = np.sort(rng.choice(adata.n_obs, size=max_cells, replace=False))
        adata = adata[idx].copy()
        print(f"[{name}] subsampled to {adata.n_obs} cells")

    # resolve columns with graceful fallback
    if pert_col not in adata.obs:
        raise KeyError(f"pert_col '{pert_col}' not in obs: {list(adata.obs.columns)}")
    if batch_col is not None and batch_col not in adata.obs:
        print(f"[{name}] batch_col '{batch_col}' missing -> single batch")
        batch_col = None
    if celltype_col is not None and celltype_col not in adata.obs:
        celltype_col = None

    # --- normalization: cp10k + log1p ---
    if not sp.issparse(adata.X):
        adata.X = sp.csr_matrix(adata.X)
    # guard: ensure raw-ish counts (integer). if already normalized, skip.
    x0 = adata.X[:200].toarray()
    looks_counts = np.allclose(x0, np.round(x0)) and x0.max() > 30
    if looks_counts:
        sc.pp.normalize_total(adata, target_sum=1e4)
        sc.pp.log1p(adata)
        print(f"[{name}] applied CP10k + log1p")
    else:
        print(f"[{name}] data appears pre-normalized (max={x0.max():.2f}); skipping norm")

    # --- hvg selection (on all cells) ---
    sc.pp.highly_variable_genes(adata, n_top_genes=n_hvg, flavor="seurat")
    adata = adata[:, adata.var["highly_variable"]].copy()
    genes = list(map(str, adata.var_names))
    Xhvg = adata.X.tocsr().astype(np.float32)
    print(f"[{name}] HVG matrix: {Xhvg.shape}")

    # --- pca (centered, invertible) ---
    from sklearn.decomposition import PCA

    Xd = Xhvg.toarray()
    mean = Xd.mean(axis=0).astype(np.float32)
    pca = PCA(n_components=n_pca, random_state=seed)
    emb = pca.fit_transform(Xd - mean).astype(np.float32)
    components = pca.components_.astype(np.float32)  # (d, h)
    evr = float(pca.explained_variance_ratio_.sum())
    print(f"[{name}] PCA d={n_pca} explains {evr:.3f} variance")

    # --- parse perturbations ---
    pert = adata.obs[pert_col].astype(str).values
    genes_per_cell = [parse_perturbation(p, control_label, sep) for p in pert]
    nperts = np.array([len(g) for g in genes_per_cell])
    is_control = nperts == 0

    obs = pd.DataFrame(
        {
            "perturbation": pert,
            "n_pert_genes": nperts,
            "is_control": is_control,
            "pert_genes": [";".join(g) for g in genes_per_cell],
            "batch": (adata.obs[batch_col].astype(str).values if batch_col else "0"),
            "celltype": (
                adata.obs[celltype_col].astype(str).values if celltype_col else name
            ),
        }
    )

    # drop perturbations (non-control) with too few cells
    vc = obs.loc[~obs.is_control, "perturbation"].value_counts()
    keep_perts = set(vc[vc >= min_cells_per_pert].index)
    keep_mask = obs.is_control.values | obs.perturbation.isin(keep_perts).values
    if keep_mask.sum() < len(obs):
        obs = obs.loc[keep_mask].reset_index(drop=True)
        Xhvg = Xhvg[keep_mask]
        emb = emb[keep_mask]
        print(f"[{name}] dropped low-count perts -> {keep_mask.sum()} cells, "
              f"{len(keep_perts)} perturbations")

    # --- pseudobulk per perturbation (gene space, hvg-log) ---
    control_mean = np.asarray(Xhvg[obs.is_control.values].mean(axis=0)).ravel().astype(np.float32)
    pb_labels, pb_vecs = [], []
    for p, sub in obs.groupby("perturbation"):
        if p == control_label:
            continue
        idx = sub.index.values
        pb_labels.append(p)
        pb_vecs.append(np.asarray(Xhvg[idx].mean(axis=0)).ravel())
    pb_vecs = np.asarray(pb_vecs, dtype=np.float32)

    # --- write cache ---
    sp.save_npz(os.path.join(out_dir, "Xhvg.npz"), Xhvg)
    np.save(os.path.join(out_dir, "pca_emb.npy"), emb)
    np.save(os.path.join(out_dir, "pca_components.npy"), components)
    np.save(os.path.join(out_dir, "pca_mean.npy"), mean)
    np.savez(
        os.path.join(out_dir, "pseudobulk.npz"),
        labels=np.array(pb_labels),
        vecs=pb_vecs,
        control_mean=control_mean,
    )
    obs.to_parquet(os.path.join(out_dir, "obs.parquet"))
    with open(os.path.join(out_dir, "genes_hvg.txt"), "w") as f:
        f.write("\n".join(genes))

    singles = [p for p in pb_labels if len(parse_perturbation(p, control_label, sep)) == 1]
    combos = [p for p in pb_labels if len(parse_perturbation(p, control_label, sep)) >= 2]
    all_genes = sorted({g for p in pb_labels for g in parse_perturbation(p, control_label, sep)})
    meta = {
        "name": name,
        "operation": operation,
        "control_label": control_label,
        "sep": sep,
        "n_cells": int(obs.shape[0]),
        "n_control": int(obs.is_control.sum()),
        "n_hvg": len(genes),
        "n_pca": n_pca,
        "pca_explained_var": evr,
        "n_perturbations": len(pb_labels),
        "n_singles": len(singles),
        "n_combos": len(combos),
        "n_unique_target_genes": len(all_genes),
        "n_batches": int(obs.batch.nunique()),
        "n_celltypes": int(obs.celltype.nunique()),
    }
    save_json(meta, os.path.join(out_dir, "meta.json"))
    print(f"[{name}] DONE -> {out_dir}")
    print(meta)
    return meta


DATASETS = {
    "norman": dict(
        raw="data/raw/NormanWeissman2019_filtered.h5ad",
        name="norman", operation="activation", control_label="control",
        batch_col="gemgroup", celltype_col="celltype",
    ),
    "replogle_k562": dict(
        raw="data/raw/ReplogleWeissman2022_K562_essential.h5ad",
        name="replogle_k562", operation="interference", control_label="control",
        batch_col="gemgroup", celltype_col="celltype", max_cells=120000,
    ),
    "replogle_rpe1": dict(
        raw="data/raw/ReplogleWeissman2022_rpe1.h5ad",
        name="replogle_rpe1", operation="interference", control_label="control",
        batch_col="gemgroup", celltype_col="celltype", max_cells=120000,
    ),
}


if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("dataset", choices=list(DATASETS.keys()))
    ap.add_argument("--out-root", default="data/processed")
    ap.add_argument("--n-hvg", type=int, default=2000)
    ap.add_argument("--n-pca", type=int, default=50)
    args = ap.parse_args()
    cfg = dict(DATASETS[args.dataset])
    raw = cfg.pop("raw")
    out = os.path.join(args.out_root, cfg["name"])
    preprocess(raw, out, n_hvg=args.n_hvg, n_pca=args.n_pca, **cfg)