File size: 1,469 Bytes
e59f78e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from torch.utils.data import Dataset
import scanpy as sc
import numpy as np
import scipy.sparse

class CellDreamerDataset(Dataset):
     def __init__(
          self, 
          data_path="celldreamer/data/processed/cleaned.h5ad",
          pairs_path="celldreamer/data/processed/train_pairs.npy",
          normalize=False
     ):
        
          adata = sc.read(data_path)
          
          data_min = adata.X.min()
          data_max = adata.X.max()
          print(f"min: {data_min:.4f}, max: {data_max:.4f}")
          
          if normalize:
               sc.pp.normalize_total(adata, target_sum=1e4)
               sc.pp.log1p(adata)
               
          self.pairs = np.load(pairs_path)
        
          if scipy.sparse.issparse(adata.X):
               self.data = torch.tensor(adata.X.toarray(), dtype=torch.float32)
          else:
               self.data = torch.tensor(adata.X, dtype=torch.float32)
     
          self.times = torch.tensor(adata.obs['dpt_pseudotime'].values, dtype=torch.float32)

     def __len__(self):
          return len(self.pairs)

     def __getitem__(self, idx):
          curr_idx, next_idx = self.pairs[idx]
        
          x_t = self.data[curr_idx]
          x_next = self.data[next_idx]
          
          return {
               "x_t": x_t,
               "x_next": x_next,
               "delta": x_next - x_t, 
               "dt": self.times[next_idx] - self.times[curr_idx] 
          }