Spaces:
Runtime error
Runtime error
| # Copyright 2023 (authors: Feiteng Li) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from functools import partial | |
| from typing import Any, Dict, List, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # from icefall.utils import make_pad_mask | |
| # from torchmetrics.classification import BinaryAccuracy | |
| from models.vallex import Transpose | |
| from modules.embedding import SinePositionalEmbedding, TokenEmbedding | |
| from modules.scaling import BalancedDoubleSwish, ScaledLinear | |
| from modules.transformer import ( | |
| BalancedBasicNorm, | |
| IdentityNorm, | |
| TransformerDecoderLayer, | |
| TransformerEncoder, | |
| TransformerEncoderLayer, | |
| ) | |
| from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS | |
| from .visualizer import visualize | |
| IdentityNorm = IdentityNorm | |
| class Transformer(nn.Module): | |
| """It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding) | |
| Neural Speech Synthesis with Transformer Network | |
| https://arxiv.org/abs/1809.08895 | |
| """ | |
| def __init__( | |
| self, | |
| d_model: int, | |
| nhead: int, | |
| num_layers: int, | |
| norm_first: bool = True, | |
| add_prenet: bool = False, | |
| scaling_xformers: bool = False, | |
| ): | |
| """ | |
| Args: | |
| d_model: | |
| The number of expected features in the input (required). | |
| nhead: | |
| The number of heads in the multiheadattention models (required). | |
| num_layers: | |
| The number of sub-decoder-layers in the decoder (required). | |
| """ | |
| super().__init__() | |
| self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x | |
| if add_prenet: | |
| self.encoder_prenet = nn.Sequential( | |
| Transpose(), | |
| nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), | |
| nn.BatchNorm1d(d_model), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), | |
| nn.BatchNorm1d(d_model), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), | |
| nn.BatchNorm1d(d_model), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| Transpose(), | |
| nn.Linear(d_model, d_model), | |
| ) | |
| self.decoder_prenet = nn.Sequential( | |
| nn.Linear(NUM_MEL_BINS, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(256, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(256, d_model), | |
| ) | |
| assert scaling_xformers is False # TODO: update this block | |
| else: | |
| self.encoder_prenet = nn.Identity() | |
| if scaling_xformers: | |
| self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model) | |
| else: | |
| self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model) | |
| self.encoder_position = SinePositionalEmbedding( | |
| d_model, | |
| dropout=0.1, | |
| scale=False, | |
| ) | |
| self.decoder_position = SinePositionalEmbedding( | |
| d_model, dropout=0.1, scale=False | |
| ) | |
| if scaling_xformers: | |
| self.encoder = TransformerEncoder( | |
| TransformerEncoderLayer( | |
| d_model, | |
| nhead, | |
| dim_feedforward=d_model * 4, | |
| dropout=0.1, | |
| batch_first=True, | |
| norm_first=norm_first, | |
| linear1_self_attention_cls=ScaledLinear, | |
| linear2_self_attention_cls=partial( | |
| ScaledLinear, initial_scale=0.01 | |
| ), | |
| linear1_feedforward_cls=ScaledLinear, | |
| linear2_feedforward_cls=partial( | |
| ScaledLinear, initial_scale=0.01 | |
| ), | |
| activation=partial( | |
| BalancedDoubleSwish, | |
| channel_dim=-1, | |
| max_abs=10.0, | |
| min_prob=0.25, | |
| ), | |
| layer_norm_cls=IdentityNorm, | |
| ), | |
| num_layers=num_layers, | |
| norm=BalancedBasicNorm(d_model) if norm_first else None, | |
| ) | |
| self.decoder = nn.TransformerDecoder( | |
| TransformerDecoderLayer( | |
| d_model, | |
| nhead, | |
| dim_feedforward=d_model * 4, | |
| dropout=0.1, | |
| batch_first=True, | |
| norm_first=norm_first, | |
| linear1_self_attention_cls=ScaledLinear, | |
| linear2_self_attention_cls=partial( | |
| ScaledLinear, initial_scale=0.01 | |
| ), | |
| linear1_feedforward_cls=ScaledLinear, | |
| linear2_feedforward_cls=partial( | |
| ScaledLinear, initial_scale=0.01 | |
| ), | |
| activation=partial( | |
| BalancedDoubleSwish, | |
| channel_dim=-1, | |
| max_abs=10.0, | |
| min_prob=0.25, | |
| ), | |
| layer_norm_cls=IdentityNorm, | |
| ), | |
| num_layers=num_layers, | |
| norm=BalancedBasicNorm(d_model) if norm_first else None, | |
| ) | |
| self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS) | |
| self.stop_layer = nn.Linear(d_model, 1) | |
| else: | |
| self.encoder = nn.TransformerEncoder( | |
| nn.TransformerEncoderLayer( | |
| d_model, | |
| nhead, | |
| dim_feedforward=d_model * 4, | |
| activation=F.relu, | |
| dropout=0.1, | |
| batch_first=True, | |
| norm_first=norm_first, | |
| ), | |
| num_layers=num_layers, | |
| norm=nn.LayerNorm(d_model) if norm_first else None, | |
| ) | |
| self.decoder = nn.TransformerDecoder( | |
| nn.TransformerDecoderLayer( | |
| d_model, | |
| nhead, | |
| dim_feedforward=d_model * 4, | |
| activation=F.relu, | |
| dropout=0.1, | |
| batch_first=True, | |
| norm_first=norm_first, | |
| ), | |
| num_layers=num_layers, | |
| norm=nn.LayerNorm(d_model) if norm_first else None, | |
| ) | |
| self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS) | |
| self.stop_layer = nn.Linear(d_model, 1) | |
| self.stop_accuracy_metric = BinaryAccuracy( | |
| threshold=0.5, multidim_average="global" | |
| ) | |
| # self.apply(self._init_weights) | |
| # def _init_weights(self, module): | |
| # if isinstance(module, (nn.Linear)): | |
| # module.weight.data.normal_(mean=0.0, std=0.02) | |
| # if isinstance(module, nn.Linear) and module.bias is not None: | |
| # module.bias.data.zero_() | |
| # elif isinstance(module, nn.LayerNorm): | |
| # module.bias.data.zero_() | |
| # module.weight.data.fill_(1.0) | |
| # elif isinstance(module, nn.Embedding): | |
| # module.weight.data.normal_(mean=0.0, std=0.02) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| x_lens: torch.Tensor, | |
| y: torch.Tensor, | |
| y_lens: torch.Tensor, | |
| reduction: str = "sum", | |
| train_stage: int = 0, | |
| **kwargs, | |
| ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: | |
| """ | |
| Args: | |
| x: | |
| A 2-D tensor of shape (N, S). | |
| x_lens: | |
| A 1-D tensor of shape (N,). It contains the number of tokens in `x` | |
| before padding. | |
| y: | |
| A 3-D tensor of shape (N, T, 8). | |
| y_lens: | |
| A 1-D tensor of shape (N,). It contains the number of tokens in `x` | |
| before padding. | |
| train_stage: | |
| Not used in this model. | |
| Returns: | |
| Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy. | |
| """ | |
| del train_stage | |
| assert x.ndim == 2, x.shape | |
| assert x_lens.ndim == 1, x_lens.shape | |
| assert y.ndim == 3, y.shape | |
| assert y_lens.ndim == 1, y_lens.shape | |
| assert torch.all(x_lens > 0) | |
| # NOTE: x has been padded in TextTokenCollater | |
| x_mask = make_pad_mask(x_lens).to(x.device) | |
| x = self.text_embedding(x) | |
| x = self.encoder_prenet(x) | |
| x = self.encoder_position(x) | |
| x = self.encoder(x, src_key_padding_mask=x_mask) | |
| total_loss, metrics = 0.0, {} | |
| y_mask = make_pad_mask(y_lens).to(y.device) | |
| y_mask_float = y_mask.type(torch.float32) | |
| data_mask = 1.0 - y_mask_float.unsqueeze(-1) | |
| # Training | |
| # AR Decoder | |
| def pad_y(y): | |
| y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach() | |
| # inputs, targets | |
| return y[:, :-1], y[:, 1:] | |
| y, targets = pad_y(y * data_mask) # mask padding as zeros | |
| y_emb = self.decoder_prenet(y) | |
| y_pos = self.decoder_position(y_emb) | |
| y_len = y_lens.max() | |
| tgt_mask = torch.triu( | |
| torch.ones(y_len, y_len, device=y.device, dtype=torch.bool), | |
| diagonal=1, | |
| ) | |
| y_dec = self.decoder( | |
| y_pos, | |
| x, | |
| tgt_mask=tgt_mask, | |
| memory_key_padding_mask=x_mask, | |
| ) | |
| predict = self.predict_layer(y_dec) | |
| # loss | |
| total_loss = F.mse_loss(predict, targets, reduction=reduction) | |
| logits = self.stop_layer(y_dec).squeeze(-1) | |
| stop_loss = F.binary_cross_entropy_with_logits( | |
| logits, | |
| y_mask_float.detach(), | |
| weight=1.0 + y_mask_float.detach() * 4.0, | |
| reduction=reduction, | |
| ) | |
| metrics["stop_loss"] = stop_loss.detach() | |
| stop_accuracy = self.stop_accuracy_metric( | |
| (torch.sigmoid(logits) >= 0.5).type(torch.int64), | |
| y_mask.type(torch.int64), | |
| ) | |
| # icefall MetricsTracker.norm_items() | |
| metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type( | |
| torch.float32 | |
| ) | |
| return ((x, predict), total_loss + 100.0 * stop_loss, metrics) | |
| def inference( | |
| self, | |
| x: torch.Tensor, | |
| x_lens: torch.Tensor, | |
| y: Any = None, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: | |
| A 2-D tensor of shape (1, S). | |
| x_lens: | |
| A 1-D tensor of shape (1,). It contains the number of tokens in `x` | |
| before padding. | |
| Returns: | |
| Return the predicted audio code matrix and cross-entropy loss. | |
| """ | |
| assert x.ndim == 2, x.shape | |
| assert x_lens.ndim == 1, x_lens.shape | |
| assert torch.all(x_lens > 0) | |
| x_mask = make_pad_mask(x_lens).to(x.device) | |
| x = self.text_embedding(x) | |
| x = self.encoder_prenet(x) | |
| x = self.encoder_position(x) | |
| x = self.encoder(x, src_key_padding_mask=x_mask) | |
| x_mask = make_pad_mask(x_lens).to(x.device) | |
| # AR Decoder | |
| # TODO: Managing decoder steps avoid repetitive computation | |
| y = torch.zeros( | |
| [x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device | |
| ) | |
| while True: | |
| y_emb = self.decoder_prenet(y) | |
| y_pos = self.decoder_position(y_emb) | |
| tgt_mask = torch.triu( | |
| torch.ones( | |
| y.shape[1], y.shape[1], device=y.device, dtype=torch.bool | |
| ), | |
| diagonal=1, | |
| ) | |
| y_dec = self.decoder( | |
| y_pos, | |
| x, | |
| tgt_mask=tgt_mask, | |
| memory_mask=None, | |
| memory_key_padding_mask=x_mask, | |
| ) | |
| predict = self.predict_layer(y_dec[:, -1:]) | |
| logits = self.stop_layer(y_dec[:, -1:]) > 0 # sigmoid(0.0) = 0.5 | |
| if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()): | |
| print( | |
| f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]" | |
| ) | |
| break | |
| y = torch.concat([y, predict], dim=1) | |
| return y[:, 1:] | |
| def visualize( | |
| self, | |
| predicts: Tuple[torch.Tensor], | |
| batch: Dict[str, Union[List, torch.Tensor]], | |
| output_dir: str, | |
| limit: int = 4, | |
| ) -> None: | |
| visualize(predicts, batch, output_dir, limit=limit) | |