"""alternative cell-state encoders phi (Table 16). pca is the default (built in preprocess.py). this adds scvi, a standard deep single-cell latent space. to keep cell ordering identical to the cached pca embedding, we reproduce the exact filtered cell set from preprocess.py (same seed, hvgs, min-cells filter), train scvi on the raw counts, and save emb_scvi.npy aligned 1:1. """ from __future__ import annotations import argparse import os import numpy as np import scipy.sparse as sp from src.data.preprocess import DATASETS, parse_perturbation from src.utils.common import load_json, set_seed def load_filtered_adata(dataset: str, n_hvg: int = 2000): """reproduce preprocess.py's filtered anndata (raw counts kept in .x), in the same cell order as the cached pca embedding.""" import scanpy as sc cfg = dict(DATASETS[dataset]) raw = cfg["raw"] seed = 0 set_seed(seed) adata = sc.read_h5ad(raw) max_cells = cfg.get("max_cells") if max_cells is not None and adata.n_obs > max_cells: rng = np.random.default_rng(seed) idx = np.sort(rng.choice(adata.n_obs, size=max_cells, replace=False)) adata = adata[idx].copy() adata.layers["counts"] = adata.X.copy() # normalize only to pick the same hvgs as preprocess sc.pp.normalize_total(adata, target_sum=1e4) sc.pp.log1p(adata) sc.pp.highly_variable_genes(adata, n_top_genes=n_hvg, flavor="seurat") adata = adata[:, adata.var["highly_variable"]].copy() # min-cells filter identical to preprocess (sep/pert_col use preprocess defaults) control_label = cfg["control_label"]; sep = cfg.get("sep", "_") pert = adata.obs[cfg.get("pert_col", "perturbation")].astype(str).values npg = np.array([len(parse_perturbation(p, control_label, sep)) for p in pert]) is_ctrl = npg == 0 import pandas as pd vc = pd.Series(pert[~is_ctrl]).value_counts() keep = set(vc[vc >= 20].index) mask = is_ctrl | np.isin(pert, list(keep)) adata = adata[mask].copy() # restore raw counts into x adata.X = adata.layers["counts"] return adata def build_scvi(dataset: str, n_latent: int = 50, max_epochs: int = 80, gpu: int = 0): import scvi import torch cache = os.path.join("data/processed", dataset) meta = load_json(os.path.join(cache, "meta.json")) adata = load_filtered_adata(dataset) n_cached = np.load(os.path.join(cache, "pca_emb.npy")).shape[0] assert adata.n_obs == n_cached, f"cell count mismatch {adata.n_obs} vs cached {n_cached}" scvi.settings.seed = 0 scvi.model.SCVI.setup_anndata(adata, layer=None) model = scvi.model.SCVI(adata, n_latent=n_latent) model.to_device(gpu if torch.cuda.is_available() else "cpu") model.train(max_epochs=max_epochs, early_stopping=True, plan_kwargs={"lr": 1e-3}, enable_progress_bar=False) z = model.get_latent_representation().astype(np.float32) np.save(os.path.join(cache, "emb_scvi.npy"), z) print(f"[{dataset}] scVI latent {z.shape} -> {cache}/emb_scvi.npy") return z if __name__ == "__main__": ap = argparse.ArgumentParser() ap.add_argument("dataset") ap.add_argument("--gpu", type=int, default=0) ap.add_argument("--n-latent", type=int, default=50) ap.add_argument("--max-epochs", type=int, default=80) args = ap.parse_args() build_scvi(args.dataset, n_latent=args.n_latent, max_epochs=args.max_epochs, gpu=args.gpu)