|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from nemo.collections.tts.modules.submodules import ConvNorm |
|
|
from nemo.collections.tts.parts.utils.helpers import binarize_attention_parallel |
|
|
|
|
|
|
|
|
class AlignmentEncoder(torch.nn.Module): |
|
|
"""Module for alignment text and mel spectrogram. """ |
|
|
|
|
|
def __init__( |
|
|
self, n_mel_channels=80, n_text_channels=512, n_att_channels=80, temperature=0.0005, |
|
|
): |
|
|
super().__init__() |
|
|
self.temperature = temperature |
|
|
self.softmax = torch.nn.Softmax(dim=3) |
|
|
self.log_softmax = torch.nn.LogSoftmax(dim=3) |
|
|
|
|
|
self.key_proj = nn.Sequential( |
|
|
ConvNorm(n_text_channels, n_text_channels * 2, kernel_size=3, bias=True, w_init_gain='relu'), |
|
|
torch.nn.ReLU(), |
|
|
ConvNorm(n_text_channels * 2, n_att_channels, kernel_size=1, bias=True), |
|
|
) |
|
|
|
|
|
self.query_proj = nn.Sequential( |
|
|
ConvNorm(n_mel_channels, n_mel_channels * 2, kernel_size=3, bias=True, w_init_gain='relu'), |
|
|
torch.nn.ReLU(), |
|
|
ConvNorm(n_mel_channels * 2, n_mel_channels, kernel_size=1, bias=True), |
|
|
torch.nn.ReLU(), |
|
|
ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True), |
|
|
) |
|
|
|
|
|
def get_dist(self, keys, queries, mask=None): |
|
|
"""Calculation of distance matrix. |
|
|
|
|
|
Args: |
|
|
queries (torch.tensor): B x C x T1 tensor (probably going to be mel data). |
|
|
keys (torch.tensor): B x C2 x T2 tensor (text data). |
|
|
mask (torch.tensor): B x T2 x 1 tensor, binary mask for variable length entries and also can be used |
|
|
for ignoring unnecessary elements from keys in the resulting distance matrix (True = mask element, False = leave unchanged). |
|
|
Output: |
|
|
dist (torch.tensor): B x T1 x T2 tensor. |
|
|
""" |
|
|
keys_enc = self.key_proj(keys) |
|
|
queries_enc = self.query_proj(queries) |
|
|
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 |
|
|
dist = attn.sum(1, keepdim=True) |
|
|
|
|
|
if mask is not None: |
|
|
dist.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), float("inf")) |
|
|
|
|
|
return dist.squeeze(1) |
|
|
|
|
|
@staticmethod |
|
|
def get_durations(attn_soft, text_len, spect_len): |
|
|
"""Calculation of durations. |
|
|
|
|
|
Args: |
|
|
attn_soft (torch.tensor): B x 1 x T1 x T2 tensor. |
|
|
text_len (torch.tensor): B tensor, lengths of text. |
|
|
spect_len (torch.tensor): B tensor, lengths of mel spectrogram. |
|
|
""" |
|
|
attn_hard = binarize_attention_parallel(attn_soft, text_len, spect_len) |
|
|
durations = attn_hard.sum(2)[:, 0, :] |
|
|
assert torch.all(torch.eq(durations.sum(dim=1), spect_len)) |
|
|
return durations |
|
|
|
|
|
@staticmethod |
|
|
def get_mean_dist_by_durations(dist, durations, mask=None): |
|
|
"""Select elements from the distance matrix for the given durations and mask and return mean distance. |
|
|
|
|
|
Args: |
|
|
dist (torch.tensor): B x T1 x T2 tensor. |
|
|
durations (torch.tensor): B x T2 tensor. Dim T2 should sum to T1. |
|
|
mask (torch.tensor): B x T2 x 1 binary mask for variable length entries and also can be used |
|
|
for ignoring unnecessary elements in dist by T2 dim (True = mask element, False = leave unchanged). |
|
|
Output: |
|
|
mean_dist (torch.tensor): B x 1 tensor. |
|
|
""" |
|
|
batch_size, t1_size, t2_size = dist.size() |
|
|
assert torch.all(torch.eq(durations.sum(dim=1), t1_size)) |
|
|
|
|
|
if mask is not None: |
|
|
dist = dist.masked_fill(mask.permute(0, 2, 1).unsqueeze(2), 0) |
|
|
|
|
|
|
|
|
mean_dist_by_durations = [] |
|
|
for dist_idx in range(batch_size): |
|
|
mean_dist_by_durations.append( |
|
|
torch.mean( |
|
|
dist[ |
|
|
dist_idx, |
|
|
torch.arange(t1_size), |
|
|
torch.repeat_interleave(torch.arange(t2_size), repeats=durations[dist_idx]), |
|
|
] |
|
|
) |
|
|
) |
|
|
|
|
|
return torch.tensor(mean_dist_by_durations, dtype=dist.dtype, device=dist.device) |
|
|
|
|
|
@staticmethod |
|
|
def get_mean_distance_for_word(l2_dists, durs, start_token, num_tokens): |
|
|
"""Calculates the mean distance between text and audio embeddings given a range of text tokens. |
|
|
|
|
|
Args: |
|
|
l2_dists (torch.tensor): L2 distance matrix from Aligner inference. T1 x T2 tensor. |
|
|
durs (torch.tensor): List of durations corresponding to each text token. T2 tensor. Should sum to T1. |
|
|
start_token (int): Index of the starting token for the word of interest. |
|
|
num_tokens (int): Length (in tokens) of the word of interest. |
|
|
Output: |
|
|
mean_dist_for_word (float): Mean embedding distance between the word indicated and its predicted audio frames. |
|
|
""" |
|
|
|
|
|
start_frame = torch.sum(durs[:start_token]).data |
|
|
|
|
|
total_frames = 0 |
|
|
dist_sum = 0 |
|
|
|
|
|
|
|
|
for token_ind in range(start_token, start_token + num_tokens): |
|
|
|
|
|
for frame_ind in range(start_frame, start_frame + durs[token_ind]): |
|
|
|
|
|
dist_sum += l2_dists[frame_ind, token_ind] |
|
|
|
|
|
|
|
|
total_frames += durs[token_ind] |
|
|
start_frame += durs[token_ind] |
|
|
|
|
|
return dist_sum / total_frames |
|
|
|
|
|
def forward(self, queries, keys, mask=None, attn_prior=None, conditioning=None): |
|
|
"""Forward pass of the aligner encoder. |
|
|
|
|
|
Args: |
|
|
queries (torch.tensor): B x C x T1 tensor (probably going to be mel data). |
|
|
keys (torch.tensor): B x C2 x T2 tensor (text data). |
|
|
mask (torch.tensor): B x T2 x 1 tensor, binary mask for variable length entries (True = mask element, False = leave unchanged). |
|
|
attn_prior (torch.tensor): prior for attention matrix. |
|
|
conditioning (torch.tensor): B x T2 x 1 conditioning embedding |
|
|
Output: |
|
|
attn (torch.tensor): B x 1 x T1 x T2 attention mask. Final dim T2 should sum to 1. |
|
|
attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask. |
|
|
""" |
|
|
if conditioning is not None: |
|
|
keys = keys + conditioning.transpose(1, 2) |
|
|
keys_enc = self.key_proj(keys) |
|
|
queries_enc = self.query_proj(queries) |
|
|
|
|
|
|
|
|
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 |
|
|
attn = -self.temperature * attn.sum(1, keepdim=True) |
|
|
|
|
|
if attn_prior is not None: |
|
|
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8) |
|
|
|
|
|
attn_logprob = attn.clone() |
|
|
|
|
|
if mask is not None: |
|
|
attn.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), -float("inf")) |
|
|
|
|
|
attn = self.softmax(attn) |
|
|
return attn, attn_logprob |
|
|
|