Spaces:
Runtime error
Runtime error
SpeechCloning
/
TrainingInterfaces
/Text_to_Spectrogram
/FastSpeech2
/FastSpeechDatasetLanguageID.py
| import os | |
| import statistics | |
| import torch | |
| from torch.utils.data import Dataset | |
| from tqdm import tqdm | |
| from Preprocessing.ArticulatoryCombinedTextFrontend import get_language_id | |
| from Preprocessing.ProsodicConditionExtractor import ProsodicConditionExtractor | |
| from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.Aligner import Aligner | |
| from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.AlignerDataset import AlignerDataset | |
| from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.DurationCalculator import DurationCalculator | |
| from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.EnergyCalculator import EnergyCalculator | |
| from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.PitchCalculator import Dio | |
| class FastSpeechDataset(Dataset): | |
| def __init__(self, | |
| path_to_transcript_dict, | |
| acoustic_checkpoint_path, | |
| cache_dir, | |
| lang, | |
| loading_processes=40, | |
| min_len_in_seconds=1, | |
| max_len_in_seconds=20, | |
| cut_silence=False, | |
| reduction_factor=1, | |
| device=torch.device("cpu"), | |
| rebuild_cache=False, | |
| ctc_selection=True, | |
| save_imgs=False): | |
| self.cache_dir = cache_dir | |
| os.makedirs(cache_dir, exist_ok=True) | |
| if not os.path.exists(os.path.join(cache_dir, "fast_train_cache.pt")) or rebuild_cache: | |
| if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache: | |
| AlignerDataset(path_to_transcript_dict=path_to_transcript_dict, | |
| cache_dir=cache_dir, | |
| lang=lang, | |
| loading_processes=loading_processes, | |
| min_len_in_seconds=min_len_in_seconds, | |
| max_len_in_seconds=max_len_in_seconds, | |
| cut_silences=cut_silence, | |
| rebuild_cache=rebuild_cache, | |
| device=device) | |
| datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu') | |
| # we use the aligner dataset as basis and augment it to contain the additional information we need for fastspeech. | |
| if not isinstance(datapoints, tuple): # check for backwards compatibility | |
| print(f"It seems like the Aligner dataset in {cache_dir} is not a tuple. Regenerating it, since we need the preprocessed waves.") | |
| AlignerDataset(path_to_transcript_dict=path_to_transcript_dict, | |
| cache_dir=cache_dir, | |
| lang=lang, | |
| loading_processes=loading_processes, | |
| min_len_in_seconds=min_len_in_seconds, | |
| max_len_in_seconds=max_len_in_seconds, | |
| cut_silences=cut_silence, | |
| rebuild_cache=True) | |
| datapoints = torch.load(os.path.join(cache_dir, "aligner_train_cache.pt"), map_location='cpu') | |
| dataset = datapoints[0] | |
| norm_waves = datapoints[1] | |
| # build cache | |
| print("... building dataset cache ...") | |
| self.datapoints = list() | |
| self.ctc_losses = list() | |
| acoustic_model = Aligner() | |
| acoustic_model.load_state_dict(torch.load(acoustic_checkpoint_path, map_location='cpu')["asr_model"]) | |
| # ========================================== | |
| # actual creation of datapoints starts here | |
| # ========================================== | |
| acoustic_model = acoustic_model.to(device) | |
| dio = Dio(reduction_factor=reduction_factor, fs=16000) | |
| energy_calc = EnergyCalculator(reduction_factor=reduction_factor, fs=16000) | |
| dc = DurationCalculator(reduction_factor=reduction_factor) | |
| vis_dir = os.path.join(cache_dir, "duration_vis") | |
| os.makedirs(vis_dir, exist_ok=True) | |
| pros_cond_ext = ProsodicConditionExtractor(sr=16000, device=device) | |
| for index in tqdm(range(len(dataset))): | |
| norm_wave = norm_waves[index] | |
| norm_wave_length = torch.LongTensor([len(norm_wave)]) | |
| if len(norm_wave) / 16000 < min_len_in_seconds and ctc_selection: | |
| continue | |
| text = dataset[index][0] | |
| melspec = dataset[index][2] | |
| melspec_length = dataset[index][3] | |
| alignment_path, ctc_loss = acoustic_model.inference(mel=melspec.to(device), | |
| tokens=text.to(device), | |
| save_img_for_debug=os.path.join(vis_dir, f"{index}.png") if save_imgs else None, | |
| return_ctc=True) | |
| cached_duration = dc(torch.LongTensor(alignment_path), vis=None).cpu() | |
| last_vec = None | |
| for phoneme_index, vec in enumerate(text): | |
| if last_vec is not None: | |
| if last_vec.numpy().tolist() == vec.numpy().tolist(): | |
| # we found a case of repeating phonemes! | |
| # now we must repair their durations by giving the first one 3/5 of their sum and the second one 2/5 (i.e. the rest) | |
| dur_1 = cached_duration[phoneme_index - 1] | |
| dur_2 = cached_duration[phoneme_index] | |
| total_dur = dur_1 + dur_2 | |
| new_dur_1 = int((total_dur / 5) * 3) | |
| new_dur_2 = total_dur - new_dur_1 | |
| cached_duration[phoneme_index - 1] = new_dur_1 | |
| cached_duration[phoneme_index] = new_dur_2 | |
| last_vec = vec | |
| cached_energy = energy_calc(input_waves=norm_wave.unsqueeze(0), | |
| input_waves_lengths=norm_wave_length, | |
| feats_lengths=melspec_length, | |
| durations=cached_duration.unsqueeze(0), | |
| durations_lengths=torch.LongTensor([len(cached_duration)]))[0].squeeze(0).cpu() | |
| cached_pitch = dio(input_waves=norm_wave.unsqueeze(0), | |
| input_waves_lengths=norm_wave_length, | |
| feats_lengths=melspec_length, | |
| durations=cached_duration.unsqueeze(0), | |
| durations_lengths=torch.LongTensor([len(cached_duration)]))[0].squeeze(0).cpu() | |
| try: | |
| prosodic_condition = pros_cond_ext.extract_condition_from_reference_wave(norm_wave, already_normalized=True).cpu() | |
| except RuntimeError: | |
| # if there is an audio without any voiced segments whatsoever we have to skip it. | |
| continue | |
| self.datapoints.append([dataset[index][0], | |
| dataset[index][1], | |
| dataset[index][2], | |
| dataset[index][3], | |
| cached_duration.cpu(), | |
| cached_energy, | |
| cached_pitch, | |
| prosodic_condition]) | |
| self.ctc_losses.append(ctc_loss) | |
| # ============================= | |
| # done with datapoint creation | |
| # ============================= | |
| if ctc_selection: | |
| # now we can filter out some bad datapoints based on the CTC scores we collected | |
| mean_ctc = sum(self.ctc_losses) / len(self.ctc_losses) | |
| std_dev = statistics.stdev(self.ctc_losses) | |
| threshold = mean_ctc + std_dev | |
| for index in range(len(self.ctc_losses), 0, -1): | |
| if self.ctc_losses[index - 1] > threshold: | |
| self.datapoints.pop(index - 1) | |
| print( | |
| f"Removing datapoint {index - 1}, because the CTC loss is one standard deviation higher than the mean. \n ctc: {round(self.ctc_losses[index - 1], 4)} vs. mean: {round(mean_ctc, 4)}") | |
| # save to cache | |
| if len(self.datapoints) > 0: | |
| torch.save(self.datapoints, os.path.join(cache_dir, "fast_train_cache.pt")) | |
| else: | |
| import sys | |
| print("No datapoints were prepared! Exiting...") | |
| sys.exit() | |
| else: | |
| # just load the datapoints from cache | |
| self.datapoints = torch.load(os.path.join(cache_dir, "fast_train_cache.pt"), map_location='cpu') | |
| self.cache_dir = cache_dir | |
| self.language_id = get_language_id(lang) | |
| print(f"Prepared a FastSpeech dataset with {len(self.datapoints)} datapoints in {cache_dir}.") | |
| def __getitem__(self, index): | |
| return self.datapoints[index][0], \ | |
| self.datapoints[index][1], \ | |
| self.datapoints[index][2], \ | |
| self.datapoints[index][3], \ | |
| self.datapoints[index][4], \ | |
| self.datapoints[index][5], \ | |
| self.datapoints[index][6], \ | |
| self.datapoints[index][7], \ | |
| self.language_id | |
| def __len__(self): | |
| return len(self.datapoints) | |
| def remove_samples(self, list_of_samples_to_remove): | |
| for remove_id in sorted(list_of_samples_to_remove, reverse=True): | |
| self.datapoints.pop(remove_id) | |
| torch.save(self.datapoints, os.path.join(self.cache_dir, "fast_train_cache.pt")) | |
| print("Dataset updated!") | |
| def fix_repeating_phones(self): | |
| """ | |
| The viterbi decoding of the durations cannot | |
| handle repetitions. This is now solved heuristically, | |
| but if you have a cache from before March 2022, | |
| use this method to postprocess those cases. | |
| """ | |
| for datapoint_index in tqdm(list(range(len(self.datapoints)))): | |
| last_vec = None | |
| for phoneme_index, vec in enumerate(self.datapoints[datapoint_index][0]): | |
| if last_vec is not None: | |
| if last_vec.numpy().tolist() == vec.numpy().tolist(): | |
| # we found a case of repeating phonemes! | |
| # now we must repair their durations by giving the first one 3/5 of their sum and the second one 2/5 (i.e. the rest) | |
| dur_1 = self.datapoints[datapoint_index][4][phoneme_index - 1] | |
| dur_2 = self.datapoints[datapoint_index][4][phoneme_index] | |
| total_dur = dur_1 + dur_2 | |
| new_dur_1 = int((total_dur / 5) * 3) | |
| new_dur_2 = total_dur - new_dur_1 | |
| self.datapoints[datapoint_index][4][phoneme_index - 1] = new_dur_1 | |
| self.datapoints[datapoint_index][4][phoneme_index] = new_dur_2 | |
| print("fix applied") | |
| last_vec = vec | |
| torch.save(self.datapoints, os.path.join(self.cache_dir, "fast_train_cache.pt")) | |
| print("Dataset updated!") | |