|
|
import numpy as np |
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
from tf_data_process import DataProcessorForPad |
|
|
|
|
|
|
|
|
class DatasetMultiPad(Dataset): |
|
|
def __init__(self, *args, **kwargs): |
|
|
self.return_index = False |
|
|
self.return_batch_label = kwargs['return_batch_label'] |
|
|
|
|
|
self.cell_type_map = kwargs['cell_type_map'] |
|
|
self.cell_type_col = kwargs['cell_type_col'] |
|
|
if self.return_batch_label: |
|
|
self.batch_label_map = kwargs['batch_label_map'] |
|
|
self.batch_label_col = kwargs['batch_label_col'] |
|
|
else: |
|
|
self.batch_label_map = None |
|
|
self.batch_label_col = None |
|
|
self.data_args = kwargs |
|
|
self.data_processor = DataProcessorForPad(**self.data_args) |
|
|
self.adata_list = args |
|
|
self.cumsum_lengths = np.cumsum([adata.shape[0] for adata in self.adata_list]) |
|
|
|
|
|
def __len__(self): |
|
|
return self.cumsum_lengths[-1] |
|
|
|
|
|
def process_data(self, sparse_matrix, file_index): |
|
|
value_list, chromosome_list, hg38_start, hg38_end = \ |
|
|
self.data_processor.process( |
|
|
value_data=sparse_matrix.toarray()[0].tolist(), |
|
|
chromosome=self.adata_list[file_index].var["#Chromosome"].tolist(), |
|
|
hg38_start=self.adata_list[file_index].var["hg38_Start"].tolist(), |
|
|
hg38_end=self.adata_list[file_index].var["hg38_End"].tolist() |
|
|
) |
|
|
return value_list, chromosome_list, hg38_start, hg38_end |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
file_index = np.searchsorted(self.cumsum_lengths, idx, side="right") |
|
|
if file_index == 0: |
|
|
row_idx = idx |
|
|
else: |
|
|
row_idx = idx - self.cumsum_lengths[file_index - 1] |
|
|
x_idx = self.adata_list[file_index][row_idx].X |
|
|
value_list, chromosome_list, hg38_start, hg38_end = self.process_data( |
|
|
x_idx, file_index) |
|
|
cell_type = self.adata_list[file_index].obs.iloc[row_idx][self.cell_type_col] |
|
|
if cell_type not in self.cell_type_map: |
|
|
cell_type = 'Astrocyte 1' |
|
|
res = [ |
|
|
torch.tensor(value_list), |
|
|
torch.tensor(chromosome_list), |
|
|
torch.tensor(hg38_start), |
|
|
torch.tensor(hg38_end), |
|
|
torch.tensor(self.cell_type_map[cell_type]) |
|
|
] |
|
|
if self.return_index: |
|
|
res.insert(0, torch.tensor(idx)) |
|
|
if self.return_batch_label: |
|
|
batch_label = self.adata_list[file_index].obs.iloc[row_idx][self.batch_label_col] |
|
|
res.append(torch.tensor(self.batch_label_map[batch_label])) |
|
|
return res |
|
|
|
|
|
def __del__(self): |
|
|
|
|
|
[data.file.close() for data in self.adata_list] |
|
|
|