| import torch
|
|
|
|
|
| from inference.svs.base_svs_infer import BaseSVSInfer
|
| from utils import load_ckpt
|
| from utils.hparams import hparams
|
| from usr.diff.shallow_diffusion_tts import GaussianDiffusion
|
| from usr.diffsinger_task import DIFF_DECODERS
|
|
|
| class DiffSingerCascadeInfer(BaseSVSInfer):
|
| def build_model(self):
|
| model = GaussianDiffusion(
|
| phone_encoder=self.ph_encoder,
|
| out_dims=hparams['audio_num_mel_bins'], denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
|
| timesteps=hparams['timesteps'],
|
| K_step=hparams['K_step'],
|
| loss_type=hparams['diff_loss_type'],
|
| spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
|
| )
|
| model.eval()
|
| load_ckpt(model, hparams['work_dir'], 'model')
|
| return model
|
|
|
| def forward_model(self, inp):
|
| sample = self.input_to_batch(inp)
|
| txt_tokens = sample['txt_tokens']
|
| spk_id = sample.get('spk_ids')
|
| with torch.no_grad():
|
| output = self.model(txt_tokens, spk_id=spk_id, ref_mels=None, infer=True,
|
| pitch_midi=sample['pitch_midi'], midi_dur=sample['midi_dur'],
|
| is_slur=sample['is_slur'])
|
| mel_out = output['mel_out']
|
| f0_pred = output['f0_denorm']
|
| wav_out = self.run_vocoder(mel_out, f0=f0_pred)
|
| wav_out = wav_out.cpu().numpy()
|
| return wav_out[0]
|
|
|
|
|
| if __name__ == '__main__':
|
| inp = {
|
| 'text': '小酒窝长睫毛AP是你最美的记号',
|
| 'notes': 'C#4/Db4 | F#4/Gb4 | G#4/Ab4 | A#4/Bb4 F#4/Gb4 | F#4/Gb4 C#4/Db4 | C#4/Db4 | rest | C#4/Db4 | A#4/Bb4 | G#4/Ab4 | A#4/Bb4 | G#4/Ab4 | F4 | C#4/Db4',
|
| 'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340',
|
| 'input_type': 'word'
|
| }
|
| c = {
|
| 'text': '小酒窝长睫毛AP是你最美的记号',
|
| 'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
|
| 'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
|
| 'note_dur_seq': '0.407140 0.407140 0.376190 0.376190 0.242180 0.242180 0.509550 0.509550 0.183420 0.315400 0.315400 0.235020 0.361660 0.361660 0.223070 0.377270 0.377270 0.340550 0.340550 0.299620 0.299620 0.344510 0.344510 0.283770 0.283770 0.323390 0.323390 0.360340 0.360340',
|
| 'is_slur_seq': '0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0',
|
| 'input_type': 'phoneme'
|
| }
|
| DiffSingerCascadeInfer.example_run(inp)
|
|
|