Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import text | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from torch.utils.data import Dataset | |
| from utils import read_lines_from_file, progbar | |
| from utils.audio import MelSpectrogram | |
| def text_mel_collate_fn(batch, pad_value=0): | |
| """ | |
| Args: | |
| batch: List[(text_ids, mel_spec)] | |
| Returns: | |
| text_ids_pad | |
| input_lengths | |
| mel_pad | |
| gate_pad | |
| output_lengths | |
| """ | |
| input_lens_sorted, input_sort_ids = torch.sort( | |
| torch.LongTensor([len(x[0]) for x in batch]), | |
| dim=0, descending=True) | |
| max_input_len = input_lens_sorted[0] | |
| num_mels = batch[0][1].size(0) | |
| max_target_len = max([x[1].size(1) for x in batch]) | |
| text_ids_pad = torch.LongTensor(len(batch), max_input_len) | |
| mel_pad = torch.FloatTensor(len(batch), num_mels, max_target_len) | |
| gate_pad = torch.FloatTensor(len(batch), max_target_len) | |
| output_lengths = torch.LongTensor(len(batch)) | |
| text_ids_pad.zero_(), mel_pad.fill_(pad_value), gate_pad.zero_() | |
| for i in range(len(input_sort_ids)): | |
| text_ids, mel = batch[input_sort_ids[i]] | |
| text_ids_pad[i, :text_ids.size(0)] = text_ids | |
| mel_pad[i, :, :mel.size(1)] = mel | |
| gate_pad[i, mel.size(1)-1:] = 1 | |
| output_lengths[i] = mel.size(1) | |
| return text_ids_pad, input_lens_sorted, \ | |
| mel_pad, gate_pad, output_lengths | |
| def normalize_pitch(pitch, | |
| mean: float = 130.05478, | |
| std: float = 22.86267): | |
| zeros = (pitch == 0.0) | |
| pitch -= mean | |
| pitch /= std | |
| pitch[zeros] = 0.0 | |
| return pitch | |
| def remove_silence(energy_per_frame: torch.Tensor, | |
| thresh: float = -10.0): | |
| keep = energy_per_frame > thresh | |
| # keep silence at the end | |
| i = keep.size(0)-1 | |
| while not keep[i] and i > 0: | |
| keep[i] = True | |
| i -= 1 | |
| return keep | |
| def make_dataset_from_subdirs(folder_path): | |
| samples = [] | |
| for root, _, fnames in os.walk(folder_path, followlinks=True): | |
| for fname in fnames: | |
| if fname.endswith('.wav'): | |
| samples.append(os.path.join(root, fname)) | |
| return samples | |
| def _process_line(label_pattern: str, line: str): | |
| match = re.search(label_pattern, line) | |
| if match is None: | |
| raise Exception(f'no match for line: {line}') | |
| res_dict = match.groupdict() | |
| if 'arabic' in res_dict: | |
| phonemes = text.arabic_to_phonemes(res_dict['arabic']) | |
| elif 'phonemes' in res_dict: | |
| phonemes = res_dict['phonemes'] | |
| elif 'buckwalter' in res_dict: | |
| phonemes = text.buckwalter_to_phonemes(res_dict['buckwalter']) | |
| if 'filename' in res_dict: | |
| filename = res_dict['filename'] | |
| elif 'filestem' in res_dict: | |
| filename = f"{res_dict['filestem']}.wav" | |
| return phonemes, filename | |
| class ArabDataset(Dataset): | |
| def __init__(self, | |
| txtpath: str = 'tts data sample/text.txt', | |
| wavpath: str = './', | |
| label_pattern: str = '"(?P<filename>.*)" "(?P<phonemes>.*)"', | |
| sr_target: int = 22050 | |
| ): | |
| super().__init__() | |
| self.mel_fn = MelSpectrogram() | |
| self.wav_path = wavpath | |
| self.label_pattern = label_pattern | |
| self.sr_target = sr_target | |
| self.data = self._process_textfile(txtpath) | |
| def _process_textfile(self, txtpath: str): | |
| lines = read_lines_from_file(txtpath) | |
| phoneme_mel_list = [] | |
| for l_idx, line in enumerate(progbar(lines)): | |
| try: | |
| phonemes, filename = _process_line( | |
| self.label_pattern, line) | |
| except: | |
| print(f'invalid line {l_idx}: {line}') | |
| continue | |
| fpath = os.path.join(self.wav_path, filename) | |
| if not os.path.exists(fpath): | |
| print(f"{fpath} does not exist") | |
| continue | |
| try: | |
| tokens = text.phonemes_to_tokens(phonemes) | |
| token_ids = text.tokens_to_ids(tokens) | |
| except: | |
| print(f'invalid phonemes at line {l_idx}: {line}') | |
| continue | |
| phoneme_mel_list.append((torch.LongTensor(token_ids), fpath)) | |
| return phoneme_mel_list | |
| def _get_mel_from_fpath(self, fpath): | |
| wave, sr = torchaudio.load(fpath) | |
| if sr != self.sr_target: | |
| wave = torchaudio.functional.resample(wave, sr, self.sr_target, 64) | |
| mel_raw = self.mel_fn(wave) | |
| mel_log = mel_raw.clamp_min(1e-5).log().squeeze() | |
| energy_per_frame = mel_log.mean(0) | |
| mel_log = mel_log[:, remove_silence(energy_per_frame)] | |
| return mel_log | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| phonemes, fpath = self.data[idx] | |
| mel_log = self._get_mel_from_fpath(fpath) | |
| return phonemes, mel_log | |
| class ArabDataset4FastPitch(Dataset): | |
| def __init__(self, | |
| txtpath: str = './data/train_phon.txt', | |
| wavpath: str = 'G:/data/arabic-speech-corpus/wav_new', | |
| label_pattern: str = '"(?P<filename>.*)" "(?P<phonemes>.*)"', | |
| f0_dict_path: str = './data/pitch_dict.pt', | |
| f0_mean: float = 130.05478, | |
| f0_std: float = 22.86267, | |
| sr_target: int = 22050 | |
| ): | |
| super().__init__() | |
| from models.fastpitch.fastpitch.data_function import BetaBinomialInterpolator | |
| self.mel_fn = MelSpectrogram() | |
| self.wav_path = wavpath | |
| self.label_pattern = label_pattern | |
| self.sr_target = sr_target | |
| self.f0_dict = torch.load(f0_dict_path) | |
| self.f0_mean = f0_mean | |
| self.f0_std = f0_std | |
| self.betabinomial_interpolator = BetaBinomialInterpolator() | |
| self.data = self._process_textfile(txtpath) | |
| def _process_textfile(self, txtpath: str): | |
| lines = read_lines_from_file(txtpath) | |
| phoneme_mel_pitch_list = [] | |
| for l_idx, line in enumerate(progbar(lines)): | |
| try: | |
| phonemes, filename = _process_line( | |
| self.label_pattern, line) | |
| except: | |
| print(f'invalid line {l_idx}: {line}') | |
| continue | |
| fpath = os.path.join(self.wav_path, filename) | |
| if not os.path.exists(fpath): | |
| print(f"{fpath} does not exist") | |
| continue | |
| try: | |
| tokens = text.phonemes_to_tokens(phonemes) | |
| token_ids = text.tokens_to_ids(tokens) | |
| except: | |
| print(f'invalid phonemes at line {l_idx}: {line}') | |
| continue | |
| wav_name = os.path.basename(fpath) | |
| pitch_mel = self.f0_dict[wav_name][None] | |
| phoneme_mel_pitch_list.append( | |
| (torch.LongTensor(token_ids), fpath, pitch_mel)) | |
| return phoneme_mel_pitch_list | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| phonemes, fpath, pitch_mel = self.data[idx] | |
| wave, sr = torchaudio.load(fpath) | |
| if sr != self.sr_target: | |
| wave = torchaudio.functional.resample(wave, sr, self.sr_target, 64) | |
| mel_raw = self.mel_fn(wave) | |
| mel_log = mel_raw.clamp_min(1e-5).log().squeeze() | |
| keep = remove_silence(mel_log.mean(0)) | |
| mel_log = mel_log[:, keep] | |
| pitch_mel = normalize_pitch(pitch_mel[:,keep], self.f0_mean, self.f0_std) | |
| energy = torch.norm(mel_log.float(), dim=0, p=2) | |
| attn_prior = torch.from_numpy( | |
| self.betabinomial_interpolator(mel_log.size(1), len(phonemes))) | |
| speaker = None | |
| return (phonemes, mel_log, len(phonemes), pitch_mel, | |
| energy, speaker, attn_prior, | |
| fpath) | |
| class DynBatchDataset(ArabDataset4FastPitch): | |
| def __init__(self, | |
| txtpath: str = './data/train_phon.txt', | |
| wavpath: str = 'G:/data/arabic-speech-corpus/wav_new', | |
| label_pattern: str = '"(?P<filename>.*)" "(?P<phonemes>.*)"', | |
| f0_dict_path: str = './data/pitch_dict.pt', | |
| f0_mean: float = 130.05478, | |
| f0_std: float = 22.86267, | |
| max_lengths: list[int] = [1000, 1300, 1850, 30000], | |
| batch_sizes: list[int] = [10, 8, 6, 4], | |
| ): | |
| super().__init__(txtpath=txtpath, wavpath=wavpath, | |
| label_pattern=label_pattern, | |
| f0_dict_path=f0_dict_path, | |
| f0_mean=f0_mean, f0_std=f0_std) | |
| self.max_lens = [0,] + max_lengths | |
| self.b_sizes = batch_sizes | |
| self.id_batches = [] | |
| self.shuffle() | |
| def shuffle(self): | |
| lens = [x[2].size(1) for x in self.data] # x[2]: pitch | |
| ids_per_bs = {b: [] for b in self.b_sizes} | |
| for i, mel_len in enumerate(lens): | |
| b_idx = next(i for i in range(len(self.max_lens)-1) | |
| if self.max_lens[i] <= mel_len < self.max_lens[i+1]) | |
| ids_per_bs[self.b_sizes[b_idx]].append(i) | |
| id_batches = [] | |
| for bs, ids in ids_per_bs.items(): | |
| np.random.shuffle(ids) | |
| ids_chnk = [ids[i:i+bs] for i in range(0, len(ids), bs)] | |
| id_batches += ids_chnk | |
| self.id_batches = id_batches | |
| def __len__(self): | |
| return len(self.id_batches) | |
| def __getitem__(self, idx): | |
| batch = [super(DynBatchDataset, self).__getitem__(idx) | |
| for idx in self.id_batches[idx]] | |
| return batch |