Spaces:
Sleeping
Sleeping
| import torch | |
| import scanpy as sc | |
| import os | |
| from celldreamer.models.class_celldreamer import ClassCellDreamer | |
| from celldreamer.models import load_config | |
| def solve_projector(): | |
| # loading stuff | |
| adata = sc.read("celldreamer/data/processed/cleaned.h5ad") | |
| stats = torch.load("celldreamer/data/stats/stats.pt", weights_only=False) | |
| args = load_config("celldreamer/config/evaluate_config.yml") | |
| args.device = "cpu" | |
| wrapper = ClassCellDreamer(args) | |
| wrapper.model.load_state_dict(torch.load("celldreamer/checkpoints/best.pth", map_location="cpu", weights_only=True)) | |
| wrapper.model.eval() | |
| if 'X_umap' not in adata.obsm: | |
| sc.pp.neighbors(adata) | |
| sc.tl.umap(adata) | |
| Y_umap = torch.tensor(adata.obsm['X_umap'], dtype=torch.float32) | |
| # raw otherwise just x | |
| if adata.raw is not None: | |
| data = adata.raw[:, adata.var_names].X | |
| else: | |
| data = adata.X | |
| if hasattr(data, "toarray"): | |
| data = data.toarray() | |
| #XTXb = XTy: | |
| x_in = torch.tensor(data, dtype=torch.float32) | |
| x_in = torch.log1p(x_in) | |
| x_in = (x_in - stats["mean"]) / stats["std"] | |
| x_in = torch.clamp(x_in, max=10.0) | |
| with torch.no_grad(): | |
| Z_latent, _ = wrapper.model.encoder(x_in) | |
| solution = torch.linalg.lstsq(Z_latent, Y_umap).solution | |
| state_dict = { | |
| "weight": solution.T, | |
| "bias": torch.zeros(2) # ignore | |
| } | |
| os.makedirs("celldreamer/data/artifacts", exist_ok=True) | |
| torch.save(state_dict, "celldreamer/data/artifacts/projector_weights.pth") | |
| if __name__ == "__main__": | |
| solve_projector() |