speaker_verification / dataloader.py
xuesongyan
Upload dataloader.py
b58a6ff
import torch
from torch.utils.data import Dataset
from DatasetLoader import AugmentWAV, loadWAV
import os
import numpy as np
import random
class TrainDataset(Dataset):
def __init__(self, train_list, train_path, augment, musan_path, rir_path, max_frames,):
self.train_list = train_list
self.max_frames = max_frames
self.augment_wav = AugmentWAV(musan_path=musan_path, rir_path=rir_path, max_frames=max_frames)
self.augment = augment
self.musan_path = musan_path
self.rir_path = rir_path
with open(train_list) as dataset_file:
lines = dataset_file.readlines()
dictkeys = list(set([x.split()[0] for x in lines]))
dictkeys.sort()
dictkeys = {key: ii for ii, key in enumerate(dictkeys)}
np.random.seed(100)
np.random.shuffle(lines)
self.data_list = []
self.data_label = []
for lidx, line in enumerate(lines):
data = line.strip().split()
speaker_label = dictkeys[data[0]]
filename = os.path.join(train_path, data[1])
self.data_list.append(filename)
self.data_label.append(speaker_label)
def __getitem__(self, index):
audio = loadWAV(self.data_list[index], self.max_frames, evalmode=False)
if self.augment:
augtype = random.randint(0, 4) # 包括0,4
if augtype == 1:
audio = self.augment_wav.reverberate(audio)
elif augtype == 2:
audio = self.augment_wav.additive_noise('music', audio)
elif augtype == 3:
audio = self.augment_wav.additive_noise('speech', audio)
elif augtype == 4:
audio = self.augment_wav.additive_noise('noise', audio)
return torch.FloatTensor(audio), self.data_label[index]
def __len__(self):
return len(self.data_list)
if __name__ == "__main__":
train_dataset = TrainDataset(train_list="data/train_list.txt", augment=True,
musan_path="data/musan_split", rir_path="data/RIRS_NOISES/simulated_rirs",
max_frames=300, train_path="data/voxceleb2")
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=32,
pin_memory=False,
drop_last=True,
)
x, y = iter(train_loader).next()
print("x:", x.shape, "y:", y.shape)