import torch import torchaudio.datasets as datasets import torchaudio.transforms as transforms from collections import defaultdict import random import layers import warnings class SpeakerMelLoader(torch.utils.data.Dataset): """ computes mel-spectrograms from audio file and pulls the speaker ID from the dataset """ def __init__(self, dataset, format='speaker', speaker_utterances=4, mel_length = 128, mel_type = 'Tacotron'): self.dataset = dataset self.set_format(format) self.speaker_utterances = speaker_utterances self.mel_length = mel_length self.mel_type = mel_type self.mel_generators = dict() def set_format(self,format): self.format = format if format == 'speaker': self.create_speaker_index() def create_speaker_index(self): vals = [x.split('-',1) for x in self.dataset._walker] speaker_map = defaultdict(list) for i,v in enumerate(vals): speaker_map[v[0]].append(i) self.speaker_map = speaker_map self.speaker_keys = list(speaker_map.keys()) def apply_mel_gen(self, waveform, sampling_rate, channels=80): if (sampling_rate, channels) not in self.mel_generators: if self.mel_type == 'MFCC': mel_gen = transforms.MFCC(sample_rate=sampling_rate, n_mfcc=channels) elif self.mel_type == 'Mel': mel_gen = transforms.MelSpectrogram(sample_rate=sampling_rate, n_mels=channels) elif self.mel_type == 'Tacotron': mel_gen = layers.TacotronSTFT(sampling_rate=sampling_rate,n_mel_channels=channels) else: raise NotImplementedError('Unsupported mel_type in MelSpeakerLoader: '+self.mel_type) self.mel_generators[(sampling_rate,channels)] = mel_gen else: mel_gen = self.mel_generators[(sampling_rate, channels)] if self.mel_type == 'Tacotron': #Replicating from Tacotron2 data loader max_wav_value=32768.0 #skip normalization from Tacotron2, LibriSpeech data looks pre-normalized (all vals between 0-1) audio_norm = waveform #/ max_wav_value audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) melspec = mel_gen.mel_spectrogram(audio_norm) else: audio = waveform.unsqueeze(0) audio = torch.autograd.Variable(audio, requires_grad=False) melspec = mel_gen(audio) return melspec def get_mel(self, waveform, sampling_rate, channels=80): # We previously identified that these warnings were ok. with warnings.catch_warnings(): warnings.filterwarnings('ignore', message=r'At least one mel filterbank has all zero values.*', module=r'torchaudio.*') melspec = self.apply_mel_gen(waveform, sampling_rate, channels) # melspec is (1,1,channels, time) by default # return (time, channels) melspec = torch.squeeze(melspec).T return melspec def __getitem__(self, index): if self.format == 'utterance': (waveform, sample_rate, _, speaker_id, _, _) = self.dataset[index] mel = self.get_mel(waveform, sample_rate) return (speaker_id, mel) elif self.format == 'speaker': speaker_id = self.speaker_keys[index] utter_indexes = random.sample(self.speaker_map[speaker_id], self.speaker_utterances) mels = [] for i in utter_indexes: (waveform, sample_rate, _, speaker_id, _, _) = self.dataset[i] mel = self.get_mel(waveform, sample_rate) if mel.shape[0] < self.mel_length: #Zero pad mel on the right to mel_length #pad_tuple is (dn start, dn end, dn-1 start, dn-1 end, ... , d1 start, d1 end) pad_tuple = (0,0,0,self.mel_length-mel.shape[0]) mel=torch.nn.functional.pad(mel,pad_tuple) mel_frame = 0 else: mel_frame = random.randint(0,mel.shape[0]-self.mel_length) mels.append(mel[mel_frame:mel_frame+self.mel_length,:]) return (speaker_id, torch.stack(mels,0)) else: raise NotImplementedError() def __len__(self): if self.format == 'utterance': return len(self.dataset) elif self.format == 'speaker': return len(self.speaker_keys) else: raise NotImplementedError()