|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
from DatasetLoader import AugmentWAV, loadWAV |
|
|
import os |
|
|
import numpy as np |
|
|
import random |
|
|
|
|
|
|
|
|
class TrainDataset(Dataset): |
|
|
def __init__(self, train_list, train_path, augment, musan_path, rir_path, max_frames,): |
|
|
self.train_list = train_list |
|
|
self.max_frames = max_frames |
|
|
self.augment_wav = AugmentWAV(musan_path=musan_path, rir_path=rir_path, max_frames=max_frames) |
|
|
self.augment = augment |
|
|
self.musan_path = musan_path |
|
|
self.rir_path = rir_path |
|
|
|
|
|
with open(train_list) as dataset_file: |
|
|
lines = dataset_file.readlines() |
|
|
|
|
|
dictkeys = list(set([x.split()[0] for x in lines])) |
|
|
dictkeys.sort() |
|
|
dictkeys = {key: ii for ii, key in enumerate(dictkeys)} |
|
|
|
|
|
np.random.seed(100) |
|
|
np.random.shuffle(lines) |
|
|
|
|
|
self.data_list = [] |
|
|
self.data_label = [] |
|
|
|
|
|
for lidx, line in enumerate(lines): |
|
|
data = line.strip().split() |
|
|
speaker_label = dictkeys[data[0]] |
|
|
filename = os.path.join(train_path, data[1]) |
|
|
|
|
|
self.data_list.append(filename) |
|
|
self.data_label.append(speaker_label) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
|
|
audio = loadWAV(self.data_list[index], self.max_frames, evalmode=False) |
|
|
if self.augment: |
|
|
augtype = random.randint(0, 4) |
|
|
if augtype == 1: |
|
|
audio = self.augment_wav.reverberate(audio) |
|
|
elif augtype == 2: |
|
|
audio = self.augment_wav.additive_noise('music', audio) |
|
|
elif augtype == 3: |
|
|
audio = self.augment_wav.additive_noise('speech', audio) |
|
|
elif augtype == 4: |
|
|
audio = self.augment_wav.additive_noise('noise', audio) |
|
|
|
|
|
return torch.FloatTensor(audio), self.data_label[index] |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data_list) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
train_dataset = TrainDataset(train_list="data/train_list.txt", augment=True, |
|
|
musan_path="data/musan_split", rir_path="data/RIRS_NOISES/simulated_rirs", |
|
|
max_frames=300, train_path="data/voxceleb2") |
|
|
train_loader = torch.utils.data.DataLoader( |
|
|
train_dataset, |
|
|
batch_size=32, |
|
|
pin_memory=False, |
|
|
drop_last=True, |
|
|
) |
|
|
x, y = iter(train_loader).next() |
|
|
print("x:", x.shape, "y:", y.shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|