| import torch |
| from modules.GenerSpeech.model.glow_modules import Glow |
| from modules.fastspeech.tts_modules import PitchPredictor |
| import random |
| from modules.GenerSpeech.model.prosody_util import ProsodyAligner, LocalStyleAdaptor |
| from utils.pitch_utils import f0_to_coarse, denorm_f0 |
| from modules.commons.common_layers import * |
| import torch.distributions as dist |
| from utils.hparams import hparams |
| from modules.GenerSpeech.model.mixstyle import MixStyle |
| from modules.fastspeech.fs2 import FastSpeech2 |
| import json |
| from modules.fastspeech.tts_modules import DEFAULT_MAX_SOURCE_POSITIONS, DEFAULT_MAX_TARGET_POSITIONS |
|
|
| class GenerSpeech(FastSpeech2): |
| ''' |
| GenerSpeech: Towards Style Transfer for Generalizable Out-Of-Domain Text-to-Speech |
| https://arxiv.org/abs/2205.07211 |
| ''' |
| def __init__(self, dictionary, out_dims=None): |
| super().__init__(dictionary, out_dims) |
|
|
| |
| self.norm = MixStyle(p=0.5, alpha=0.1, eps=1e-6, hidden_size=self.hidden_size) |
|
|
| |
| self.emo_embed_proj = Linear(256, self.hidden_size, bias=True) |
|
|
| |
| |
| self.prosody_extractor_utter = LocalStyleAdaptor(self.hidden_size, hparams['nVQ'], self.padding_idx) |
| self.l1_utter = nn.Linear(self.hidden_size * 2, self.hidden_size) |
| self.align_utter = ProsodyAligner(num_layers=2) |
|
|
| |
| self.prosody_extractor_ph = LocalStyleAdaptor(self.hidden_size, hparams['nVQ'], self.padding_idx) |
| self.l1_ph = nn.Linear(self.hidden_size * 2, self.hidden_size) |
| self.align_ph = ProsodyAligner(num_layers=2) |
|
|
| |
| self.prosody_extractor_word = LocalStyleAdaptor(self.hidden_size, hparams['nVQ'], self.padding_idx) |
| self.l1_word = nn.Linear(self.hidden_size * 2, self.hidden_size) |
| self.align_word = ProsodyAligner(num_layers=2) |
|
|
| self.pitch_inpainter_predictor = PitchPredictor( |
| self.hidden_size, n_chans=self.hidden_size, |
| n_layers=3, dropout_rate=0.1, odim=2, |
| padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']) |
|
|
| |
| self.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS |
| self.embed_positions = SinusoidalPositionalEmbedding( |
| self.hidden_size, self.padding_idx, |
| init_size=self.max_source_positions + self.padding_idx + 1, |
| ) |
|
|
| |
| cond_hs = 80 |
| if hparams.get('use_txt_cond', True): |
| cond_hs = cond_hs + hparams['hidden_size'] |
|
|
| cond_hs = cond_hs + hparams['hidden_size'] * 3 |
| self.post_flow = Glow( |
| 80, hparams['post_glow_hidden'], hparams['post_glow_kernel_size'], 1, |
| hparams['post_glow_n_blocks'], hparams['post_glow_n_block_layers'], |
| n_split=4, n_sqz=2, |
| gin_channels=cond_hs, |
| share_cond_layers=hparams['post_share_cond_layers'], |
| share_wn_layers=hparams['share_wn_layers'], |
| sigmoid_scale=hparams['sigmoid_scale'] |
| ) |
| self.prior_dist = dist.Normal(0, 1) |
|
|
|
|
| def forward(self, txt_tokens, mel2ph=None, ref_mel2ph=None, ref_mel2word=None, spk_embed=None, emo_embed=None, ref_mels=None, |
| f0=None, uv=None, skip_decoder=False, global_steps=0, infer=False, **kwargs): |
| ret = {} |
| encoder_out = self.encoder(txt_tokens) |
| src_nonpadding = (txt_tokens > 0).float()[:, :, None] |
|
|
| |
| spk_embed = self.spk_embed_proj(spk_embed)[:, None, :] |
| emo_embed = self.emo_embed_proj(emo_embed)[:, None, :] |
|
|
|
|
| |
| dur_inp = (encoder_out + spk_embed + emo_embed) * src_nonpadding |
| mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret) |
| tgt_nonpadding = (mel2ph > 0).float()[:, :, None] |
| decoder_inp = self.expand_states(encoder_out, mel2ph) |
| decoder_inp = self.norm(decoder_inp, spk_embed + emo_embed) |
|
|
| |
| ret['ref_mel2ph'] = ref_mel2ph |
| ret['ref_mel2word'] = ref_mel2word |
| prosody_utter_mel = self.get_prosody_utter(decoder_inp, ref_mels, ret, infer, global_steps) |
| prosody_ph_mel = self.get_prosody_ph(decoder_inp, ref_mels, ret, infer, global_steps) |
| prosody_word_mel = self.get_prosody_word(decoder_inp, ref_mels, ret, infer, global_steps) |
|
|
| |
| pitch_inp_domain_agnostic = decoder_inp * tgt_nonpadding |
| pitch_inp_domain_specific = (decoder_inp + spk_embed + emo_embed + prosody_utter_mel + prosody_ph_mel + prosody_word_mel) * tgt_nonpadding |
| predicted_pitch = self.inpaint_pitch(pitch_inp_domain_agnostic, pitch_inp_domain_specific, f0, uv, mel2ph, ret) |
|
|
| |
| decoder_inp = decoder_inp + spk_embed + emo_embed + predicted_pitch + prosody_utter_mel + prosody_ph_mel + prosody_word_mel |
| ret['decoder_inp'] = decoder_inp = decoder_inp * tgt_nonpadding |
| if skip_decoder: |
| return ret |
| ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs) |
|
|
| |
| is_training = self.training |
| ret['x_mask'] = tgt_nonpadding |
| ret['spk_embed'] = spk_embed |
| ret['emo_embed'] = emo_embed |
| ret['ref_prosody'] = prosody_utter_mel + prosody_ph_mel + prosody_word_mel |
| self.run_post_glow(ref_mels, infer, is_training, ret) |
| return ret |
|
|
| def get_prosody_ph(self, encoder_out, ref_mels, ret, infer=False, global_steps=0): |
| |
| if global_steps > hparams['vq_start'] or infer: |
| prosody_embedding, loss, ppl = self.prosody_extractor_ph(ref_mels, ret['ref_mel2ph'], no_vq=False) |
| ret['vq_loss_ph'] = loss |
| ret['ppl_ph'] = ppl |
| else: |
| prosody_embedding = self.prosody_extractor_ph(ref_mels, ret['ref_mel2ph'], no_vq=True) |
|
|
| |
| positions = self.embed_positions(prosody_embedding[:, :, 0]) |
| prosody_embedding = self.l1_ph(torch.cat([prosody_embedding, positions], dim=-1)) |
|
|
|
|
| |
| src_key_padding_mask = encoder_out[:, :, 0].eq(self.padding_idx).data |
| prosody_key_padding_mask = prosody_embedding[:, :, 0].eq(self.padding_idx).data |
| if global_steps < hparams['forcing']: |
| output, guided_loss, attn_emo = self.align_ph(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), |
| src_key_padding_mask, prosody_key_padding_mask, forcing=True) |
| else: |
| output, guided_loss, attn_emo = self.align_ph(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), |
| src_key_padding_mask, prosody_key_padding_mask, forcing=False) |
|
|
| ret['gloss_ph'] = guided_loss |
| ret['attn_ph'] = attn_emo |
| return output.transpose(0, 1) |
|
|
| def get_prosody_word(self, encoder_out, ref_mels, ret, infer=False, global_steps=0): |
| |
| if global_steps > hparams['vq_start'] or infer: |
| prosody_embedding, loss, ppl = self.prosody_extractor_word(ref_mels, ret['ref_mel2word'], no_vq=False) |
| ret['vq_loss_word'] = loss |
| ret['ppl_word'] = ppl |
| else: |
| prosody_embedding = self.prosody_extractor_word(ref_mels, ret['ref_mel2word'], no_vq=True) |
|
|
| |
| positions = self.embed_positions(prosody_embedding[:, :, 0]) |
| prosody_embedding = self.l1_word(torch.cat([prosody_embedding, positions], dim=-1)) |
|
|
|
|
| |
| src_key_padding_mask = encoder_out[:, :, 0].eq(self.padding_idx).data |
| prosody_key_padding_mask = prosody_embedding[:, :, 0].eq(self.padding_idx).data |
| if global_steps < hparams['forcing']: |
| output, guided_loss, attn_emo = self.align_word(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), |
| src_key_padding_mask, prosody_key_padding_mask, forcing=True) |
| else: |
| output, guided_loss, attn_emo = self.align_word(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), |
| src_key_padding_mask, prosody_key_padding_mask, forcing=False) |
| ret['gloss_word'] = guided_loss |
| ret['attn_word'] = attn_emo |
| return output.transpose(0, 1) |
|
|
| def get_prosody_utter(self, encoder_out, ref_mels, ret, infer=False, global_steps=0): |
| |
| if global_steps > hparams['vq_start'] or infer: |
| prosody_embedding, loss, ppl = self.prosody_extractor_utter(ref_mels, no_vq=False) |
| ret['vq_loss_utter'] = loss |
| ret['ppl_utter'] = ppl |
| else: |
| prosody_embedding = self.prosody_extractor_utter(ref_mels, no_vq=True) |
|
|
| |
| positions = self.embed_positions(prosody_embedding[:, :, 0]) |
| prosody_embedding = self.l1_utter(torch.cat([prosody_embedding, positions], dim=-1)) |
|
|
|
|
| |
| src_key_padding_mask = encoder_out[:, :, 0].eq(self.padding_idx).data |
| prosody_key_padding_mask = prosody_embedding[:, :, 0].eq(self.padding_idx).data |
| if global_steps < hparams['forcing']: |
| output, guided_loss, attn_emo = self.align_utter(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), |
| src_key_padding_mask, prosody_key_padding_mask, forcing=True) |
| else: |
| output, guided_loss, attn_emo = self.align_utter(encoder_out.transpose(0, 1), prosody_embedding.transpose(0, 1), |
| src_key_padding_mask, prosody_key_padding_mask, forcing=False) |
| ret['gloss_utter'] = guided_loss |
| ret['attn_utter'] = attn_emo |
| return output.transpose(0, 1) |
|
|
|
|
|
|
| def inpaint_pitch(self, pitch_inp_domain_agnostic, pitch_inp_domain_specific, f0, uv, mel2ph, ret): |
| if hparams['pitch_type'] == 'frame': |
| pitch_padding = mel2ph == 0 |
| if hparams['predictor_grad'] != 1: |
| pitch_inp_domain_agnostic = pitch_inp_domain_agnostic.detach() + hparams['predictor_grad'] * (pitch_inp_domain_agnostic - pitch_inp_domain_agnostic.detach()) |
| pitch_inp_domain_specific = pitch_inp_domain_specific.detach() + hparams['predictor_grad'] * (pitch_inp_domain_specific - pitch_inp_domain_specific.detach()) |
|
|
| pitch_domain_agnostic = self.pitch_predictor(pitch_inp_domain_agnostic) |
| pitch_domain_specific = self.pitch_inpainter_predictor(pitch_inp_domain_specific) |
| pitch_pred = pitch_domain_agnostic + pitch_domain_specific |
| ret['pitch_pred'] = pitch_pred |
|
|
| use_uv = hparams['pitch_type'] == 'frame' and hparams['use_uv'] |
| if f0 is None: |
| f0 = pitch_pred[:, :, 0] |
| if use_uv: |
| uv = pitch_pred[:, :, 1] > 0 |
| f0_denorm = denorm_f0(f0, uv if use_uv else None, hparams, pitch_padding=pitch_padding) |
| pitch = f0_to_coarse(f0_denorm) |
| ret['f0_denorm'] = f0_denorm |
| ret['f0_denorm_pred'] = denorm_f0(pitch_pred[:, :, 0], (pitch_pred[:, :, 1] > 0) if use_uv else None, hparams, pitch_padding=pitch_padding) |
| if hparams['pitch_type'] == 'ph': |
| pitch = torch.gather(F.pad(pitch, [1, 0]), 1, mel2ph) |
| ret['f0_denorm'] = torch.gather(F.pad(ret['f0_denorm'], [1, 0]), 1, mel2ph) |
| ret['f0_denorm_pred'] = torch.gather(F.pad(ret['f0_denorm_pred'], [1, 0]), 1, mel2ph) |
| pitch_embed = self.pitch_embed(pitch) |
| return pitch_embed |
|
|
| def run_post_glow(self, tgt_mels, infer, is_training, ret): |
| x_recon = ret['mel_out'].transpose(1, 2) |
| g = x_recon |
| B, _, T = g.shape |
| if hparams.get('use_txt_cond', True): |
| g = torch.cat([g, ret['decoder_inp'].transpose(1, 2)], 1) |
| g_spk_embed = ret['spk_embed'].repeat(1, T, 1).transpose(1, 2) |
| g_emo_embed = ret['emo_embed'].repeat(1, T, 1).transpose(1, 2) |
| l_ref_prosody = ret['ref_prosody'].transpose(1, 2) |
| g = torch.cat([g, g_spk_embed, g_emo_embed, l_ref_prosody], dim=1) |
| prior_dist = self.prior_dist |
| if not infer: |
| if is_training: |
| self.train() |
| x_mask = ret['x_mask'].transpose(1, 2) |
| y_lengths = x_mask.sum(-1) |
| g = g.detach() |
| tgt_mels = tgt_mels.transpose(1, 2) |
| z_postflow, ldj = self.post_flow(tgt_mels, x_mask, g=g) |
| ldj = ldj / y_lengths / 80 |
| ret['z_pf'], ret['ldj_pf'] = z_postflow, ldj |
| ret['postflow'] = -prior_dist.log_prob(z_postflow).mean() - ldj.mean() |
| else: |
| x_mask = torch.ones_like(x_recon[:, :1, :]) |
| z_post = prior_dist.sample(x_recon.shape).to(g.device) * hparams['noise_scale'] |
| x_recon_, _ = self.post_flow(z_post, x_mask, g, reverse=True) |
| x_recon = x_recon_ |
| ret['mel_out'] = x_recon.transpose(1, 2) |