Spaces:
Runtime error
Runtime error
| """ | |
| Taken from ESPNet | |
| """ | |
| from abc import ABC | |
| import torch | |
| from Layers.Conformer import Conformer | |
| from Layers.DurationPredictor import DurationPredictor | |
| from Layers.LengthRegulator import LengthRegulator | |
| from Layers.PostNet import PostNet | |
| from Layers.VariancePredictor import VariancePredictor | |
| from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.FastSpeech2Loss import FastSpeech2Loss | |
| from Utility.SoftDTW.sdtw_cuda_loss import SoftDTW | |
| from Utility.utils import initialize | |
| from Utility.utils import make_non_pad_mask | |
| from Utility.utils import make_pad_mask | |
| class FastSpeech2(torch.nn.Module, ABC): | |
| """ | |
| FastSpeech 2 module. | |
| This is a module of FastSpeech 2 described in FastSpeech 2: Fast and | |
| High-Quality End-to-End Text to Speech. Instead of quantized pitch and | |
| energy, we use token-averaged value introduced in FastPitch: Parallel | |
| Text-to-speech with Pitch Prediction. The encoder and decoder are Conformers | |
| instead of regular Transformers. | |
| https://arxiv.org/abs/2006.04558 | |
| https://arxiv.org/abs/2006.06873 | |
| https://arxiv.org/pdf/2005.08100 | |
| """ | |
| def __init__(self, | |
| # network structure related | |
| idim=66, | |
| odim=80, | |
| adim=384, | |
| aheads=4, | |
| elayers=6, | |
| eunits=1536, | |
| dlayers=6, | |
| dunits=1536, | |
| postnet_layers=5, | |
| postnet_chans=256, | |
| postnet_filts=5, | |
| positionwise_layer_type="conv1d", | |
| positionwise_conv_kernel_size=1, | |
| use_scaled_pos_enc=True, | |
| use_batch_norm=True, | |
| encoder_normalize_before=True, | |
| decoder_normalize_before=True, | |
| encoder_concat_after=False, | |
| decoder_concat_after=False, | |
| reduction_factor=1, | |
| # encoder / decoder | |
| use_macaron_style_in_conformer=True, | |
| use_cnn_in_conformer=True, | |
| conformer_enc_kernel_size=7, | |
| conformer_dec_kernel_size=31, | |
| # duration predictor | |
| duration_predictor_layers=2, | |
| duration_predictor_chans=256, | |
| duration_predictor_kernel_size=3, | |
| # energy predictor | |
| energy_predictor_layers=2, | |
| energy_predictor_chans=256, | |
| energy_predictor_kernel_size=3, | |
| energy_predictor_dropout=0.5, | |
| energy_embed_kernel_size=1, | |
| energy_embed_dropout=0.0, | |
| stop_gradient_from_energy_predictor=False, | |
| # pitch predictor | |
| pitch_predictor_layers=5, | |
| pitch_predictor_chans=256, | |
| pitch_predictor_kernel_size=5, | |
| pitch_predictor_dropout=0.5, | |
| pitch_embed_kernel_size=1, | |
| pitch_embed_dropout=0.0, | |
| stop_gradient_from_pitch_predictor=True, | |
| # training related | |
| transformer_enc_dropout_rate=0.2, | |
| transformer_enc_positional_dropout_rate=0.2, | |
| transformer_enc_attn_dropout_rate=0.2, | |
| transformer_dec_dropout_rate=0.2, | |
| transformer_dec_positional_dropout_rate=0.2, | |
| transformer_dec_attn_dropout_rate=0.2, | |
| duration_predictor_dropout_rate=0.2, | |
| postnet_dropout_rate=0.5, | |
| init_type="xavier_uniform", | |
| init_enc_alpha=1.0, | |
| init_dec_alpha=1.0, | |
| use_masking=False, | |
| use_weighted_masking=True, | |
| # additional features | |
| use_dtw_loss=False, | |
| utt_embed_dim=704, | |
| connect_utt_emb_at_encoder_out=True, | |
| lang_embs=100): | |
| super().__init__() | |
| # store hyperparameters | |
| self.idim = idim | |
| self.odim = odim | |
| self.use_dtw_loss = use_dtw_loss | |
| self.eos = 1 | |
| self.reduction_factor = reduction_factor | |
| self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor | |
| self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor | |
| self.use_scaled_pos_enc = use_scaled_pos_enc | |
| self.multilingual_model = lang_embs is not None | |
| self.multispeaker_model = utt_embed_dim is not None | |
| # define encoder | |
| embed = torch.nn.Sequential(torch.nn.Linear(idim, 100), | |
| torch.nn.Tanh(), | |
| torch.nn.Linear(100, adim)) | |
| self.encoder = Conformer(idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, | |
| input_layer=embed, dropout_rate=transformer_enc_dropout_rate, | |
| positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, | |
| normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, | |
| positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, | |
| use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=False, | |
| utt_embed=utt_embed_dim, connect_utt_emb_at_encoder_out=connect_utt_emb_at_encoder_out, lang_embs=lang_embs) | |
| # define duration predictor | |
| self.duration_predictor = DurationPredictor(idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, | |
| kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) | |
| # define pitch predictor | |
| self.pitch_predictor = VariancePredictor(idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, | |
| kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout) | |
| # continuous pitch + FastPitch style avg | |
| self.pitch_embed = torch.nn.Sequential( | |
| torch.nn.Conv1d(in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2), | |
| torch.nn.Dropout(pitch_embed_dropout)) | |
| # define energy predictor | |
| self.energy_predictor = VariancePredictor(idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, | |
| kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout) | |
| # continuous energy + FastPitch style avg | |
| self.energy_embed = torch.nn.Sequential( | |
| torch.nn.Conv1d(in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2), | |
| torch.nn.Dropout(energy_embed_dropout)) | |
| # define length regulator | |
| self.length_regulator = LengthRegulator() | |
| self.decoder = Conformer(idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, | |
| dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, | |
| attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, | |
| concat_after=decoder_concat_after, positionwise_conv_kernel_size=positionwise_conv_kernel_size, | |
| macaron_style=use_macaron_style_in_conformer, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size) | |
| # define final projection | |
| self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) | |
| # define postnet | |
| self.postnet = PostNet(idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, | |
| dropout_rate=postnet_dropout_rate) | |
| # initialize parameters | |
| self._reset_parameters(init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha) | |
| # define criterions | |
| self.criterion = FastSpeech2Loss(use_masking=use_masking, use_weighted_masking=use_weighted_masking) | |
| self.dtw_criterion = SoftDTW(use_cuda=True, gamma=0.1) | |
| def forward(self, | |
| text_tensors, | |
| text_lengths, | |
| gold_speech, | |
| speech_lengths, | |
| gold_durations, | |
| gold_pitch, | |
| gold_energy, | |
| utterance_embedding, | |
| return_mels=False, | |
| lang_ids=None): | |
| """ | |
| Calculate forward propagation. | |
| Args: | |
| return_mels: whether to return the predicted spectrogram | |
| text_tensors (LongTensor): Batch of padded text vectors (B, Tmax). | |
| text_lengths (LongTensor): Batch of lengths of each input (B,). | |
| gold_speech (Tensor): Batch of padded target features (B, Lmax, odim). | |
| speech_lengths (LongTensor): Batch of the lengths of each target (B,). | |
| gold_durations (LongTensor): Batch of padded durations (B, Tmax + 1). | |
| gold_pitch (Tensor): Batch of padded token-averaged pitch (B, Tmax + 1, 1). | |
| gold_energy (Tensor): Batch of padded token-averaged energy (B, Tmax + 1, 1). | |
| Returns: | |
| Tensor: Loss scalar value. | |
| Dict: Statistics to be monitored. | |
| Tensor: Weight value. | |
| """ | |
| # Texts include EOS token from the teacher model already in this version | |
| # forward propagation | |
| before_outs, after_outs, d_outs, p_outs, e_outs = self._forward(text_tensors, text_lengths, gold_speech, speech_lengths, | |
| gold_durations, gold_pitch, gold_energy, utterance_embedding=utterance_embedding, | |
| is_inference=False, lang_ids=lang_ids) | |
| # modify mod part of groundtruth (speaking pace) | |
| if self.reduction_factor > 1: | |
| speech_lengths = speech_lengths.new([olen - olen % self.reduction_factor for olen in speech_lengths]) | |
| # calculate loss | |
| l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion(after_outs=after_outs, before_outs=before_outs, d_outs=d_outs, p_outs=p_outs, | |
| e_outs=e_outs, ys=gold_speech, ds=gold_durations, ps=gold_pitch, es=gold_energy, | |
| ilens=text_lengths, olens=speech_lengths) | |
| loss = l1_loss + duration_loss + pitch_loss + energy_loss | |
| if self.use_dtw_loss: | |
| # print("Regular Loss: {}".format(loss)) | |
| dtw_loss = self.dtw_criterion(after_outs, gold_speech).mean() / 2000.0 # division to balance orders of magnitude | |
| # print("DTW Loss: {}".format(dtw_loss)) | |
| loss = loss + dtw_loss | |
| if return_mels: | |
| return loss, after_outs | |
| return loss | |
| def _forward(self, text_tensors, text_lens, gold_speech=None, speech_lens=None, | |
| gold_durations=None, gold_pitch=None, gold_energy=None, | |
| is_inference=False, alpha=1.0, utterance_embedding=None, lang_ids=None): | |
| if not self.multilingual_model: | |
| lang_ids = None | |
| if not self.multispeaker_model: | |
| utterance_embedding = None | |
| # forward encoder | |
| text_masks = self._source_mask(text_lens) | |
| encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) # (B, Tmax, adim) | |
| # forward duration predictor and variance predictors | |
| d_masks = make_pad_mask(text_lens, device=text_lens.device) | |
| if self.stop_gradient_from_pitch_predictor: | |
| pitch_predictions = self.pitch_predictor(encoded_texts.detach(), d_masks.unsqueeze(-1)) | |
| else: | |
| pitch_predictions = self.pitch_predictor(encoded_texts, d_masks.unsqueeze(-1)) | |
| if self.stop_gradient_from_energy_predictor: | |
| energy_predictions = self.energy_predictor(encoded_texts.detach(), d_masks.unsqueeze(-1)) | |
| else: | |
| energy_predictions = self.energy_predictor(encoded_texts, d_masks.unsqueeze(-1)) | |
| if is_inference: | |
| d_outs = self.duration_predictor.inference(encoded_texts, d_masks) # (B, Tmax) | |
| # use prediction in inference | |
| p_embs = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2) | |
| e_embs = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2) | |
| encoded_texts = encoded_texts + e_embs + p_embs | |
| encoded_texts = self.length_regulator(encoded_texts, d_outs, alpha) # (B, Lmax, adim) | |
| else: | |
| d_outs = self.duration_predictor(encoded_texts, d_masks) | |
| # use groundtruth in training | |
| p_embs = self.pitch_embed(gold_pitch.transpose(1, 2)).transpose(1, 2) | |
| e_embs = self.energy_embed(gold_energy.transpose(1, 2)).transpose(1, 2) | |
| encoded_texts = encoded_texts + e_embs + p_embs | |
| encoded_texts = self.length_regulator(encoded_texts, gold_durations) # (B, Lmax, adim) | |
| # forward decoder | |
| if speech_lens is not None and not is_inference: | |
| if self.reduction_factor > 1: | |
| olens_in = speech_lens.new([olen // self.reduction_factor for olen in speech_lens]) | |
| else: | |
| olens_in = speech_lens | |
| h_masks = self._source_mask(olens_in) | |
| else: | |
| h_masks = None | |
| zs, _ = self.decoder(encoded_texts, h_masks) # (B, Lmax, adim) | |
| before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim) | |
| # postnet -> (B, Lmax//r * r, odim) | |
| after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2) | |
| return before_outs, after_outs, d_outs, pitch_predictions, energy_predictions | |
| def batch_inference(self, texts, text_lens, utt_emb): | |
| _, after_outs, d_outs, _, _ = self._forward(texts, | |
| text_lens, | |
| None, | |
| is_inference=True, | |
| alpha=1.0) | |
| return after_outs, d_outs | |
| def inference(self, | |
| text, | |
| speech=None, | |
| durations=None, | |
| pitch=None, | |
| energy=None, | |
| alpha=1.0, | |
| use_teacher_forcing=False, | |
| utterance_embedding=None, | |
| return_duration_pitch_energy=False, | |
| lang_id=None): | |
| """ | |
| Generate the sequence of features given the sequences of characters. | |
| Args: | |
| text (LongTensor): Input sequence of characters (T,). | |
| speech (Tensor, optional): Feature sequence to extract style (N, idim). | |
| durations (LongTensor, optional): Groundtruth of duration (T + 1,). | |
| pitch (Tensor, optional): Groundtruth of token-averaged pitch (T + 1, 1). | |
| energy (Tensor, optional): Groundtruth of token-averaged energy (T + 1, 1). | |
| alpha (float, optional): Alpha to control the speed. | |
| use_teacher_forcing (bool, optional): Whether to use teacher forcing. | |
| If true, groundtruth of duration, pitch and energy will be used. | |
| return_duration_pitch_energy: whether to return the list of predicted durations for nicer plotting | |
| Returns: | |
| Tensor: Output sequence of features (L, odim). | |
| """ | |
| self.eval() | |
| x, y = text, speech | |
| d, p, e = durations, pitch, energy | |
| # setup batch axis | |
| ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) | |
| xs, ys = x.unsqueeze(0), None | |
| if y is not None: | |
| ys = y.unsqueeze(0) | |
| if lang_id is not None: | |
| lang_id = lang_id.unsqueeze(0) | |
| if use_teacher_forcing: | |
| # use groundtruth of duration, pitch, and energy | |
| ds, ps, es = d.unsqueeze(0), p.unsqueeze(0), e.unsqueeze(0) | |
| before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(xs, | |
| ilens, | |
| ys, | |
| gold_durations=ds, | |
| gold_pitch=ps, | |
| gold_energy=es, | |
| utterance_embedding=utterance_embedding.unsqueeze(0), | |
| lang_ids=lang_id) # (1, L, odim) | |
| else: | |
| before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(xs, | |
| ilens, | |
| ys, | |
| is_inference=True, | |
| alpha=alpha, | |
| utterance_embedding=utterance_embedding.unsqueeze(0), | |
| lang_ids=lang_id) # (1, L, odim) | |
| self.train() | |
| if return_duration_pitch_energy: | |
| return after_outs[0], d_outs[0], pitch_predictions[0], energy_predictions[0] | |
| return after_outs[0] | |
| def _source_mask(self, ilens): | |
| """ | |
| Make masks for self-attention. | |
| Args: | |
| ilens (LongTensor): Batch of lengths (B,). | |
| Returns: | |
| Tensor: Mask tensor for self-attention. | |
| """ | |
| x_masks = make_non_pad_mask(ilens, device=ilens.device) | |
| return x_masks.unsqueeze(-2) | |
| def _reset_parameters(self, init_type, init_enc_alpha, init_dec_alpha): | |
| # initialize parameters | |
| if init_type != "pytorch": | |
| initialize(self, init_type) | |