Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from tasks.tts.dataset_utils import FastSpeechWordDataset | |
| from tasks.tts.tts_utils import load_data_preprocessor | |
| import numpy as np | |
| from modules.FastDiff.module.util import compute_hyperparams_given_schedule, sampling_given_noise_schedule | |
| import os | |
| import torch | |
| from modules.FastDiff.module.FastDiff_model import FastDiff | |
| from utils.ckpt_utils import load_ckpt | |
| from utils.hparams import set_hparams | |
| class BaseTTSInfer: | |
| def __init__(self, hparams, device=None): | |
| if device is None: | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| self.hparams = hparams | |
| self.device = device | |
| self.data_dir = hparams['binary_data_dir'] | |
| self.preprocessor, self.preprocess_args = load_data_preprocessor() | |
| self.ph_encoder = self.preprocessor.load_dict(self.data_dir) | |
| self.spk_map = self.preprocessor.load_spk_map(self.data_dir) | |
| self.ds_cls = FastSpeechWordDataset | |
| self.model = self.build_model() | |
| self.model.eval() | |
| self.model.to(self.device) | |
| self.vocoder, self.diffusion_hyperparams, self.noise_schedule = self.build_vocoder() | |
| self.vocoder.eval() | |
| self.vocoder.to(self.device) | |
| def build_model(self): | |
| raise NotImplementedError | |
| def forward_model(self, inp): | |
| raise NotImplementedError | |
| def build_vocoder(self): | |
| base_dir = self.hparams['vocoder_ckpt'] | |
| config_path = f'{base_dir}/config.yaml' | |
| config = set_hparams(config_path, global_hparams=False) | |
| vocoder = FastDiff(audio_channels=config['audio_channels'], | |
| inner_channels=config['inner_channels'], | |
| cond_channels=config['cond_channels'], | |
| upsample_ratios=config['upsample_ratios'], | |
| lvc_layers_each_block=config['lvc_layers_each_block'], | |
| lvc_kernel_size=config['lvc_kernel_size'], | |
| kpnet_hidden_channels=config['kpnet_hidden_channels'], | |
| kpnet_conv_size=config['kpnet_conv_size'], | |
| dropout=config['dropout'], | |
| diffusion_step_embed_dim_in=config['diffusion_step_embed_dim_in'], | |
| diffusion_step_embed_dim_mid=config['diffusion_step_embed_dim_mid'], | |
| diffusion_step_embed_dim_out=config['diffusion_step_embed_dim_out'], | |
| use_weight_norm=config['use_weight_norm']) | |
| load_ckpt(vocoder, base_dir, 'model') | |
| # Init hyperparameters by linear schedule | |
| noise_schedule = torch.linspace(float(config["beta_0"]), float(config["beta_T"]), int(config["T"])) | |
| diffusion_hyperparams = compute_hyperparams_given_schedule(noise_schedule) | |
| if config['noise_schedule'] != '': | |
| noise_schedule = config['noise_schedule'] | |
| if isinstance(noise_schedule, list): | |
| noise_schedule = torch.FloatTensor(noise_schedule) | |
| else: | |
| # Select Schedule | |
| try: | |
| reverse_step = int(self.hparams.get('N')) | |
| except: | |
| print( | |
| 'Please specify $N (the number of revere iterations) in config file. Now denoise with 4 iterations.') | |
| reverse_step = 4 | |
| if reverse_step == 1000: | |
| noise_schedule = torch.linspace(0.000001, 0.01, 1000) | |
| elif reverse_step == 200: | |
| noise_schedule = torch.linspace(0.0001, 0.02, 200) | |
| # Below are schedules derived by Noise Predictor. | |
| # We will release codes of noise predictor training process & noise scheduling process soon. Please Stay Tuned! | |
| elif reverse_step == 8: | |
| noise_schedule = [6.689325005027058e-07, 1.0033881153503899e-05, 0.00015496854030061513, | |
| 0.002387222135439515, 0.035597629845142365, 0.3681158423423767, 0.4735414385795593, | |
| 0.5] | |
| elif reverse_step == 6: | |
| noise_schedule = [1.7838445955931093e-06, 2.7984189728158526e-05, 0.00043231004383414984, | |
| 0.006634317338466644, 0.09357017278671265, 0.6000000238418579] | |
| elif reverse_step == 4: | |
| noise_schedule = [3.2176e-04, 2.5743e-03, 2.5376e-02, 7.0414e-01] | |
| elif reverse_step == 3: | |
| noise_schedule = [9.0000e-05, 9.0000e-03, 6.0000e-01] | |
| else: | |
| raise NotImplementedError | |
| if isinstance(noise_schedule, list): | |
| noise_schedule = torch.FloatTensor(noise_schedule) | |
| return vocoder, diffusion_hyperparams, noise_schedule | |
| def run_vocoder(self, c): | |
| c = c.transpose(2, 1) | |
| audio_length = c.shape[-1] * self.hparams["hop_size"] | |
| y = sampling_given_noise_schedule( | |
| self.vocoder, (1, 1, audio_length), self.diffusion_hyperparams, self.noise_schedule, condition=c, ddim=False, return_sequence=False) | |
| return y | |
| def preprocess_input(self, inp): | |
| """ | |
| :param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)} | |
| :return: | |
| """ | |
| preprocessor, preprocess_args = self.preprocessor, self.preprocess_args | |
| text_raw = inp['text'] | |
| item_name = inp.get('item_name', '<ITEM_NAME>') | |
| spk_name = inp.get('spk_name', 'SPK1') | |
| ph, txt = preprocessor.txt_to_ph( | |
| preprocessor.txt_processor, text_raw, preprocess_args) | |
| ph_token = self.ph_encoder.encode(ph) | |
| spk_id = self.spk_map[spk_name] | |
| item = {'item_name': item_name, 'text': txt, 'ph': ph, 'spk_id': spk_id, 'ph_token': ph_token} | |
| item['ph_len'] = len(item['ph_token']) | |
| return item | |
| def input_to_batch(self, item): | |
| item_names = [item['item_name']] | |
| text = [item['text']] | |
| ph = [item['ph']] | |
| txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device) | |
| txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device) | |
| spk_ids = torch.LongTensor(item['spk_id'])[None, :].to(self.device) | |
| batch = { | |
| 'item_name': item_names, | |
| 'text': text, | |
| 'ph': ph, | |
| 'txt_tokens': txt_tokens, | |
| 'txt_lengths': txt_lengths, | |
| 'spk_ids': spk_ids, | |
| } | |
| return batch | |
| def postprocess_output(self, output): | |
| return output | |
| def infer_once(self, inp): | |
| inp = self.preprocess_input(inp) | |
| output = self.forward_model(inp) | |
| output = self.postprocess_output(output) | |
| return output | |
| def example_run(cls): | |
| from utils.hparams import set_hparams | |
| from utils.hparams import hparams as hp | |
| from utils.audio import save_wav | |
| set_hparams() | |
| inp = { | |
| 'text': hp['text'] | |
| } | |
| infer_ins = cls(hp) | |
| out = infer_ins.infer_once(inp) | |
| os.makedirs('infer_out', exist_ok=True) | |
| save_wav(out, f'infer_out/{hp["text"]}.wav', hp['audio_sample_rate']) | |