Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.utils.data import Dataset | |
| from DatasetLoader import AugmentWAV, loadWAV | |
| import os | |
| import numpy as np | |
| import random | |
| class TrainDataset(Dataset): | |
| def __init__(self, train_list, train_path, augment, musan_path, rir_path, max_frames,): | |
| self.train_list = train_list | |
| self.max_frames = max_frames | |
| self.augment_wav = AugmentWAV(musan_path=musan_path, rir_path=rir_path, max_frames=max_frames) | |
| self.augment = augment | |
| self.musan_path = musan_path | |
| self.rir_path = rir_path | |
| with open(train_list) as dataset_file: | |
| lines = dataset_file.readlines() | |
| dictkeys = list(set([x.split()[0] for x in lines])) | |
| dictkeys.sort() | |
| dictkeys = {key: ii for ii, key in enumerate(dictkeys)} | |
| np.random.seed(100) | |
| np.random.shuffle(lines) | |
| self.data_list = [] | |
| self.data_label = [] | |
| for lidx, line in enumerate(lines): | |
| data = line.strip().split() | |
| speaker_label = dictkeys[data[0]] | |
| filename = os.path.join(train_path, data[1]) | |
| self.data_list.append(filename) | |
| self.data_label.append(speaker_label) | |
| def __getitem__(self, index): | |
| audio = loadWAV(self.data_list[index], self.max_frames, evalmode=False) | |
| if self.augment: | |
| augtype = random.randint(0, 4) # 包括0,4 | |
| if augtype == 1: | |
| audio = self.augment_wav.reverberate(audio) | |
| elif augtype == 2: | |
| audio = self.augment_wav.additive_noise('music', audio) | |
| elif augtype == 3: | |
| audio = self.augment_wav.additive_noise('speech', audio) | |
| elif augtype == 4: | |
| audio = self.augment_wav.additive_noise('noise', audio) | |
| return torch.FloatTensor(audio), self.data_label[index] | |
| def __len__(self): | |
| return len(self.data_list) | |
| if __name__ == "__main__": | |
| train_dataset = TrainDataset(train_list="data/train_list.txt", augment=True, | |
| musan_path="data/musan_split", rir_path="data/RIRS_NOISES/simulated_rirs", | |
| max_frames=300, train_path="data/voxceleb2") | |
| train_loader = torch.utils.data.DataLoader( | |
| train_dataset, | |
| batch_size=32, | |
| pin_memory=False, | |
| drop_last=True, | |
| ) | |
| x, y = iter(train_loader).next() | |
| print("x:", x.shape, "y:", y.shape) | |