xuesongyan commited on
Commit
b58a6ff
·
1 Parent(s): ee4b9b7

Upload dataloader.py

Browse files
Files changed (1) hide show
  1. dataloader.py +75 -0
dataloader.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from DatasetLoader import AugmentWAV, loadWAV
4
+ import os
5
+ import numpy as np
6
+ import random
7
+
8
+
9
+ class TrainDataset(Dataset):
10
+ def __init__(self, train_list, train_path, augment, musan_path, rir_path, max_frames,):
11
+ self.train_list = train_list
12
+ self.max_frames = max_frames
13
+ self.augment_wav = AugmentWAV(musan_path=musan_path, rir_path=rir_path, max_frames=max_frames)
14
+ self.augment = augment
15
+ self.musan_path = musan_path
16
+ self.rir_path = rir_path
17
+
18
+ with open(train_list) as dataset_file:
19
+ lines = dataset_file.readlines()
20
+
21
+ dictkeys = list(set([x.split()[0] for x in lines]))
22
+ dictkeys.sort()
23
+ dictkeys = {key: ii for ii, key in enumerate(dictkeys)}
24
+
25
+ np.random.seed(100)
26
+ np.random.shuffle(lines)
27
+
28
+ self.data_list = []
29
+ self.data_label = []
30
+
31
+ for lidx, line in enumerate(lines):
32
+ data = line.strip().split()
33
+ speaker_label = dictkeys[data[0]]
34
+ filename = os.path.join(train_path, data[1])
35
+
36
+ self.data_list.append(filename)
37
+ self.data_label.append(speaker_label)
38
+
39
+ def __getitem__(self, index):
40
+
41
+ audio = loadWAV(self.data_list[index], self.max_frames, evalmode=False)
42
+ if self.augment:
43
+ augtype = random.randint(0, 4) # 包括0,4
44
+ if augtype == 1:
45
+ audio = self.augment_wav.reverberate(audio)
46
+ elif augtype == 2:
47
+ audio = self.augment_wav.additive_noise('music', audio)
48
+ elif augtype == 3:
49
+ audio = self.augment_wav.additive_noise('speech', audio)
50
+ elif augtype == 4:
51
+ audio = self.augment_wav.additive_noise('noise', audio)
52
+
53
+ return torch.FloatTensor(audio), self.data_label[index]
54
+
55
+ def __len__(self):
56
+ return len(self.data_list)
57
+
58
+
59
+ if __name__ == "__main__":
60
+ train_dataset = TrainDataset(train_list="data/train_list.txt", augment=True,
61
+ musan_path="data/musan_split", rir_path="data/RIRS_NOISES/simulated_rirs",
62
+ max_frames=300, train_path="data/voxceleb2")
63
+ train_loader = torch.utils.data.DataLoader(
64
+ train_dataset,
65
+ batch_size=32,
66
+ pin_memory=False,
67
+ drop_last=True,
68
+ )
69
+ x, y = iter(train_loader).next()
70
+ print("x:", x.shape, "y:", y.shape)
71
+
72
+
73
+
74
+
75
+