Spaces:
Runtime error
Runtime error
| import torch | |
| from inference.base_tts_infer import BaseTTSInfer | |
| from utils.ckpt_utils import load_ckpt, get_last_checkpoint | |
| from utils.hparams import hparams | |
| from modules.ProDiff.model.ProDiff_teacher import GaussianDiffusion | |
| from usr.diff.net import DiffNet | |
| import os | |
| import numpy as np | |
| class ProDiffTeacherInfer(BaseTTSInfer): | |
| def build_model(self): | |
| f0_stats_fn = f'{hparams["binary_data_dir"]}/train_f0s_mean_std.npy' | |
| if os.path.exists(f0_stats_fn): | |
| hparams['f0_mean'], hparams['f0_std'] = np.load(f0_stats_fn) | |
| hparams['f0_mean'] = float(hparams['f0_mean']) | |
| hparams['f0_std'] = float(hparams['f0_std']) | |
| model = GaussianDiffusion( | |
| phone_encoder=self.ph_encoder, | |
| out_dims=80, denoise_fn=DiffNet(hparams['audio_num_mel_bins']), | |
| timesteps=hparams['timesteps'], | |
| loss_type=hparams['diff_loss_type'], | |
| spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], | |
| ) | |
| model.eval() | |
| load_ckpt(model, hparams['work_dir'], 'model') | |
| return model | |
| def forward_model(self, inp): | |
| sample = self.input_to_batch(inp) | |
| txt_tokens = sample['txt_tokens'] # [B, T_t] | |
| with torch.no_grad(): | |
| output = self.model(txt_tokens, infer=True) | |
| mel_out = output['mel_out'] | |
| wav_out = self.run_vocoder(mel_out) | |
| wav_out = wav_out.squeeze().cpu().numpy() | |
| return wav_out | |
| if __name__ == '__main__': | |
| ProDiffTeacherInfer.example_run() | |