Spaces:
Sleeping
Sleeping
File size: 1,469 Bytes
e59f78e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | import torch
from torch.utils.data import Dataset
import scanpy as sc
import numpy as np
import scipy.sparse
class CellDreamerDataset(Dataset):
def __init__(
self,
data_path="celldreamer/data/processed/cleaned.h5ad",
pairs_path="celldreamer/data/processed/train_pairs.npy",
normalize=False
):
adata = sc.read(data_path)
data_min = adata.X.min()
data_max = adata.X.max()
print(f"min: {data_min:.4f}, max: {data_max:.4f}")
if normalize:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
self.pairs = np.load(pairs_path)
if scipy.sparse.issparse(adata.X):
self.data = torch.tensor(adata.X.toarray(), dtype=torch.float32)
else:
self.data = torch.tensor(adata.X, dtype=torch.float32)
self.times = torch.tensor(adata.obs['dpt_pseudotime'].values, dtype=torch.float32)
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
curr_idx, next_idx = self.pairs[idx]
x_t = self.data[curr_idx]
x_next = self.data[next_idx]
return {
"x_t": x_t,
"x_next": x_next,
"delta": x_next - x_t,
"dt": self.times[next_idx] - self.times[curr_idx]
} |