| from typing import Sequence |
|
|
| import math |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from typeguard import check_argument_types |
|
|
|
|
| class VectorQuantizer(nn.Module): |
| """ |
| Reference: |
| [1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py |
| """ |
| def __init__(self, |
| num_embeddings: int, |
| hidden_dim: int, |
| beta: float = 0.25): |
| super().__init__() |
| self.K = num_embeddings |
| self.D = hidden_dim |
| self.beta = 0.05 |
|
|
| self.embedding = nn.Embedding(self.K, self.D) |
| self.embedding.weight.data.normal_(0.8, 0.1) |
|
|
| def forward(self, latents: torch.Tensor) -> torch.Tensor: |
| |
| latents_shape = latents.shape |
| flat_latents = latents.view(-1, self.D) |
|
|
| |
| dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \ |
| torch.sum(self.embedding.weight ** 2, dim=1) - \ |
| 2 * torch.matmul(flat_latents, self.embedding.weight.t()) |
|
|
| |
| encoding_inds = torch.argmin(dist, dim=1) |
| output_inds = encoding_inds.view(latents_shape[0], latents_shape[1]) |
| encoding_inds = encoding_inds.unsqueeze(1) |
|
|
| |
| device = latents.device |
| encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device) |
| encoding_one_hot.scatter_(1, encoding_inds, 1) |
|
|
| |
| |
| quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight) |
| quantized_latents = quantized_latents.view(latents_shape) |
|
|
| |
| commitment_loss = F.mse_loss(quantized_latents.detach(), latents) |
| embedding_loss = F.mse_loss(quantized_latents, latents.detach()) |
|
|
| vq_loss = commitment_loss * self.beta + embedding_loss |
|
|
| |
| quantized_latents = latents + (quantized_latents - latents).detach() |
|
|
| |
| |
|
|
| |
| |
| avg_probs = torch.mean(encoding_one_hot, dim=0) |
| |
| perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) |
|
|
| return quantized_latents, vq_loss, output_inds, self.embedding, perplexity |
|
|
|
|
| class ProsodyEncoder(nn.Module): |
| """VQ-VAE prosody encoder module. |
| |
| Args: |
| odim (int): Number of input channels (mel spectrogram channels). |
| ref_enc_conv_layers (int, optional): |
| The number of conv layers in the reference encoder. |
| ref_enc_conv_chans_list: (Sequence[int], optional): |
| List of the number of channels of conv layers in the referece encoder. |
| ref_enc_conv_kernel_size (int, optional): |
| Kernal size of conv layers in the reference encoder. |
| ref_enc_conv_stride (int, optional): |
| Stride size of conv layers in the reference encoder. |
| ref_enc_gru_layers (int, optional): |
| The number of GRU layers in the reference encoder. |
| ref_enc_gru_units (int, optional): |
| The number of GRU units in the reference encoder. |
| ref_emb_integration_type: How to integrate reference embedding. |
| adim (int, optional): This value is not that important. |
| This will not change the capacity in the information-bottleneck. |
| num_embeddings (int, optional): The higher this value, the higher the |
| capacity in the information bottleneck. |
| FG (int, optional): Number of hidden channels. |
| """ |
| def __init__( |
| self, |
| odim: int, |
| adim: int = 64, |
| num_embeddings: int = 10, |
| hidden_dim: int = 3, |
| beta: float = 0.25, |
| ref_enc_conv_layers: int = 2, |
| ref_enc_conv_chans_list: Sequence[int] = (32, 32), |
| ref_enc_conv_kernel_size: int = 3, |
| ref_enc_conv_stride: int = 1, |
| global_enc_gru_layers: int = 1, |
| global_enc_gru_units: int = 32, |
| global_emb_integration_type: str = "add", |
| ) -> None: |
| assert check_argument_types() |
| super().__init__() |
|
|
| |
| self.global_emb_integration_type = global_emb_integration_type |
|
|
| padding = (ref_enc_conv_kernel_size - 1) // 2 |
|
|
| self.ref_encoder = RefEncoder( |
| ref_enc_conv_layers=ref_enc_conv_layers, |
| ref_enc_conv_chans_list=ref_enc_conv_chans_list, |
| ref_enc_conv_kernel_size=ref_enc_conv_kernel_size, |
| ref_enc_conv_stride=ref_enc_conv_stride, |
| ref_enc_conv_padding=padding, |
| ) |
|
|
| |
| ref_enc_output_units = odim |
| for i in range(ref_enc_conv_layers): |
| ref_enc_output_units = ( |
| ref_enc_output_units - ref_enc_conv_kernel_size + 2 * padding |
| ) // ref_enc_conv_stride + 1 |
| ref_enc_output_units *= ref_enc_conv_chans_list[-1] |
|
|
| self.fg_encoder = FGEncoder( |
| ref_enc_output_units + global_enc_gru_units, |
| hidden_dim=hidden_dim, |
| ) |
|
|
| self.global_encoder = GlobalEncoder( |
| ref_enc_output_units, |
| global_enc_gru_layers=global_enc_gru_layers, |
| global_enc_gru_units=global_enc_gru_units, |
| ) |
|
|
| |
| if self.global_emb_integration_type == "add": |
| self.global_projection = nn.Linear(global_enc_gru_units, adim) |
| else: |
| self.global_projection = nn.Linear( |
| adim + global_enc_gru_units, adim |
| ) |
|
|
| self.ar_prior = ARPrior( |
| adim, |
| num_embeddings=num_embeddings, |
| hidden_dim=hidden_dim, |
| ) |
|
|
| self.vq_layer = VectorQuantizer(num_embeddings, hidden_dim, beta) |
|
|
| |
| self.qfg_projection = nn.Linear(hidden_dim, adim) |
|
|
| def forward( |
| self, |
| ys: torch.Tensor, |
| ds: torch.Tensor, |
| hs: torch.Tensor, |
| global_embs: torch.Tensor = None, |
| train_ar_prior: bool = False, |
| ar_prior_inference: bool = False, |
| fg_inds: torch.Tensor = None, |
| ) -> Sequence[torch.Tensor]: |
| """Calculate forward propagation. |
| |
| Args: |
| ys (Tensor): Batch of padded target features (B, Lmax, odim). |
| ds (LongTensor): Batch of padded durations (B, Tmax). |
| hs (Tensor): Batch of phoneme embeddings (B, Tmax, D). |
| global_embs (Tensor, optional): Global embeddings (B, D) |
| |
| Returns: |
| Tensor: Fine-grained quantized prosody embeddings (B, Tmax, adim). |
| Tensor: VQ loss. |
| Tensor: Global prosody embeddings (B, ref_enc_gru_units) |
| """ |
| if ys is not None: |
| print('generating global_embs') |
| ref_embs = self.ref_encoder(ys) |
| global_embs = self.global_encoder(ref_embs) |
|
|
| if ar_prior_inference: |
| print('Using ar prior') |
| hs_integrated = self._integrate_with_global_embs(hs, global_embs) |
| qs, top_inds = self.ar_prior.inference( |
| hs_integrated, fg_inds, self.vq_layer.embedding |
| ) |
|
|
| qs = self.qfg_projection(qs) |
| assert hs.size(2) == qs.size(2) |
|
|
| p_embs = self._integrate_with_global_embs(qs, global_embs) |
| assert hs.shape == p_embs.shape |
|
|
| return p_embs, 0, 0, 0, top_inds |
|
|
| |
| global_embs_expanded = global_embs.unsqueeze(1).expand(-1, ref_embs.size(1), -1) |
| |
| ref_embs_integrated = torch.cat([ref_embs, global_embs_expanded], dim=-1) |
|
|
| |
| fg_embs = self.fg_encoder(ref_embs_integrated, ds, ys.size(1)) |
|
|
| |
| qs, vq_loss, inds, codebook, perplexity = self.vq_layer(fg_embs) |
| |
| assert hs.size(1) == qs.size(1) |
|
|
| qs = self.qfg_projection(qs) |
| assert hs.size(2) == qs.size(2) |
|
|
| p_embs = self._integrate_with_global_embs(qs, global_embs) |
| assert hs.shape == p_embs.shape |
|
|
| ar_prior_loss = 0 |
| if train_ar_prior: |
| |
| hs_integrated = self._integrate_with_global_embs(hs, global_embs) |
| qs, ar_prior_loss = self.ar_prior(hs_integrated, inds, codebook) |
| qs = self.qfg_projection(qs) |
| assert hs.size(2) == qs.size(2) |
|
|
| p_embs = self._integrate_with_global_embs(qs, global_embs) |
| assert hs.shape == p_embs.shape |
|
|
| return p_embs, vq_loss, ar_prior_loss, perplexity, global_embs |
|
|
| def _integrate_with_global_embs( |
| self, |
| qs: torch.Tensor, |
| global_embs: torch.Tensor |
| ) -> torch.Tensor: |
| """Integrate ref embedding with spectrogram hidden states. |
| |
| Args: |
| qs (Tensor): Batch of quantized FG embeddings (B, Tmax, adim). |
| global_embs (Tensor): Batch of global embeddings (B, global_enc_gru_units). |
| |
| Returns: |
| Tensor: Batch of integrated hidden state sequences (B, Tmax, adim). |
| """ |
| if self.global_emb_integration_type == "add": |
| |
| global_embs = self.global_projection(global_embs) |
| res = qs + global_embs.unsqueeze(1) |
| elif self.global_emb_integration_type == "concat": |
| |
| |
| global_embs = global_embs.unsqueeze(1).expand(-1, qs.size(1), -1) |
| |
| res = self.prosody_projection(torch.cat([qs, global_embs], dim=-1)) |
| else: |
| raise NotImplementedError("support only add or concat.") |
|
|
| return res |
|
|
|
|
| class RefEncoder(nn.Module): |
| def __init__( |
| self, |
| ref_enc_conv_layers: int = 2, |
| ref_enc_conv_chans_list: Sequence[int] = (32, 32), |
| ref_enc_conv_kernel_size: int = 3, |
| ref_enc_conv_stride: int = 1, |
| ref_enc_conv_padding: int = 1, |
| ): |
| """Initilize reference encoder module.""" |
| assert check_argument_types() |
| super().__init__() |
|
|
| |
| assert ref_enc_conv_kernel_size % 2 == 1, "kernel size must be odd." |
| assert ( |
| len(ref_enc_conv_chans_list) == ref_enc_conv_layers |
| ), "the number of conv layers and length of channels list must be the same." |
|
|
| convs = [] |
| for i in range(ref_enc_conv_layers): |
| conv_in_chans = 1 if i == 0 else ref_enc_conv_chans_list[i - 1] |
| conv_out_chans = ref_enc_conv_chans_list[i] |
| convs += [ |
| nn.Conv2d( |
| conv_in_chans, |
| conv_out_chans, |
| kernel_size=ref_enc_conv_kernel_size, |
| stride=ref_enc_conv_stride, |
| padding=ref_enc_conv_padding, |
| ), |
| nn.ReLU(inplace=True), |
|
|
| ] |
| self.convs = nn.Sequential(*convs) |
|
|
| def forward(self, ys: torch.Tensor) -> torch.Tensor: |
| """Calculate forward propagation. |
| |
| Args: |
| ys (Tensor): Batch of padded target features (B, Lmax, odim). |
| |
| Returns: |
| Tensor: Batch of spectrogram hiddens (B, L', ref_enc_output_units) |
| |
| """ |
| B = ys.size(0) |
| ys = ys.unsqueeze(1) |
| hs = self.convs(ys) |
| hs = hs.transpose(1, 2) |
| L = hs.size(1) |
| |
| hs = hs.contiguous().view(B, L, -1) |
|
|
| return hs |
|
|
|
|
| class GlobalEncoder(nn.Module): |
| """Module that creates a global embedding from a hidden spectrogram sequence. |
| |
| Args: |
| """ |
| def __init__( |
| self, |
| ref_enc_output_units: int, |
| global_enc_gru_layers: int = 1, |
| global_enc_gru_units: int = 32, |
| ): |
| super().__init__() |
| self.gru = torch.nn.GRU(ref_enc_output_units, global_enc_gru_units, |
| global_enc_gru_layers, batch_first=True) |
|
|
| def forward( |
| self, |
| hs: torch.Tensor, |
| ): |
| """Calculate forward propagation. |
| |
| Args: |
| hs (Tensor): Batch of spectrogram hiddens (B, L', ref_enc_output_units). |
| |
| Returns: |
| Tensor: Reference embedding (B, ref_enc_gru_units). |
| """ |
| self.gru.flatten_parameters() |
| _, global_embs = self.gru(hs) |
| global_embs = global_embs[-1] |
|
|
| return global_embs |
|
|
|
|
| class FGEncoder(nn.Module): |
| """Spectrogram to phoneme alignment module. |
| |
| Args: |
| """ |
| def __init__( |
| self, |
| input_units: int, |
| hidden_dim: int = 3, |
| ): |
| assert check_argument_types() |
| super().__init__() |
|
|
| self.projection = nn.Sequential( |
| nn.Sequential( |
| nn.Linear(input_units, input_units // 2), |
| nn.ReLU(), |
| nn.Dropout(p=0.2), |
| ), |
| nn.Sequential( |
| nn.Linear(input_units // 2, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(p=0.2), |
| ) |
| ) |
|
|
| def forward( |
| self, |
| hs: torch.Tensor, |
| ds: torch.Tensor, |
| Lmax: int |
| ): |
| """Calculate forward propagation. |
| |
| Args: |
| hs (Tensor): Batch of spectrogram hiddens |
| (B, L', ref_enc_output_units + global_enc_gru_units). |
| ds (LongTensor): Batch of padded durations (B, Tmax). |
| |
| Returns: |
| Tensor: aligned spectrogram hiddens (B, Tmax, hidden_dim). |
| """ |
| |
| hs = self._align_durations(hs, ds, Lmax) |
| hs = self.projection(hs) |
|
|
| return hs |
|
|
| def _align_durations(self, hs, ds, Lmax): |
| """Transform the spectrogram hiddens according to the ground-truth durations |
| so that there's only one hidden per phoneme hidden. |
| |
| Args: |
| # (B, L', ref_enc_output_units + global_enc_gru_units) |
| hs (Tensor): Batch of spectrogram hidden state sequences . |
| ds (LongTensor): Batch of padded durations (B, Tmax) |
| |
| Returns: |
| # (B, Tmax, ref_enc_output_units + global_enc_gru_units) |
| Tensor: Batch of averaged spectrogram hidden state sequences. |
| """ |
| B = hs.size(0) |
| L = hs.size(1) |
| D = hs.size(2) |
|
|
| Tmax = ds.size(1) |
|
|
| device = hs.device |
| hs_res = torch.zeros( |
| [B, Tmax, D], |
| device=device |
| ) |
|
|
| with torch.no_grad(): |
| for b_i in range(B): |
| durations = ds[b_i] |
| multiplier = L / Lmax |
| i = 0 |
| for d_i in range(Tmax): |
| |
| d = max(math.floor(durations[d_i].item() * multiplier), 1) |
| if durations[d_i].item() > 0: |
| hs_slice = hs[b_i, i:i + d, :] |
| hs_res[b_i, d_i, :] = torch.mean(hs_slice, 0) |
| i += d |
| hs_res.requires_grad_(hs.requires_grad) |
| return hs_res |
|
|
|
|
| class ARPrior(nn.Module): |
| |
| """Autoregressive prior. |
| |
| This module is inspired by the AR prior described in `Generating diverse and |
| natural text-to-speech samples using a quantized fine-grained VAE and |
| auto-regressive prosody prior`. This prior is fit in the continuous latent space. |
| """ |
| def __init__( |
| self, |
| adim: int, |
| num_embeddings: int = 10, |
| hidden_dim: int = 3, |
| ): |
| assert check_argument_types() |
| super().__init__() |
|
|
| |
| self.adim = adim |
| self.hidden_dim = hidden_dim |
| self.num_embeddings = num_embeddings |
|
|
| self.qs_projection = nn.Linear(hidden_dim, adim) |
|
|
| self.lstm = nn.LSTMCell( |
| self.adim, |
| self.num_embeddings, |
| ) |
|
|
| self.criterion = nn.NLLLoss() |
|
|
| def inds_to_embs(self, inds, codebook, device): |
| """Returns the quantized embeddings from the codebook, |
| corresponding to the indices. |
| |
| Args: |
| inds (Tensor): Batch of indices (B, Tmax, 1). |
| codebook (Embedding): (num_embeddings, D). |
| |
| Returns: |
| Tensor: Quantized embeddings (B, Tmax, D). |
| """ |
| flat_inds = torch.flatten(inds).unsqueeze(1) |
|
|
| |
| encoding_one_hot = torch.zeros( |
| flat_inds.size(0), |
| self.num_embeddings, |
| device=device |
| ) |
| encoding_one_hot.scatter_(1, flat_inds, 1) |
|
|
| |
| |
| quantized_embs = torch.matmul(encoding_one_hot, codebook.weight) |
| |
| quantized_embs = quantized_embs.view( |
| inds.size(0), inds.size(1), self.hidden_dim |
| ) |
|
|
| return quantized_embs |
|
|
| def top_embeddings(self, emb_scores: torch.Tensor, codebook): |
| """Returns the top quantized embeddings from the codebook using the scores. |
| |
| Args: |
| emb_scores (Tensor): Batch of embedding scores (B, Tmax, num_embeddings). |
| codebook (Embedding): (num_embeddings, D). |
| |
| Returns: |
| Tensor: Top quantized embeddings (B, Tmax, D). |
| Tensor: Top 3 inds (B, Tmax, 3). |
| """ |
| _, top_inds = emb_scores.topk(1, dim=-1) |
| quantized_embs = self.inds_to_embs( |
| top_inds, |
| codebook, |
| emb_scores.device, |
| ) |
| _, top3_inds = emb_scores.topk(3, dim=-1) |
| return quantized_embs, top3_inds |
|
|
| def _forward(self, hs_ref_embs, codebook, fg_inds=None): |
| inds = [] |
| scores = [] |
| embs = [] |
|
|
| if fg_inds is not None: |
| init_embs = self.inds_to_embs(fg_inds, codebook, hs_ref_embs.device) |
| embs = [init_emb.unsqueeze(1) for init_emb in init_embs.transpose(1, 0)] |
|
|
| start = fg_inds.size(1) if fg_inds is not None else 0 |
| hidden = hs_ref_embs.new_zeros(hs_ref_embs.size(0), self.lstm.hidden_size) |
| cell = hs_ref_embs.new_zeros(hs_ref_embs.size(0), self.lstm.hidden_size) |
|
|
| for i in range(start, hs_ref_embs.size(1)): |
| |
| input = hs_ref_embs[:, i] |
| if i != 0: |
| |
| qs = self.qs_projection(embs[-1]) |
| |
| input = hs_ref_embs[:, i] + qs.squeeze() |
| hidden, cell = self.lstm(input, (hidden, cell)) |
| out = hidden.unsqueeze(1) |
| |
| emb_scores = F.log_softmax(out, dim=2) |
| quantized_embs, top_inds = self.top_embeddings(emb_scores, codebook) |
| |
| embs.append(quantized_embs) |
| scores.append(emb_scores) |
| inds.append(top_inds) |
|
|
| out_embs = torch.cat(embs, dim=1) |
| assert(out_embs.size(0) == hs_ref_embs.size(0)) |
| assert(out_embs.size(1) == hs_ref_embs.size(1)) |
| out_emb_scores = torch.cat(scores, dim=1) if start < hs_ref_embs.size(1) else scores |
| out_inds = torch.cat(inds, dim=1) if start < hs_ref_embs.size(1) else fg_inds |
|
|
| return out_embs, out_emb_scores, out_inds |
|
|
| def forward(self, hs_ref_embs, inds, codebook): |
| """Calculate forward propagation. |
| |
| Args: |
| hs_p_embs (Tensor): Batch of phoneme embeddings |
| with integrated global prosody embeddings (B, Tmax, D). |
| inds (Tensor): Batch of ground-truth codebook indices |
| (B, Tmax). |
| |
| Returns: |
| Tensor: Batch of predicted quantized latents (B, Tmax, D). |
| Tensor: Cross entropy loss value. |
| |
| """ |
| quantized_embs, emb_scores, _ = self._forward(hs_ref_embs, codebook) |
| emb_scores = emb_scores.permute(0, 2, 1).contiguous() |
| loss = self.criterion(emb_scores, inds) |
| return quantized_embs, loss |
|
|
| def inference(self, hs_ref_embs, fg_inds, codebook): |
| """Inference duration. |
| |
| Args: |
| hs_p_embs (Tensor): Batch of phoneme embeddings |
| with integrated global prosody embeddings (B, Tmax, D). |
| |
| Returns: |
| Tensor: Batch of predicted quantized latents (B, Tmax, D). |
| |
| """ |
| |
| |
| |
| |
| |
|
|
| quantized_embs, _, top_inds = self._forward(hs_ref_embs, codebook, fg_inds) |
| return quantized_embs, top_inds |
|
|