File size: 2,738 Bytes
534e5a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
import torch
from torch.utils.data import Dataset

from tf_data_process import DataProcessorForPad


class DatasetMultiPad(Dataset):
    def __init__(self, *args, **kwargs):
        self.return_index = False
        self.return_batch_label = kwargs['return_batch_label']

        self.cell_type_map = kwargs['cell_type_map']
        self.cell_type_col = kwargs['cell_type_col']
        if self.return_batch_label:
            self.batch_label_map = kwargs['batch_label_map']
            self.batch_label_col = kwargs['batch_label_col']
        else:
            self.batch_label_map = None
            self.batch_label_col = None
        self.data_args = kwargs
        self.data_processor = DataProcessorForPad(**self.data_args)
        self.adata_list = args
        self.cumsum_lengths = np.cumsum([adata.shape[0] for adata in self.adata_list])

    def __len__(self):
        return self.cumsum_lengths[-1]

    def process_data(self, sparse_matrix, file_index):
        value_list, chromosome_list, hg38_start, hg38_end = \
            self.data_processor.process(
                value_data=sparse_matrix.toarray()[0].tolist(),
                chromosome=self.adata_list[file_index].var["#Chromosome"].tolist(),
                hg38_start=self.adata_list[file_index].var["hg38_Start"].tolist(),
                hg38_end=self.adata_list[file_index].var["hg38_End"].tolist()
            )
        return value_list, chromosome_list, hg38_start, hg38_end

    def __getitem__(self, idx):
        file_index = np.searchsorted(self.cumsum_lengths, idx, side="right")
        if file_index == 0:
            row_idx = idx
        else:
            row_idx = idx - self.cumsum_lengths[file_index - 1]
        x_idx = self.adata_list[file_index][row_idx].X
        value_list, chromosome_list, hg38_start, hg38_end = self.process_data(
            x_idx, file_index)
        cell_type = self.adata_list[file_index].obs.iloc[row_idx][self.cell_type_col]
        if cell_type not in self.cell_type_map:
            cell_type = 'Astrocyte 1'
        res = [
            torch.tensor(value_list),
            torch.tensor(chromosome_list),
            torch.tensor(hg38_start),
            torch.tensor(hg38_end),
            torch.tensor(self.cell_type_map[cell_type])
        ]
        if self.return_index:
            res.insert(0, torch.tensor(idx))
        if self.return_batch_label:
            batch_label = self.adata_list[file_index].obs.iloc[row_idx][self.batch_label_col]
            res.append(torch.tensor(self.batch_label_map[batch_label]))
        return res

    def __del__(self):
        # Properly close the backed file when the dataset is deleted
        [data.file.close() for data in self.adata_list]