|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple |
|
|
|
|
|
import torch |
|
|
from torch import Tensor |
|
|
from torch.autograd import Variable |
|
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
|
def masked_instance_norm( |
|
|
input: Tensor, mask: Tensor, weight: Tensor, bias: Tensor, momentum: float, eps: float = 1e-5, |
|
|
) -> Tensor: |
|
|
r"""Applies Masked Instance Normalization for each channel in each data sample in a batch. |
|
|
|
|
|
See :class:`~MaskedInstanceNorm1d` for details. |
|
|
""" |
|
|
lengths = mask.sum((-1,)) |
|
|
mean = (input * mask).sum((-1,)) / lengths |
|
|
var = (((input - mean[(..., None)]) * mask) ** 2).sum((-1,)) / lengths |
|
|
out = (input - mean[(..., None)]) / torch.sqrt(var[(..., None)] + eps) |
|
|
out = out * weight[None, :][(..., None)] + bias[None, :][(..., None)] |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
class MaskedInstanceNorm1d(torch.nn.InstanceNorm1d): |
|
|
r"""Applies Instance Normalization over a masked 3D input |
|
|
(a mini-batch of 1D inputs with additional channel dimension).. |
|
|
|
|
|
See documentation of :class:`~torch.nn.InstanceNorm1d` for details. |
|
|
|
|
|
Shape: |
|
|
- Input: :math:`(N, C, L)` |
|
|
- Mask: :math:`(N, 1, L)` |
|
|
- Output: :math:`(N, C, L)` (same shape as input) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_features: int, |
|
|
eps: float = 1e-5, |
|
|
momentum: float = 0.1, |
|
|
affine: bool = False, |
|
|
track_running_stats: bool = False, |
|
|
) -> None: |
|
|
super(MaskedInstanceNorm1d, self).__init__(num_features, eps, momentum, affine, track_running_stats) |
|
|
|
|
|
def forward(self, input: Tensor, mask: Tensor) -> Tensor: |
|
|
return masked_instance_norm(input, mask, self.weight, self.bias, self.momentum, self.eps,) |
|
|
|
|
|
|
|
|
class PartialConv1d(torch.nn.Conv1d): |
|
|
""" |
|
|
Zero padding creates a unique identifier for where the edge of the data is, such that the model can almost always identify |
|
|
exactly where it is relative to either edge given a sufficient receptive field. Partial padding goes to some lengths to remove |
|
|
this affect. |
|
|
""" |
|
|
|
|
|
__constants__ = ['slide_winsize'] |
|
|
slide_winsize: float |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super(PartialConv1d, self).__init__(*args, **kwargs) |
|
|
weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0]) |
|
|
self.register_buffer("weight_maskUpdater", weight_maskUpdater, persistent=False) |
|
|
self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] |
|
|
|
|
|
def forward(self, input, mask_in): |
|
|
if mask_in is None: |
|
|
mask = torch.ones(1, 1, input.shape[2], dtype=input.dtype, device=input.device) |
|
|
else: |
|
|
mask = mask_in |
|
|
input = torch.mul(input, mask) |
|
|
with torch.no_grad(): |
|
|
update_mask = F.conv1d( |
|
|
mask, |
|
|
self.weight_maskUpdater, |
|
|
bias=None, |
|
|
stride=self.stride, |
|
|
padding=self.padding, |
|
|
dilation=self.dilation, |
|
|
groups=1, |
|
|
) |
|
|
update_mask_filled = torch.masked_fill(update_mask, update_mask == 0, self.slide_winsize) |
|
|
mask_ratio = self.slide_winsize / update_mask_filled |
|
|
update_mask = torch.clamp(update_mask, 0, 1) |
|
|
mask_ratio = torch.mul(mask_ratio, update_mask) |
|
|
|
|
|
raw_out = self._conv_forward(input, self.weight, self.bias) |
|
|
|
|
|
if self.bias is not None: |
|
|
bias_view = self.bias.view(1, self.out_channels, 1) |
|
|
output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view |
|
|
output = torch.mul(output, update_mask) |
|
|
else: |
|
|
output = torch.mul(raw_out, mask_ratio) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class LinearNorm(torch.nn.Module): |
|
|
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): |
|
|
super().__init__() |
|
|
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) |
|
|
|
|
|
torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.linear_layer(x) |
|
|
|
|
|
|
|
|
class ConvNorm(torch.nn.Module): |
|
|
__constants__ = ['use_partial_padding'] |
|
|
use_partial_padding: bool |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_channels, |
|
|
out_channels, |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
padding=None, |
|
|
dilation=1, |
|
|
bias=True, |
|
|
w_init_gain='linear', |
|
|
use_partial_padding=False, |
|
|
use_weight_norm=False, |
|
|
norm_fn=None, |
|
|
): |
|
|
super(ConvNorm, self).__init__() |
|
|
if padding is None: |
|
|
assert kernel_size % 2 == 1 |
|
|
padding = int(dilation * (kernel_size - 1) / 2) |
|
|
self.use_partial_padding = use_partial_padding |
|
|
conv_fn = torch.nn.Conv1d |
|
|
if use_partial_padding: |
|
|
conv_fn = PartialConv1d |
|
|
self.conv = conv_fn( |
|
|
in_channels, |
|
|
out_channels, |
|
|
kernel_size=kernel_size, |
|
|
stride=stride, |
|
|
padding=padding, |
|
|
dilation=dilation, |
|
|
bias=bias, |
|
|
) |
|
|
torch.nn.init.xavier_uniform_(self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) |
|
|
if use_weight_norm: |
|
|
self.conv = torch.nn.utils.weight_norm(self.conv) |
|
|
if norm_fn is not None: |
|
|
self.norm = norm_fn(out_channels, affine=True) |
|
|
else: |
|
|
self.norm = None |
|
|
|
|
|
def forward(self, signal, mask=None): |
|
|
if self.use_partial_padding: |
|
|
ret = self.conv(signal, mask) |
|
|
if self.norm is not None: |
|
|
ret = self.norm(ret, mask) |
|
|
else: |
|
|
if mask is not None: |
|
|
signal = signal.mul(mask) |
|
|
ret = self.conv(signal) |
|
|
if self.norm is not None: |
|
|
ret = self.norm(ret) |
|
|
return ret |
|
|
|
|
|
|
|
|
class LocationLayer(torch.nn.Module): |
|
|
def __init__(self, attention_n_filters, attention_kernel_size, attention_dim): |
|
|
super().__init__() |
|
|
padding = int((attention_kernel_size - 1) / 2) |
|
|
self.location_conv = ConvNorm( |
|
|
2, |
|
|
attention_n_filters, |
|
|
kernel_size=attention_kernel_size, |
|
|
padding=padding, |
|
|
bias=False, |
|
|
stride=1, |
|
|
dilation=1, |
|
|
) |
|
|
self.location_dense = LinearNorm(attention_n_filters, attention_dim, bias=False, w_init_gain='tanh') |
|
|
|
|
|
def forward(self, attention_weights_cat): |
|
|
processed_attention = self.location_conv(attention_weights_cat) |
|
|
processed_attention = processed_attention.transpose(1, 2) |
|
|
processed_attention = self.location_dense(processed_attention) |
|
|
return processed_attention |
|
|
|
|
|
|
|
|
class Attention(torch.nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
attention_rnn_dim, |
|
|
embedding_dim, |
|
|
attention_dim, |
|
|
attention_location_n_filters, |
|
|
attention_location_kernel_size, |
|
|
): |
|
|
super().__init__() |
|
|
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, bias=False, w_init_gain='tanh') |
|
|
self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, w_init_gain='tanh') |
|
|
self.v = LinearNorm(attention_dim, 1, bias=False) |
|
|
self.location_layer = LocationLayer( |
|
|
attention_location_n_filters, attention_location_kernel_size, attention_dim, |
|
|
) |
|
|
self.score_mask_value = -float("inf") |
|
|
|
|
|
def get_alignment_energies(self, query, processed_memory, attention_weights_cat): |
|
|
""" |
|
|
PARAMS |
|
|
------ |
|
|
query: decoder output (batch, n_mel_channels * n_frames_per_step) |
|
|
processed_memory: processed encoder outputs (B, T_in, attention_dim) |
|
|
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) |
|
|
RETURNS |
|
|
------- |
|
|
alignment (batch, max_time) |
|
|
""" |
|
|
|
|
|
processed_query = self.query_layer(query.unsqueeze(1)) |
|
|
processed_attention_weights = self.location_layer(attention_weights_cat) |
|
|
energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_memory)) |
|
|
|
|
|
energies = energies.squeeze(-1) |
|
|
return energies |
|
|
|
|
|
def forward( |
|
|
self, attention_hidden_state, memory, processed_memory, attention_weights_cat, mask, |
|
|
): |
|
|
""" |
|
|
PARAMS |
|
|
------ |
|
|
attention_hidden_state: attention rnn last output |
|
|
memory: encoder outputs |
|
|
processed_memory: processed encoder outputs |
|
|
attention_weights_cat: previous and cummulative attention weights |
|
|
mask: binary mask for padded data |
|
|
""" |
|
|
alignment = self.get_alignment_energies(attention_hidden_state, processed_memory, attention_weights_cat) |
|
|
|
|
|
if mask is not None: |
|
|
alignment.data.masked_fill_(mask, self.score_mask_value) |
|
|
|
|
|
attention_weights = F.softmax(alignment, dim=1) |
|
|
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) |
|
|
attention_context = attention_context.squeeze(1) |
|
|
|
|
|
return attention_context, attention_weights |
|
|
|
|
|
|
|
|
class Prenet(torch.nn.Module): |
|
|
def __init__(self, in_dim, sizes, p_dropout=0.5): |
|
|
super().__init__() |
|
|
in_sizes = [in_dim] + sizes[:-1] |
|
|
self.p_dropout = p_dropout |
|
|
self.layers = torch.nn.ModuleList( |
|
|
[LinearNorm(in_size, out_size, bias=False) for (in_size, out_size) in zip(in_sizes, sizes)] |
|
|
) |
|
|
|
|
|
def forward(self, x, inference=False): |
|
|
if inference: |
|
|
for linear in self.layers: |
|
|
x = F.relu(linear(x)) |
|
|
x0 = x[0].unsqueeze(0) |
|
|
mask = torch.autograd.Variable(torch.bernoulli(x0.data.new(x0.data.size()).fill_(1 - self.p_dropout))) |
|
|
mask = mask.expand(x.size(0), x.size(1)) |
|
|
x = x * mask * 1 / (1 - self.p_dropout) |
|
|
else: |
|
|
for linear in self.layers: |
|
|
x = F.dropout(F.relu(linear(x)), p=self.p_dropout, training=True) |
|
|
return x |
|
|
|
|
|
|
|
|
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels_int): |
|
|
in_act = input_a + input_b |
|
|
t_act = torch.tanh(in_act[:, :n_channels_int, :]) |
|
|
s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) |
|
|
acts = t_act * s_act |
|
|
return acts |
|
|
|
|
|
|
|
|
class Invertible1x1Conv(torch.nn.Module): |
|
|
""" |
|
|
The layer outputs both the convolution, and the log determinant |
|
|
of its weight matrix. If reverse=True it does convolution with |
|
|
inverse |
|
|
""" |
|
|
|
|
|
def __init__(self, c): |
|
|
super().__init__() |
|
|
self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, bias=False) |
|
|
|
|
|
|
|
|
W = torch.linalg.qr(torch.FloatTensor(c, c).normal_())[0] |
|
|
|
|
|
|
|
|
if torch.det(W) < 0: |
|
|
W[:, 0] = -1 * W[:, 0] |
|
|
W = W.view(c, c, 1) |
|
|
self.conv.weight.data = W |
|
|
self.inv_conv = None |
|
|
|
|
|
def forward(self, z, reverse: bool = False): |
|
|
if reverse: |
|
|
if self.inv_conv is None: |
|
|
|
|
|
|
|
|
|
|
|
self.inv_conv = torch.nn.Conv1d( |
|
|
self.conv.in_channels, self.conv.out_channels, kernel_size=1, stride=1, padding=0, bias=False |
|
|
) |
|
|
W_inverse = self.conv.weight.squeeze().data.float().inverse() |
|
|
W_inverse = Variable(W_inverse[..., None]) |
|
|
self.inv_conv.weight.data = W_inverse |
|
|
self.inv_conv.to(device=self.conv.weight.device, dtype=self.conv.weight.dtype) |
|
|
return self.inv_conv(z) |
|
|
else: |
|
|
|
|
|
|
|
|
W = self.conv.weight.squeeze() |
|
|
batch_size, group_size, n_of_groups = z.size() |
|
|
log_det_W = batch_size * n_of_groups * torch.logdet(W.float()) |
|
|
z = self.conv(z) |
|
|
return ( |
|
|
z, |
|
|
log_det_W, |
|
|
) |
|
|
|
|
|
|
|
|
class WaveNet(torch.nn.Module): |
|
|
""" |
|
|
This is the WaveNet like layer for the affine coupling. The primary |
|
|
difference from WaveNet is the convolutions need not be causal. There is |
|
|
also no dilation size reset. The dilation only doubles on each layer |
|
|
""" |
|
|
|
|
|
def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels, kernel_size): |
|
|
super().__init__() |
|
|
assert kernel_size % 2 == 1 |
|
|
assert n_channels % 2 == 0 |
|
|
self.n_layers = n_layers |
|
|
self.n_channels = n_channels |
|
|
self.in_layers = torch.nn.ModuleList() |
|
|
self.res_skip_layers = torch.nn.ModuleList() |
|
|
|
|
|
start = torch.nn.Conv1d(n_in_channels, n_channels, 1) |
|
|
start = torch.nn.utils.weight_norm(start, name='weight') |
|
|
self.start = start |
|
|
|
|
|
|
|
|
|
|
|
end = torch.nn.Conv1d(n_channels, 2 * n_in_channels, 1) |
|
|
end.weight.data.zero_() |
|
|
end.bias.data.zero_() |
|
|
self.end = end |
|
|
|
|
|
cond_layer = torch.nn.Conv1d(n_mel_channels, 2 * n_channels * n_layers, 1) |
|
|
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') |
|
|
|
|
|
for i in range(n_layers): |
|
|
dilation = 2 ** i |
|
|
padding = int((kernel_size * dilation - dilation) / 2) |
|
|
in_layer = torch.nn.Conv1d(n_channels, 2 * n_channels, kernel_size, dilation=dilation, padding=padding,) |
|
|
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') |
|
|
self.in_layers.append(in_layer) |
|
|
|
|
|
|
|
|
if i < n_layers - 1: |
|
|
res_skip_channels = 2 * n_channels |
|
|
else: |
|
|
res_skip_channels = n_channels |
|
|
res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) |
|
|
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') |
|
|
self.res_skip_layers.append(res_skip_layer) |
|
|
|
|
|
def forward(self, forward_input: Tuple[torch.Tensor, torch.Tensor]): |
|
|
audio, spect = forward_input[0], forward_input[1] |
|
|
audio = self.start(audio) |
|
|
output = torch.zeros_like(audio) |
|
|
|
|
|
spect = self.cond_layer(spect) |
|
|
|
|
|
for i in range(self.n_layers): |
|
|
spect_offset = i * 2 * self.n_channels |
|
|
acts = fused_add_tanh_sigmoid_multiply( |
|
|
self.in_layers[i](audio), |
|
|
spect[:, spect_offset : spect_offset + 2 * self.n_channels, :], |
|
|
self.n_channels, |
|
|
) |
|
|
|
|
|
res_skip_acts = self.res_skip_layers[i](acts) |
|
|
if i < self.n_layers - 1: |
|
|
audio = audio + res_skip_acts[:, : self.n_channels, :] |
|
|
output = output + res_skip_acts[:, self.n_channels :, :] |
|
|
else: |
|
|
output = output + res_skip_acts |
|
|
|
|
|
return self.end(output) |
|
|
|