PIVOT / src /data /embeddings.py
bryan7264's picture
pivot: code + trained checkpoints (norman, replogle k562)
3b4941f verified
Raw
History Blame
3.44 kB
"""alternative cell-state encoders phi (Table 16).
pca is the default (built in preprocess.py). this adds scvi, a standard deep
single-cell latent space. to keep cell ordering identical to the cached pca
embedding, we reproduce the exact filtered cell set from preprocess.py (same
seed, hvgs, min-cells filter), train scvi on the raw counts, and save
emb_scvi.npy aligned 1:1.
"""
from __future__ import annotations
import argparse
import os
import numpy as np
import scipy.sparse as sp
from src.data.preprocess import DATASETS, parse_perturbation
from src.utils.common import load_json, set_seed
def load_filtered_adata(dataset: str, n_hvg: int = 2000):
"""reproduce preprocess.py's filtered anndata (raw counts kept in .x), in the
same cell order as the cached pca embedding."""
import scanpy as sc
cfg = dict(DATASETS[dataset])
raw = cfg["raw"]
seed = 0
set_seed(seed)
adata = sc.read_h5ad(raw)
max_cells = cfg.get("max_cells")
if max_cells is not None and adata.n_obs > max_cells:
rng = np.random.default_rng(seed)
idx = np.sort(rng.choice(adata.n_obs, size=max_cells, replace=False))
adata = adata[idx].copy()
adata.layers["counts"] = adata.X.copy()
# normalize only to pick the same hvgs as preprocess
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=n_hvg, flavor="seurat")
adata = adata[:, adata.var["highly_variable"]].copy()
# min-cells filter identical to preprocess (sep/pert_col use preprocess defaults)
control_label = cfg["control_label"]; sep = cfg.get("sep", "_")
pert = adata.obs[cfg.get("pert_col", "perturbation")].astype(str).values
npg = np.array([len(parse_perturbation(p, control_label, sep)) for p in pert])
is_ctrl = npg == 0
import pandas as pd
vc = pd.Series(pert[~is_ctrl]).value_counts()
keep = set(vc[vc >= 20].index)
mask = is_ctrl | np.isin(pert, list(keep))
adata = adata[mask].copy()
# restore raw counts into x
adata.X = adata.layers["counts"]
return adata
def build_scvi(dataset: str, n_latent: int = 50, max_epochs: int = 80, gpu: int = 0):
import scvi
import torch
cache = os.path.join("data/processed", dataset)
meta = load_json(os.path.join(cache, "meta.json"))
adata = load_filtered_adata(dataset)
n_cached = np.load(os.path.join(cache, "pca_emb.npy")).shape[0]
assert adata.n_obs == n_cached, f"cell count mismatch {adata.n_obs} vs cached {n_cached}"
scvi.settings.seed = 0
scvi.model.SCVI.setup_anndata(adata, layer=None)
model = scvi.model.SCVI(adata, n_latent=n_latent)
model.to_device(gpu if torch.cuda.is_available() else "cpu")
model.train(max_epochs=max_epochs, early_stopping=True,
plan_kwargs={"lr": 1e-3}, enable_progress_bar=False)
z = model.get_latent_representation().astype(np.float32)
np.save(os.path.join(cache, "emb_scvi.npy"), z)
print(f"[{dataset}] scVI latent {z.shape} -> {cache}/emb_scvi.npy")
return z
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("dataset")
ap.add_argument("--gpu", type=int, default=0)
ap.add_argument("--n-latent", type=int, default=50)
ap.add_argument("--max-epochs", type=int, default=80)
args = ap.parse_args()
build_scvi(args.dataset, n_latent=args.n_latent, max_epochs=args.max_epochs, gpu=args.gpu)