import torch from torch.utils.data import Dataset import scanpy as sc import numpy as np import scipy.sparse class CellDreamerDataset(Dataset): def __init__( self, data_path="celldreamer/data/processed/cleaned.h5ad", pairs_path="celldreamer/data/processed/train_pairs.npy", normalize=False ): adata = sc.read(data_path) data_min = adata.X.min() data_max = adata.X.max() print(f"min: {data_min:.4f}, max: {data_max:.4f}") if normalize: sc.pp.normalize_total(adata, target_sum=1e4) sc.pp.log1p(adata) self.pairs = np.load(pairs_path) if scipy.sparse.issparse(adata.X): self.data = torch.tensor(adata.X.toarray(), dtype=torch.float32) else: self.data = torch.tensor(adata.X, dtype=torch.float32) self.times = torch.tensor(adata.obs['dpt_pseudotime'].values, dtype=torch.float32) def __len__(self): return len(self.pairs) def __getitem__(self, idx): curr_idx, next_idx = self.pairs[idx] x_t = self.data[curr_idx] x_next = self.data[next_idx] return { "x_t": x_t, "x_next": x_next, "delta": x_next - x_t, "dt": self.times[next_idx] - self.times[curr_idx] }