File size: 3,442 Bytes
3b4941f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""alternative cell-state encoders phi (Table 16).

pca is the default (built in preprocess.py). this adds scvi, a standard deep
single-cell latent space. to keep cell ordering identical to the cached pca
embedding, we reproduce the exact filtered cell set from preprocess.py (same
seed, hvgs, min-cells filter), train scvi on the raw counts, and save
emb_scvi.npy aligned 1:1.
"""
from __future__ import annotations

import argparse
import os

import numpy as np
import scipy.sparse as sp

from src.data.preprocess import DATASETS, parse_perturbation
from src.utils.common import load_json, set_seed


def load_filtered_adata(dataset: str, n_hvg: int = 2000):
    """reproduce preprocess.py's filtered anndata (raw counts kept in .x), in the
    same cell order as the cached pca embedding."""
    import scanpy as sc

    cfg = dict(DATASETS[dataset])
    raw = cfg["raw"]
    seed = 0
    set_seed(seed)
    adata = sc.read_h5ad(raw)
    max_cells = cfg.get("max_cells")
    if max_cells is not None and adata.n_obs > max_cells:
        rng = np.random.default_rng(seed)
        idx = np.sort(rng.choice(adata.n_obs, size=max_cells, replace=False))
        adata = adata[idx].copy()
    adata.layers["counts"] = adata.X.copy()
    # normalize only to pick the same hvgs as preprocess
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, n_top_genes=n_hvg, flavor="seurat")
    adata = adata[:, adata.var["highly_variable"]].copy()
    # min-cells filter identical to preprocess (sep/pert_col use preprocess defaults)
    control_label = cfg["control_label"]; sep = cfg.get("sep", "_")
    pert = adata.obs[cfg.get("pert_col", "perturbation")].astype(str).values
    npg = np.array([len(parse_perturbation(p, control_label, sep)) for p in pert])
    is_ctrl = npg == 0
    import pandas as pd
    vc = pd.Series(pert[~is_ctrl]).value_counts()
    keep = set(vc[vc >= 20].index)
    mask = is_ctrl | np.isin(pert, list(keep))
    adata = adata[mask].copy()
    # restore raw counts into x
    adata.X = adata.layers["counts"]
    return adata


def build_scvi(dataset: str, n_latent: int = 50, max_epochs: int = 80, gpu: int = 0):
    import scvi
    import torch

    cache = os.path.join("data/processed", dataset)
    meta = load_json(os.path.join(cache, "meta.json"))
    adata = load_filtered_adata(dataset)
    n_cached = np.load(os.path.join(cache, "pca_emb.npy")).shape[0]
    assert adata.n_obs == n_cached, f"cell count mismatch {adata.n_obs} vs cached {n_cached}"

    scvi.settings.seed = 0
    scvi.model.SCVI.setup_anndata(adata, layer=None)
    model = scvi.model.SCVI(adata, n_latent=n_latent)
    model.to_device(gpu if torch.cuda.is_available() else "cpu")
    model.train(max_epochs=max_epochs, early_stopping=True,
                plan_kwargs={"lr": 1e-3}, enable_progress_bar=False)
    z = model.get_latent_representation().astype(np.float32)
    np.save(os.path.join(cache, "emb_scvi.npy"), z)
    print(f"[{dataset}] scVI latent {z.shape} -> {cache}/emb_scvi.npy")
    return z


if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("dataset")
    ap.add_argument("--gpu", type=int, default=0)
    ap.add_argument("--n-latent", type=int, default=50)
    ap.add_argument("--max-epochs", type=int, default=80)
    args = ap.parse_args()
    build_scvi(args.dataset, n_latent=args.n_latent, max_epochs=args.max_epochs, gpu=args.gpu)