Spaces:
Build error
Build error
| # -------------------------------------------------------- | |
| # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
| # Github source: https://github.com/mbzuai-nlp/ArTST | |
| # Based on speecht5, fairseq and espnet code bases | |
| # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
| # -------------------------------------------------------- | |
| import contextlib | |
| import torch | |
| import torch.nn as nn | |
| from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet | |
| class SpeechDecoderPostnet(nn.Module): | |
| """ | |
| Args: | |
| in_channels (int): the number of input channels | |
| mid_channels (int): the number of intermediate channels | |
| out_channels (int): the number of output channels | |
| kernel_sizes (List[int]): the kernel size for each convolutional layer | |
| """ | |
| def __init__( | |
| self, | |
| odim, | |
| args, | |
| ): | |
| super(SpeechDecoderPostnet, self).__init__() | |
| # define decoder postnet | |
| # define final projection | |
| self.feat_out = torch.nn.Linear(args.decoder_embed_dim, odim * args.reduction_factor) | |
| self.prob_out = torch.nn.Linear(args.decoder_embed_dim, args.reduction_factor) | |
| # define postnet | |
| self.postnet = ( | |
| None | |
| if args.postnet_layers == 0 | |
| else Postnet( | |
| idim=0, | |
| odim=odim, | |
| n_layers=args.postnet_layers, | |
| n_chans=args.postnet_chans, | |
| n_filts=args.postnet_filts, | |
| use_batch_norm=args.use_batch_norm, | |
| dropout_rate=args.postnet_dropout_rate, | |
| ) | |
| ) | |
| self.odim = odim | |
| self.num_updates = 0 | |
| self.freeze_decoder_updates = args.freeze_decoder_updates | |
| def forward(self, zs): | |
| ft = self.freeze_decoder_updates <= self.num_updates | |
| with torch.no_grad() if not ft else contextlib.ExitStack(): | |
| # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim) | |
| before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) | |
| # (B, Lmax//r, r) -> (B, Lmax//r * r) | |
| logits = self.prob_out(zs).view(zs.size(0), -1) | |
| # postnet -> (B, Lmax//r * r, odim) | |
| if self.postnet is None: | |
| after_outs = before_outs | |
| else: | |
| after_outs = before_outs + self.postnet( | |
| before_outs.transpose(1, 2) | |
| ).transpose(1, 2) | |
| return before_outs, after_outs, logits | |
| def set_num_updates(self, num_updates): | |
| """Set the number of parameters updates.""" | |
| self.num_updates = num_updates | |