Spaces:
Sleeping
Sleeping
File size: 4,344 Bytes
e59f78e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import torch
import os
import scanpy as sc
import numpy as np
import json
import scipy
from celldreamer.data.download import collect_data
from celldreamer.data.process import process
from celldreamer.data.plots import validate
from celldreamer.data.class_celldreamerDataset import CellDreamerDataset
def create_data():
collect_data()
process()
validate()
dtr = CellDreamerDataset(pairs_path="celldreamer/data/processed/train_pairs.npy")
dv = CellDreamerDataset(pairs_path="celldreamer/data/processed/val_pairs.npy")
dt = CellDreamerDataset(pairs_path="celldreamer/data/processed/test_pairs.npy")
os.makedirs("celldreamer/data/datasets", exist_ok=True)
torch.save(dtr, "celldreamer/data/datasets/train.pt")
torch.save(dv, "celldreamer/data/datasets/val.pt")
torch.save(dt, "celldreamer/data/datasets/test.pt")
def get_data_stats(n_background_points=5000):
data_path = "celldreamer/data/processed/cleaned.h5ad"
adata = sc.read(data_path)
if adata.raw is not None:
raw_subset = adata.raw[:, adata.var_names]
X_source = raw_subset.X
if scipy.sparse.issparse(X_source):
X_source = X_source.toarray()
mean = np.mean(X_source, axis=0)
std = np.std(X_source, axis=0)
else:
X_source = adata.X
if scipy.sparse.issparse(X_source):
X_source = X_source.toarray()
mean = np.mean(X_source, axis=0)
std = np.std(X_source, axis=0)
std[std == 0] = 1.0
stats = {
"mean": torch.tensor(mean),
"std": torch.tensor(std)
}
os.makedirs("celldreamer/data/stats", exist_ok=True)
torch.save(stats, "celldreamer/data/stats/stats.pt")
# create useful data for react application
output_dir="celldreamer/data/artifacts"
os.makedirs(output_dir, exist_ok=True)
# create index to gene name map
gene_names = adata.var_names.tolist()
gene_indices = {name: i for i, name in enumerate(gene_names)}
gene_map_payload = {
"gene_names": gene_names, # dropdown
"indices": gene_indices # model gene perterbation
}
with open(f"{output_dir}/gene_map.json", "w") as f:
json.dump(gene_map_payload, f)
# get random 5000 coords for showing cell type clusters
if 'X_umap' not in adata.obsm:
if 'neighbors' not in adata.uns:
sc.pp.neighbors(adata)
sc.tl.umap(adata)
total_cells = adata.shape[0]
if total_cells > n_background_points:
indices = np.random.choice(total_cells, n_background_points, replace=False)
indices.sort()
else:
indices = np.arange(total_cells)
umap_coords = adata.obsm['X_umap']
background_payload = []
has_celltype = 'celltype' in adata.obs
for idx in indices:
idx = int(idx)
point = {
"id": idx,
"x": round(float(umap_coords[idx, 0]), 3),
"y": round(float(umap_coords[idx, 1]), 3),
"t": round(float(adata.obs['dpt_pseudotime'].iloc[idx]), 3)
}
if has_celltype:
point["label"] = str(adata.obs['celltype'].iloc[idx])
background_payload.append(point)
with open(f"{output_dir}/background_map.json", "w") as f:
json.dump(background_payload, f)
# get mean ductal cell that can be used as a starting point for people to perterb
stem_mask = adata.obs['celltype'].str.contains('ductal', case=False)
if stem_mask.sum() == 0:
stem_data = adata.X
else:
stem_data = adata.X[stem_mask]
if scipy.sparse.issparse(stem_data):
mean_stem_z_score = stem_data.mean(axis=0).A1
else:
mean_stem_z_score = stem_data.mean(axis=0)
# Un-scale the data so the UI gets usable numbers (not -1.7)
usable_stem_vector = (mean_stem_z_score * std) + mean
usable_stem_vector = np.maximum(usable_stem_vector, 0.0)
with open(f"{output_dir}/default_stem_cell.json", "w") as f:
json.dump(usable_stem_vector.tolist(), f)
if __name__ == "__main__":
create_data()
get_data_stats() |