CellDreamer-API / celldreamer /models /least_squares_umap.py
RobroKools's picture
Upload 44 files
e59f78e verified
import torch
import scanpy as sc
import os
from celldreamer.models.class_celldreamer import ClassCellDreamer
from celldreamer.models import load_config
def solve_projector():
# loading stuff
adata = sc.read("celldreamer/data/processed/cleaned.h5ad")
stats = torch.load("celldreamer/data/stats/stats.pt", weights_only=False)
args = load_config("celldreamer/config/evaluate_config.yml")
args.device = "cpu"
wrapper = ClassCellDreamer(args)
wrapper.model.load_state_dict(torch.load("celldreamer/checkpoints/best.pth", map_location="cpu", weights_only=True))
wrapper.model.eval()
if 'X_umap' not in adata.obsm:
sc.pp.neighbors(adata)
sc.tl.umap(adata)
Y_umap = torch.tensor(adata.obsm['X_umap'], dtype=torch.float32)
# raw otherwise just x
if adata.raw is not None:
data = adata.raw[:, adata.var_names].X
else:
data = adata.X
if hasattr(data, "toarray"):
data = data.toarray()
#XTXb = XTy:
x_in = torch.tensor(data, dtype=torch.float32)
x_in = torch.log1p(x_in)
x_in = (x_in - stats["mean"]) / stats["std"]
x_in = torch.clamp(x_in, max=10.0)
with torch.no_grad():
Z_latent, _ = wrapper.model.encoder(x_in)
solution = torch.linalg.lstsq(Z_latent, Y_umap).solution
state_dict = {
"weight": solution.T,
"bias": torch.zeros(2) # ignore
}
os.makedirs("celldreamer/data/artifacts", exist_ok=True)
torch.save(state_dict, "celldreamer/data/artifacts/projector_weights.pth")
if __name__ == "__main__":
solve_projector()