Spaces:
Build error
Build error
| # -------------------------------------------------------- | |
| # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
| # Github source: https://github.com/mbzuai-nlp/ArTST | |
| # Based on speecht5, fairseq and espnet code bases | |
| # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
| # -------------------------------------------------------- | |
| import torch.nn as nn | |
| import torch | |
| import contextlib | |
| from fairseq import utils | |
| from fairseq.modules import ( | |
| AdaptiveSoftmax, | |
| ) | |
| class TextDecoderPostnet(nn.Module): | |
| """ | |
| Args: | |
| in_channels (int): the number of input channels | |
| mid_channels (int): the number of intermediate channels | |
| out_channels (int): the number of output channels | |
| kernel_sizes (List[int]): the kernel size for each convolutional layer | |
| """ | |
| def __init__(self, embed_tokens, dictionary, args, output_projection=None,): | |
| super(TextDecoderPostnet, self).__init__() | |
| self.output_embed_dim = args.decoder_output_dim | |
| self.output_projection = output_projection | |
| self.adaptive_softmax = None | |
| self.share_input_output_embed = args.share_input_output_embed | |
| if self.output_projection is None: | |
| self.build_output_projection(args, dictionary, embed_tokens) | |
| self.freeze_decoder_updates = args.freeze_decoder_updates | |
| self.num_updates = 0 | |
| def output_layer(self, features): | |
| """Project features to the vocabulary size.""" | |
| if self.adaptive_softmax is None: | |
| # project back to size of vocabulary | |
| return self.output_projection(features) | |
| else: | |
| return features | |
| def build_output_projection(self, args, dictionary, embed_tokens): | |
| if args.adaptive_softmax_cutoff is not None: | |
| self.adaptive_softmax = AdaptiveSoftmax( | |
| len(dictionary), | |
| self.output_embed_dim, | |
| utils.eval_str_list(args.adaptive_softmax_cutoff, type=int), | |
| dropout=args.adaptive_softmax_dropout, | |
| adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, | |
| factor=args.adaptive_softmax_factor, | |
| tie_proj=args.tie_adaptive_proj, | |
| ) | |
| elif self.share_input_output_embed: | |
| self.output_projection = nn.Linear( | |
| embed_tokens.weight.shape[1], | |
| embed_tokens.weight.shape[0], | |
| bias=False, | |
| ) | |
| self.output_projection.weight = embed_tokens.weight | |
| else: | |
| self.output_projection = nn.Linear( | |
| self.output_embed_dim, len(dictionary), bias=False | |
| ) | |
| nn.init.normal_( | |
| self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5 | |
| ) | |
| # num_base_layers = getattr(args, "base_layers", 0) | |
| # for i in range(num_base_layers): | |
| # self.layers.insert( | |
| # ((i + 1) * args.decoder_layers) // (num_base_layers + 1), | |
| # BaseLayer(args), | |
| # ) | |
| def forward(self, x): | |
| ft = self.freeze_decoder_updates <= self.num_updates | |
| with torch.no_grad() if not ft else contextlib.ExitStack(): | |
| return self._forward(x) | |
| def _forward(self, x): | |
| # embed positions | |
| x = self.output_layer(x) | |
| return x | |
| def set_num_updates(self, num_updates): | |
| """Set the number of parameters updates.""" | |
| self.num_updates = num_updates | |