| | from modules.fastspeech.tts_modules import FastspeechDecoder |
| | |
| | |
| | |
| | import torch |
| | from torch.nn import functional as F |
| | import torch.nn as nn |
| | import math |
| | from utils.hparams import hparams |
| | from .diffusion import Mish |
| | Linear = nn.Linear |
| |
|
| |
|
| | class SinusoidalPosEmb(nn.Module): |
| | def __init__(self, dim): |
| | super().__init__() |
| | self.dim = dim |
| |
|
| | def forward(self, x): |
| | device = x.device |
| | half_dim = self.dim // 2 |
| | emb = math.log(10000) / (half_dim - 1) |
| | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) |
| | emb = x[:, None] * emb[None, :] |
| | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) |
| | return emb |
| |
|
| |
|
| | def Conv1d(*args, **kwargs): |
| | layer = nn.Conv1d(*args, **kwargs) |
| | nn.init.kaiming_normal_(layer.weight) |
| | return layer |
| |
|
| |
|
| | class FFT(FastspeechDecoder): |
| | def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None): |
| | super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads) |
| | dim = hparams['residual_channels'] |
| | self.input_projection = Conv1d(hparams['audio_num_mel_bins'], dim, 1) |
| | self.diffusion_embedding = SinusoidalPosEmb(dim) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(dim, dim * 4), |
| | Mish(), |
| | nn.Linear(dim * 4, dim) |
| | ) |
| | self.get_mel_out = Linear(hparams['hidden_size'], 80, bias=True) |
| | self.get_decode_inp = Linear(hparams['hidden_size'] + dim + dim, |
| | hparams['hidden_size']) |
| |
|
| | def forward(self, spec, diffusion_step, cond, padding_mask=None, attn_mask=None, return_hiddens=False): |
| | """ |
| | :param spec: [B, 1, 80, T] |
| | :param diffusion_step: [B, 1] |
| | :param cond: [B, M, T] |
| | :return: |
| | """ |
| | x = spec[:, 0] |
| | x = self.input_projection(x).permute([0, 2, 1]) |
| | diffusion_step = self.diffusion_embedding(diffusion_step) |
| | diffusion_step = self.mlp(diffusion_step) |
| | cond = cond.permute([0, 2, 1]) |
| |
|
| | seq_len = cond.shape[1] |
| | time_embed = diffusion_step[:, None, :] |
| | time_embed = time_embed.repeat([1, seq_len, 1]) |
| |
|
| | decoder_inp = torch.cat([x, cond, time_embed], dim=-1) |
| | decoder_inp = self.get_decode_inp(decoder_inp) |
| | x = decoder_inp |
| |
|
| | ''' |
| | Required x: [B, T, C] |
| | :return: [B, T, C] or [L, B, T, C] |
| | ''' |
| | padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask |
| | nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] |
| | if self.use_pos_embed: |
| | positions = self.pos_embed_alpha * self.embed_positions(x[..., 0]) |
| | x = x + positions |
| | x = F.dropout(x, p=self.dropout, training=self.training) |
| | |
| | x = x.transpose(0, 1) * nonpadding_mask_TB |
| | hiddens = [] |
| | for layer in self.layers: |
| | x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB |
| | hiddens.append(x) |
| | if self.use_last_norm: |
| | x = self.layer_norm(x) * nonpadding_mask_TB |
| | if return_hiddens: |
| | x = torch.stack(hiddens, 0) |
| | x = x.transpose(1, 2) |
| | else: |
| | x = x.transpose(0, 1) |
| |
|
| | x = self.get_mel_out(x).permute([0, 2, 1]) |
| | return x[:, None, :, :] |