| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
|
|
| from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize |
|
|
|
|
| class WaveRNNDataset(Dataset): |
| """ |
| WaveRNN Dataset searchs for all the wav files under root path |
| and converts them to acoustic features on the fly. |
| """ |
|
|
| def __init__( |
| self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True |
| ): |
| super().__init__() |
| self.ap = ap |
| self.compute_feat = not isinstance(items[0], (tuple, list)) |
| self.item_list = items |
| self.seq_len = seq_len |
| self.hop_len = hop_len |
| self.mel_len = seq_len // hop_len |
| self.pad = pad |
| self.mode = mode |
| self.mulaw = mulaw |
| self.is_training = is_training |
| self.verbose = verbose |
| self.return_segments = return_segments |
|
|
| assert self.seq_len % self.hop_len == 0 |
|
|
| def __len__(self): |
| return len(self.item_list) |
|
|
| def __getitem__(self, index): |
| item = self.load_item(index) |
| return item |
|
|
| def load_test_samples(self, num_samples): |
| samples = [] |
| return_segments = self.return_segments |
| self.return_segments = False |
| for idx in range(num_samples): |
| mel, audio, _ = self.load_item(idx) |
| samples.append([mel, audio]) |
| self.return_segments = return_segments |
| return samples |
|
|
| def load_item(self, index): |
| """ |
| load (audio, feat) couple if feature_path is set |
| else compute it on the fly |
| """ |
| if self.compute_feat: |
| wavpath = self.item_list[index] |
| audio = self.ap.load_wav(wavpath) |
| if self.return_segments: |
| min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len) |
| else: |
| min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len) |
| if audio.shape[0] < min_audio_len: |
| print(" [!] Instance is too short! : {}".format(wavpath)) |
| audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len]) |
| mel = self.ap.melspectrogram(audio) |
|
|
| if self.mode in ["gauss", "mold"]: |
| x_input = audio |
| elif isinstance(self.mode, int): |
| x_input = ( |
| mulaw_encode(wav=audio, mulaw_qc=self.mode) |
| if self.mulaw |
| else quantize(x=audio, quantize_bits=self.mode) |
| ) |
| else: |
| raise RuntimeError("Unknown dataset mode - ", self.mode) |
|
|
| else: |
| wavpath, feat_path = self.item_list[index] |
| mel = np.load(feat_path.replace("/quant/", "/mel/")) |
|
|
| if mel.shape[-1] < self.mel_len + 2 * self.pad: |
| print(" [!] Instance is too short! : {}".format(wavpath)) |
| self.item_list[index] = self.item_list[index + 1] |
| feat_path = self.item_list[index] |
| mel = np.load(feat_path.replace("/quant/", "/mel/")) |
| if self.mode in ["gauss", "mold"]: |
| x_input = self.ap.load_wav(wavpath) |
| elif isinstance(self.mode, int): |
| x_input = np.load(feat_path.replace("/mel/", "/quant/")) |
| else: |
| raise RuntimeError("Unknown dataset mode - ", self.mode) |
|
|
| return mel, x_input, wavpath |
|
|
| def collate(self, batch): |
| mel_win = self.seq_len // self.hop_len + 2 * self.pad |
| max_offsets = [x[0].shape[-1] - (mel_win + 2 * self.pad) for x in batch] |
|
|
| mel_offsets = [np.random.randint(0, offset) for offset in max_offsets] |
| sig_offsets = [(offset + self.pad) * self.hop_len for offset in mel_offsets] |
|
|
| mels = [x[0][:, mel_offsets[i] : mel_offsets[i] + mel_win] for i, x in enumerate(batch)] |
|
|
| coarse = [x[1][sig_offsets[i] : sig_offsets[i] + self.seq_len + 1] for i, x in enumerate(batch)] |
|
|
| mels = np.stack(mels).astype(np.float32) |
| if self.mode in ["gauss", "mold"]: |
| coarse = np.stack(coarse).astype(np.float32) |
| coarse = torch.FloatTensor(coarse) |
| x_input = coarse[:, : self.seq_len] |
| elif isinstance(self.mode, int): |
| coarse = np.stack(coarse).astype(np.int64) |
| coarse = torch.LongTensor(coarse) |
| x_input = 2 * coarse[:, : self.seq_len].float() / (2**self.mode - 1.0) - 1.0 |
| y_coarse = coarse[:, 1:] |
| mels = torch.FloatTensor(mels) |
| return x_input, mels, y_coarse |
|
|