Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from utils import filter_bank_mean | |
| from fast_decoders import DecodeFunc_Sp | |
| from model_sea import Encoder_2 as Encoder_Code_2 | |
| from override_decoder import OnmtDecoder_1 as OnmtDecoder | |
| from onmt_modules.misc import sequence_mask | |
| from onmt_modules.embeddings import PositionalEncoding | |
| from onmt_modules.encoder_transformer import TransformerEncoder as OnmtEncoder | |
| class Prenet(nn.Module): | |
| def __init__(self, dim_input, dim_output, dropout=0.1): | |
| super().__init__() | |
| mlp = nn.Linear(dim_input, dim_output, bias=True) | |
| pe = PositionalEncoding(dropout, dim_output, 1600) | |
| self.make_prenet = nn.Sequential() | |
| self.make_prenet.add_module('mlp', mlp) | |
| self.make_prenet.add_module('pe', pe) | |
| self.word_padding_idx = 1 | |
| def forward(self, source, step=None): | |
| for i, module in enumerate(self.make_prenet._modules.values()): | |
| if i == len(self.make_prenet._modules.values()) - 1: | |
| source = module(source, step=step) | |
| else: | |
| source = module(source) | |
| return source | |
| class Decoder_Sp(nn.Module): | |
| """ | |
| Speech Decoder | |
| """ | |
| def __init__(self, hparams): | |
| super().__init__() | |
| self.dim_freq = hparams.dim_freq | |
| self.max_decoder_steps = hparams.dec_steps_sp | |
| self.gate_threshold = hparams.gate_threshold | |
| prenet = Prenet(hparams.dim_freq, hparams.dec_rnn_size) | |
| self.decoder = OnmtDecoder.from_opt(hparams, prenet) | |
| self.postnet = nn.Linear(hparams.dec_rnn_size, | |
| hparams.dim_freq+1, bias=True) | |
| def forward(self, tgt, tgt_lengths, memory_bank, memory_lengths): | |
| dec_outs, attns = self.decoder(tgt, memory_bank, step=None, | |
| memory_lengths=memory_lengths, | |
| tgt_lengths=tgt_lengths) | |
| spect_gate = self.postnet(dec_outs) | |
| spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1] | |
| return spect, gate | |
| class Encoder_Tx_Spk(nn.Module): | |
| """ | |
| Text Encoder | |
| """ | |
| def __init__(self, hparams): | |
| super().__init__() | |
| prenet = Prenet(hparams.dim_code+hparams.dim_spk, | |
| hparams.enc_rnn_size) | |
| self.encoder = OnmtEncoder.from_opt(hparams, prenet) | |
| def forward(self, src, src_lengths, spk_emb): | |
| spk_emb = spk_emb.unsqueeze(0).expand(src.size(0),-1,-1) | |
| src_spk = torch.cat((src, spk_emb), dim=-1) | |
| enc_states, memory_bank, src_lengths = self.encoder(src_spk, src_lengths) | |
| return enc_states, memory_bank, src_lengths | |
| class Decoder_Tx(nn.Module): | |
| """ | |
| Text Decoder with stop | |
| and num_rep prediction | |
| """ | |
| def __init__(self, hparams): | |
| super().__init__() | |
| self.dim_code = hparams.dim_code | |
| self.max_decoder_steps = hparams.dec_steps_tx | |
| self.gate_threshold = hparams.gate_threshold | |
| self.dim_rep = hparams.dim_rep | |
| prenet = Prenet(hparams.dim_code, hparams.dec_rnn_size) | |
| self.decoder = OnmtDecoder.from_opt(hparams, prenet) | |
| self.postnet_1 = nn.Linear(hparams.dec_rnn_size, | |
| hparams.dim_code+1, bias=True) | |
| self.postnet_2 = nn.Linear(hparams.dec_rnn_size, | |
| self.dim_rep, bias=True) | |
| def forward(self, tgt, tgt_lengths, memory_bank, memory_lengths): | |
| dec_outs, attns = self.decoder(tgt, memory_bank, step=None, | |
| memory_lengths=memory_lengths, | |
| tgt_lengths=tgt_lengths) | |
| gate_text = self.postnet_1(dec_outs) | |
| rep = self.postnet_2(dec_outs) | |
| gate, text = gate_text[:, :, :1], gate_text[:, :, 1:] | |
| return text, gate, rep | |
| class Generator_1(nn.Module): | |
| ''' | |
| sync stage 1 | |
| ''' | |
| def __init__(self, hparams): | |
| super().__init__() | |
| self.encoder_cd = Encoder_Code_2(hparams) | |
| self.encoder_tx = Encoder_Tx_Spk(hparams) | |
| self.decoder_sp = Decoder_Sp(hparams) | |
| self.encoder_spk = nn.Linear(hparams.dim_spk, | |
| hparams.enc_rnn_size, bias=True) | |
| self.fast_dec_sp = DecodeFunc_Sp(hparams, 'Sp') | |
| def pad_sequences_rnn(self, cd_short, num_rep, len_long): | |
| B, L, C = cd_short.size() | |
| out_tensor = torch.zeros((B, len_long.max(), C), device=cd_short.device) | |
| ''' | |
| len_long = len_spect + 1 | |
| ''' | |
| for i in range(B): | |
| code_sync = cd_short[i].repeat_interleave(num_rep[i], dim=0) | |
| out_tensor[i, :len_long[i]-1, :] = code_sync | |
| return out_tensor | |
| def forward(self, cep_in, mask_long, codes_mask, num_rep, len_short, | |
| tgt_spect, len_spect, | |
| spk_emb): | |
| cd_long = self.encoder_cd(cep_in, mask_long) | |
| fb = filter_bank_mean(num_rep, codes_mask, cd_long.size(1)) | |
| cd_short = torch.bmm(fb.detach(), cd_long) | |
| cd_short_sync = self.pad_sequences_rnn(cd_short, num_rep, len_spect) | |
| spk_emb_1 = self.encoder_spk(spk_emb) | |
| # text to speech | |
| _, memory_tx, _ = self.encoder_tx(cd_short_sync.transpose(1,0), len_spect, | |
| spk_emb) | |
| memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0) | |
| self.decoder_sp.decoder.init_state(memory_tx_spk, None, None) | |
| spect_out, gate_sp_out \ | |
| = self.decoder_sp(tgt_spect, len_spect, memory_tx_spk, len_spect+1) | |
| return spect_out, gate_sp_out | |
| def infer_onmt(self, cep_in, mask_long, | |
| len_spect, | |
| spk_emb): | |
| cd_long = self.encoder_cd(cep_in, mask_long) | |
| spk_emb_1 = self.encoder_spk(spk_emb) | |
| # text to speech | |
| _, memory_tx, _ = self.encoder_tx(cd_long.transpose(1,0), len_spect, | |
| spk_emb) | |
| memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0) | |
| self.decoder_sp.decoder.init_state(memory_tx_spk, None, None) | |
| spect_output, len_spect_out, stop_sp_output \ | |
| = self.fast_dec_sp.infer(None, memory_tx_spk, len_spect+1, | |
| self.decoder_sp.decoder, | |
| self.decoder_sp.postnet) | |
| return spect_output, len_spect_out | |
| class Generator_2(nn.Module): | |
| ''' | |
| async stage 2 | |
| ''' | |
| def __init__(self, hparams): | |
| super().__init__() | |
| self.encoder_cd = Encoder_Code_2(hparams) | |
| self.encoder_tx = Encoder_Tx_Spk(hparams) | |
| self.decoder_sp = Decoder_Sp(hparams) | |
| self.encoder_spk = nn.Linear(hparams.dim_spk, | |
| hparams.enc_rnn_size, bias=True) | |
| self.fast_dec_sp = DecodeFunc_Sp(hparams, 'Sp') | |
| def forward(self, cep_in, mask_long, codes_mask, num_rep, len_short, | |
| tgt_spect, len_spect, | |
| spk_emb): | |
| cd_long = self.encoder_cd(cep_in, mask_long) | |
| fb = filter_bank_mean(num_rep, codes_mask, cd_long.size(1)) | |
| cd_short = torch.bmm(fb.detach(), cd_long.detach()) | |
| spk_emb_1 = self.encoder_spk(spk_emb) | |
| # text to speech | |
| _, memory_tx, _ = self.encoder_tx(cd_short.transpose(1,0), len_short, | |
| spk_emb) | |
| memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0) | |
| self.decoder_sp.decoder.init_state(memory_tx_spk, None, None) | |
| spect_out, gate_sp_out \ | |
| = self.decoder_sp(tgt_spect, len_spect, memory_tx_spk, len_short+1) | |
| return spect_out, gate_sp_out | |
| def infer_onmt(self, cep_in, mask_long, len_spect, | |
| spk_emb): | |
| cd_long = self.encoder_cd(cep_in, mask_long) | |
| spk_emb_1 = self.encoder_spk(spk_emb) | |
| # text to speech | |
| _, memory_tx, _ = self.encoder_tx(cd_long.transpose(1,0), len_spect, | |
| spk_emb) | |
| memory_tx_spk = torch.cat((spk_emb_1.unsqueeze(0), memory_tx), dim=0) | |
| self.decoder_sp.decoder.init_state(memory_tx_spk, None, None) | |
| spect_output, len_spect_out, stop_sp_output \ | |
| = self.fast_dec_sp.infer(None, memory_tx_spk, len_spect+1, | |
| self.decoder_sp.decoder, | |
| self.decoder_sp.postnet) | |
| return spect_output, len_spect_out |