camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
from typing import Optional, Tuple
import numpy as np
import torch
from torch import Tensor, nn
from torch.cuda import amp
from torch.cuda.amp import autocast as autocast
from torch.nn import functional as F
from nemo.collections.tts.modules.submodules import ConvNorm, LinearNorm, MaskedInstanceNorm1d
from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths, sort_tensor, unsort_tensor
from nemo.collections.tts.parts.utils.splines import (
piecewise_linear_inverse_transform,
piecewise_linear_transform,
unbounded_piecewise_quadratic_transform,
)
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b):
t_act = torch.tanh(input_a)
s_act = torch.sigmoid(input_b)
acts = t_act * s_act
return acts
class ExponentialClass(torch.nn.Module):
def __init__(self):
super(ExponentialClass, self).__init__()
def forward(self, x):
return torch.exp(x)
class DenseLayer(nn.Module):
def __init__(self, in_dim=1024, sizes=[1024, 1024]):
super(DenseLayer, self).__init__()
in_sizes = [in_dim] + sizes[:-1]
self.layers = nn.ModuleList(
[LinearNorm(in_size, out_size, bias=True) for (in_size, out_size) in zip(in_sizes, sizes)]
)
def forward(self, x):
for linear in self.layers:
x = torch.tanh(linear(x))
return x
class BiLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1, lstm_norm_fn="spectral", max_batch_size=64):
super().__init__()
self.bilstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
if lstm_norm_fn is not None:
if 'spectral' in lstm_norm_fn:
print("Applying spectral norm to LSTM")
lstm_norm_fn_pntr = torch.nn.utils.spectral_norm
elif 'weight' in lstm_norm_fn:
print("Applying weight norm to LSTM")
lstm_norm_fn_pntr = torch.nn.utils.weight_norm
lstm_norm_fn_pntr(self.bilstm, 'weight_hh_l0')
lstm_norm_fn_pntr(self.bilstm, 'weight_hh_l0_reverse')
self.real_hidden_size: int = self.bilstm.proj_size if self.bilstm.proj_size > 0 else self.bilstm.hidden_size
self.bilstm.flatten_parameters()
def lstm_sorted(self, context: Tensor, lens: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tensor:
seq = nn.utils.rnn.pack_padded_sequence(context, lens.long().cpu(), batch_first=True, enforce_sorted=True)
ret, _ = self.bilstm(seq, hx)
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)[0]
def lstm(self, context: Tensor, lens: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tensor:
# To be ONNX-exportable, we need to sort here rather that while packing
context, lens, unsort_ids = sort_tensor(context, lens)
ret = self.lstm_sorted(context, lens, hx=hx)
return unsort_tensor(ret, unsort_ids)
def lstm_nocast(self, context: Tensor, lens: Tensor) -> Tensor:
dtype = context.dtype
# autocast guard is only needed for Torchscript to run in Triton
# (https://github.com/pytorch/pytorch/issues/89241)
with torch.cuda.amp.autocast(enabled=False):
# Calculate sizes and prepare views to our zero buffer to pass as hx
max_batch_size = context.shape[0]
context = context.to(dtype=torch.float32)
common_shape = (self.bilstm.num_layers * 2, max_batch_size)
hx = (
context.new_zeros(*common_shape, self.real_hidden_size),
context.new_zeros(*common_shape, self.bilstm.hidden_size),
)
return self.lstm(context, lens, hx=hx).to(dtype=dtype)
def forward(self, context: Tensor, lens: Tensor) -> Tensor:
self.bilstm.flatten_parameters()
if torch.jit.is_tracing():
return self.lstm_nocast(context, lens)
return self.lstm(context, lens)
class ConvLSTMLinear(nn.Module):
def __init__(
self,
in_dim=None,
out_dim=None,
n_layers=2,
n_channels=256,
kernel_size=3,
p_dropout=0.1,
use_partial_padding=False,
norm_fn=None,
):
super(ConvLSTMLinear, self).__init__()
self.bilstm = BiLSTM(n_channels, int(n_channels // 2), 1)
self.convolutions = nn.ModuleList()
if n_layers > 0:
self.dropout = nn.Dropout(p=p_dropout)
use_weight_norm = norm_fn is None
for i in range(n_layers):
conv_layer = ConvNorm(
in_dim if i == 0 else n_channels,
n_channels,
kernel_size=kernel_size,
stride=1,
padding=int((kernel_size - 1) / 2),
dilation=1,
w_init_gain='relu',
use_weight_norm=use_weight_norm,
use_partial_padding=use_partial_padding,
norm_fn=norm_fn,
)
if norm_fn is not None:
print("Applying {} norm to {}".format(norm_fn, conv_layer))
else:
print("Applying weight norm to {}".format(conv_layer))
self.convolutions.append(conv_layer)
self.dense = None
if out_dim is not None:
self.dense = nn.Linear(n_channels, out_dim)
def forward(self, context: Tensor, lens: Tensor) -> Tensor:
mask = get_mask_from_lengths(lens, context)
mask = mask.to(dtype=context.dtype).unsqueeze(1)
for conv in self.convolutions:
context = self.dropout(F.relu(conv(context, mask)))
# Apply Bidirectional LSTM
context = self.bilstm(context.transpose(1, 2), lens=lens)
if self.dense is not None:
context = self.dense(context).permute(0, 2, 1)
return context
def get_radtts_encoder(
encoder_n_convolutions=3, encoder_embedding_dim=512, encoder_kernel_size=5, norm_fn=MaskedInstanceNorm1d,
):
return ConvLSTMLinear(
in_dim=encoder_embedding_dim,
n_layers=encoder_n_convolutions,
n_channels=encoder_embedding_dim,
kernel_size=encoder_kernel_size,
p_dropout=0.5,
use_partial_padding=True,
norm_fn=norm_fn,
)
class Invertible1x1ConvLUS(torch.nn.Module):
def __init__(self, c):
super(Invertible1x1ConvLUS, self).__init__()
# Sample a random orthonormal matrix to initialize weights
W, _ = torch.linalg.qr(torch.FloatTensor(c, c).normal_())
# Ensure determinant is 1.0 not -1.0
if torch.det(W) < 0:
W[:, 0] = -1 * W[:, 0]
p, lower, upper = torch.lu_unpack(*torch.lu(W))
self.register_buffer('p', p)
# diagonals of lower will always be 1s anyway
lower = torch.tril(lower, -1)
lower_diag = torch.diag(torch.eye(c, c))
self.register_buffer('lower_diag', lower_diag)
self.lower = nn.Parameter(lower)
self.upper_diag = nn.Parameter(torch.diag(upper))
self.upper = nn.Parameter(torch.triu(upper, 1))
@amp.autocast(False)
def forward(self, z, inverse=False):
U = torch.triu(self.upper, 1) + torch.diag(self.upper_diag)
L = torch.tril(self.lower, -1) + torch.diag(self.lower_diag)
W = torch.mm(self.p, torch.mm(L, U))
if inverse:
if not hasattr(self, 'W_inverse'):
# inverse computation
W_inverse = W.float().inverse().to(dtype=z.dtype)
self.W_inverse = W_inverse[..., None]
z = F.conv1d(z, self.W_inverse.to(dtype=z.dtype), bias=None, stride=1, padding=0)
return z
else:
W = W[..., None]
z = F.conv1d(z, W, bias=None, stride=1, padding=0)
log_det_W = torch.sum(torch.log(torch.abs(self.upper_diag)))
return z, log_det_W
class Invertible1x1Conv(torch.nn.Module):
"""
The layer outputs both the convolution, and the log determinant
of its weight matrix. If inverse=True it does convolution with
inverse
"""
def __init__(self, c):
super(Invertible1x1Conv, self).__init__()
self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, bias=False)
# Sample a random orthonormal matrix to initialize weights
W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
# Ensure determinant is 1.0 not -1.0
if torch.det(W) < 0:
W[:, 0] = -1 * W[:, 0]
W = W.view(c, c, 1)
self.conv.weight.data = W
def forward(self, z, inverse=False):
# DO NOT apply n_of_groups, as it doesn't account for padded sequences
W = self.conv.weight.squeeze()
if inverse:
if not hasattr(self, 'W_inverse'):
# Inverse computation
W_inverse = W.float().inverse().to(dtype=z.dtype)
self.W_inverse = W_inverse[..., None]
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
return z
else:
# Forward computation
log_det_W = torch.logdet(W).clone()
z = self.conv(z)
return z, log_det_W
class SimpleConvNet(torch.nn.Module):
def __init__(
self,
n_mel_channels,
n_context_dim,
final_out_channels,
n_layers=2,
kernel_size=5,
with_dilation=True,
max_channels=1024,
zero_init=True,
use_partial_padding=True,
):
super(SimpleConvNet, self).__init__()
self.layers = torch.nn.ModuleList()
self.n_layers = n_layers
in_channels = n_mel_channels + n_context_dim
out_channels = -1
self.use_partial_padding = use_partial_padding
for i in range(n_layers):
dilation = 2 ** i if with_dilation else 1
padding = int((kernel_size * dilation - dilation) / 2)
out_channels = min(max_channels, in_channels * 2)
self.layers.append(
ConvNorm(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
dilation=dilation,
bias=True,
w_init_gain='relu',
use_partial_padding=use_partial_padding,
)
)
in_channels = out_channels
self.last_layer = torch.nn.Conv1d(out_channels, final_out_channels, kernel_size=1)
if zero_init:
self.last_layer.weight.data *= 0
self.last_layer.bias.data *= 0
def forward(self, z_w_context, seq_lens: Optional[Tensor] = None):
# seq_lens: tensor array of sequence sequence lengths
# output should be b x n_mel_channels x z_w_context.shape(2)
mask = get_mask_from_lengths(seq_lens, z_w_context).unsqueeze(1).to(dtype=z_w_context.dtype)
for i in range(self.n_layers):
z_w_context = self.layers[i](z_w_context, mask)
z_w_context = torch.relu(z_w_context)
z_w_context = self.last_layer(z_w_context)
return z_w_context
class WN(torch.nn.Module):
"""
Adapted from WN() module in WaveGlow with modififcations to variable names
"""
def __init__(
self,
n_in_channels,
n_context_dim,
n_layers,
n_channels,
kernel_size=5,
affine_activation='softplus',
use_partial_padding=True,
):
super(WN, self).__init__()
assert kernel_size % 2 == 1
assert n_channels % 2 == 0
self.n_layers = n_layers
self.n_channels = n_channels
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
start = torch.nn.Conv1d(n_in_channels + n_context_dim, n_channels, 1)
start = torch.nn.utils.weight_norm(start, name='weight')
self.start = start
self.softplus = torch.nn.Softplus()
self.affine_activation = affine_activation
self.use_partial_padding = use_partial_padding
# Initializing last layer to 0 makes the affine coupling layers
# do nothing at first. This helps with training stability
end = torch.nn.Conv1d(n_channels, 2 * n_in_channels, 1)
end.weight.data.zero_()
end.bias.data.zero_()
self.end = end
for i in range(n_layers):
dilation = 2 ** i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = ConvNorm(
n_channels,
n_channels,
kernel_size=kernel_size,
dilation=dilation,
padding=padding,
use_partial_padding=use_partial_padding,
use_weight_norm=True,
)
self.in_layers.append(in_layer)
res_skip_layer = nn.Conv1d(n_channels, n_channels, 1)
res_skip_layer = nn.utils.weight_norm(res_skip_layer)
self.res_skip_layers.append(res_skip_layer)
def forward(self, forward_input: Tuple[Tensor, Tensor], seq_lens: Tensor = None):
z, context = forward_input
z = torch.cat((z, context), 1) # append context to z as well
z = self.start(z)
output = torch.zeros_like(z)
mask = None
if self.use_partial_padding:
mask = get_mask_from_lengths(seq_lens).unsqueeze(1).float()
non_linearity = torch.relu
if self.affine_activation == 'softplus':
non_linearity = self.softplus
for i in range(self.n_layers):
z = non_linearity(self.in_layers[i](z, mask))
res_skip_acts = non_linearity(self.res_skip_layers[i](z))
output = output + res_skip_acts
output = self.end(output) # [B, dim, seq_len]
return output
# Affine Coupling Layers
class SplineTransformationLayerAR(torch.nn.Module):
def __init__(
self,
n_in_channels,
n_context_dim,
n_layers,
affine_model='simple_conv',
kernel_size=1,
scaling_fn='exp',
affine_activation='softplus',
n_channels=1024,
n_bins=8,
left=-6,
right=6,
bottom=-6,
top=6,
use_quadratic=False,
):
super(SplineTransformationLayerAR, self).__init__()
self.n_in_channels = n_in_channels # input dimensions
self.left = left
self.right = right
self.bottom = bottom
self.top = top
self.n_bins = n_bins
self.spline_fn = piecewise_linear_transform
self.inv_spline_fn = piecewise_linear_inverse_transform
self.use_quadratic = use_quadratic
if self.use_quadratic:
self.spline_fn = unbounded_piecewise_quadratic_transform
self.inv_spline_fn = unbounded_piecewise_quadratic_transform
self.n_bins = 2 * self.n_bins + 1
final_out_channels = self.n_in_channels * self.n_bins
# autoregressive flow, kernel size 1 and no dilation
self.param_predictor = SimpleConvNet(
n_context_dim,
0,
final_out_channels,
n_layers,
with_dilation=False,
kernel_size=1,
zero_init=True,
use_partial_padding=False,
)
# output is unnormalized bin weights
def normalize(self, z, inverse):
# normalize to [0, 1]
if inverse:
z = (z - self.bottom) / (self.top - self.bottom)
else:
z = (z - self.left) / (self.right - self.left)
return z
def denormalize(self, z, inverse):
if inverse:
z = z * (self.right - self.left) + self.left
else:
z = z * (self.top - self.bottom) + self.bottom
return z
def forward(self, z, context, inverse=False):
b_s, c_s, t_s = z.size(0), z.size(1), z.size(2)
z = self.normalize(z, inverse)
if z.min() < 0.0 or z.max() > 1.0:
print('spline z scaled beyond [0, 1]', z.min(), z.max())
z_reshaped = z.permute(0, 2, 1).reshape(b_s * t_s, -1)
affine_params = self.param_predictor(context)
q_tilde = affine_params.permute(0, 2, 1).reshape(b_s * t_s, c_s, -1)
with amp.autocast(enabled=False):
if self.use_quadratic:
w = q_tilde[:, :, : self.n_bins // 2]
v = q_tilde[:, :, self.n_bins // 2 :]
z_tformed, log_s = self.spline_fn(z_reshaped.float(), w.float(), v.float(), inverse=inverse)
else:
z_tformed, log_s = self.spline_fn(z_reshaped.float(), q_tilde.float())
z = z_tformed.reshape(b_s, t_s, -1).permute(0, 2, 1)
z = self.denormalize(z, inverse)
if inverse:
return z
log_s = log_s.reshape(b_s, t_s, -1)
log_s = log_s.permute(0, 2, 1)
log_s = log_s + c_s * (np.log(self.top - self.bottom) - np.log(self.right - self.left))
return z, log_s
class SplineTransformationLayer(torch.nn.Module):
def __init__(
self,
n_mel_channels,
n_context_dim,
n_layers,
with_dilation=True,
kernel_size=5,
scaling_fn='exp',
affine_activation='softplus',
n_channels=1024,
n_bins=8,
left=-4,
right=4,
bottom=-4,
top=4,
use_quadratic=False,
):
super(SplineTransformationLayer, self).__init__()
self.n_mel_channels = n_mel_channels # input dimensions
self.half_mel_channels = int(n_mel_channels / 2) # half, because we split
self.left = left
self.right = right
self.bottom = bottom
self.top = top
self.n_bins = n_bins
self.spline_fn = piecewise_linear_transform
self.inv_spline_fn = piecewise_linear_inverse_transform
self.use_quadratic = use_quadratic
if self.use_quadratic:
self.spline_fn = unbounded_piecewise_quadratic_transform
self.inv_spline_fn = unbounded_piecewise_quadratic_transform
self.n_bins = 2 * self.n_bins + 1
final_out_channels = self.half_mel_channels * self.n_bins
self.param_predictor = SimpleConvNet(
self.half_mel_channels,
n_context_dim,
final_out_channels,
n_layers,
with_dilation=with_dilation,
kernel_size=kernel_size,
zero_init=False,
)
# output is unnormalized bin weights
def forward(self, z, context, inverse=False, seq_lens=None):
b_s, c_s, t_s = z.size(0), z.size(1), z.size(2)
# condition on z_0, transform z_1
n_half = self.half_mel_channels
z_0, z_1 = z[:, :n_half], z[:, n_half:]
# normalize to [0,1]
if inverse:
z_1 = (z_1 - self.bottom) / (self.top - self.bottom)
else:
z_1 = (z_1 - self.left) / (self.right - self.left)
z_w_context = torch.cat((z_0, context), 1)
affine_params = self.param_predictor(z_w_context, seq_lens)
z_1_reshaped = z_1.permute(0, 2, 1).reshape(b_s * t_s, -1)
q_tilde = affine_params.permute(0, 2, 1).reshape(b_s * t_s, n_half, self.n_bins)
with autocast(enabled=False):
if self.use_quadratic:
w = q_tilde[:, :, : self.n_bins // 2]
v = q_tilde[:, :, self.n_bins // 2 :]
z_1_tformed, log_s = self.spline_fn(z_1_reshaped.float(), w.float(), v.float(), inverse=inverse)
if not inverse:
log_s = torch.sum(log_s, 1)
else:
if inverse:
z_1_tformed, _dc = self.inv_spline_fn(z_1_reshaped.float(), q_tilde.float(), False)
else:
z_1_tformed, log_s = self.spline_fn(z_1_reshaped.float(), q_tilde.float())
z_1 = z_1_tformed.reshape(b_s, t_s, -1).permute(0, 2, 1)
# undo [0, 1] normalization
if inverse:
z_1 = z_1 * (self.right - self.left) + self.left
z = torch.cat((z_0, z_1), dim=1)
return z
else: # training
z_1 = z_1 * (self.top - self.bottom) + self.bottom
z = torch.cat((z_0, z_1), dim=1)
log_s = log_s.reshape(b_s, t_s).unsqueeze(1) + n_half * (
np.log(self.top - self.bottom) - np.log(self.right - self.left)
)
return z, log_s
class AffineTransformationLayer(torch.nn.Module):
def __init__(
self,
n_mel_channels,
n_context_dim,
n_layers,
affine_model='simple_conv',
with_dilation=True,
kernel_size=5,
scaling_fn='exp',
affine_activation='softplus',
n_channels=1024,
use_partial_padding=False,
):
super(AffineTransformationLayer, self).__init__()
if affine_model not in ("wavenet", "simple_conv"):
raise Exception("{} affine model not supported".format(affine_model))
if isinstance(scaling_fn, list):
if not all([x in ("translate", "exp", "tanh", "sigmoid") for x in scaling_fn]):
raise Exception("{} scaling fn not supported".format(scaling_fn))
else:
if scaling_fn not in ("translate", "exp", "tanh", "sigmoid"):
raise Exception("{} scaling fn not supported".format(scaling_fn))
self.affine_model = affine_model
self.scaling_fn = scaling_fn
if affine_model == 'wavenet':
self.affine_param_predictor = WN(
int(n_mel_channels / 2),
n_context_dim,
n_layers=n_layers,
n_channels=n_channels,
affine_activation=affine_activation,
use_partial_padding=use_partial_padding,
)
elif affine_model == 'simple_conv':
self.affine_param_predictor = SimpleConvNet(
int(n_mel_channels / 2),
n_context_dim,
n_mel_channels,
n_layers,
with_dilation=with_dilation,
kernel_size=kernel_size,
use_partial_padding=use_partial_padding,
)
self.n_mel_channels = n_mel_channels
def get_scaling_and_logs(self, scale_unconstrained):
# (rvalle) re-write this
if self.scaling_fn == 'translate':
s = torch.exp(scale_unconstrained * 0)
log_s = scale_unconstrained * 0
elif self.scaling_fn == 'exp':
s = torch.exp(scale_unconstrained)
log_s = scale_unconstrained # log(exp
elif self.scaling_fn == 'tanh':
s = torch.tanh(scale_unconstrained) + 1 + 1e-6
log_s = torch.log(s)
elif self.scaling_fn == 'sigmoid':
s = torch.sigmoid(scale_unconstrained + 10) + 1e-6
log_s = torch.log(s)
elif isinstance(self.scaling_fn, list):
s_list, log_s_list = [], []
for i in range(scale_unconstrained.shape[1]):
scaling_i = self.scaling_fn[i]
if scaling_i == 'translate':
s_i = torch.exp(scale_unconstrained[:i] * 0)
log_s_i = scale_unconstrained[:, i] * 0
elif scaling_i == 'exp':
s_i = torch.exp(scale_unconstrained[:, i])
log_s_i = scale_unconstrained[:, i]
elif scaling_i == 'tanh':
s_i = torch.tanh(scale_unconstrained[:, i]) + 1 + 1e-6
log_s_i = torch.log(s_i)
elif scaling_i == 'sigmoid':
s_i = torch.sigmoid(scale_unconstrained[:, i]) + 1e-6
log_s_i = torch.log(s_i)
s_list.append(s_i[:, None])
log_s_list.append(log_s_i[:, None])
s = torch.cat(s_list, dim=1)
log_s = torch.cat(log_s_list, dim=1)
return s, log_s
def forward(self, z, context, inverse=False, seq_lens=None):
n_half = int(self.n_mel_channels / 2)
z_0, z_1 = z[:, :n_half], z[:, n_half:]
if self.affine_model == 'wavenet':
affine_params = self.affine_param_predictor((z_0, context), seq_lens=seq_lens)
elif self.affine_model == 'simple_conv':
z_w_context = torch.cat((z_0, context), 1)
affine_params = self.affine_param_predictor(z_w_context, seq_lens=seq_lens)
scale_unconstrained = affine_params[:, :n_half, :]
b = affine_params[:, n_half:, :]
s, log_s = self.get_scaling_and_logs(scale_unconstrained)
if inverse:
z_1 = (z_1 - b) / s
z = torch.cat((z_0, z_1), dim=1)
return z
else:
z_1 = s * z_1 + b
z = torch.cat((z_0, z_1), dim=1)
return z, log_s
class ConvAttention(torch.nn.Module):
def __init__(self, n_mel_channels=80, n_speaker_dim=128, n_text_channels=512, n_att_channels=80, temperature=1.0):
super(ConvAttention, self).__init__()
self.temperature = temperature
self.softmax = torch.nn.Softmax(dim=3)
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.query_proj = Invertible1x1ConvLUS(n_mel_channels)
self.key_proj = nn.Sequential(
ConvNorm(n_text_channels, n_text_channels * 2, kernel_size=3, bias=True, w_init_gain='relu'),
torch.nn.ReLU(),
ConvNorm(n_text_channels * 2, n_att_channels, kernel_size=1, bias=True),
)
self.query_proj = nn.Sequential(
ConvNorm(n_mel_channels, n_mel_channels * 2, kernel_size=3, bias=True, w_init_gain='relu'),
torch.nn.ReLU(),
ConvNorm(n_mel_channels * 2, n_mel_channels, kernel_size=1, bias=True),
torch.nn.ReLU(),
ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True),
)
def forward(self, queries, keys, query_lens, mask=None, key_lens=None, attn_prior=None):
"""Attention mechanism for radtts. Unlike in Flowtron, we have no
restrictions such as causality etc, since we only need this during
training.
Args:
queries (torch.tensor): B x C x T1 tensor (likely mel data)
keys (torch.tensor): B x C2 x T2 tensor (text data)
query_lens: lengths for sorting the queries in descending order
mask (torch.tensor): uint8 binary mask for variable length entries
(should be in the T2 domain)
Output:
attn (torch.tensor): B x 1 x T1 x T2 attention mask.
Final dim T2 should sum to 1
"""
temp = 0.0005
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
# Beware can only do this since query_dim = attn_dim = n_mel_channels
queries_enc = self.query_proj(queries)
# Gaussian Isotopic Attention
# B x n_attn_dims x T1 x T2
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2
# compute log-likelihood from gaussian
eps = 1e-8
attn = -temp * attn.sum(1, keepdim=True)
if attn_prior is not None:
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + eps)
attn_logprob = attn.clone()
if mask is not None:
attn.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), -float("inf"))
attn = self.softmax(attn) # softmax along T2
return attn, attn_logprob