| from modules.fastspeech.tts_modules import FastspeechDecoder
|
|
|
|
|
|
|
| import torch
|
| from torch.nn import functional as F
|
| import torch.nn as nn
|
| import math
|
| from utils.hparams import hparams
|
| from .diffusion import Mish
|
| Linear = nn.Linear
|
|
|
|
|
| class SinusoidalPosEmb(nn.Module):
|
| def __init__(self, dim):
|
| super().__init__()
|
| self.dim = dim
|
|
|
| def forward(self, x):
|
| device = x.device
|
| half_dim = self.dim // 2
|
| emb = math.log(10000) / (half_dim - 1)
|
| emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
| emb = x[:, None] * emb[None, :]
|
| emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| return emb
|
|
|
|
|
| def Conv1d(*args, **kwargs):
|
| layer = nn.Conv1d(*args, **kwargs)
|
| nn.init.kaiming_normal_(layer.weight)
|
| return layer
|
|
|
|
|
| class FFT(FastspeechDecoder):
|
| def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None):
|
| super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
|
| dim = hparams['residual_channels']
|
| self.input_projection = Conv1d(hparams['audio_num_mel_bins'], dim, 1)
|
| self.diffusion_embedding = SinusoidalPosEmb(dim)
|
| self.mlp = nn.Sequential(
|
| nn.Linear(dim, dim * 4),
|
| Mish(),
|
| nn.Linear(dim * 4, dim)
|
| )
|
| self.get_mel_out = Linear(hparams['hidden_size'], 80, bias=True)
|
| self.get_decode_inp = Linear(hparams['hidden_size'] + dim + dim,
|
| hparams['hidden_size'])
|
|
|
| def forward(self, spec, diffusion_step, cond, padding_mask=None, attn_mask=None, return_hiddens=False):
|
| """
|
| :param spec: [B, 1, 80, T]
|
| :param diffusion_step: [B, 1]
|
| :param cond: [B, M, T]
|
| :return:
|
| """
|
| x = spec[:, 0]
|
| x = self.input_projection(x).permute([0, 2, 1])
|
| diffusion_step = self.diffusion_embedding(diffusion_step)
|
| diffusion_step = self.mlp(diffusion_step)
|
| cond = cond.permute([0, 2, 1])
|
|
|
| seq_len = cond.shape[1]
|
| time_embed = diffusion_step[:, None, :]
|
| time_embed = time_embed.repeat([1, seq_len, 1])
|
|
|
| decoder_inp = torch.cat([x, cond, time_embed], dim=-1)
|
| decoder_inp = self.get_decode_inp(decoder_inp)
|
| x = decoder_inp
|
|
|
| '''
|
| Required x: [B, T, C]
|
| :return: [B, T, C] or [L, B, T, C]
|
| '''
|
| padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
|
| nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None]
|
| if self.use_pos_embed:
|
| positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
|
| x = x + positions
|
| x = F.dropout(x, p=self.dropout, training=self.training)
|
|
|
| x = x.transpose(0, 1) * nonpadding_mask_TB
|
| hiddens = []
|
| for layer in self.layers:
|
| x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
|
| hiddens.append(x)
|
| if self.use_last_norm:
|
| x = self.layer_norm(x) * nonpadding_mask_TB
|
| if return_hiddens:
|
| x = torch.stack(hiddens, 0)
|
| x = x.transpose(1, 2)
|
| else:
|
| x = x.transpose(0, 1)
|
|
|
| x = self.get_mel_out(x).permute([0, 2, 1])
|
| return x[:, None, :, :] |