| |
| |
| |
| |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Tuple |
| from numba import jit |
| from scipy.stats import betabinom |
|
|
|
|
| class AlignmentModule(nn.Module): |
| """Alignment Learning Framework proposed for parallel TTS models in: |
| |
| https://arxiv.org/abs/2108.10447 |
| |
| """ |
|
|
| def __init__(self, adim, odim, cache_prior=True): |
| """Initialize AlignmentModule. |
| |
| Args: |
| adim (int): Dimension of attention. |
| odim (int): Dimension of feats. |
| cache_prior (bool): Whether to cache beta-binomial prior. |
| |
| """ |
| super().__init__() |
| self.cache_prior = cache_prior |
| self._cache = {} |
|
|
| self.t_conv1 = nn.Conv1d(adim, adim, kernel_size=3, padding=1) |
| self.t_conv2 = nn.Conv1d(adim, adim, kernel_size=1, padding=0) |
|
|
| self.f_conv1 = nn.Conv1d(odim, adim, kernel_size=3, padding=1) |
| self.f_conv2 = nn.Conv1d(adim, adim, kernel_size=3, padding=1) |
| self.f_conv3 = nn.Conv1d(adim, adim, kernel_size=1, padding=0) |
|
|
| def forward(self, text, feats, text_lengths, feats_lengths, x_masks=None): |
| """Calculate alignment loss. |
| |
| Args: |
| text (Tensor): Batched text embedding (B, T_text, adim). |
| feats (Tensor): Batched acoustic feature (B, T_feats, odim). |
| text_lengths (Tensor): Text length tensor (B,). |
| feats_lengths (Tensor): Feature length tensor (B,). |
| x_masks (Tensor): Mask tensor (B, T_text). |
| |
| Returns: |
| Tensor: Log probability of attention matrix (B, T_feats, T_text). |
| |
| """ |
| text = text.transpose(1, 2) |
| text = F.relu(self.t_conv1(text)) |
| text = self.t_conv2(text) |
| text = text.transpose(1, 2) |
|
|
| feats = feats.transpose(1, 2) |
| feats = F.relu(self.f_conv1(feats)) |
| feats = F.relu(self.f_conv2(feats)) |
| feats = self.f_conv3(feats) |
| feats = feats.transpose(1, 2) |
|
|
| dist = feats.unsqueeze(2) - text.unsqueeze(1) |
| dist = torch.norm(dist, p=2, dim=3) |
| score = -dist |
|
|
| if x_masks is not None: |
| x_masks = x_masks.unsqueeze(-2) |
| score = score.masked_fill(x_masks, -np.inf) |
|
|
| log_p_attn = F.log_softmax(score, dim=-1) |
|
|
| |
| bb_prior = self._generate_prior( |
| text_lengths, |
| feats_lengths, |
| ).to(dtype=log_p_attn.dtype, device=log_p_attn.device) |
| log_p_attn = log_p_attn + bb_prior |
|
|
| return log_p_attn |
|
|
| def _generate_prior(self, text_lengths, feats_lengths, w=1) -> torch.Tensor: |
| """Generate alignment prior formulated as beta-binomial distribution |
| |
| Args: |
| text_lengths (Tensor): Batch of the lengths of each input (B,). |
| feats_lengths (Tensor): Batch of the lengths of each target (B,). |
| w (float): Scaling factor; lower -> wider the width. |
| |
| Returns: |
| Tensor: Batched 2d static prior matrix (B, T_feats, T_text). |
| |
| """ |
| B = len(text_lengths) |
| T_text = text_lengths.max() |
| T_feats = feats_lengths.max() |
|
|
| bb_prior = torch.full((B, T_feats, T_text), fill_value=-np.inf) |
| for bidx in range(B): |
| T = feats_lengths[bidx].item() |
| N = text_lengths[bidx].item() |
|
|
| key = str(T) + "," + str(N) |
| if self.cache_prior and key in self._cache: |
| prob = self._cache[key] |
| else: |
| alpha = w * np.arange(1, T + 1, dtype=float) |
| beta = w * np.array([T - t + 1 for t in alpha]) |
| k = np.arange(N) |
| batched_k = k[..., None] |
| prob = betabinom.logpmf(batched_k, N, alpha, beta) |
|
|
| |
| if self.cache_prior and key not in self._cache: |
| self._cache[key] = prob |
|
|
| prob = torch.from_numpy(prob).transpose(0, 1) |
| bb_prior[bidx, :T, :N] = prob |
|
|
| return bb_prior |
|
|
|
|
| @jit(nopython=True) |
| def _monotonic_alignment_search(log_p_attn): |
| |
| T_mel = log_p_attn.shape[0] |
| T_inp = log_p_attn.shape[1] |
| Q = np.full((T_inp, T_mel), fill_value=-np.inf) |
|
|
| log_prob = log_p_attn.transpose(1, 0) |
| |
| for j in range(T_mel): |
| Q[0, j] = log_prob[0, : j + 1].sum() |
|
|
| |
| for j in range(1, T_mel): |
| for i in range(1, min(j + 1, T_inp)): |
| Q[i, j] = max(Q[i - 1, j - 1], Q[i, j - 1]) + log_prob[i, j] |
|
|
| |
| A = np.full((T_mel,), fill_value=T_inp - 1) |
| for j in range(T_mel - 2, -1, -1): |
| |
| i_a = A[j + 1] - 1 |
| i_b = A[j + 1] |
| if i_b == 0: |
| argmax_i = 0 |
| elif Q[i_a, j] >= Q[i_b, j]: |
| argmax_i = i_a |
| else: |
| argmax_i = i_b |
| A[j] = argmax_i |
| return A |
|
|
|
|
| def viterbi_decode(log_p_attn, text_lengths, feats_lengths): |
| """Extract duration from an attention probability matrix |
| |
| Args: |
| log_p_attn (Tensor): Batched log probability of attention |
| matrix (B, T_feats, T_text). |
| text_lengths (Tensor): Text length tensor (B,). |
| feats_legnths (Tensor): Feature length tensor (B,). |
| |
| Returns: |
| Tensor: Batched token duration extracted from `log_p_attn` (B, T_text). |
| Tensor: Binarization loss tensor (). |
| |
| """ |
| B = log_p_attn.size(0) |
| T_text = log_p_attn.size(2) |
| device = log_p_attn.device |
|
|
| bin_loss = 0 |
| ds = torch.zeros((B, T_text), device=device) |
| for b in range(B): |
| cur_log_p_attn = log_p_attn[b, : feats_lengths[b], : text_lengths[b]] |
| viterbi = _monotonic_alignment_search(cur_log_p_attn.detach().cpu().numpy()) |
| _ds = np.bincount(viterbi) |
| ds[b, : len(_ds)] = torch.from_numpy(_ds).to(device) |
|
|
| t_idx = torch.arange(feats_lengths[b]) |
| bin_loss = bin_loss - cur_log_p_attn[t_idx, viterbi].mean() |
| bin_loss = bin_loss / B |
| return ds, bin_loss |
|
|
|
|
| @jit(nopython=True) |
| def _average_by_duration(ds, xs, text_lengths, feats_lengths): |
| B = ds.shape[0] |
| xs_avg = np.zeros_like(ds) |
| ds = ds.astype(np.int32) |
| for b in range(B): |
| t_text = text_lengths[b] |
| t_feats = feats_lengths[b] |
| d = ds[b, :t_text] |
| d_cumsum = d.cumsum() |
| d_cumsum = [0] + list(d_cumsum) |
| x = xs[b, :t_feats] |
| for n, (start, end) in enumerate(zip(d_cumsum[:-1], d_cumsum[1:])): |
| if len(x[start:end]) != 0: |
| xs_avg[b, n] = x[start:end].mean() |
| else: |
| xs_avg[b, n] = 0 |
| return xs_avg |
|
|
|
|
| def average_by_duration(ds, xs, text_lengths, feats_lengths): |
| """Average frame-level features into token-level according to durations |
| |
| Args: |
| ds (Tensor): Batched token duration (B, T_text). |
| xs (Tensor): Batched feature sequences to be averaged (B, T_feats). |
| text_lengths (Tensor): Text length tensor (B,). |
| feats_lengths (Tensor): Feature length tensor (B,). |
| |
| Returns: |
| Tensor: Batched feature averaged according to the token duration (B, T_text). |
| |
| """ |
| device = ds.device |
| args = [ds, xs, text_lengths, feats_lengths] |
| args = [arg.detach().cpu().numpy() for arg in args] |
| xs_avg = _average_by_duration(*args) |
| xs_avg = torch.from_numpy(xs_avg).to(device) |
| return xs_avg |
|
|
|
|
| def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): |
| """Make mask tensor containing indices of padded part. |
| |
| Args: |
| lengths (LongTensor or List): Batch of lengths (B,). |
| xs (Tensor, optional): The reference tensor. |
| If set, masks will be the same shape as this tensor. |
| length_dim (int, optional): Dimension indicator of the above tensor. |
| See the example. |
| |
| Returns: |
| Tensor: Mask tensor containing indices of padded part. |
| dtype=torch.uint8 in PyTorch 1.2- |
| dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
| |
| Examples: |
| With only lengths. |
| |
| >>> lengths = [5, 3, 2] |
| >>> make_pad_mask(lengths) |
| masks = [[0, 0, 0, 0 ,0], |
| [0, 0, 0, 1, 1], |
| [0, 0, 1, 1, 1]] |
| |
| With the reference tensor. |
| |
| >>> xs = torch.zeros((3, 2, 4)) |
| >>> make_pad_mask(lengths, xs) |
| tensor([[[0, 0, 0, 0], |
| [0, 0, 0, 0]], |
| [[0, 0, 0, 1], |
| [0, 0, 0, 1]], |
| [[0, 0, 1, 1], |
| [0, 0, 1, 1]]], dtype=torch.uint8) |
| >>> xs = torch.zeros((3, 2, 6)) |
| >>> make_pad_mask(lengths, xs) |
| tensor([[[0, 0, 0, 0, 0, 1], |
| [0, 0, 0, 0, 0, 1]], |
| [[0, 0, 0, 1, 1, 1], |
| [0, 0, 0, 1, 1, 1]], |
| [[0, 0, 1, 1, 1, 1], |
| [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) |
| |
| With the reference tensor and dimension indicator. |
| |
| >>> xs = torch.zeros((3, 6, 6)) |
| >>> make_pad_mask(lengths, xs, 1) |
| tensor([[[0, 0, 0, 0, 0, 0], |
| [0, 0, 0, 0, 0, 0], |
| [0, 0, 0, 0, 0, 0], |
| [0, 0, 0, 0, 0, 0], |
| [0, 0, 0, 0, 0, 0], |
| [1, 1, 1, 1, 1, 1]], |
| [[0, 0, 0, 0, 0, 0], |
| [0, 0, 0, 0, 0, 0], |
| [0, 0, 0, 0, 0, 0], |
| [1, 1, 1, 1, 1, 1], |
| [1, 1, 1, 1, 1, 1], |
| [1, 1, 1, 1, 1, 1]], |
| [[0, 0, 0, 0, 0, 0], |
| [0, 0, 0, 0, 0, 0], |
| [1, 1, 1, 1, 1, 1], |
| [1, 1, 1, 1, 1, 1], |
| [1, 1, 1, 1, 1, 1], |
| [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) |
| >>> make_pad_mask(lengths, xs, 2) |
| tensor([[[0, 0, 0, 0, 0, 1], |
| [0, 0, 0, 0, 0, 1], |
| [0, 0, 0, 0, 0, 1], |
| [0, 0, 0, 0, 0, 1], |
| [0, 0, 0, 0, 0, 1], |
| [0, 0, 0, 0, 0, 1]], |
| [[0, 0, 0, 1, 1, 1], |
| [0, 0, 0, 1, 1, 1], |
| [0, 0, 0, 1, 1, 1], |
| [0, 0, 0, 1, 1, 1], |
| [0, 0, 0, 1, 1, 1], |
| [0, 0, 0, 1, 1, 1]], |
| [[0, 0, 1, 1, 1, 1], |
| [0, 0, 1, 1, 1, 1], |
| [0, 0, 1, 1, 1, 1], |
| [0, 0, 1, 1, 1, 1], |
| [0, 0, 1, 1, 1, 1], |
| [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) |
| |
| """ |
| if length_dim == 0: |
| raise ValueError("length_dim cannot be 0: {}".format(length_dim)) |
|
|
| if not isinstance(lengths, list): |
| lengths = lengths.tolist() |
| bs = int(len(lengths)) |
| if maxlen is None: |
| if xs is None: |
| maxlen = int(max(lengths)) |
| else: |
| maxlen = xs.size(length_dim) |
| else: |
| assert xs is None |
| assert maxlen >= int(max(lengths)) |
|
|
| seq_range = torch.arange(0, maxlen, dtype=torch.int64) |
| seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) |
| seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) |
| mask = seq_range_expand >= seq_length_expand |
|
|
| if xs is not None: |
| assert xs.size(0) == bs, (xs.size(0), bs) |
|
|
| if length_dim < 0: |
| length_dim = xs.dim() + length_dim |
| |
| ind = tuple( |
| slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) |
| ) |
| mask = mask[ind].expand_as(xs).to(xs.device) |
| return mask |
|
|
|
|
| def make_non_pad_mask(lengths, xs=None, length_dim=-1): |
| """Make mask tensor containing indices of non-padded part. |
| |
| Args: |
| lengths (LongTensor or List): Batch of lengths (B,). |
| xs (Tensor, optional): The reference tensor. |
| If set, masks will be the same shape as this tensor. |
| length_dim (int, optional): Dimension indicator of the above tensor. |
| See the example. |
| |
| Returns: |
| ByteTensor: mask tensor containing indices of padded part. |
| dtype=torch.uint8 in PyTorch 1.2- |
| dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
| |
| Examples: |
| With only lengths. |
| |
| >>> lengths = [5, 3, 2] |
| >>> make_non_pad_mask(lengths) |
| masks = [[1, 1, 1, 1 ,1], |
| [1, 1, 1, 0, 0], |
| [1, 1, 0, 0, 0]] |
| |
| With the reference tensor. |
| |
| >>> xs = torch.zeros((3, 2, 4)) |
| >>> make_non_pad_mask(lengths, xs) |
| tensor([[[1, 1, 1, 1], |
| [1, 1, 1, 1]], |
| [[1, 1, 1, 0], |
| [1, 1, 1, 0]], |
| [[1, 1, 0, 0], |
| [1, 1, 0, 0]]], dtype=torch.uint8) |
| >>> xs = torch.zeros((3, 2, 6)) |
| >>> make_non_pad_mask(lengths, xs) |
| tensor([[[1, 1, 1, 1, 1, 0], |
| [1, 1, 1, 1, 1, 0]], |
| [[1, 1, 1, 0, 0, 0], |
| [1, 1, 1, 0, 0, 0]], |
| [[1, 1, 0, 0, 0, 0], |
| [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) |
| |
| With the reference tensor and dimension indicator. |
| |
| >>> xs = torch.zeros((3, 6, 6)) |
| >>> make_non_pad_mask(lengths, xs, 1) |
| tensor([[[1, 1, 1, 1, 1, 1], |
| [1, 1, 1, 1, 1, 1], |
| [1, 1, 1, 1, 1, 1], |
| [1, 1, 1, 1, 1, 1], |
| [1, 1, 1, 1, 1, 1], |
| [0, 0, 0, 0, 0, 0]], |
| [[1, 1, 1, 1, 1, 1], |
| [1, 1, 1, 1, 1, 1], |
| [1, 1, 1, 1, 1, 1], |
| [0, 0, 0, 0, 0, 0], |
| [0, 0, 0, 0, 0, 0], |
| [0, 0, 0, 0, 0, 0]], |
| [[1, 1, 1, 1, 1, 1], |
| [1, 1, 1, 1, 1, 1], |
| [0, 0, 0, 0, 0, 0], |
| [0, 0, 0, 0, 0, 0], |
| [0, 0, 0, 0, 0, 0], |
| [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) |
| >>> make_non_pad_mask(lengths, xs, 2) |
| tensor([[[1, 1, 1, 1, 1, 0], |
| [1, 1, 1, 1, 1, 0], |
| [1, 1, 1, 1, 1, 0], |
| [1, 1, 1, 1, 1, 0], |
| [1, 1, 1, 1, 1, 0], |
| [1, 1, 1, 1, 1, 0]], |
| [[1, 1, 1, 0, 0, 0], |
| [1, 1, 1, 0, 0, 0], |
| [1, 1, 1, 0, 0, 0], |
| [1, 1, 1, 0, 0, 0], |
| [1, 1, 1, 0, 0, 0], |
| [1, 1, 1, 0, 0, 0]], |
| [[1, 1, 0, 0, 0, 0], |
| [1, 1, 0, 0, 0, 0], |
| [1, 1, 0, 0, 0, 0], |
| [1, 1, 0, 0, 0, 0], |
| [1, 1, 0, 0, 0, 0], |
| [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) |
| |
| """ |
| return ~make_pad_mask(lengths, xs, length_dim) |
|
|
|
|
| def get_random_segments( |
| x: torch.Tensor, |
| x_lengths: torch.Tensor, |
| segment_size: int, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Get random segments. |
| |
| Args: |
| x (Tensor): Input tensor (B, C, T). |
| x_lengths (Tensor): Length tensor (B,). |
| segment_size (int): Segment size. |
| |
| Returns: |
| Tensor: Segmented tensor (B, C, segment_size). |
| Tensor: Start index tensor (B,). |
| |
| """ |
| b, c, t = x.size() |
| max_start_idx = x_lengths - segment_size |
| start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to( |
| dtype=torch.long, |
| ) |
| segments = get_segments(x, start_idxs, segment_size) |
| return segments, start_idxs |
|
|
|
|
| def get_segments( |
| x: torch.Tensor, |
| start_idxs: torch.Tensor, |
| segment_size: int, |
| ) -> torch.Tensor: |
| """Get segments. |
| |
| Args: |
| x (Tensor): Input tensor (B, C, T). |
| start_idxs (Tensor): Start index tensor (B,). |
| segment_size (int): Segment size. |
| |
| Returns: |
| Tensor: Segmented tensor (B, C, segment_size). |
| |
| """ |
| b, c, t = x.size() |
| segments = x.new_zeros(b, c, segment_size) |
| for i, start_idx in enumerate(start_idxs): |
| segments[i] = x[i, :, start_idx : start_idx + segment_size] |
| return segments |
|
|