import torch import os import scanpy as sc import numpy as np import json import scipy from celldreamer.data.download import collect_data from celldreamer.data.process import process from celldreamer.data.plots import validate from celldreamer.data.class_celldreamerDataset import CellDreamerDataset def create_data(): collect_data() process() validate() dtr = CellDreamerDataset(pairs_path="celldreamer/data/processed/train_pairs.npy") dv = CellDreamerDataset(pairs_path="celldreamer/data/processed/val_pairs.npy") dt = CellDreamerDataset(pairs_path="celldreamer/data/processed/test_pairs.npy") os.makedirs("celldreamer/data/datasets", exist_ok=True) torch.save(dtr, "celldreamer/data/datasets/train.pt") torch.save(dv, "celldreamer/data/datasets/val.pt") torch.save(dt, "celldreamer/data/datasets/test.pt") def get_data_stats(n_background_points=5000): data_path = "celldreamer/data/processed/cleaned.h5ad" adata = sc.read(data_path) if adata.raw is not None: raw_subset = adata.raw[:, adata.var_names] X_source = raw_subset.X if scipy.sparse.issparse(X_source): X_source = X_source.toarray() mean = np.mean(X_source, axis=0) std = np.std(X_source, axis=0) else: X_source = adata.X if scipy.sparse.issparse(X_source): X_source = X_source.toarray() mean = np.mean(X_source, axis=0) std = np.std(X_source, axis=0) std[std == 0] = 1.0 stats = { "mean": torch.tensor(mean), "std": torch.tensor(std) } os.makedirs("celldreamer/data/stats", exist_ok=True) torch.save(stats, "celldreamer/data/stats/stats.pt") # create useful data for react application output_dir="celldreamer/data/artifacts" os.makedirs(output_dir, exist_ok=True) # create index to gene name map gene_names = adata.var_names.tolist() gene_indices = {name: i for i, name in enumerate(gene_names)} gene_map_payload = { "gene_names": gene_names, # dropdown "indices": gene_indices # model gene perterbation } with open(f"{output_dir}/gene_map.json", "w") as f: json.dump(gene_map_payload, f) # get random 5000 coords for showing cell type clusters if 'X_umap' not in adata.obsm: if 'neighbors' not in adata.uns: sc.pp.neighbors(adata) sc.tl.umap(adata) total_cells = adata.shape[0] if total_cells > n_background_points: indices = np.random.choice(total_cells, n_background_points, replace=False) indices.sort() else: indices = np.arange(total_cells) umap_coords = adata.obsm['X_umap'] background_payload = [] has_celltype = 'celltype' in adata.obs for idx in indices: idx = int(idx) point = { "id": idx, "x": round(float(umap_coords[idx, 0]), 3), "y": round(float(umap_coords[idx, 1]), 3), "t": round(float(adata.obs['dpt_pseudotime'].iloc[idx]), 3) } if has_celltype: point["label"] = str(adata.obs['celltype'].iloc[idx]) background_payload.append(point) with open(f"{output_dir}/background_map.json", "w") as f: json.dump(background_payload, f) # get mean ductal cell that can be used as a starting point for people to perterb stem_mask = adata.obs['celltype'].str.contains('ductal', case=False) if stem_mask.sum() == 0: stem_data = adata.X else: stem_data = adata.X[stem_mask] if scipy.sparse.issparse(stem_data): mean_stem_z_score = stem_data.mean(axis=0).A1 else: mean_stem_z_score = stem_data.mean(axis=0) # Un-scale the data so the UI gets usable numbers (not -1.7) usable_stem_vector = (mean_stem_z_score * std) + mean usable_stem_vector = np.maximum(usable_stem_vector, 0.0) with open(f"{output_dir}/default_stem_cell.json", "w") as f: json.dump(usable_stem_vector.tolist(), f) if __name__ == "__main__": create_data() get_data_stats()