| |
| |
|
|
| """TTS-Transformer related modules.""" |
|
|
| from typing import Dict |
| from typing import Sequence |
| from typing import Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
| from typeguard import check_argument_types |
|
|
| from espnet.nets.pytorch_backend.e2e_tts_transformer import GuidedMultiHeadAttentionLoss |
| from espnet.nets.pytorch_backend.e2e_tts_transformer import TransformerLoss |
| from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask |
| from espnet.nets.pytorch_backend.nets_utils import make_pad_mask |
| from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet |
| from espnet.nets.pytorch_backend.tacotron2.decoder import Prenet as DecoderPrenet |
| from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder as EncoderPrenet |
| from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention |
| from espnet.nets.pytorch_backend.transformer.decoder import Decoder |
| from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding |
| from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding |
| from espnet.nets.pytorch_backend.transformer.encoder import Encoder |
| from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask |
| from espnet2.torch_utils.device_funcs import force_gatherable |
| from espnet2.torch_utils.initialize import initialize |
| from espnet2.tts.abs_tts import AbsTTS |
| from espnet2.tts.gst.style_encoder import StyleEncoder |
|
|
|
|
| class Transformer(AbsTTS): |
| """TTS-Transformer module. |
| |
| This is a module of text-to-speech Transformer described in `Neural Speech Synthesis |
| with Transformer Network`_, which convert the sequence of tokens into the sequence |
| of Mel-filterbanks. |
| |
| .. _`Neural Speech Synthesis with Transformer Network`: |
| https://arxiv.org/pdf/1809.08895.pdf |
| |
| Args: |
| idim (int): Dimension of the inputs. |
| odim (int): Dimension of the outputs. |
| embed_dim (int, optional): Dimension of character embedding. |
| eprenet_conv_layers (int, optional): |
| Number of encoder prenet convolution layers. |
| eprenet_conv_chans (int, optional): |
| Number of encoder prenet convolution channels. |
| eprenet_conv_filts (int, optional): |
| Filter size of encoder prenet convolution. |
| dprenet_layers (int, optional): Number of decoder prenet layers. |
| dprenet_units (int, optional): Number of decoder prenet hidden units. |
| elayers (int, optional): Number of encoder layers. |
| eunits (int, optional): Number of encoder hidden units. |
| adim (int, optional): Number of attention transformation dimensions. |
| aheads (int, optional): Number of heads for multi head attention. |
| dlayers (int, optional): Number of decoder layers. |
| dunits (int, optional): Number of decoder hidden units. |
| postnet_layers (int, optional): Number of postnet layers. |
| postnet_chans (int, optional): Number of postnet channels. |
| postnet_filts (int, optional): Filter size of postnet. |
| use_scaled_pos_enc (bool, optional): |
| Whether to use trainable scaled positional encoding. |
| use_batch_norm (bool, optional): |
| Whether to use batch normalization in encoder prenet. |
| encoder_normalize_before (bool, optional): |
| Whether to perform layer normalization before encoder block. |
| decoder_normalize_before (bool, optional): |
| Whether to perform layer normalization before decoder block. |
| encoder_concat_after (bool, optional): Whether to concatenate attention |
| layer's input and output in encoder. |
| decoder_concat_after (bool, optional): Whether to concatenate attention |
| layer's input and output in decoder. |
| positionwise_layer_type (str, optional): |
| Position-wise operation type. |
| positionwise_conv_kernel_size (int, optional): |
| Kernel size in position wise conv 1d. |
| reduction_factor (int, optional): Reduction factor. |
| spk_embed_dim (int, optional): Number of speaker embedding dimenstions. |
| spk_embed_integration_type (str, optional): How to integrate speaker embedding. |
| use_gst (str, optional): Whether to use global style token. |
| gst_tokens (int, optional): The number of GST embeddings. |
| gst_heads (int, optional): The number of heads in GST multihead attention. |
| gst_conv_layers (int, optional): The number of conv layers in GST. |
| gst_conv_chans_list: (Sequence[int], optional): |
| List of the number of channels of conv layers in GST. |
| gst_conv_kernel_size (int, optional): Kernal size of conv layers in GST. |
| gst_conv_stride (int, optional): Stride size of conv layers in GST. |
| gst_gru_layers (int, optional): The number of GRU layers in GST. |
| gst_gru_units (int, optional): The number of GRU units in GST. |
| transformer_lr (float, optional): Initial value of learning rate. |
| transformer_warmup_steps (int, optional): Optimizer warmup steps. |
| transformer_enc_dropout_rate (float, optional): |
| Dropout rate in encoder except attention and positional encoding. |
| transformer_enc_positional_dropout_rate (float, optional): |
| Dropout rate after encoder positional encoding. |
| transformer_enc_attn_dropout_rate (float, optional): |
| Dropout rate in encoder self-attention module. |
| transformer_dec_dropout_rate (float, optional): |
| Dropout rate in decoder except attention & positional encoding. |
| transformer_dec_positional_dropout_rate (float, optional): |
| Dropout rate after decoder positional encoding. |
| transformer_dec_attn_dropout_rate (float, optional): |
| Dropout rate in deocoder self-attention module. |
| transformer_enc_dec_attn_dropout_rate (float, optional): |
| Dropout rate in encoder-deocoder attention module. |
| init_type (str, optional): |
| How to initialize transformer parameters. |
| init_enc_alpha (float, optional): |
| Initial value of alpha in scaled pos encoding of the encoder. |
| init_dec_alpha (float, optional): |
| Initial value of alpha in scaled pos encoding of the decoder. |
| eprenet_dropout_rate (float, optional): Dropout rate in encoder prenet. |
| dprenet_dropout_rate (float, optional): Dropout rate in decoder prenet. |
| postnet_dropout_rate (float, optional): Dropout rate in postnet. |
| use_masking (bool, optional): |
| Whether to apply masking for padded part in loss calculation. |
| use_weighted_masking (bool, optional): |
| Whether to apply weighted masking in loss calculation. |
| bce_pos_weight (float, optional): Positive sample weight in bce calculation |
| (only for use_masking=true). |
| loss_type (str, optional): How to calculate loss. |
| use_guided_attn_loss (bool, optional): Whether to use guided attention loss. |
| num_heads_applied_guided_attn (int, optional): |
| Number of heads in each layer to apply guided attention loss. |
| num_layers_applied_guided_attn (int, optional): |
| Number of layers to apply guided attention loss. |
| modules_applied_guided_attn (Sequence[str], optional): |
| List of module names to apply guided attention loss. |
| guided_attn_loss_sigma (float, optional) Sigma in guided attention loss. |
| guided_attn_loss_lambda (float, optional): Lambda in guided attention loss. |
| |
| """ |
|
|
| def __init__( |
| self, |
| |
| idim: int, |
| odim: int, |
| embed_dim: int = 512, |
| eprenet_conv_layers: int = 3, |
| eprenet_conv_chans: int = 256, |
| eprenet_conv_filts: int = 5, |
| dprenet_layers: int = 2, |
| dprenet_units: int = 256, |
| elayers: int = 6, |
| eunits: int = 1024, |
| adim: int = 512, |
| aheads: int = 4, |
| dlayers: int = 6, |
| dunits: int = 1024, |
| postnet_layers: int = 5, |
| postnet_chans: int = 256, |
| postnet_filts: int = 5, |
| positionwise_layer_type: str = "conv1d", |
| positionwise_conv_kernel_size: int = 1, |
| use_scaled_pos_enc: bool = True, |
| use_batch_norm: bool = True, |
| encoder_normalize_before: bool = True, |
| decoder_normalize_before: bool = True, |
| encoder_concat_after: bool = False, |
| decoder_concat_after: bool = False, |
| reduction_factor: int = 1, |
| spk_embed_dim: int = None, |
| spk_embed_integration_type: str = "add", |
| use_gst: bool = False, |
| gst_tokens: int = 10, |
| gst_heads: int = 4, |
| gst_conv_layers: int = 6, |
| gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), |
| gst_conv_kernel_size: int = 3, |
| gst_conv_stride: int = 2, |
| gst_gru_layers: int = 1, |
| gst_gru_units: int = 128, |
| |
| transformer_enc_dropout_rate: float = 0.1, |
| transformer_enc_positional_dropout_rate: float = 0.1, |
| transformer_enc_attn_dropout_rate: float = 0.1, |
| transformer_dec_dropout_rate: float = 0.1, |
| transformer_dec_positional_dropout_rate: float = 0.1, |
| transformer_dec_attn_dropout_rate: float = 0.1, |
| transformer_enc_dec_attn_dropout_rate: float = 0.1, |
| eprenet_dropout_rate: float = 0.5, |
| dprenet_dropout_rate: float = 0.5, |
| postnet_dropout_rate: float = 0.5, |
| init_type: str = "xavier_uniform", |
| init_enc_alpha: float = 1.0, |
| init_dec_alpha: float = 1.0, |
| use_masking: bool = False, |
| use_weighted_masking: bool = False, |
| bce_pos_weight: float = 5.0, |
| loss_type: str = "L1", |
| use_guided_attn_loss: bool = True, |
| num_heads_applied_guided_attn: int = 2, |
| num_layers_applied_guided_attn: int = 2, |
| modules_applied_guided_attn: Sequence[str] = ("encoder-decoder"), |
| guided_attn_loss_sigma: float = 0.4, |
| guided_attn_loss_lambda: float = 1.0, |
| ): |
| """Initialize Transformer module.""" |
| assert check_argument_types() |
| super().__init__() |
|
|
| |
| self.idim = idim |
| self.odim = odim |
| self.eos = idim - 1 |
| self.spk_embed_dim = spk_embed_dim |
| self.reduction_factor = reduction_factor |
| self.use_gst = use_gst |
| self.use_guided_attn_loss = use_guided_attn_loss |
| self.use_scaled_pos_enc = use_scaled_pos_enc |
| self.loss_type = loss_type |
| self.use_guided_attn_loss = use_guided_attn_loss |
| if self.use_guided_attn_loss: |
| if num_layers_applied_guided_attn == -1: |
| self.num_layers_applied_guided_attn = elayers |
| else: |
| self.num_layers_applied_guided_attn = num_layers_applied_guided_attn |
| if num_heads_applied_guided_attn == -1: |
| self.num_heads_applied_guided_attn = aheads |
| else: |
| self.num_heads_applied_guided_attn = num_heads_applied_guided_attn |
| self.modules_applied_guided_attn = modules_applied_guided_attn |
| if self.spk_embed_dim is not None: |
| self.spk_embed_integration_type = spk_embed_integration_type |
|
|
| |
| self.padding_idx = 0 |
|
|
| |
| pos_enc_class = ( |
| ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding |
| ) |
|
|
| |
| if eprenet_conv_layers != 0: |
| |
| encoder_input_layer = torch.nn.Sequential( |
| EncoderPrenet( |
| idim=idim, |
| embed_dim=embed_dim, |
| elayers=0, |
| econv_layers=eprenet_conv_layers, |
| econv_chans=eprenet_conv_chans, |
| econv_filts=eprenet_conv_filts, |
| use_batch_norm=use_batch_norm, |
| dropout_rate=eprenet_dropout_rate, |
| padding_idx=self.padding_idx, |
| ), |
| torch.nn.Linear(eprenet_conv_chans, adim), |
| ) |
| else: |
| encoder_input_layer = torch.nn.Embedding( |
| num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx |
| ) |
| self.encoder = Encoder( |
| idim=idim, |
| attention_dim=adim, |
| attention_heads=aheads, |
| linear_units=eunits, |
| num_blocks=elayers, |
| input_layer=encoder_input_layer, |
| dropout_rate=transformer_enc_dropout_rate, |
| positional_dropout_rate=transformer_enc_positional_dropout_rate, |
| attention_dropout_rate=transformer_enc_attn_dropout_rate, |
| pos_enc_class=pos_enc_class, |
| normalize_before=encoder_normalize_before, |
| concat_after=encoder_concat_after, |
| positionwise_layer_type=positionwise_layer_type, |
| positionwise_conv_kernel_size=positionwise_conv_kernel_size, |
| ) |
|
|
| |
| if self.use_gst: |
| self.gst = StyleEncoder( |
| idim=odim, |
| gst_tokens=gst_tokens, |
| gst_token_dim=adim, |
| gst_heads=gst_heads, |
| conv_layers=gst_conv_layers, |
| conv_chans_list=gst_conv_chans_list, |
| conv_kernel_size=gst_conv_kernel_size, |
| conv_stride=gst_conv_stride, |
| gru_layers=gst_gru_layers, |
| gru_units=gst_gru_units, |
| ) |
|
|
| |
| if self.spk_embed_dim is not None: |
| if self.spk_embed_integration_type == "add": |
| self.projection = torch.nn.Linear(self.spk_embed_dim, adim) |
| else: |
| self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) |
|
|
| |
| if dprenet_layers != 0: |
| |
| decoder_input_layer = torch.nn.Sequential( |
| DecoderPrenet( |
| idim=odim, |
| n_layers=dprenet_layers, |
| n_units=dprenet_units, |
| dropout_rate=dprenet_dropout_rate, |
| ), |
| torch.nn.Linear(dprenet_units, adim), |
| ) |
| else: |
| decoder_input_layer = "linear" |
| self.decoder = Decoder( |
| odim=odim, |
| attention_dim=adim, |
| attention_heads=aheads, |
| linear_units=dunits, |
| num_blocks=dlayers, |
| dropout_rate=transformer_dec_dropout_rate, |
| positional_dropout_rate=transformer_dec_positional_dropout_rate, |
| self_attention_dropout_rate=transformer_dec_attn_dropout_rate, |
| src_attention_dropout_rate=transformer_enc_dec_attn_dropout_rate, |
| input_layer=decoder_input_layer, |
| use_output_layer=False, |
| pos_enc_class=pos_enc_class, |
| normalize_before=decoder_normalize_before, |
| concat_after=decoder_concat_after, |
| ) |
|
|
| |
| self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) |
| self.prob_out = torch.nn.Linear(adim, reduction_factor) |
|
|
| |
| self.postnet = ( |
| None |
| if postnet_layers == 0 |
| else Postnet( |
| idim=idim, |
| odim=odim, |
| n_layers=postnet_layers, |
| n_chans=postnet_chans, |
| n_filts=postnet_filts, |
| use_batch_norm=use_batch_norm, |
| dropout_rate=postnet_dropout_rate, |
| ) |
| ) |
|
|
| |
| self.criterion = TransformerLoss( |
| use_masking=use_masking, |
| use_weighted_masking=use_weighted_masking, |
| bce_pos_weight=bce_pos_weight, |
| ) |
| if self.use_guided_attn_loss: |
| self.attn_criterion = GuidedMultiHeadAttentionLoss( |
| sigma=guided_attn_loss_sigma, |
| alpha=guided_attn_loss_lambda, |
| ) |
|
|
| |
| self._reset_parameters( |
| init_type=init_type, |
| init_enc_alpha=init_enc_alpha, |
| init_dec_alpha=init_enc_alpha, |
| ) |
|
|
| def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_alpha=1.0): |
| |
| if init_type != "pytorch": |
| initialize(self, init_type) |
|
|
| |
| if self.use_scaled_pos_enc: |
| self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) |
| self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) |
|
|
| def forward( |
| self, |
| text: torch.Tensor, |
| text_lengths: torch.Tensor, |
| speech: torch.Tensor, |
| speech_lengths: torch.Tensor, |
| spembs: torch.Tensor = None, |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| """Calculate forward propagation. |
| |
| Args: |
| text (LongTensor): Batch of padded character ids (B, Tmax). |
| text_lengths (LongTensor): Batch of lengths of each input batch (B,). |
| speech (Tensor): Batch of padded target features (B, Lmax, odim). |
| speech_lengths (LongTensor): Batch of the lengths of each target (B,). |
| spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim). |
| |
| Returns: |
| Tensor: Loss scalar value. |
| Dict: Statistics to be monitored. |
| Tensor: Weight value. |
| |
| """ |
| text = text[:, : text_lengths.max()] |
| speech = speech[:, : speech_lengths.max()] |
| batch_size = text.size(0) |
|
|
| |
| xs = F.pad(text, [0, 1], "constant", self.padding_idx) |
| for i, l in enumerate(text_lengths): |
| xs[i, l] = self.eos |
| ilens = text_lengths + 1 |
|
|
| ys = speech |
| olens = speech_lengths |
|
|
| |
| labels = make_pad_mask(olens - 1).to(ys.device, ys.dtype) |
| labels = F.pad(labels, [0, 1], "constant", 1.0) |
|
|
| |
| after_outs, before_outs, logits = self._forward(xs, ilens, ys, olens, spembs) |
|
|
| |
| olens_in = olens |
| if self.reduction_factor > 1: |
| olens_in = olens.new([olen // self.reduction_factor for olen in olens]) |
| olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) |
| max_olen = max(olens) |
| ys = ys[:, :max_olen] |
| labels = labels[:, :max_olen] |
| labels[:, -1] = 1.0 |
|
|
| |
| l1_loss, l2_loss, bce_loss = self.criterion( |
| after_outs, before_outs, logits, ys, labels, olens |
| ) |
| if self.loss_type == "L1": |
| loss = l1_loss + bce_loss |
| elif self.loss_type == "L2": |
| loss = l2_loss + bce_loss |
| elif self.loss_type == "L1+L2": |
| loss = l1_loss + l2_loss + bce_loss |
| else: |
| raise ValueError("unknown --loss-type " + self.loss_type) |
|
|
| stats = dict( |
| l1_loss=l1_loss.item(), |
| l2_loss=l2_loss.item(), |
| bce_loss=bce_loss.item(), |
| ) |
|
|
| |
| if self.use_guided_attn_loss: |
| |
| if "encoder" in self.modules_applied_guided_attn: |
| att_ws = [] |
| for idx, layer_idx in enumerate( |
| reversed(range(len(self.encoder.encoders))) |
| ): |
| att_ws += [ |
| self.encoder.encoders[layer_idx].self_attn.attn[ |
| :, : self.num_heads_applied_guided_attn |
| ] |
| ] |
| if idx + 1 == self.num_layers_applied_guided_attn: |
| break |
| att_ws = torch.cat(att_ws, dim=1) |
| enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens) |
| loss = loss + enc_attn_loss |
| stats.update(enc_attn_loss=enc_attn_loss.item()) |
| |
| if "decoder" in self.modules_applied_guided_attn: |
| att_ws = [] |
| for idx, layer_idx in enumerate( |
| reversed(range(len(self.decoder.decoders))) |
| ): |
| att_ws += [ |
| self.decoder.decoders[layer_idx].self_attn.attn[ |
| :, : self.num_heads_applied_guided_attn |
| ] |
| ] |
| if idx + 1 == self.num_layers_applied_guided_attn: |
| break |
| att_ws = torch.cat(att_ws, dim=1) |
| dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in) |
| loss = loss + dec_attn_loss |
| stats.update(dec_attn_loss=dec_attn_loss.item()) |
| |
| if "encoder-decoder" in self.modules_applied_guided_attn: |
| att_ws = [] |
| for idx, layer_idx in enumerate( |
| reversed(range(len(self.decoder.decoders))) |
| ): |
| att_ws += [ |
| self.decoder.decoders[layer_idx].src_attn.attn[ |
| :, : self.num_heads_applied_guided_attn |
| ] |
| ] |
| if idx + 1 == self.num_layers_applied_guided_attn: |
| break |
| att_ws = torch.cat(att_ws, dim=1) |
| enc_dec_attn_loss = self.attn_criterion(att_ws, ilens, olens_in) |
| loss = loss + enc_dec_attn_loss |
| stats.update(enc_dec_attn_loss=enc_dec_attn_loss.item()) |
|
|
| stats.update(loss=loss.item()) |
|
|
| |
| if self.use_scaled_pos_enc: |
| stats.update( |
| encoder_alpha=self.encoder.embed[-1].alpha.data.item(), |
| decoder_alpha=self.decoder.embed[-1].alpha.data.item(), |
| ) |
|
|
| loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| return loss, stats, weight |
|
|
| def _forward( |
| self, |
| xs: torch.Tensor, |
| ilens: torch.Tensor, |
| ys: torch.Tensor, |
| olens: torch.Tensor, |
| spembs: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| |
| x_masks = self._source_mask(ilens) |
| hs, h_masks = self.encoder(xs, x_masks) |
|
|
| |
| if self.use_gst: |
| style_embs = self.gst(ys) |
| hs = hs + style_embs.unsqueeze(1) |
|
|
| |
| if self.spk_embed_dim is not None: |
| hs = self._integrate_with_spk_embed(hs, spembs) |
|
|
| |
| if self.reduction_factor > 1: |
| ys_in = ys[:, self.reduction_factor - 1 :: self.reduction_factor] |
| olens_in = olens.new([olen // self.reduction_factor for olen in olens]) |
| else: |
| ys_in, olens_in = ys, olens |
|
|
| |
| ys_in = self._add_first_frame_and_remove_last_frame(ys_in) |
|
|
| |
| y_masks = self._target_mask(olens_in) |
| zs, _ = self.decoder(ys_in, y_masks, hs, h_masks) |
| |
| before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) |
| |
| logits = self.prob_out(zs).view(zs.size(0), -1) |
|
|
| |
| if self.postnet is None: |
| after_outs = before_outs |
| else: |
| after_outs = before_outs + self.postnet( |
| before_outs.transpose(1, 2) |
| ).transpose(1, 2) |
|
|
| return after_outs, before_outs, logits |
|
|
| def inference( |
| self, |
| text: torch.Tensor, |
| speech: torch.Tensor = None, |
| spembs: torch.Tensor = None, |
| threshold: float = 0.5, |
| minlenratio: float = 0.0, |
| maxlenratio: float = 10.0, |
| use_teacher_forcing: bool = False, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Generate the sequence of features given the sequences of characters. |
| |
| Args: |
| text (LongTensor): Input sequence of characters (T,). |
| speech (Tensor, optional): Feature sequence to extract style (N, idim). |
| spembs (Tensor, optional): Speaker embedding vector (spk_embed_dim,). |
| threshold (float, optional): Threshold in inference. |
| minlenratio (float, optional): Minimum length ratio in inference. |
| maxlenratio (float, optional): Maximum length ratio in inference. |
| use_teacher_forcing (bool, optional): Whether to use teacher forcing. |
| |
| Returns: |
| Tensor: Output sequence of features (L, odim). |
| Tensor: Output sequence of stop probabilities (L,). |
| Tensor: Encoder-decoder (source) attention weights (#layers, #heads, L, T). |
| |
| """ |
| x = text |
| y = speech |
| spemb = spembs |
|
|
| |
| x = F.pad(x, [0, 1], "constant", self.eos) |
|
|
| |
| if use_teacher_forcing: |
| assert speech is not None, "speech must be provided with teacher forcing." |
|
|
| |
| xs, ys = x.unsqueeze(0), y.unsqueeze(0) |
| spembs = None if spemb is None else spemb.unsqueeze(0) |
| ilens = x.new_tensor([xs.size(1)]).long() |
| olens = y.new_tensor([ys.size(1)]).long() |
| outs, *_ = self._forward(xs, ilens, ys, olens, spembs) |
|
|
| |
| att_ws = [] |
| for i in range(len(self.decoder.decoders)): |
| att_ws += [self.decoder.decoders[i].src_attn.attn] |
| att_ws = torch.stack(att_ws, dim=1) |
|
|
| return outs[0], None, att_ws[0] |
|
|
| |
| xs = x.unsqueeze(0) |
| hs, _ = self.encoder(xs, None) |
|
|
| |
| if self.use_gst: |
| style_embs = self.gst(y.unsqueeze(0)) |
| hs = hs + style_embs.unsqueeze(1) |
|
|
| |
| if self.spk_embed_dim is not None: |
| spembs = spemb.unsqueeze(0) |
| hs = self._integrate_with_spk_embed(hs, spembs) |
|
|
| |
| maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor) |
| minlen = int(hs.size(1) * minlenratio / self.reduction_factor) |
|
|
| |
| idx = 0 |
| ys = hs.new_zeros(1, 1, self.odim) |
| outs, probs = [], [] |
|
|
| |
| z_cache = self.decoder.init_state(x) |
| while True: |
| |
| idx += 1 |
|
|
| |
| y_masks = subsequent_mask(idx).unsqueeze(0).to(x.device) |
| z, z_cache = self.decoder.forward_one_step( |
| ys, y_masks, hs, cache=z_cache |
| ) |
| outs += [ |
| self.feat_out(z).view(self.reduction_factor, self.odim) |
| ] |
| probs += [torch.sigmoid(self.prob_out(z))[0]] |
|
|
| |
| ys = torch.cat( |
| (ys, outs[-1][-1].view(1, 1, self.odim)), dim=1 |
| ) |
|
|
| |
| att_ws_ = [] |
| for name, m in self.named_modules(): |
| if isinstance(m, MultiHeadedAttention) and "src" in name: |
| att_ws_ += [m.attn[0, :, -1].unsqueeze(1)] |
| if idx == 1: |
| att_ws = att_ws_ |
| else: |
| |
| att_ws = [ |
| torch.cat([att_w, att_w_], dim=1) |
| for att_w, att_w_ in zip(att_ws, att_ws_) |
| ] |
|
|
| |
| if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: |
| |
| if idx < minlen: |
| continue |
| outs = ( |
| torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2) |
| ) |
| if self.postnet is not None: |
| outs = outs + self.postnet(outs) |
| outs = outs.transpose(2, 1).squeeze(0) |
| probs = torch.cat(probs, dim=0) |
| break |
|
|
| |
| att_ws = torch.stack(att_ws, dim=0) |
|
|
| return outs, probs, att_ws |
|
|
| def _add_first_frame_and_remove_last_frame(self, ys: torch.Tensor) -> torch.Tensor: |
| ys_in = torch.cat( |
| [ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1 |
| ) |
| return ys_in |
|
|
| def _source_mask(self, ilens): |
| """Make masks for self-attention. |
| |
| Args: |
| ilens (LongTensor): Batch of lengths (B,). |
| |
| Returns: |
| Tensor: Mask tensor for self-attention. |
| dtype=torch.uint8 in PyTorch 1.2- |
| dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
| |
| Examples: |
| >>> ilens = [5, 3] |
| >>> self._source_mask(ilens) |
| tensor([[[1, 1, 1, 1, 1], |
| [[1, 1, 1, 0, 0]]], dtype=torch.uint8) |
| |
| """ |
| x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) |
| return x_masks.unsqueeze(-2) |
|
|
| def _target_mask(self, olens: torch.Tensor) -> torch.Tensor: |
| """Make masks for masked self-attention. |
| |
| Args: |
| olens (LongTensor): Batch of lengths (B,). |
| |
| Returns: |
| Tensor: Mask tensor for masked self-attention. |
| dtype=torch.uint8 in PyTorch 1.2- |
| dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
| |
| Examples: |
| >>> olens = [5, 3] |
| >>> self._target_mask(olens) |
| tensor([[[1, 0, 0, 0, 0], |
| [1, 1, 0, 0, 0], |
| [1, 1, 1, 0, 0], |
| [1, 1, 1, 1, 0], |
| [1, 1, 1, 1, 1]], |
| [[1, 0, 0, 0, 0], |
| [1, 1, 0, 0, 0], |
| [1, 1, 1, 0, 0], |
| [1, 1, 1, 0, 0], |
| [1, 1, 1, 0, 0]]], dtype=torch.uint8) |
| |
| """ |
| y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) |
| s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0) |
| return y_masks.unsqueeze(-2) & s_masks |
|
|
| def _integrate_with_spk_embed( |
| self, hs: torch.Tensor, spembs: torch.Tensor |
| ) -> torch.Tensor: |
| """Integrate speaker embedding with hidden states. |
| |
| Args: |
| hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). |
| spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). |
| |
| Returns: |
| Tensor: Batch of integrated hidden state sequences (B, Tmax, adim). |
| |
| """ |
| if self.spk_embed_integration_type == "add": |
| |
| spembs = self.projection(F.normalize(spembs)) |
| hs = hs + spembs.unsqueeze(1) |
| elif self.spk_embed_integration_type == "concat": |
| |
| spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) |
| hs = self.projection(torch.cat([hs, spembs], dim=-1)) |
| else: |
| raise NotImplementedError("support only add or concat.") |
|
|
| return hs |
|
|