import torch.utils.data.dataset as dataset import pickle class DataSet(dataset.Dataset): def __init__(self, config: dict): data_path_list = config["data_path_list"] self.data_set_type = config["subset"] self.files = [] for fname in data_path_list: self.files.append(self.read_data(fname)) def read_data(self, data_path): with open(data_path, 'rb') as pickle_file: file_data_dict = pickle.load(pickle_file) return file_data_dict def __len__(self): return len(self.files) def __getitem__(self, index): return self.files[index]