|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from torch import Tensor, nn |
|
|
from torch.cuda import amp |
|
|
from torch.cuda.amp import autocast as autocast |
|
|
from torch.nn import functional as F |
|
|
|
|
|
from nemo.collections.tts.modules.submodules import ConvNorm, LinearNorm, MaskedInstanceNorm1d |
|
|
from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths, sort_tensor, unsort_tensor |
|
|
from nemo.collections.tts.parts.utils.splines import ( |
|
|
piecewise_linear_inverse_transform, |
|
|
piecewise_linear_transform, |
|
|
unbounded_piecewise_quadratic_transform, |
|
|
) |
|
|
|
|
|
|
|
|
@torch.jit.script |
|
|
def fused_add_tanh_sigmoid_multiply(input_a, input_b): |
|
|
t_act = torch.tanh(input_a) |
|
|
s_act = torch.sigmoid(input_b) |
|
|
acts = t_act * s_act |
|
|
return acts |
|
|
|
|
|
|
|
|
class ExponentialClass(torch.nn.Module): |
|
|
def __init__(self): |
|
|
super(ExponentialClass, self).__init__() |
|
|
|
|
|
def forward(self, x): |
|
|
return torch.exp(x) |
|
|
|
|
|
|
|
|
class DenseLayer(nn.Module): |
|
|
def __init__(self, in_dim=1024, sizes=[1024, 1024]): |
|
|
super(DenseLayer, self).__init__() |
|
|
in_sizes = [in_dim] + sizes[:-1] |
|
|
self.layers = nn.ModuleList( |
|
|
[LinearNorm(in_size, out_size, bias=True) for (in_size, out_size) in zip(in_sizes, sizes)] |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
for linear in self.layers: |
|
|
x = torch.tanh(linear(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
class BiLSTM(nn.Module): |
|
|
def __init__(self, input_size, hidden_size, num_layers=1, lstm_norm_fn="spectral", max_batch_size=64): |
|
|
super().__init__() |
|
|
self.bilstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True) |
|
|
if lstm_norm_fn is not None: |
|
|
if 'spectral' in lstm_norm_fn: |
|
|
print("Applying spectral norm to LSTM") |
|
|
lstm_norm_fn_pntr = torch.nn.utils.spectral_norm |
|
|
elif 'weight' in lstm_norm_fn: |
|
|
print("Applying weight norm to LSTM") |
|
|
lstm_norm_fn_pntr = torch.nn.utils.weight_norm |
|
|
|
|
|
lstm_norm_fn_pntr(self.bilstm, 'weight_hh_l0') |
|
|
lstm_norm_fn_pntr(self.bilstm, 'weight_hh_l0_reverse') |
|
|
|
|
|
self.real_hidden_size: int = self.bilstm.proj_size if self.bilstm.proj_size > 0 else self.bilstm.hidden_size |
|
|
|
|
|
self.bilstm.flatten_parameters() |
|
|
|
|
|
def lstm_sorted(self, context: Tensor, lens: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tensor: |
|
|
seq = nn.utils.rnn.pack_padded_sequence(context, lens.long().cpu(), batch_first=True, enforce_sorted=True) |
|
|
ret, _ = self.bilstm(seq, hx) |
|
|
return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)[0] |
|
|
|
|
|
def lstm(self, context: Tensor, lens: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tensor: |
|
|
|
|
|
context, lens, unsort_ids = sort_tensor(context, lens) |
|
|
ret = self.lstm_sorted(context, lens, hx=hx) |
|
|
return unsort_tensor(ret, unsort_ids) |
|
|
|
|
|
def lstm_nocast(self, context: Tensor, lens: Tensor) -> Tensor: |
|
|
dtype = context.dtype |
|
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
|
|
|
max_batch_size = context.shape[0] |
|
|
context = context.to(dtype=torch.float32) |
|
|
common_shape = (self.bilstm.num_layers * 2, max_batch_size) |
|
|
hx = ( |
|
|
context.new_zeros(*common_shape, self.real_hidden_size), |
|
|
context.new_zeros(*common_shape, self.bilstm.hidden_size), |
|
|
) |
|
|
return self.lstm(context, lens, hx=hx).to(dtype=dtype) |
|
|
|
|
|
def forward(self, context: Tensor, lens: Tensor) -> Tensor: |
|
|
self.bilstm.flatten_parameters() |
|
|
if torch.jit.is_tracing(): |
|
|
return self.lstm_nocast(context, lens) |
|
|
return self.lstm(context, lens) |
|
|
|
|
|
|
|
|
class ConvLSTMLinear(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_dim=None, |
|
|
out_dim=None, |
|
|
n_layers=2, |
|
|
n_channels=256, |
|
|
kernel_size=3, |
|
|
p_dropout=0.1, |
|
|
use_partial_padding=False, |
|
|
norm_fn=None, |
|
|
): |
|
|
super(ConvLSTMLinear, self).__init__() |
|
|
self.bilstm = BiLSTM(n_channels, int(n_channels // 2), 1) |
|
|
self.convolutions = nn.ModuleList() |
|
|
|
|
|
if n_layers > 0: |
|
|
self.dropout = nn.Dropout(p=p_dropout) |
|
|
|
|
|
use_weight_norm = norm_fn is None |
|
|
|
|
|
for i in range(n_layers): |
|
|
conv_layer = ConvNorm( |
|
|
in_dim if i == 0 else n_channels, |
|
|
n_channels, |
|
|
kernel_size=kernel_size, |
|
|
stride=1, |
|
|
padding=int((kernel_size - 1) / 2), |
|
|
dilation=1, |
|
|
w_init_gain='relu', |
|
|
use_weight_norm=use_weight_norm, |
|
|
use_partial_padding=use_partial_padding, |
|
|
norm_fn=norm_fn, |
|
|
) |
|
|
if norm_fn is not None: |
|
|
print("Applying {} norm to {}".format(norm_fn, conv_layer)) |
|
|
else: |
|
|
print("Applying weight norm to {}".format(conv_layer)) |
|
|
self.convolutions.append(conv_layer) |
|
|
|
|
|
self.dense = None |
|
|
if out_dim is not None: |
|
|
self.dense = nn.Linear(n_channels, out_dim) |
|
|
|
|
|
def forward(self, context: Tensor, lens: Tensor) -> Tensor: |
|
|
mask = get_mask_from_lengths(lens, context) |
|
|
mask = mask.to(dtype=context.dtype).unsqueeze(1) |
|
|
for conv in self.convolutions: |
|
|
context = self.dropout(F.relu(conv(context, mask))) |
|
|
|
|
|
context = self.bilstm(context.transpose(1, 2), lens=lens) |
|
|
if self.dense is not None: |
|
|
context = self.dense(context).permute(0, 2, 1) |
|
|
return context |
|
|
|
|
|
|
|
|
def get_radtts_encoder( |
|
|
encoder_n_convolutions=3, encoder_embedding_dim=512, encoder_kernel_size=5, norm_fn=MaskedInstanceNorm1d, |
|
|
): |
|
|
return ConvLSTMLinear( |
|
|
in_dim=encoder_embedding_dim, |
|
|
n_layers=encoder_n_convolutions, |
|
|
n_channels=encoder_embedding_dim, |
|
|
kernel_size=encoder_kernel_size, |
|
|
p_dropout=0.5, |
|
|
use_partial_padding=True, |
|
|
norm_fn=norm_fn, |
|
|
) |
|
|
|
|
|
|
|
|
class Invertible1x1ConvLUS(torch.nn.Module): |
|
|
def __init__(self, c): |
|
|
super(Invertible1x1ConvLUS, self).__init__() |
|
|
|
|
|
W, _ = torch.linalg.qr(torch.FloatTensor(c, c).normal_()) |
|
|
|
|
|
if torch.det(W) < 0: |
|
|
W[:, 0] = -1 * W[:, 0] |
|
|
p, lower, upper = torch.lu_unpack(*torch.lu(W)) |
|
|
|
|
|
self.register_buffer('p', p) |
|
|
|
|
|
lower = torch.tril(lower, -1) |
|
|
lower_diag = torch.diag(torch.eye(c, c)) |
|
|
self.register_buffer('lower_diag', lower_diag) |
|
|
self.lower = nn.Parameter(lower) |
|
|
self.upper_diag = nn.Parameter(torch.diag(upper)) |
|
|
self.upper = nn.Parameter(torch.triu(upper, 1)) |
|
|
|
|
|
@amp.autocast(False) |
|
|
def forward(self, z, inverse=False): |
|
|
U = torch.triu(self.upper, 1) + torch.diag(self.upper_diag) |
|
|
L = torch.tril(self.lower, -1) + torch.diag(self.lower_diag) |
|
|
W = torch.mm(self.p, torch.mm(L, U)) |
|
|
if inverse: |
|
|
if not hasattr(self, 'W_inverse'): |
|
|
|
|
|
W_inverse = W.float().inverse().to(dtype=z.dtype) |
|
|
self.W_inverse = W_inverse[..., None] |
|
|
z = F.conv1d(z, self.W_inverse.to(dtype=z.dtype), bias=None, stride=1, padding=0) |
|
|
return z |
|
|
else: |
|
|
W = W[..., None] |
|
|
z = F.conv1d(z, W, bias=None, stride=1, padding=0) |
|
|
log_det_W = torch.sum(torch.log(torch.abs(self.upper_diag))) |
|
|
return z, log_det_W |
|
|
|
|
|
|
|
|
class Invertible1x1Conv(torch.nn.Module): |
|
|
""" |
|
|
The layer outputs both the convolution, and the log determinant |
|
|
of its weight matrix. If inverse=True it does convolution with |
|
|
inverse |
|
|
""" |
|
|
|
|
|
def __init__(self, c): |
|
|
super(Invertible1x1Conv, self).__init__() |
|
|
self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, bias=False) |
|
|
|
|
|
|
|
|
W = torch.qr(torch.FloatTensor(c, c).normal_())[0] |
|
|
|
|
|
|
|
|
if torch.det(W) < 0: |
|
|
W[:, 0] = -1 * W[:, 0] |
|
|
W = W.view(c, c, 1) |
|
|
self.conv.weight.data = W |
|
|
|
|
|
def forward(self, z, inverse=False): |
|
|
|
|
|
W = self.conv.weight.squeeze() |
|
|
|
|
|
if inverse: |
|
|
if not hasattr(self, 'W_inverse'): |
|
|
|
|
|
W_inverse = W.float().inverse().to(dtype=z.dtype) |
|
|
self.W_inverse = W_inverse[..., None] |
|
|
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) |
|
|
return z |
|
|
else: |
|
|
|
|
|
log_det_W = torch.logdet(W).clone() |
|
|
z = self.conv(z) |
|
|
return z, log_det_W |
|
|
|
|
|
|
|
|
class SimpleConvNet(torch.nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
n_mel_channels, |
|
|
n_context_dim, |
|
|
final_out_channels, |
|
|
n_layers=2, |
|
|
kernel_size=5, |
|
|
with_dilation=True, |
|
|
max_channels=1024, |
|
|
zero_init=True, |
|
|
use_partial_padding=True, |
|
|
): |
|
|
super(SimpleConvNet, self).__init__() |
|
|
self.layers = torch.nn.ModuleList() |
|
|
self.n_layers = n_layers |
|
|
in_channels = n_mel_channels + n_context_dim |
|
|
out_channels = -1 |
|
|
self.use_partial_padding = use_partial_padding |
|
|
for i in range(n_layers): |
|
|
dilation = 2 ** i if with_dilation else 1 |
|
|
padding = int((kernel_size * dilation - dilation) / 2) |
|
|
out_channels = min(max_channels, in_channels * 2) |
|
|
self.layers.append( |
|
|
ConvNorm( |
|
|
in_channels, |
|
|
out_channels, |
|
|
kernel_size=kernel_size, |
|
|
stride=1, |
|
|
padding=padding, |
|
|
dilation=dilation, |
|
|
bias=True, |
|
|
w_init_gain='relu', |
|
|
use_partial_padding=use_partial_padding, |
|
|
) |
|
|
) |
|
|
in_channels = out_channels |
|
|
|
|
|
self.last_layer = torch.nn.Conv1d(out_channels, final_out_channels, kernel_size=1) |
|
|
|
|
|
if zero_init: |
|
|
self.last_layer.weight.data *= 0 |
|
|
self.last_layer.bias.data *= 0 |
|
|
|
|
|
def forward(self, z_w_context, seq_lens: Optional[Tensor] = None): |
|
|
|
|
|
|
|
|
|
|
|
mask = get_mask_from_lengths(seq_lens, z_w_context).unsqueeze(1).to(dtype=z_w_context.dtype) |
|
|
|
|
|
for i in range(self.n_layers): |
|
|
z_w_context = self.layers[i](z_w_context, mask) |
|
|
z_w_context = torch.relu(z_w_context) |
|
|
|
|
|
z_w_context = self.last_layer(z_w_context) |
|
|
return z_w_context |
|
|
|
|
|
|
|
|
class WN(torch.nn.Module): |
|
|
""" |
|
|
Adapted from WN() module in WaveGlow with modififcations to variable names |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
n_in_channels, |
|
|
n_context_dim, |
|
|
n_layers, |
|
|
n_channels, |
|
|
kernel_size=5, |
|
|
affine_activation='softplus', |
|
|
use_partial_padding=True, |
|
|
): |
|
|
super(WN, self).__init__() |
|
|
assert kernel_size % 2 == 1 |
|
|
assert n_channels % 2 == 0 |
|
|
self.n_layers = n_layers |
|
|
self.n_channels = n_channels |
|
|
self.in_layers = torch.nn.ModuleList() |
|
|
self.res_skip_layers = torch.nn.ModuleList() |
|
|
start = torch.nn.Conv1d(n_in_channels + n_context_dim, n_channels, 1) |
|
|
start = torch.nn.utils.weight_norm(start, name='weight') |
|
|
self.start = start |
|
|
self.softplus = torch.nn.Softplus() |
|
|
self.affine_activation = affine_activation |
|
|
self.use_partial_padding = use_partial_padding |
|
|
|
|
|
|
|
|
end = torch.nn.Conv1d(n_channels, 2 * n_in_channels, 1) |
|
|
end.weight.data.zero_() |
|
|
end.bias.data.zero_() |
|
|
self.end = end |
|
|
|
|
|
for i in range(n_layers): |
|
|
dilation = 2 ** i |
|
|
padding = int((kernel_size * dilation - dilation) / 2) |
|
|
in_layer = ConvNorm( |
|
|
n_channels, |
|
|
n_channels, |
|
|
kernel_size=kernel_size, |
|
|
dilation=dilation, |
|
|
padding=padding, |
|
|
use_partial_padding=use_partial_padding, |
|
|
use_weight_norm=True, |
|
|
) |
|
|
self.in_layers.append(in_layer) |
|
|
res_skip_layer = nn.Conv1d(n_channels, n_channels, 1) |
|
|
res_skip_layer = nn.utils.weight_norm(res_skip_layer) |
|
|
self.res_skip_layers.append(res_skip_layer) |
|
|
|
|
|
def forward(self, forward_input: Tuple[Tensor, Tensor], seq_lens: Tensor = None): |
|
|
z, context = forward_input |
|
|
z = torch.cat((z, context), 1) |
|
|
z = self.start(z) |
|
|
output = torch.zeros_like(z) |
|
|
mask = None |
|
|
if self.use_partial_padding: |
|
|
mask = get_mask_from_lengths(seq_lens).unsqueeze(1).float() |
|
|
non_linearity = torch.relu |
|
|
if self.affine_activation == 'softplus': |
|
|
non_linearity = self.softplus |
|
|
|
|
|
for i in range(self.n_layers): |
|
|
z = non_linearity(self.in_layers[i](z, mask)) |
|
|
res_skip_acts = non_linearity(self.res_skip_layers[i](z)) |
|
|
output = output + res_skip_acts |
|
|
|
|
|
output = self.end(output) |
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
class SplineTransformationLayerAR(torch.nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
n_in_channels, |
|
|
n_context_dim, |
|
|
n_layers, |
|
|
affine_model='simple_conv', |
|
|
kernel_size=1, |
|
|
scaling_fn='exp', |
|
|
affine_activation='softplus', |
|
|
n_channels=1024, |
|
|
n_bins=8, |
|
|
left=-6, |
|
|
right=6, |
|
|
bottom=-6, |
|
|
top=6, |
|
|
use_quadratic=False, |
|
|
): |
|
|
super(SplineTransformationLayerAR, self).__init__() |
|
|
self.n_in_channels = n_in_channels |
|
|
self.left = left |
|
|
self.right = right |
|
|
self.bottom = bottom |
|
|
self.top = top |
|
|
self.n_bins = n_bins |
|
|
self.spline_fn = piecewise_linear_transform |
|
|
self.inv_spline_fn = piecewise_linear_inverse_transform |
|
|
self.use_quadratic = use_quadratic |
|
|
|
|
|
if self.use_quadratic: |
|
|
self.spline_fn = unbounded_piecewise_quadratic_transform |
|
|
self.inv_spline_fn = unbounded_piecewise_quadratic_transform |
|
|
self.n_bins = 2 * self.n_bins + 1 |
|
|
final_out_channels = self.n_in_channels * self.n_bins |
|
|
|
|
|
|
|
|
self.param_predictor = SimpleConvNet( |
|
|
n_context_dim, |
|
|
0, |
|
|
final_out_channels, |
|
|
n_layers, |
|
|
with_dilation=False, |
|
|
kernel_size=1, |
|
|
zero_init=True, |
|
|
use_partial_padding=False, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def normalize(self, z, inverse): |
|
|
|
|
|
if inverse: |
|
|
z = (z - self.bottom) / (self.top - self.bottom) |
|
|
else: |
|
|
z = (z - self.left) / (self.right - self.left) |
|
|
|
|
|
return z |
|
|
|
|
|
def denormalize(self, z, inverse): |
|
|
if inverse: |
|
|
z = z * (self.right - self.left) + self.left |
|
|
else: |
|
|
z = z * (self.top - self.bottom) + self.bottom |
|
|
|
|
|
return z |
|
|
|
|
|
def forward(self, z, context, inverse=False): |
|
|
b_s, c_s, t_s = z.size(0), z.size(1), z.size(2) |
|
|
|
|
|
z = self.normalize(z, inverse) |
|
|
|
|
|
if z.min() < 0.0 or z.max() > 1.0: |
|
|
print('spline z scaled beyond [0, 1]', z.min(), z.max()) |
|
|
|
|
|
z_reshaped = z.permute(0, 2, 1).reshape(b_s * t_s, -1) |
|
|
affine_params = self.param_predictor(context) |
|
|
q_tilde = affine_params.permute(0, 2, 1).reshape(b_s * t_s, c_s, -1) |
|
|
with amp.autocast(enabled=False): |
|
|
if self.use_quadratic: |
|
|
w = q_tilde[:, :, : self.n_bins // 2] |
|
|
v = q_tilde[:, :, self.n_bins // 2 :] |
|
|
z_tformed, log_s = self.spline_fn(z_reshaped.float(), w.float(), v.float(), inverse=inverse) |
|
|
else: |
|
|
z_tformed, log_s = self.spline_fn(z_reshaped.float(), q_tilde.float()) |
|
|
|
|
|
z = z_tformed.reshape(b_s, t_s, -1).permute(0, 2, 1) |
|
|
z = self.denormalize(z, inverse) |
|
|
if inverse: |
|
|
return z |
|
|
|
|
|
log_s = log_s.reshape(b_s, t_s, -1) |
|
|
log_s = log_s.permute(0, 2, 1) |
|
|
log_s = log_s + c_s * (np.log(self.top - self.bottom) - np.log(self.right - self.left)) |
|
|
return z, log_s |
|
|
|
|
|
|
|
|
class SplineTransformationLayer(torch.nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
n_mel_channels, |
|
|
n_context_dim, |
|
|
n_layers, |
|
|
with_dilation=True, |
|
|
kernel_size=5, |
|
|
scaling_fn='exp', |
|
|
affine_activation='softplus', |
|
|
n_channels=1024, |
|
|
n_bins=8, |
|
|
left=-4, |
|
|
right=4, |
|
|
bottom=-4, |
|
|
top=4, |
|
|
use_quadratic=False, |
|
|
): |
|
|
super(SplineTransformationLayer, self).__init__() |
|
|
self.n_mel_channels = n_mel_channels |
|
|
self.half_mel_channels = int(n_mel_channels / 2) |
|
|
self.left = left |
|
|
self.right = right |
|
|
self.bottom = bottom |
|
|
self.top = top |
|
|
self.n_bins = n_bins |
|
|
self.spline_fn = piecewise_linear_transform |
|
|
self.inv_spline_fn = piecewise_linear_inverse_transform |
|
|
self.use_quadratic = use_quadratic |
|
|
|
|
|
if self.use_quadratic: |
|
|
self.spline_fn = unbounded_piecewise_quadratic_transform |
|
|
self.inv_spline_fn = unbounded_piecewise_quadratic_transform |
|
|
self.n_bins = 2 * self.n_bins + 1 |
|
|
final_out_channels = self.half_mel_channels * self.n_bins |
|
|
|
|
|
self.param_predictor = SimpleConvNet( |
|
|
self.half_mel_channels, |
|
|
n_context_dim, |
|
|
final_out_channels, |
|
|
n_layers, |
|
|
with_dilation=with_dilation, |
|
|
kernel_size=kernel_size, |
|
|
zero_init=False, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, z, context, inverse=False, seq_lens=None): |
|
|
b_s, c_s, t_s = z.size(0), z.size(1), z.size(2) |
|
|
|
|
|
|
|
|
n_half = self.half_mel_channels |
|
|
z_0, z_1 = z[:, :n_half], z[:, n_half:] |
|
|
|
|
|
|
|
|
if inverse: |
|
|
z_1 = (z_1 - self.bottom) / (self.top - self.bottom) |
|
|
else: |
|
|
z_1 = (z_1 - self.left) / (self.right - self.left) |
|
|
|
|
|
z_w_context = torch.cat((z_0, context), 1) |
|
|
affine_params = self.param_predictor(z_w_context, seq_lens) |
|
|
z_1_reshaped = z_1.permute(0, 2, 1).reshape(b_s * t_s, -1) |
|
|
q_tilde = affine_params.permute(0, 2, 1).reshape(b_s * t_s, n_half, self.n_bins) |
|
|
|
|
|
with autocast(enabled=False): |
|
|
if self.use_quadratic: |
|
|
w = q_tilde[:, :, : self.n_bins // 2] |
|
|
v = q_tilde[:, :, self.n_bins // 2 :] |
|
|
z_1_tformed, log_s = self.spline_fn(z_1_reshaped.float(), w.float(), v.float(), inverse=inverse) |
|
|
if not inverse: |
|
|
log_s = torch.sum(log_s, 1) |
|
|
else: |
|
|
if inverse: |
|
|
z_1_tformed, _dc = self.inv_spline_fn(z_1_reshaped.float(), q_tilde.float(), False) |
|
|
else: |
|
|
z_1_tformed, log_s = self.spline_fn(z_1_reshaped.float(), q_tilde.float()) |
|
|
|
|
|
z_1 = z_1_tformed.reshape(b_s, t_s, -1).permute(0, 2, 1) |
|
|
|
|
|
|
|
|
if inverse: |
|
|
z_1 = z_1 * (self.right - self.left) + self.left |
|
|
z = torch.cat((z_0, z_1), dim=1) |
|
|
return z |
|
|
else: |
|
|
z_1 = z_1 * (self.top - self.bottom) + self.bottom |
|
|
z = torch.cat((z_0, z_1), dim=1) |
|
|
log_s = log_s.reshape(b_s, t_s).unsqueeze(1) + n_half * ( |
|
|
np.log(self.top - self.bottom) - np.log(self.right - self.left) |
|
|
) |
|
|
return z, log_s |
|
|
|
|
|
|
|
|
class AffineTransformationLayer(torch.nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
n_mel_channels, |
|
|
n_context_dim, |
|
|
n_layers, |
|
|
affine_model='simple_conv', |
|
|
with_dilation=True, |
|
|
kernel_size=5, |
|
|
scaling_fn='exp', |
|
|
affine_activation='softplus', |
|
|
n_channels=1024, |
|
|
use_partial_padding=False, |
|
|
): |
|
|
super(AffineTransformationLayer, self).__init__() |
|
|
if affine_model not in ("wavenet", "simple_conv"): |
|
|
raise Exception("{} affine model not supported".format(affine_model)) |
|
|
if isinstance(scaling_fn, list): |
|
|
if not all([x in ("translate", "exp", "tanh", "sigmoid") for x in scaling_fn]): |
|
|
raise Exception("{} scaling fn not supported".format(scaling_fn)) |
|
|
else: |
|
|
if scaling_fn not in ("translate", "exp", "tanh", "sigmoid"): |
|
|
raise Exception("{} scaling fn not supported".format(scaling_fn)) |
|
|
|
|
|
self.affine_model = affine_model |
|
|
self.scaling_fn = scaling_fn |
|
|
if affine_model == 'wavenet': |
|
|
self.affine_param_predictor = WN( |
|
|
int(n_mel_channels / 2), |
|
|
n_context_dim, |
|
|
n_layers=n_layers, |
|
|
n_channels=n_channels, |
|
|
affine_activation=affine_activation, |
|
|
use_partial_padding=use_partial_padding, |
|
|
) |
|
|
elif affine_model == 'simple_conv': |
|
|
self.affine_param_predictor = SimpleConvNet( |
|
|
int(n_mel_channels / 2), |
|
|
n_context_dim, |
|
|
n_mel_channels, |
|
|
n_layers, |
|
|
with_dilation=with_dilation, |
|
|
kernel_size=kernel_size, |
|
|
use_partial_padding=use_partial_padding, |
|
|
) |
|
|
self.n_mel_channels = n_mel_channels |
|
|
|
|
|
def get_scaling_and_logs(self, scale_unconstrained): |
|
|
|
|
|
if self.scaling_fn == 'translate': |
|
|
s = torch.exp(scale_unconstrained * 0) |
|
|
log_s = scale_unconstrained * 0 |
|
|
elif self.scaling_fn == 'exp': |
|
|
s = torch.exp(scale_unconstrained) |
|
|
log_s = scale_unconstrained |
|
|
elif self.scaling_fn == 'tanh': |
|
|
s = torch.tanh(scale_unconstrained) + 1 + 1e-6 |
|
|
log_s = torch.log(s) |
|
|
elif self.scaling_fn == 'sigmoid': |
|
|
s = torch.sigmoid(scale_unconstrained + 10) + 1e-6 |
|
|
log_s = torch.log(s) |
|
|
elif isinstance(self.scaling_fn, list): |
|
|
s_list, log_s_list = [], [] |
|
|
for i in range(scale_unconstrained.shape[1]): |
|
|
scaling_i = self.scaling_fn[i] |
|
|
if scaling_i == 'translate': |
|
|
s_i = torch.exp(scale_unconstrained[:i] * 0) |
|
|
log_s_i = scale_unconstrained[:, i] * 0 |
|
|
elif scaling_i == 'exp': |
|
|
s_i = torch.exp(scale_unconstrained[:, i]) |
|
|
log_s_i = scale_unconstrained[:, i] |
|
|
elif scaling_i == 'tanh': |
|
|
s_i = torch.tanh(scale_unconstrained[:, i]) + 1 + 1e-6 |
|
|
log_s_i = torch.log(s_i) |
|
|
elif scaling_i == 'sigmoid': |
|
|
s_i = torch.sigmoid(scale_unconstrained[:, i]) + 1e-6 |
|
|
log_s_i = torch.log(s_i) |
|
|
s_list.append(s_i[:, None]) |
|
|
log_s_list.append(log_s_i[:, None]) |
|
|
s = torch.cat(s_list, dim=1) |
|
|
log_s = torch.cat(log_s_list, dim=1) |
|
|
return s, log_s |
|
|
|
|
|
def forward(self, z, context, inverse=False, seq_lens=None): |
|
|
n_half = int(self.n_mel_channels / 2) |
|
|
z_0, z_1 = z[:, :n_half], z[:, n_half:] |
|
|
if self.affine_model == 'wavenet': |
|
|
affine_params = self.affine_param_predictor((z_0, context), seq_lens=seq_lens) |
|
|
elif self.affine_model == 'simple_conv': |
|
|
z_w_context = torch.cat((z_0, context), 1) |
|
|
affine_params = self.affine_param_predictor(z_w_context, seq_lens=seq_lens) |
|
|
|
|
|
scale_unconstrained = affine_params[:, :n_half, :] |
|
|
b = affine_params[:, n_half:, :] |
|
|
s, log_s = self.get_scaling_and_logs(scale_unconstrained) |
|
|
|
|
|
if inverse: |
|
|
z_1 = (z_1 - b) / s |
|
|
z = torch.cat((z_0, z_1), dim=1) |
|
|
return z |
|
|
else: |
|
|
z_1 = s * z_1 + b |
|
|
z = torch.cat((z_0, z_1), dim=1) |
|
|
return z, log_s |
|
|
|
|
|
|
|
|
class ConvAttention(torch.nn.Module): |
|
|
def __init__(self, n_mel_channels=80, n_speaker_dim=128, n_text_channels=512, n_att_channels=80, temperature=1.0): |
|
|
super(ConvAttention, self).__init__() |
|
|
self.temperature = temperature |
|
|
self.softmax = torch.nn.Softmax(dim=3) |
|
|
self.log_softmax = torch.nn.LogSoftmax(dim=3) |
|
|
self.query_proj = Invertible1x1ConvLUS(n_mel_channels) |
|
|
|
|
|
self.key_proj = nn.Sequential( |
|
|
ConvNorm(n_text_channels, n_text_channels * 2, kernel_size=3, bias=True, w_init_gain='relu'), |
|
|
torch.nn.ReLU(), |
|
|
ConvNorm(n_text_channels * 2, n_att_channels, kernel_size=1, bias=True), |
|
|
) |
|
|
|
|
|
self.query_proj = nn.Sequential( |
|
|
ConvNorm(n_mel_channels, n_mel_channels * 2, kernel_size=3, bias=True, w_init_gain='relu'), |
|
|
torch.nn.ReLU(), |
|
|
ConvNorm(n_mel_channels * 2, n_mel_channels, kernel_size=1, bias=True), |
|
|
torch.nn.ReLU(), |
|
|
ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True), |
|
|
) |
|
|
|
|
|
def forward(self, queries, keys, query_lens, mask=None, key_lens=None, attn_prior=None): |
|
|
"""Attention mechanism for radtts. Unlike in Flowtron, we have no |
|
|
restrictions such as causality etc, since we only need this during |
|
|
training. |
|
|
|
|
|
Args: |
|
|
queries (torch.tensor): B x C x T1 tensor (likely mel data) |
|
|
keys (torch.tensor): B x C2 x T2 tensor (text data) |
|
|
query_lens: lengths for sorting the queries in descending order |
|
|
mask (torch.tensor): uint8 binary mask for variable length entries |
|
|
(should be in the T2 domain) |
|
|
Output: |
|
|
attn (torch.tensor): B x 1 x T1 x T2 attention mask. |
|
|
Final dim T2 should sum to 1 |
|
|
""" |
|
|
temp = 0.0005 |
|
|
keys_enc = self.key_proj(keys) |
|
|
|
|
|
queries_enc = self.query_proj(queries) |
|
|
|
|
|
|
|
|
|
|
|
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 |
|
|
|
|
|
|
|
|
eps = 1e-8 |
|
|
attn = -temp * attn.sum(1, keepdim=True) |
|
|
if attn_prior is not None: |
|
|
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + eps) |
|
|
|
|
|
attn_logprob = attn.clone() |
|
|
|
|
|
if mask is not None: |
|
|
attn.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), -float("inf")) |
|
|
|
|
|
attn = self.softmax(attn) |
|
|
return attn, attn_logprob |
|
|
|