| import torch
|
|
|
| import utils
|
| from utils.hparams import hparams
|
| from .diff.net import DiffNet
|
| from .diff.shallow_diffusion_tts import GaussianDiffusion
|
| from .task import DiffFsTask
|
| from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder
|
| from utils.pitch_utils import denorm_f0
|
| from tasks.tts.fs2_utils import FastSpeechDataset
|
|
|
| DIFF_DECODERS = {
|
| 'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']),
|
| }
|
|
|
|
|
| class DiffSpeechTask(DiffFsTask):
|
| def __init__(self):
|
| super(DiffSpeechTask, self).__init__()
|
| self.dataset_cls = FastSpeechDataset
|
| self.vocoder: BaseVocoder = get_vocoder_cls(hparams)()
|
|
|
| def build_tts_model(self):
|
| mel_bins = hparams['audio_num_mel_bins']
|
| self.model = GaussianDiffusion(
|
| phone_encoder=self.phone_encoder,
|
| out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
|
| timesteps=hparams['timesteps'],
|
| K_step=hparams['K_step'],
|
| loss_type=hparams['diff_loss_type'],
|
| spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
|
| )
|
| if hparams['fs2_ckpt'] != '':
|
| utils.load_ckpt(self.model.fs2, hparams['fs2_ckpt'], 'model', strict=True)
|
|
|
| for k, v in self.model.fs2.named_parameters():
|
| if not 'predictor' in k:
|
| v.requires_grad = False
|
|
|
| def build_optimizer(self, model):
|
| self.optimizer = optimizer = torch.optim.AdamW(
|
| filter(lambda p: p.requires_grad, model.parameters()),
|
| lr=hparams['lr'],
|
| betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
|
| weight_decay=hparams['weight_decay'])
|
| return optimizer
|
|
|
| def run_model(self, model, sample, return_output=False, infer=False):
|
| txt_tokens = sample['txt_tokens']
|
| target = sample['mels']
|
|
|
| mel2ph = sample['mel2ph']
|
| f0 = sample['f0']
|
| uv = sample['uv']
|
| energy = sample['energy']
|
|
|
| spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
|
| if hparams['pitch_type'] == 'cwt':
|
| cwt_spec = sample[f'cwt_spec']
|
| f0_mean = sample['f0_mean']
|
| f0_std = sample['f0_std']
|
| sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph)
|
|
|
| output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed,
|
| ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer)
|
|
|
| losses = {}
|
| if 'diff_loss' in output:
|
| losses['mel'] = output['diff_loss']
|
| self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses)
|
| if hparams['use_pitch_embed']:
|
| self.add_pitch_loss(output, sample, losses)
|
| if hparams['use_energy_embed']:
|
| self.add_energy_loss(output['energy_pred'], energy, losses)
|
| if not return_output:
|
| return losses
|
| else:
|
| return losses, output
|
|
|
| def validation_step(self, sample, batch_idx):
|
| outputs = {}
|
| txt_tokens = sample['txt_tokens']
|
|
|
| energy = sample['energy']
|
| spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
|
| mel2ph = sample['mel2ph']
|
| f0 = sample['f0']
|
| uv = sample['uv']
|
|
|
| outputs['losses'] = {}
|
|
|
| outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False)
|
|
|
|
|
| outputs['total_loss'] = sum(outputs['losses'].values())
|
| outputs['nsamples'] = sample['nsamples']
|
| outputs = utils.tensors_to_scalars(outputs)
|
| if batch_idx < hparams['num_valid_plots']:
|
|
|
|
|
|
|
| model_out = self.model(
|
| txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, energy=energy, ref_mels=None, infer=True)
|
| gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams)
|
| self.plot_wav(batch_idx, sample['mels'], model_out['mel_out'], is_mel=True, gt_f0=gt_f0, f0=model_out.get('f0_denorm'))
|
| self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'])
|
| return outputs
|
|
|
|
|
|
|
|
|
| def plot_wav(self, batch_idx, gt_wav, wav_out, is_mel=False, gt_f0=None, f0=None, name=None):
|
| gt_wav = gt_wav[0].cpu().numpy()
|
| wav_out = wav_out[0].cpu().numpy()
|
| gt_f0 = gt_f0[0].cpu().numpy()
|
| f0 = f0[0].cpu().numpy() if f0 is not None else None
|
| if is_mel:
|
| gt_wav = self.vocoder.spec2wav(gt_wav, f0=gt_f0)
|
| wav_out = self.vocoder.spec2wav(wav_out, f0=f0)
|
| self.logger.experiment.add_audio(f'gt_{batch_idx}', gt_wav, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step)
|
| self.logger.experiment.add_audio(f'wav_{batch_idx}', wav_out, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step)
|
|
|
|
|