| import torch | |
| from torch.utils.data import DataLoader | |
| from data.dataset import load_ECG_Dataset | |
| import os | |
| import numpy as np | |
| import math | |
| def get_class_weight(labels_dict): | |
| total = sum(labels_dict.values()) | |
| max_num = max(labels_dict.values()) | |
| mu = 1.0 / (total / max_num) | |
| class_weight = dict() | |
| for key, value in labels_dict.items(): | |
| score = math.log(mu * total / float(value)) | |
| class_weight[key] = score if score > 1.0 else 1.0 | |
| return class_weight | |
| class ECGDataloader: | |
| testdata_path: str | |
| traindata_path: str | |
| valdata_path: str | |
| def __init__(self, data_path, data_type, hparams): | |
| self.traindata_path = os.path.join(data_path, data_type, f'train.pt') | |
| self.testdata_path = os.path.join(data_path, data_type, f'test.pt') | |
| self.validdata_path = os.path.join(data_path, data_type, f'val.pt') | |
| self.batch_size = hparams['batch_size'] | |
| def train_dataloader(self): | |
| train_dataset = torch.load(self.traindata_path) | |
| train_dataset = load_ECG_Dataset(train_dataset) | |
| cw = train_dataset.y_data.numpy().tolist() | |
| cw_dict = {} | |
| for i in range(len(np.unique(train_dataset.y_data.numpy()))): | |
| cw_dict[i] = cw.count(i) | |
| train_loader = DataLoader(dataset=train_dataset, batch_size=self.batch_size, shuffle=True, | |
| drop_last=True, num_workers=4) | |
| return train_loader, get_class_weight(cw_dict) | |
| def test_dataloader(self): | |
| test_dataset = torch.load(self.testdata_path) | |
| test_dataset = load_ECG_Dataset(test_dataset) | |
| test_loader = DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=False, | |
| drop_last=False, num_workers=4) | |
| return test_loader | |
| def valid_dataloader(self): | |
| valid_dataset = torch.load(self.validdata_path) | |
| valid_dataset = load_ECG_Dataset(valid_dataset) | |
| valid_loader = DataLoader(dataset=valid_dataset, batch_size=self.batch_size, shuffle=True, | |
| drop_last=False, num_workers=4) | |
| return valid_loader | |