| """alternative cell-state encoders phi (Table 16). |
| |
| pca is the default (built in preprocess.py). this adds scvi, a standard deep |
| single-cell latent space. to keep cell ordering identical to the cached pca |
| embedding, we reproduce the exact filtered cell set from preprocess.py (same |
| seed, hvgs, min-cells filter), train scvi on the raw counts, and save |
| emb_scvi.npy aligned 1:1. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import os |
|
|
| import numpy as np |
| import scipy.sparse as sp |
|
|
| from src.data.preprocess import DATASETS, parse_perturbation |
| from src.utils.common import load_json, set_seed |
|
|
|
|
| def load_filtered_adata(dataset: str, n_hvg: int = 2000): |
| """reproduce preprocess.py's filtered anndata (raw counts kept in .x), in the |
| same cell order as the cached pca embedding.""" |
| import scanpy as sc |
|
|
| cfg = dict(DATASETS[dataset]) |
| raw = cfg["raw"] |
| seed = 0 |
| set_seed(seed) |
| adata = sc.read_h5ad(raw) |
| max_cells = cfg.get("max_cells") |
| if max_cells is not None and adata.n_obs > max_cells: |
| rng = np.random.default_rng(seed) |
| idx = np.sort(rng.choice(adata.n_obs, size=max_cells, replace=False)) |
| adata = adata[idx].copy() |
| adata.layers["counts"] = adata.X.copy() |
| |
| sc.pp.normalize_total(adata, target_sum=1e4) |
| sc.pp.log1p(adata) |
| sc.pp.highly_variable_genes(adata, n_top_genes=n_hvg, flavor="seurat") |
| adata = adata[:, adata.var["highly_variable"]].copy() |
| |
| control_label = cfg["control_label"]; sep = cfg.get("sep", "_") |
| pert = adata.obs[cfg.get("pert_col", "perturbation")].astype(str).values |
| npg = np.array([len(parse_perturbation(p, control_label, sep)) for p in pert]) |
| is_ctrl = npg == 0 |
| import pandas as pd |
| vc = pd.Series(pert[~is_ctrl]).value_counts() |
| keep = set(vc[vc >= 20].index) |
| mask = is_ctrl | np.isin(pert, list(keep)) |
| adata = adata[mask].copy() |
| |
| adata.X = adata.layers["counts"] |
| return adata |
|
|
|
|
| def build_scvi(dataset: str, n_latent: int = 50, max_epochs: int = 80, gpu: int = 0): |
| import scvi |
| import torch |
|
|
| cache = os.path.join("data/processed", dataset) |
| meta = load_json(os.path.join(cache, "meta.json")) |
| adata = load_filtered_adata(dataset) |
| n_cached = np.load(os.path.join(cache, "pca_emb.npy")).shape[0] |
| assert adata.n_obs == n_cached, f"cell count mismatch {adata.n_obs} vs cached {n_cached}" |
|
|
| scvi.settings.seed = 0 |
| scvi.model.SCVI.setup_anndata(adata, layer=None) |
| model = scvi.model.SCVI(adata, n_latent=n_latent) |
| model.to_device(gpu if torch.cuda.is_available() else "cpu") |
| model.train(max_epochs=max_epochs, early_stopping=True, |
| plan_kwargs={"lr": 1e-3}, enable_progress_bar=False) |
| z = model.get_latent_representation().astype(np.float32) |
| np.save(os.path.join(cache, "emb_scvi.npy"), z) |
| print(f"[{dataset}] scVI latent {z.shape} -> {cache}/emb_scvi.npy") |
| return z |
|
|
|
|
| if __name__ == "__main__": |
| ap = argparse.ArgumentParser() |
| ap.add_argument("dataset") |
| ap.add_argument("--gpu", type=int, default=0) |
| ap.add_argument("--n-latent", type=int, default=50) |
| ap.add_argument("--max-epochs", type=int, default=80) |
| args = ap.parse_args() |
| build_scvi(args.dataset, n_latent=args.n_latent, max_epochs=args.max_epochs, gpu=args.gpu) |
|
|