Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| class ConvNorm(torch.nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, | |
| padding=None, dilation=1, bias=True, w_init_gain='linear'): | |
| super(ConvNorm, self).__init__() | |
| if padding is None: | |
| assert(kernel_size % 2 == 1) | |
| padding = int(dilation * (kernel_size - 1) / 2) | |
| self.conv = torch.nn.Conv1d(in_channels, out_channels, | |
| kernel_size=kernel_size, stride=stride, | |
| padding=padding, dilation=dilation, | |
| bias=bias) | |
| torch.nn.init.xavier_uniform_( | |
| self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) | |
| def forward(self, signal): | |
| conv_signal = self.conv(signal) | |
| return conv_signal | |
| class Invertible1x1ConvLUS(torch.nn.Module): | |
| def __init__(self, c): | |
| super(Invertible1x1ConvLUS, self).__init__() | |
| # Sample a random orthonormal matrix to initialize weights | |
| W, _ = torch.linalg.qr(torch.randn(c, c)) | |
| # Ensure determinant is 1.0 not -1.0 | |
| if torch.det(W) < 0: | |
| W[:, 0] = -1*W[:, 0] | |
| p, lower, upper = torch.lu_unpack(*torch.lu(W)) | |
| self.register_buffer('p', p) | |
| # diagonals of lower will always be 1s anyway | |
| lower = torch.tril(lower, -1) | |
| lower_diag = torch.diag(torch.eye(c, c)) | |
| self.register_buffer('lower_diag', lower_diag) | |
| self.lower = nn.Parameter(lower) | |
| self.upper_diag = nn.Parameter(torch.diag(upper)) | |
| self.upper = nn.Parameter(torch.triu(upper, 1)) | |
| def forward(self, z, reverse=False): | |
| U = torch.triu(self.upper, 1) + torch.diag(self.upper_diag) | |
| L = torch.tril(self.lower, -1) + torch.diag(self.lower_diag) | |
| W = torch.mm(self.p, torch.mm(L, U)) | |
| if reverse: | |
| if not hasattr(self, 'W_inverse'): | |
| # Reverse computation | |
| W_inverse = W.float().inverse() | |
| if z.type() == 'torch.cuda.HalfTensor': | |
| W_inverse = W_inverse.half() | |
| self.W_inverse = W_inverse[..., None] | |
| z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) | |
| return z | |
| else: | |
| W = W[..., None] | |
| z = F.conv1d(z, W, bias=None, stride=1, padding=0) | |
| log_det_W = torch.sum(torch.log(torch.abs(self.upper_diag))) | |
| return z, log_det_W | |
| class ConvAttention(torch.nn.Module): | |
| def __init__(self, n_mel_channels=80, n_speaker_dim=128, | |
| n_text_channels=512, n_att_channels=80, temperature=1.0, | |
| n_mel_convs=2, align_query_enc_type='3xconv', | |
| use_query_proj=True): | |
| super(ConvAttention, self).__init__() | |
| self.temperature = temperature | |
| self.att_scaling_factor = np.sqrt(n_att_channels) | |
| self.softmax = torch.nn.Softmax(dim=3) | |
| self.log_softmax = torch.nn.LogSoftmax(dim=3) | |
| self.query_proj = Invertible1x1ConvLUS(n_mel_channels) | |
| self.attn_proj = torch.nn.Conv2d(n_att_channels, 1, kernel_size=1) | |
| self.align_query_enc_type = align_query_enc_type | |
| self.use_query_proj = bool(use_query_proj) | |
| self.key_proj = nn.Sequential( | |
| ConvNorm(n_text_channels, | |
| n_text_channels * 2, | |
| kernel_size=3, | |
| bias=True, | |
| w_init_gain='relu'), | |
| torch.nn.ReLU(), | |
| ConvNorm(n_text_channels * 2, | |
| n_att_channels, | |
| kernel_size=1, | |
| bias=True)) | |
| self.align_query_enc_type = align_query_enc_type | |
| if align_query_enc_type == "inv_conv": | |
| self.query_proj = Invertible1x1ConvLUS(n_mel_channels) | |
| elif align_query_enc_type == "3xconv": | |
| self.query_proj = nn.Sequential( | |
| ConvNorm(n_mel_channels, | |
| n_mel_channels * 2, | |
| kernel_size=3, | |
| bias=True, | |
| w_init_gain='relu'), | |
| torch.nn.ReLU(), | |
| ConvNorm(n_mel_channels * 2, | |
| n_mel_channels, | |
| kernel_size=1, | |
| bias=True), | |
| torch.nn.ReLU(), | |
| ConvNorm(n_mel_channels, | |
| n_att_channels, | |
| kernel_size=1, | |
| bias=True)) | |
| else: | |
| raise ValueError("Unknown query encoder type specified") | |
| def run_padded_sequence(self, sorted_idx, unsort_idx, lens, padded_data, | |
| recurrent_model): | |
| """Sorts input data by previded ordering (and un-ordering) and runs the | |
| packed data through the recurrent model | |
| Args: | |
| sorted_idx (torch.tensor): 1D sorting index | |
| unsort_idx (torch.tensor): 1D unsorting index (inverse of sorted_idx) | |
| lens: lengths of input data (sorted in descending order) | |
| padded_data (torch.tensor): input sequences (padded) | |
| recurrent_model (nn.Module): recurrent model to run data through | |
| Returns: | |
| hidden_vectors (torch.tensor): outputs of the RNN, in the original, | |
| unsorted, ordering | |
| """ | |
| # sort the data by decreasing length using provided index | |
| # we assume batch index is in dim=1 | |
| padded_data = padded_data[:, sorted_idx] | |
| padded_data = nn.utils.rnn.pack_padded_sequence(padded_data, lens) | |
| hidden_vectors = recurrent_model(padded_data)[0] | |
| hidden_vectors, _ = nn.utils.rnn.pad_packed_sequence(hidden_vectors) | |
| # unsort the results at dim=1 and return | |
| hidden_vectors = hidden_vectors[:, unsort_idx] | |
| return hidden_vectors | |
| def encode_query(self, query, query_lens): | |
| query = query.permute(2, 0, 1) # seq_len, batch, feature dim | |
| lens, ids = torch.sort(query_lens, descending=True) | |
| original_ids = [0] * lens.size(0) | |
| for i in range(len(ids)): | |
| original_ids[ids[i]] = i | |
| query_encoded = self.run_padded_sequence(ids, original_ids, lens, | |
| query, self.query_lstm) | |
| query_encoded = query_encoded.permute(1, 2, 0) | |
| return query_encoded | |
| def forward(self, queries, keys, query_lens, mask=None, key_lens=None, | |
| keys_encoded=None, attn_prior=None): | |
| """Attention mechanism for flowtron parallel | |
| Unlike in Flowtron, we have no restrictions such as causality etc, | |
| since we only need this during training. | |
| Args: | |
| queries (torch.tensor): B x C x T1 tensor | |
| (probably going to be mel data) | |
| keys (torch.tensor): B x C2 x T2 tensor (text data) | |
| query_lens: lengths for sorting the queries in descending order | |
| mask (torch.tensor): uint8 binary mask for variable length entries | |
| (should be in the T2 domain) | |
| Output: | |
| attn (torch.tensor): B x 1 x T1 x T2 attention mask. | |
| Final dim T2 should sum to 1 | |
| """ | |
| keys_enc = self.key_proj(keys) # B x n_attn_dims x T2 | |
| # Beware can only do this since query_dim = attn_dim = n_mel_channels | |
| if self.use_query_proj: | |
| if self.align_query_enc_type == "inv_conv": | |
| queries_enc, log_det_W = self.query_proj(queries) | |
| elif self.align_query_enc_type == "3xconv": | |
| queries_enc = self.query_proj(queries) | |
| log_det_W = 0.0 | |
| else: | |
| queries_enc, log_det_W = self.query_proj(queries) | |
| else: | |
| queries_enc, log_det_W = queries, 0.0 | |
| # different ways of computing attn, | |
| # one is isotopic gaussians (per phoneme) | |
| # Simplistic Gaussian Isotopic Attention | |
| # B x n_attn_dims x T1 x T2 | |
| attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 | |
| # compute log likelihood from a gaussian | |
| attn = -0.0005 * attn.sum(1, keepdim=True) | |
| if attn_prior is not None: | |
| attn = self.log_softmax(attn) + torch.log(attn_prior[:, None]+1e-8) | |
| attn_logprob = attn.clone() | |
| if mask is not None: | |
| attn.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), | |
| -float("inf")) | |
| attn = self.softmax(attn) # Softmax along T2 | |
| return attn, attn_logprob | |