File size: 645 Bytes
dfd1909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch.utils.data.dataset as dataset
import pickle

class DataSet(dataset.Dataset):

    def __init__(self, config: dict):
        data_path_list = config["data_path_list"]
        self.data_set_type = config["subset"]
        self.files = []
        for fname in data_path_list:
            self.files.append(self.read_data(fname))
    
    def read_data(self, data_path):
        with open(data_path, 'rb') as pickle_file:
            file_data_dict = pickle.load(pickle_file)
        return file_data_dict
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, index):
        return self.files[index]