File size: 4,622 Bytes
b37f199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torchaudio.datasets as datasets
import torchaudio.transforms as transforms
from collections import defaultdict
import random
import layers

import warnings

class SpeakerMelLoader(torch.utils.data.Dataset):
    """
    computes mel-spectrograms from audio file and pulls the speaker ID from the
    dataset
    """

    def __init__(self, dataset, format='speaker', speaker_utterances=4, mel_length = 128, mel_type = 'Tacotron'):
        self.dataset = dataset
        self.set_format(format)
        self.speaker_utterances = speaker_utterances
        self.mel_length = mel_length
        self.mel_type = mel_type
        self.mel_generators = dict()

    def set_format(self,format):
        self.format = format

        if format == 'speaker':
            self.create_speaker_index()

    def create_speaker_index(self):
        vals = [x.split('-',1) for x in self.dataset._walker]
        speaker_map = defaultdict(list)

        for i,v in enumerate(vals):
            speaker_map[v[0]].append(i)
        
        self.speaker_map = speaker_map
        self.speaker_keys = list(speaker_map.keys())

    def apply_mel_gen(self, waveform, sampling_rate, channels=80):
        if (sampling_rate, channels) not in self.mel_generators:
            if self.mel_type == 'MFCC':
                mel_gen = transforms.MFCC(sample_rate=sampling_rate, n_mfcc=channels)
            elif self.mel_type == 'Mel':
                mel_gen = transforms.MelSpectrogram(sample_rate=sampling_rate, n_mels=channels)
            elif self.mel_type == 'Tacotron':
                mel_gen = layers.TacotronSTFT(sampling_rate=sampling_rate,n_mel_channels=channels)
            else:
                raise NotImplementedError('Unsupported mel_type in MelSpeakerLoader: '+self.mel_type)
            self.mel_generators[(sampling_rate,channels)] = mel_gen
        else:
            mel_gen = self.mel_generators[(sampling_rate, channels)]

        if self.mel_type == 'Tacotron':
            #Replicating from Tacotron2 data loader 
            max_wav_value=32768.0
            #skip normalization from Tacotron2, LibriSpeech data looks pre-normalized (all vals between 0-1)
            audio_norm = waveform #/ max_wav_value
            audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
            melspec = mel_gen.mel_spectrogram(audio_norm)
        else:
            audio = waveform.unsqueeze(0)
            audio = torch.autograd.Variable(audio, requires_grad=False)
            melspec = mel_gen(audio)
        
        return melspec

    def get_mel(self, waveform, sampling_rate, channels=80):
        # We previously identified that these warnings were ok.
        with warnings.catch_warnings():
            warnings.filterwarnings('ignore', message=r'At least one mel filterbank has all zero values.*', module=r'torchaudio.*')
            melspec = self.apply_mel_gen(waveform, sampling_rate, channels)
            # melspec is (1,1,channels, time) by default
            # return (time, channels)
            melspec = torch.squeeze(melspec).T
            return melspec

    def __getitem__(self, index):
        if self.format == 'utterance':
            (waveform, sample_rate, _, speaker_id, _, _) = self.dataset[index]
            mel = self.get_mel(waveform, sample_rate)
            return (speaker_id, mel)
        elif self.format == 'speaker':
            speaker_id = self.speaker_keys[index]
            utter_indexes = random.sample(self.speaker_map[speaker_id], self.speaker_utterances)
            mels = []
            for i in utter_indexes:
                (waveform, sample_rate, _, speaker_id, _, _) = self.dataset[i]
                mel = self.get_mel(waveform, sample_rate)
                if mel.shape[0] < self.mel_length:
                    #Zero pad mel on the right to mel_length
                    #pad_tuple is (dn start, dn end, dn-1 start, dn-1 end, ... , d1 start, d1 end)
                    pad_tuple = (0,0,0,self.mel_length-mel.shape[0])
                    mel=torch.nn.functional.pad(mel,pad_tuple)
                    mel_frame = 0
                else:
                    mel_frame = random.randint(0,mel.shape[0]-self.mel_length)
                mels.append(mel[mel_frame:mel_frame+self.mel_length,:])
            return (speaker_id, torch.stack(mels,0))
        else:
            raise NotImplementedError()

    def __len__(self):
        if self.format == 'utterance':
            return len(self.dataset)
        elif self.format == 'speaker':
            return len(self.speaker_keys)
        else:
            raise NotImplementedError()