Spaces:
Runtime error
Runtime error
| from abc import ABC | |
| import torch | |
| from Layers.Conformer import Conformer | |
| from Layers.DurationPredictor import DurationPredictor | |
| from Layers.LengthRegulator import LengthRegulator | |
| from Layers.PostNet import PostNet | |
| from Layers.VariancePredictor import VariancePredictor | |
| from Utility.utils import make_non_pad_mask | |
| from Utility.utils import make_pad_mask | |
| class FastSpeech2(torch.nn.Module, ABC): | |
| def __init__(self, # network structure related | |
| weights, | |
| idim=66, | |
| odim=80, | |
| adim=384, | |
| aheads=4, | |
| elayers=6, | |
| eunits=1536, | |
| dlayers=6, | |
| dunits=1536, | |
| postnet_layers=5, | |
| postnet_chans=256, | |
| postnet_filts=5, | |
| positionwise_conv_kernel_size=1, | |
| use_scaled_pos_enc=True, | |
| use_batch_norm=True, | |
| encoder_normalize_before=True, | |
| decoder_normalize_before=True, | |
| encoder_concat_after=False, | |
| decoder_concat_after=False, | |
| reduction_factor=1, | |
| # encoder / decoder | |
| use_macaron_style_in_conformer=True, | |
| use_cnn_in_conformer=True, | |
| conformer_enc_kernel_size=7, | |
| conformer_dec_kernel_size=31, | |
| # duration predictor | |
| duration_predictor_layers=2, | |
| duration_predictor_chans=256, | |
| duration_predictor_kernel_size=3, | |
| # energy predictor | |
| energy_predictor_layers=2, | |
| energy_predictor_chans=256, | |
| energy_predictor_kernel_size=3, | |
| energy_predictor_dropout=0.5, | |
| energy_embed_kernel_size=1, | |
| energy_embed_dropout=0.0, | |
| stop_gradient_from_energy_predictor=True, | |
| # pitch predictor | |
| pitch_predictor_layers=5, | |
| pitch_predictor_chans=256, | |
| pitch_predictor_kernel_size=5, | |
| pitch_predictor_dropout=0.5, | |
| pitch_embed_kernel_size=1, | |
| pitch_embed_dropout=0.0, | |
| stop_gradient_from_pitch_predictor=True, | |
| # training related | |
| transformer_enc_dropout_rate=0.2, | |
| transformer_enc_positional_dropout_rate=0.2, | |
| transformer_enc_attn_dropout_rate=0.2, | |
| transformer_dec_dropout_rate=0.2, | |
| transformer_dec_positional_dropout_rate=0.2, | |
| transformer_dec_attn_dropout_rate=0.2, | |
| duration_predictor_dropout_rate=0.2, | |
| postnet_dropout_rate=0.5, | |
| # additional features | |
| utt_embed_dim=704, | |
| connect_utt_emb_at_encoder_out=True, | |
| lang_embs=100): | |
| super().__init__() | |
| self.idim = idim | |
| self.odim = odim | |
| self.reduction_factor = reduction_factor | |
| self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor | |
| self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor | |
| self.use_scaled_pos_enc = use_scaled_pos_enc | |
| embed = torch.nn.Sequential(torch.nn.Linear(idim, 100), | |
| torch.nn.Tanh(), | |
| torch.nn.Linear(100, adim)) | |
| self.encoder = Conformer(idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, | |
| input_layer=embed, dropout_rate=transformer_enc_dropout_rate, | |
| positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, | |
| normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, | |
| positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, | |
| use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=False, | |
| utt_embed=utt_embed_dim, connect_utt_emb_at_encoder_out=connect_utt_emb_at_encoder_out, lang_embs=lang_embs) | |
| self.duration_predictor = DurationPredictor(idim=adim, n_layers=duration_predictor_layers, | |
| n_chans=duration_predictor_chans, | |
| kernel_size=duration_predictor_kernel_size, | |
| dropout_rate=duration_predictor_dropout_rate, ) | |
| self.pitch_predictor = VariancePredictor(idim=adim, n_layers=pitch_predictor_layers, | |
| n_chans=pitch_predictor_chans, | |
| kernel_size=pitch_predictor_kernel_size, | |
| dropout_rate=pitch_predictor_dropout) | |
| self.pitch_embed = torch.nn.Sequential(torch.nn.Conv1d(in_channels=1, out_channels=adim, | |
| kernel_size=pitch_embed_kernel_size, | |
| padding=(pitch_embed_kernel_size - 1) // 2), | |
| torch.nn.Dropout(pitch_embed_dropout)) | |
| self.energy_predictor = VariancePredictor(idim=adim, n_layers=energy_predictor_layers, | |
| n_chans=energy_predictor_chans, | |
| kernel_size=energy_predictor_kernel_size, | |
| dropout_rate=energy_predictor_dropout) | |
| self.energy_embed = torch.nn.Sequential(torch.nn.Conv1d(in_channels=1, out_channels=adim, | |
| kernel_size=energy_embed_kernel_size, | |
| padding=(energy_embed_kernel_size - 1) // 2), | |
| torch.nn.Dropout(energy_embed_dropout)) | |
| self.length_regulator = LengthRegulator() | |
| self.decoder = Conformer(idim=0, | |
| attention_dim=adim, | |
| attention_heads=aheads, | |
| linear_units=dunits, | |
| num_blocks=dlayers, | |
| input_layer=None, | |
| dropout_rate=transformer_dec_dropout_rate, | |
| positional_dropout_rate=transformer_dec_positional_dropout_rate, | |
| attention_dropout_rate=transformer_dec_attn_dropout_rate, | |
| normalize_before=decoder_normalize_before, | |
| concat_after=decoder_concat_after, | |
| positionwise_conv_kernel_size=positionwise_conv_kernel_size, | |
| macaron_style=use_macaron_style_in_conformer, | |
| use_cnn_module=use_cnn_in_conformer, | |
| cnn_module_kernel=conformer_dec_kernel_size) | |
| self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) | |
| self.postnet = PostNet(idim=idim, | |
| odim=odim, | |
| n_layers=postnet_layers, | |
| n_chans=postnet_chans, | |
| n_filts=postnet_filts, | |
| use_batch_norm=use_batch_norm, | |
| dropout_rate=postnet_dropout_rate) | |
| self.load_state_dict(weights) | |
| def _forward(self, text_tensors, text_lens, gold_speech=None, speech_lens=None, | |
| gold_durations=None, gold_pitch=None, gold_energy=None, | |
| is_inference=False, alpha=1.0, utterance_embedding=None, lang_ids=None): | |
| # forward encoder | |
| text_masks = self._source_mask(text_lens) | |
| encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) # (B, Tmax, adim) | |
| # forward duration predictor and variance predictors | |
| duration_masks = make_pad_mask(text_lens, device=text_lens.device) | |
| if self.stop_gradient_from_pitch_predictor: | |
| pitch_predictions = self.pitch_predictor(encoded_texts.detach(), duration_masks.unsqueeze(-1)) | |
| else: | |
| pitch_predictions = self.pitch_predictor(encoded_texts, duration_masks.unsqueeze(-1)) | |
| if self.stop_gradient_from_energy_predictor: | |
| energy_predictions = self.energy_predictor(encoded_texts.detach(), duration_masks.unsqueeze(-1)) | |
| else: | |
| energy_predictions = self.energy_predictor(encoded_texts, duration_masks.unsqueeze(-1)) | |
| if is_inference: | |
| if gold_durations is not None: | |
| duration_predictions = gold_durations | |
| else: | |
| duration_predictions = self.duration_predictor.inference(encoded_texts, duration_masks) | |
| if gold_pitch is not None: | |
| pitch_predictions = gold_pitch | |
| if gold_energy is not None: | |
| energy_predictions = gold_energy | |
| pitch_embeddings = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2) | |
| energy_embeddings = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2) | |
| encoded_texts = encoded_texts + energy_embeddings + pitch_embeddings | |
| encoded_texts = self.length_regulator(encoded_texts, duration_predictions, alpha) | |
| else: | |
| duration_predictions = self.duration_predictor(encoded_texts, duration_masks) | |
| # use groundtruth in training | |
| pitch_embeddings = self.pitch_embed(gold_pitch.transpose(1, 2)).transpose(1, 2) | |
| energy_embeddings = self.energy_embed(gold_energy.transpose(1, 2)).transpose(1, 2) | |
| encoded_texts = encoded_texts + energy_embeddings + pitch_embeddings | |
| encoded_texts = self.length_regulator(encoded_texts, gold_durations) # (B, Lmax, adim) | |
| # forward decoder | |
| if speech_lens is not None and not is_inference: | |
| if self.reduction_factor > 1: | |
| olens_in = speech_lens.new([olen // self.reduction_factor for olen in speech_lens]) | |
| else: | |
| olens_in = speech_lens | |
| h_masks = self._source_mask(olens_in) | |
| else: | |
| h_masks = None | |
| zs, _ = self.decoder(encoded_texts, h_masks) # (B, Lmax, adim) | |
| before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim) | |
| # postnet -> (B, Lmax//r * r, odim) | |
| after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2) | |
| return before_outs, after_outs, duration_predictions, pitch_predictions, energy_predictions | |
| def forward(self, | |
| text, | |
| speech=None, | |
| durations=None, | |
| pitch=None, | |
| energy=None, | |
| utterance_embedding=None, | |
| return_duration_pitch_energy=False, | |
| lang_id=None): | |
| """ | |
| Generate the sequence of features given the sequences of characters. | |
| Args: | |
| text: Input sequence of characters | |
| speech: Feature sequence to extract style | |
| durations: Groundtruth of duration | |
| pitch: Groundtruth of token-averaged pitch | |
| energy: Groundtruth of token-averaged energy | |
| return_duration_pitch_energy: whether to return the list of predicted durations for nicer plotting | |
| utterance_embedding: embedding of utterance wide parameters | |
| Returns: | |
| Mel Spectrogram | |
| """ | |
| self.eval() | |
| # setup batch axis | |
| ilens = torch.tensor([text.shape[0]], dtype=torch.long, device=text.device) | |
| if speech is not None: | |
| gold_speech = speech.unsqueeze(0) | |
| else: | |
| gold_speech = None | |
| if durations is not None: | |
| durations = durations.unsqueeze(0) | |
| if pitch is not None: | |
| pitch = pitch.unsqueeze(0) | |
| if energy is not None: | |
| energy = energy.unsqueeze(0) | |
| if lang_id is not None: | |
| lang_id = lang_id.unsqueeze(0) | |
| before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(text.unsqueeze(0), | |
| ilens, | |
| gold_speech=gold_speech, | |
| gold_durations=durations, | |
| is_inference=True, | |
| gold_pitch=pitch, | |
| gold_energy=energy, | |
| utterance_embedding=utterance_embedding.unsqueeze(0), | |
| lang_ids=lang_id) | |
| self.train() | |
| if return_duration_pitch_energy: | |
| return after_outs[0], d_outs[0], pitch_predictions[0], energy_predictions[0] | |
| return after_outs[0] | |
| def _source_mask(self, ilens): | |
| x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) | |
| return x_masks.unsqueeze(-2) | |