Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data import Dataset | |
| import scanpy as sc | |
| import numpy as np | |
| import scipy.sparse | |
| class CellDreamerDataset(Dataset): | |
| def __init__( | |
| self, | |
| data_path="celldreamer/data/processed/cleaned.h5ad", | |
| pairs_path="celldreamer/data/processed/train_pairs.npy", | |
| normalize=False | |
| ): | |
| adata = sc.read(data_path) | |
| data_min = adata.X.min() | |
| data_max = adata.X.max() | |
| print(f"min: {data_min:.4f}, max: {data_max:.4f}") | |
| if normalize: | |
| sc.pp.normalize_total(adata, target_sum=1e4) | |
| sc.pp.log1p(adata) | |
| self.pairs = np.load(pairs_path) | |
| if scipy.sparse.issparse(adata.X): | |
| self.data = torch.tensor(adata.X.toarray(), dtype=torch.float32) | |
| else: | |
| self.data = torch.tensor(adata.X, dtype=torch.float32) | |
| self.times = torch.tensor(adata.obs['dpt_pseudotime'].values, dtype=torch.float32) | |
| def __len__(self): | |
| return len(self.pairs) | |
| def __getitem__(self, idx): | |
| curr_idx, next_idx = self.pairs[idx] | |
| x_t = self.data[curr_idx] | |
| x_next = self.data[next_idx] | |
| return { | |
| "x_t": x_t, | |
| "x_next": x_next, | |
| "delta": x_next - x_t, | |
| "dt": self.times[next_idx] - self.times[curr_idx] | |
| } |