Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved. | |
| # This program is free software; you can redistribute it and/or modify | |
| # it under the terms of the MIT License. | |
| # This program is distributed in the hope that it will be useful, | |
| # but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| # MIT License for more details. | |
| import random | |
| import numpy as np | |
| import torch | |
| import torchaudio as ta | |
| from text import text_to_sequence, cmudict | |
| from text.symbols import symbols | |
| from utils import parse_filelist, intersperse | |
| from model.utils import fix_len_compatibility | |
| from params import seed as random_seed | |
| import sys | |
| sys.path.insert(0, 'hifi-gan') | |
| from meldataset import mel_spectrogram | |
| class TextMelDataset(torch.utils.data.Dataset): | |
| def __init__(self, filelist_path, cmudict_path, add_blank=True, | |
| n_fft=1024, n_mels=80, sample_rate=22050, | |
| hop_length=256, win_length=1024, f_min=0., f_max=8000): | |
| self.filepaths_and_text = parse_filelist(filelist_path) | |
| self.cmudict = cmudict.CMUDict(cmudict_path) | |
| self.add_blank = add_blank | |
| self.n_fft = n_fft | |
| self.n_mels = n_mels | |
| self.sample_rate = sample_rate | |
| self.hop_length = hop_length | |
| self.win_length = win_length | |
| self.f_min = f_min | |
| self.f_max = f_max | |
| random.seed(random_seed) | |
| random.shuffle(self.filepaths_and_text) | |
| def get_pair(self, filepath_and_text): | |
| filepath, text = filepath_and_text[0], filepath_and_text[1] | |
| text = self.get_text(text, add_blank=self.add_blank) | |
| mel = self.get_mel(filepath) | |
| return (text, mel) | |
| def get_mel(self, filepath): | |
| audio, sr = ta.load(filepath) | |
| assert sr == self.sample_rate | |
| mel = mel_spectrogram(audio, self.n_fft, self.n_mels, self.sample_rate, self.hop_length, | |
| self.win_length, self.f_min, self.f_max, center=False).squeeze() | |
| return mel | |
| def get_text(self, text, add_blank=True): | |
| text_norm = text_to_sequence(text, dictionary=self.cmudict) | |
| if self.add_blank: | |
| text_norm = intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols) | |
| text_norm = torch.IntTensor(text_norm) | |
| return text_norm | |
| def __getitem__(self, index): | |
| text, mel = self.get_pair(self.filepaths_and_text[index]) | |
| item = {'y': mel, 'x': text} | |
| return item | |
| def __len__(self): | |
| return len(self.filepaths_and_text) | |
| def sample_test_batch(self, size): | |
| idx = np.random.choice(range(len(self)), size=size, replace=False) | |
| test_batch = [] | |
| for index in idx: | |
| test_batch.append(self.__getitem__(index)) | |
| return test_batch | |
| class TextMelBatchCollate(object): | |
| def __call__(self, batch): | |
| B = len(batch) | |
| y_max_length = max([item['y'].shape[-1] for item in batch]) | |
| y_max_length = fix_len_compatibility(y_max_length) | |
| x_max_length = max([item['x'].shape[-1] for item in batch]) | |
| n_feats = batch[0]['y'].shape[-2] | |
| y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) | |
| x = torch.zeros((B, x_max_length), dtype=torch.long) | |
| y_lengths, x_lengths = [], [] | |
| for i, item in enumerate(batch): | |
| y_, x_ = item['y'], item['x'] | |
| y_lengths.append(y_.shape[-1]) | |
| x_lengths.append(x_.shape[-1]) | |
| y[i, :, :y_.shape[-1]] = y_ | |
| x[i, :x_.shape[-1]] = x_ | |
| y_lengths = torch.LongTensor(y_lengths) | |
| x_lengths = torch.LongTensor(x_lengths) | |
| return {'x': x, 'x_lengths': x_lengths, 'y': y, 'y_lengths': y_lengths} | |
| class TextMelSpeakerDataset(torch.utils.data.Dataset): | |
| def __init__(self, filelist_path, cmudict_path, add_blank=True, | |
| n_fft=1024, n_mels=80, sample_rate=22050, | |
| hop_length=256, win_length=1024, f_min=0., f_max=8000): | |
| super().__init__() | |
| self.filelist = parse_filelist(filelist_path, split_char='|') | |
| self.cmudict = cmudict.CMUDict(cmudict_path) | |
| self.n_fft = n_fft | |
| self.n_mels = n_mels | |
| self.sample_rate = sample_rate | |
| self.hop_length = hop_length | |
| self.win_length = win_length | |
| self.f_min = f_min | |
| self.f_max = f_max | |
| self.add_blank = add_blank | |
| random.seed(random_seed) | |
| random.shuffle(self.filelist) | |
| def get_triplet(self, line): | |
| filepath, text, speaker = line[0], line[1], line[2] | |
| text = self.get_text(text, add_blank=self.add_blank) | |
| mel = self.get_mel(filepath) | |
| speaker = self.get_speaker(speaker) | |
| return (text, mel, speaker) | |
| def get_mel(self, filepath): | |
| audio, sr = ta.load(filepath) | |
| assert sr == self.sample_rate | |
| mel = mel_spectrogram(audio, self.n_fft, self.n_mels, self.sample_rate, self.hop_length, | |
| self.win_length, self.f_min, self.f_max, center=False).squeeze() | |
| return mel | |
| def get_text(self, text, add_blank=True): | |
| text_norm = text_to_sequence(text, dictionary=self.cmudict) | |
| if self.add_blank: | |
| text_norm = intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols) | |
| text_norm = torch.LongTensor(text_norm) | |
| return text_norm | |
| def get_speaker(self, speaker): | |
| speaker = torch.LongTensor([int(speaker)]) | |
| return speaker | |
| def __getitem__(self, index): | |
| text, mel, speaker = self.get_triplet(self.filelist[index]) | |
| item = {'y': mel, 'x': text, 'spk': speaker} | |
| return item | |
| def __len__(self): | |
| return len(self.filelist) | |
| def sample_test_batch(self, size): | |
| idx = np.random.choice(range(len(self)), size=size, replace=False) | |
| test_batch = [] | |
| for index in idx: | |
| test_batch.append(self.__getitem__(index)) | |
| return test_batch | |
| class TextMelSpeakerBatchCollate(object): | |
| def __call__(self, batch): | |
| B = len(batch) | |
| y_max_length = max([item['y'].shape[-1] for item in batch]) | |
| y_max_length = fix_len_compatibility(y_max_length) | |
| x_max_length = max([item['x'].shape[-1] for item in batch]) | |
| n_feats = batch[0]['y'].shape[-2] | |
| y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) | |
| x = torch.zeros((B, x_max_length), dtype=torch.long) | |
| y_lengths, x_lengths = [], [] | |
| spk = [] | |
| for i, item in enumerate(batch): | |
| y_, x_, spk_ = item['y'], item['x'], item['spk'] | |
| y_lengths.append(y_.shape[-1]) | |
| x_lengths.append(x_.shape[-1]) | |
| y[i, :, :y_.shape[-1]] = y_ | |
| x[i, :x_.shape[-1]] = x_ | |
| spk.append(spk_) | |
| y_lengths = torch.LongTensor(y_lengths) | |
| x_lengths = torch.LongTensor(x_lengths) | |
| spk = torch.cat(spk, dim=0) | |
| return {'x': x, 'x_lengths': x_lengths, 'y': y, 'y_lengths': y_lengths, 'spk': spk} | |