File size: 2,437 Bytes
b58a6ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from torch.utils.data import Dataset
from DatasetLoader import AugmentWAV, loadWAV
import os
import numpy as np
import random


class TrainDataset(Dataset):
    def __init__(self, train_list, train_path, augment, musan_path, rir_path, max_frames,):
        self.train_list = train_list
        self.max_frames = max_frames
        self.augment_wav = AugmentWAV(musan_path=musan_path, rir_path=rir_path, max_frames=max_frames)
        self.augment = augment
        self.musan_path = musan_path
        self.rir_path = rir_path

        with open(train_list) as dataset_file:
            lines = dataset_file.readlines()

        dictkeys = list(set([x.split()[0] for x in lines]))
        dictkeys.sort()
        dictkeys = {key: ii for ii, key in enumerate(dictkeys)}

        np.random.seed(100)
        np.random.shuffle(lines)

        self.data_list = []
        self.data_label = []

        for lidx, line in enumerate(lines):
            data = line.strip().split()
            speaker_label = dictkeys[data[0]]
            filename = os.path.join(train_path, data[1])

            self.data_list.append(filename)
            self.data_label.append(speaker_label)

    def __getitem__(self, index):

        audio = loadWAV(self.data_list[index], self.max_frames, evalmode=False)
        if self.augment:
            augtype = random.randint(0, 4)  # 包括0,4
            if augtype == 1:
                audio = self.augment_wav.reverberate(audio)
            elif augtype == 2:
                audio = self.augment_wav.additive_noise('music', audio)
            elif augtype == 3:
                audio = self.augment_wav.additive_noise('speech', audio)
            elif augtype == 4:
                audio = self.augment_wav.additive_noise('noise', audio)

        return torch.FloatTensor(audio), self.data_label[index]

    def __len__(self):
        return len(self.data_list)


if __name__ == "__main__":
    train_dataset = TrainDataset(train_list="data/train_list.txt", augment=True,
                                  musan_path="data/musan_split", rir_path="data/RIRS_NOISES/simulated_rirs",
                                  max_frames=300, train_path="data/voxceleb2")
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=32,
        pin_memory=False,
        drop_last=True,
    )
    x, y = iter(train_loader).next()
    print("x:", x.shape, "y:", y.shape)