# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import nn from nemo.collections.tts.modules.submodules import ConvNorm from nemo.collections.tts.parts.utils.helpers import binarize_attention_parallel class AlignmentEncoder(torch.nn.Module): """Module for alignment text and mel spectrogram. """ def __init__( self, n_mel_channels=80, n_text_channels=512, n_att_channels=80, temperature=0.0005, ): super().__init__() self.temperature = temperature self.softmax = torch.nn.Softmax(dim=3) self.log_softmax = torch.nn.LogSoftmax(dim=3) self.key_proj = nn.Sequential( ConvNorm(n_text_channels, n_text_channels * 2, kernel_size=3, bias=True, w_init_gain='relu'), torch.nn.ReLU(), ConvNorm(n_text_channels * 2, n_att_channels, kernel_size=1, bias=True), ) self.query_proj = nn.Sequential( ConvNorm(n_mel_channels, n_mel_channels * 2, kernel_size=3, bias=True, w_init_gain='relu'), torch.nn.ReLU(), ConvNorm(n_mel_channels * 2, n_mel_channels, kernel_size=1, bias=True), torch.nn.ReLU(), ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True), ) def get_dist(self, keys, queries, mask=None): """Calculation of distance matrix. Args: queries (torch.tensor): B x C x T1 tensor (probably going to be mel data). keys (torch.tensor): B x C2 x T2 tensor (text data). mask (torch.tensor): B x T2 x 1 tensor, binary mask for variable length entries and also can be used for ignoring unnecessary elements from keys in the resulting distance matrix (True = mask element, False = leave unchanged). Output: dist (torch.tensor): B x T1 x T2 tensor. """ keys_enc = self.key_proj(keys) # B x n_attn_dims x T2 queries_enc = self.query_proj(queries) # B x n_attn_dims x T1 attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2 dist = attn.sum(1, keepdim=True) # B x 1 x T1 x T2 if mask is not None: dist.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), float("inf")) return dist.squeeze(1) @staticmethod def get_durations(attn_soft, text_len, spect_len): """Calculation of durations. Args: attn_soft (torch.tensor): B x 1 x T1 x T2 tensor. text_len (torch.tensor): B tensor, lengths of text. spect_len (torch.tensor): B tensor, lengths of mel spectrogram. """ attn_hard = binarize_attention_parallel(attn_soft, text_len, spect_len) durations = attn_hard.sum(2)[:, 0, :] assert torch.all(torch.eq(durations.sum(dim=1), spect_len)) return durations @staticmethod def get_mean_dist_by_durations(dist, durations, mask=None): """Select elements from the distance matrix for the given durations and mask and return mean distance. Args: dist (torch.tensor): B x T1 x T2 tensor. durations (torch.tensor): B x T2 tensor. Dim T2 should sum to T1. mask (torch.tensor): B x T2 x 1 binary mask for variable length entries and also can be used for ignoring unnecessary elements in dist by T2 dim (True = mask element, False = leave unchanged). Output: mean_dist (torch.tensor): B x 1 tensor. """ batch_size, t1_size, t2_size = dist.size() assert torch.all(torch.eq(durations.sum(dim=1), t1_size)) if mask is not None: dist = dist.masked_fill(mask.permute(0, 2, 1).unsqueeze(2), 0) # TODO(oktai15): make it more efficient mean_dist_by_durations = [] for dist_idx in range(batch_size): mean_dist_by_durations.append( torch.mean( dist[ dist_idx, torch.arange(t1_size), torch.repeat_interleave(torch.arange(t2_size), repeats=durations[dist_idx]), ] ) ) return torch.tensor(mean_dist_by_durations, dtype=dist.dtype, device=dist.device) @staticmethod def get_mean_distance_for_word(l2_dists, durs, start_token, num_tokens): """Calculates the mean distance between text and audio embeddings given a range of text tokens. Args: l2_dists (torch.tensor): L2 distance matrix from Aligner inference. T1 x T2 tensor. durs (torch.tensor): List of durations corresponding to each text token. T2 tensor. Should sum to T1. start_token (int): Index of the starting token for the word of interest. num_tokens (int): Length (in tokens) of the word of interest. Output: mean_dist_for_word (float): Mean embedding distance between the word indicated and its predicted audio frames. """ # Need to calculate which audio frame we start on by summing all durations up to the start token's duration start_frame = torch.sum(durs[:start_token]).data total_frames = 0 dist_sum = 0 # Loop through each text token for token_ind in range(start_token, start_token + num_tokens): # Loop through each frame for the given text token for frame_ind in range(start_frame, start_frame + durs[token_ind]): # Recall that the L2 distance matrix is shape [spec_len, text_len] dist_sum += l2_dists[frame_ind, token_ind] # Update total frames so far & the starting frame for the next token total_frames += durs[token_ind] start_frame += durs[token_ind] return dist_sum / total_frames def forward(self, queries, keys, mask=None, attn_prior=None, conditioning=None): """Forward pass of the aligner encoder. Args: queries (torch.tensor): B x C x T1 tensor (probably going to be mel data). keys (torch.tensor): B x C2 x T2 tensor (text data). mask (torch.tensor): B x T2 x 1 tensor, binary mask for variable length entries (True = mask element, False = leave unchanged). attn_prior (torch.tensor): prior for attention matrix. conditioning (torch.tensor): B x T2 x 1 conditioning embedding Output: attn (torch.tensor): B x 1 x T1 x T2 attention mask. Final dim T2 should sum to 1. attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask. """ if conditioning is not None: keys = keys + conditioning.transpose(1, 2) keys_enc = self.key_proj(keys) # B x n_attn_dims x T2 queries_enc = self.query_proj(queries) # B x n_attn_dims x T1 # Simplistic Gaussian Isotopic Attention attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2 attn = -self.temperature * attn.sum(1, keepdim=True) if attn_prior is not None: attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8) attn_logprob = attn.clone() if mask is not None: attn.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), -float("inf")) attn = self.softmax(attn) # softmax along T2 return attn, attn_logprob