| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from transformer.Models import Encoder, Decoder |
| from transformer.Layers import PostNet |
| from modules import VarianceAdaptor |
| from utils import get_mask_from_lengths |
| import hparams as hp |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| class FastSpeech2(nn.Module): |
| """ FastSpeech2 """ |
|
|
| def __init__(self, use_postnet=True): |
| super(FastSpeech2, self).__init__() |
|
|
| self.encoder = Encoder() |
| self.variance_adaptor = VarianceAdaptor() |
|
|
| self.decoder = Decoder() |
| self.mel_linear = nn.Linear(hp.decoder_hidden, hp.n_mel_channels) |
| |
| self.use_postnet = use_postnet |
| if self.use_postnet: |
| self.postnet = PostNet() |
|
|
| def forward(self, src_seq, src_len, mel_len=None, d_target=None, p_target=None, e_target=None, max_src_len=None, max_mel_len=None): |
| src_mask = get_mask_from_lengths(src_len, max_src_len) |
| mel_mask = get_mask_from_lengths(mel_len, max_mel_len) if mel_len is not None else None |
| |
| encoder_output = self.encoder(src_seq, src_mask) |
| if d_target is not None: |
| variance_adaptor_output, d_prediction, p_prediction, e_prediction, _, _ = self.variance_adaptor( |
| encoder_output, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len) |
| else: |
| variance_adaptor_output, d_prediction, p_prediction, e_prediction, mel_len, mel_mask = self.variance_adaptor( |
| encoder_output, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len) |
| |
| decoder_output = self.decoder(variance_adaptor_output, mel_mask) |
| mel_output = self.mel_linear(decoder_output) |
| |
| if self.use_postnet: |
| mel_output_postnet = self.postnet(mel_output) + mel_output |
| else: |
| mel_output_postnet = mel_output |
|
|
| return mel_output, mel_output_postnet, d_prediction, p_prediction, e_prediction, src_mask, mel_mask, mel_len |
|
|
|
|
| if __name__ == "__main__": |
| |
| model = FastSpeech2(use_postnet=False) |
| print(model) |
| print(sum(param.numel() for param in model.parameters())) |
|
|