File size: 1,710 Bytes
b9bac12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np
import torch.utils.data as data


class DijkprofileDataset(data.Dataset):
    """Pytorch custom dataset class to use with the pytorch dataloader."""
    
    def __init__(self, profile_dict, partition, custom_scaler_path=None):
        """Dijkprofile Dataset, provides profiles and labels to pytorch model.

        Args:
            profile_dict (dict): dict containing the profiles and labels
            partition (list): list used to split the dataset into train and test
            sets. list contains ids to use for this dataset, format is
            as returned by sklearn.model_selection.train_test_split
        """
        self.data_dict = profile_dict
        self.list_IDs = partition

        print("scaler in dataset class is depracated and moved to preprocessing")
        # load scaler
        # if custom_scaler_path:
        #     self.scaler = joblib.load(custom_scaler_path)
        # else:
        #     self.scaler = joblib.load(os.path.join(dir, config.SCALER_PATH))
        # # rescale all profiles profiles
        # for key in profile_dict.keys():
            # profile_dict[key]['profile'] = self.scaler.transform(
            #     profile_dict[key]['profile'].reshape(-1, 1)).reshape(-1)
            # profile_dict[key]['profile'] = profile_dict[key]['profile'] / 10
    
    def __len__(self):
        return len(self.list_IDs)
    
    def __getitem__(self, index):
        id = self.list_IDs[index]
        X = self.data_dict[id]['profile'].reshape(1,-1).astype(np.float32)
        y = self.data_dict[id]['label'].reshape(1,-1)
        return X, y
    
    def __str__(self):
        return "<Dijkprofile dataset: datapoints={}>".format(len(self.list_IDs))