import torch from torch.utils.data import Dataset from DatasetLoader import AugmentWAV, loadWAV import os import numpy as np import random class TrainDataset(Dataset): def __init__(self, train_list, train_path, augment, musan_path, rir_path, max_frames,): self.train_list = train_list self.max_frames = max_frames self.augment_wav = AugmentWAV(musan_path=musan_path, rir_path=rir_path, max_frames=max_frames) self.augment = augment self.musan_path = musan_path self.rir_path = rir_path with open(train_list) as dataset_file: lines = dataset_file.readlines() dictkeys = list(set([x.split()[0] for x in lines])) dictkeys.sort() dictkeys = {key: ii for ii, key in enumerate(dictkeys)} np.random.seed(100) np.random.shuffle(lines) self.data_list = [] self.data_label = [] for lidx, line in enumerate(lines): data = line.strip().split() speaker_label = dictkeys[data[0]] filename = os.path.join(train_path, data[1]) self.data_list.append(filename) self.data_label.append(speaker_label) def __getitem__(self, index): audio = loadWAV(self.data_list[index], self.max_frames, evalmode=False) if self.augment: augtype = random.randint(0, 4) # 包括0,4 if augtype == 1: audio = self.augment_wav.reverberate(audio) elif augtype == 2: audio = self.augment_wav.additive_noise('music', audio) elif augtype == 3: audio = self.augment_wav.additive_noise('speech', audio) elif augtype == 4: audio = self.augment_wav.additive_noise('noise', audio) return torch.FloatTensor(audio), self.data_label[index] def __len__(self): return len(self.data_list) if __name__ == "__main__": train_dataset = TrainDataset(train_list="data/train_list.txt", augment=True, musan_path="data/musan_split", rir_path="data/RIRS_NOISES/simulated_rirs", max_frames=300, train_path="data/voxceleb2") train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=32, pin_memory=False, drop_last=True, ) x, y = iter(train_loader).next() print("x:", x.shape, "y:", y.shape)