Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import math | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from fairseq import checkpoint_utils | |
| from fairseq.incremental_decoding_utils import with_incremental_state | |
| from fairseq.models import ( | |
| CompositeEncoder, | |
| FairseqDecoder, | |
| FairseqEncoder, | |
| FairseqEncoderDecoderModel, | |
| register_model, | |
| register_model_architecture, | |
| ) | |
| from fairseq.modules import ( | |
| DownsampledMultiHeadAttention, | |
| FairseqDropout, | |
| GradMultiply, | |
| LayerNorm, | |
| LearnedPositionalEmbedding, | |
| LinearizedConvolution, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class FConvModelSelfAtt(FairseqEncoderDecoderModel): | |
| def hub_models(cls): | |
| return { | |
| "conv.stories.pretrained": { | |
| "path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz", | |
| "checkpoint_file": "pretrained_checkpoint.pt", | |
| "tokenizer": "nltk", | |
| }, | |
| "conv.stories": { | |
| "path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz", | |
| "checkpoint_file": "fusion_checkpoint.pt", | |
| "tokenizer": "nltk", | |
| "pretrained": "True", | |
| "pretrained_checkpoint": "./pretrained_checkpoint.pt", | |
| }, | |
| # Test set containing dictionaries | |
| "data.stories": "https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2", | |
| } | |
| def __init__(self, encoder, decoder, pretrained_encoder=None): | |
| super().__init__(encoder, decoder) | |
| self.encoder.num_attention_layers = sum( | |
| layer is not None for layer in decoder.attention | |
| ) | |
| self.pretrained_encoder = pretrained_encoder | |
| if self.pretrained_encoder is None: | |
| encoders = {"encoder": encoder} | |
| else: | |
| encoders = {"encoder": encoder, "pretrained": self.pretrained_encoder} | |
| # for fusion model, CompositeEncoder contains both pretrained and training encoders | |
| # these are forwarded and then combined in the decoder | |
| self.encoder = CompositeEncoder(encoders) | |
| def add_args(parser): | |
| """Add model-specific arguments to the parser.""" | |
| # fmt: off | |
| parser.add_argument('--dropout', type=float, metavar='D', | |
| help='dropout probability') | |
| parser.add_argument('--encoder-embed-dim', type=int, metavar='N', | |
| help='encoder embedding dimension') | |
| parser.add_argument('--encoder-layers', type=str, metavar='EXPR', | |
| help='encoder layers [(dim, kernel_size), ...]') | |
| parser.add_argument('--decoder-embed-dim', type=int, metavar='N', | |
| help='decoder embedding dimension') | |
| parser.add_argument('--decoder-layers', type=str, metavar='EXPR', | |
| help='decoder layers [(dim, kernel_size), ...]') | |
| parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', | |
| help='decoder output embedding dimension') | |
| parser.add_argument('--decoder-attention', type=str, metavar='EXPR', | |
| help='decoder attention [True, ...]') | |
| parser.add_argument('--self-attention', type=str, metavar='EXPR', | |
| help='decoder self-attention layers, ex: [True] + [False]*5') | |
| parser.add_argument('--multihead-attention-nheads', type=int, | |
| help='Number of heads to use in attention') | |
| parser.add_argument('--multihead-self-attention-nheads', type=int, | |
| help='Number of heads to use in self-attention') | |
| parser.add_argument('--encoder-attention', type=str, metavar='EXPR', | |
| help='encoder attention [True, ...]') | |
| parser.add_argument('--encoder-attention-nheads', type=int, | |
| help='Number of heads to use in encoder attention') | |
| parser.add_argument('--project-input', type=str, metavar='EXPR', | |
| help='Use projections in self-attention [True, ...]') | |
| parser.add_argument('--gated-attention', type=str, metavar='EXPR', | |
| help='Use GLU layers in self-attention projections [True, ...]') | |
| parser.add_argument('--downsample', type=str, metavar='EXPR', | |
| help='Use downsampling in self-attention [True, ...]') | |
| parser.add_argument('--pretrained-checkpoint', metavar='DIR', | |
| help='path to load checkpoint from pretrained model') | |
| parser.add_argument('--pretrained', type=str, metavar='EXPR', | |
| help='use pretrained model when training [True, ...]') | |
| # fmt: on | |
| def build_model(cls, args, task): | |
| """Build a new model instance.""" | |
| trained_encoder, trained_decoder = None, None | |
| pretrained = eval(args.pretrained) | |
| if pretrained: | |
| logger.info("loading pretrained model") | |
| if not os.path.exists(args.pretrained_checkpoint): | |
| new_pretrained_checkpoint = os.path.join( | |
| args.data, args.pretrained_checkpoint | |
| ) | |
| if os.path.exists(new_pretrained_checkpoint): | |
| args.pretrained_checkpoint = new_pretrained_checkpoint | |
| trained_model = checkpoint_utils.load_model_ensemble( | |
| filenames=[args.pretrained_checkpoint], | |
| task=task, | |
| )[0][0] | |
| trained_decoder = list(trained_model.children())[1] | |
| trained_encoder = list(trained_model.children())[0] | |
| # freeze pretrained model | |
| for param in trained_decoder.parameters(): | |
| param.requires_grad = False | |
| for param in trained_encoder.parameters(): | |
| param.requires_grad = False | |
| encoder = FConvEncoder( | |
| task.source_dictionary, | |
| embed_dim=args.encoder_embed_dim, | |
| convolutions=eval(args.encoder_layers), | |
| dropout=args.dropout, | |
| max_positions=args.max_source_positions, | |
| attention=eval(args.encoder_attention), | |
| attention_nheads=args.encoder_attention_nheads, | |
| ) | |
| decoder = FConvDecoder( | |
| task.target_dictionary, | |
| embed_dim=args.decoder_embed_dim, | |
| convolutions=eval(args.decoder_layers), | |
| out_embed_dim=args.decoder_out_embed_dim, | |
| attention=eval(args.decoder_attention), | |
| dropout=args.dropout, | |
| max_positions=args.max_target_positions, | |
| selfattention=eval(args.self_attention), | |
| attention_nheads=args.multihead_attention_nheads, | |
| selfattention_nheads=args.multihead_self_attention_nheads, | |
| project_input=eval(args.project_input), | |
| gated_attention=eval(args.gated_attention), | |
| downsample=eval(args.downsample), | |
| pretrained=pretrained, | |
| trained_decoder=trained_decoder, | |
| ) | |
| model = FConvModelSelfAtt(encoder, decoder, trained_encoder) | |
| return model | |
| def pretrained(self): | |
| return self.pretrained_encoder is not None | |
| class FConvEncoder(FairseqEncoder): | |
| """Convolutional encoder""" | |
| def __init__( | |
| self, | |
| dictionary, | |
| embed_dim=512, | |
| max_positions=1024, | |
| convolutions=((512, 3),) * 20, | |
| dropout=0.1, | |
| attention=False, | |
| attention_nheads=1, | |
| ): | |
| super().__init__(dictionary) | |
| self.dropout_module = FairseqDropout( | |
| dropout, module_name=self.__class__.__name__ | |
| ) | |
| self.num_attention_layers = None | |
| num_embeddings = len(dictionary) | |
| self.padding_idx = dictionary.pad() | |
| self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) | |
| self.embed_positions = PositionalEmbedding( | |
| max_positions, | |
| embed_dim, | |
| self.padding_idx, | |
| ) | |
| def expand_bool_array(val): | |
| if isinstance(val, bool): | |
| # expand True into [True, True, ...] and do the same with False | |
| return [val] * len(convolutions) | |
| return val | |
| attention = expand_bool_array(attention) | |
| in_channels = convolutions[0][0] | |
| self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) | |
| self.projections = nn.ModuleList() | |
| self.convolutions = nn.ModuleList() | |
| self.attention = nn.ModuleList() | |
| self.attproj = nn.ModuleList() | |
| for i, (out_channels, kernel_size) in enumerate(convolutions): | |
| self.projections.append( | |
| Linear(in_channels, out_channels) | |
| if in_channels != out_channels | |
| else None | |
| ) | |
| self.convolutions.append( | |
| ConvTBC(in_channels, out_channels * 2, kernel_size, dropout=dropout) | |
| ) | |
| self.attention.append( | |
| SelfAttention(out_channels, embed_dim, attention_nheads) | |
| if attention[i] | |
| else None | |
| ) | |
| in_channels = out_channels | |
| self.fc2 = Linear(in_channels, embed_dim) | |
| def forward(self, src_tokens, src_lengths): | |
| # embed tokens and positions | |
| x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens) | |
| x = self.dropout_module(x) | |
| input_embedding = x.transpose(0, 1) | |
| # project to size of convolution | |
| x = self.fc1(x) | |
| encoder_padding_mask = src_tokens.eq(self.padding_idx).t() # -> T x B | |
| if not encoder_padding_mask.any(): | |
| encoder_padding_mask = None | |
| # B x T x C -> T x B x C | |
| x = x.transpose(0, 1) | |
| # temporal convolutions | |
| for proj, conv, attention in zip( | |
| self.projections, self.convolutions, self.attention | |
| ): | |
| residual = x if proj is None else proj(x) | |
| if encoder_padding_mask is not None: | |
| x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0) | |
| x = self.dropout_module(x) | |
| padding_l = (conv.kernel_size[0] - 1) // 2 | |
| padding_r = conv.kernel_size[0] // 2 | |
| x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r)) | |
| x = conv(x) | |
| x = F.glu(x, dim=2) | |
| if attention is not None: | |
| x = attention(x) | |
| x = (x + residual) * math.sqrt(0.5) | |
| # T x B x C -> B x T x C | |
| x = x.transpose(1, 0) | |
| # project back to size of embedding | |
| x = self.fc2(x) | |
| if encoder_padding_mask is not None: | |
| encoder_padding_mask = encoder_padding_mask.t() # -> B x T | |
| x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0) | |
| # scale gradients (this only affects backward, not forward) | |
| x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers)) | |
| # add output to input embedding for attention | |
| y = (x + input_embedding.transpose(0, 1)) * math.sqrt(0.5) | |
| return { | |
| "encoder_out": (x, y), | |
| "encoder_padding_mask": encoder_padding_mask, # B x T | |
| } | |
| def reorder_encoder_out(self, encoder_out, new_order): | |
| encoder_out["encoder_out"] = tuple( | |
| eo.index_select(0, new_order) for eo in encoder_out["encoder_out"] | |
| ) | |
| if encoder_out["encoder_padding_mask"] is not None: | |
| encoder_out["encoder_padding_mask"] = encoder_out[ | |
| "encoder_padding_mask" | |
| ].index_select(0, new_order) | |
| if "pretrained" in encoder_out: | |
| encoder_out["pretrained"]["encoder_out"] = tuple( | |
| eo.index_select(0, new_order) | |
| for eo in encoder_out["pretrained"]["encoder_out"] | |
| ) | |
| return encoder_out | |
| def max_positions(self): | |
| """Maximum input length supported by the encoder.""" | |
| return self.embed_positions.max_positions | |
| class FConvDecoder(FairseqDecoder): | |
| """Convolutional decoder""" | |
| def __init__( | |
| self, | |
| dictionary, | |
| embed_dim=512, | |
| out_embed_dim=256, | |
| max_positions=1024, | |
| convolutions=((512, 3),) * 8, | |
| attention=True, | |
| dropout=0.1, | |
| selfattention=False, | |
| attention_nheads=1, | |
| selfattention_nheads=1, | |
| project_input=False, | |
| gated_attention=False, | |
| downsample=False, | |
| pretrained=False, | |
| trained_decoder=None, | |
| ): | |
| super().__init__(dictionary) | |
| self.register_buffer("version", torch.Tensor([2])) | |
| self.pretrained = pretrained | |
| self.pretrained_decoder = trained_decoder | |
| self.dropout_module = FairseqDropout( | |
| dropout, module_name=self.__class__.__name__ | |
| ) | |
| self.need_attn = True | |
| in_channels = convolutions[0][0] | |
| def expand_bool_array(val): | |
| if isinstance(val, bool): | |
| # expand True into [True, True, ...] and do the same with False | |
| return [val] * len(convolutions) | |
| return val | |
| attention = expand_bool_array(attention) | |
| selfattention = expand_bool_array(selfattention) | |
| if not isinstance(attention, list) or len(attention) != len(convolutions): | |
| raise ValueError( | |
| "Attention is expected to be a list of booleans of " | |
| "length equal to the number of layers." | |
| ) | |
| num_embeddings = len(dictionary) | |
| padding_idx = dictionary.pad() | |
| self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) | |
| self.embed_positions = PositionalEmbedding( | |
| max_positions, | |
| embed_dim, | |
| padding_idx, | |
| ) | |
| self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) | |
| self.projections = nn.ModuleList() | |
| self.convolutions = nn.ModuleList() | |
| self.attention = nn.ModuleList() | |
| self.selfattention = nn.ModuleList() | |
| self.attproj = nn.ModuleList() | |
| for i, (out_channels, kernel_size) in enumerate(convolutions): | |
| self.projections.append( | |
| Linear(in_channels, out_channels) | |
| if in_channels != out_channels | |
| else None | |
| ) | |
| self.convolutions.append( | |
| LinearizedConv1d( | |
| in_channels, | |
| out_channels * 2, | |
| kernel_size, | |
| padding=(kernel_size - 1), | |
| dropout=dropout, | |
| ) | |
| ) | |
| self.attention.append( | |
| DownsampledMultiHeadAttention( | |
| out_channels, | |
| embed_dim, | |
| attention_nheads, | |
| project_input=project_input, | |
| gated=False, | |
| downsample=False, | |
| ) | |
| if attention[i] | |
| else None | |
| ) | |
| self.attproj.append( | |
| Linear(out_channels, embed_dim, dropout=dropout) | |
| if attention[i] | |
| else None | |
| ) | |
| self.selfattention.append( | |
| SelfAttention( | |
| out_channels, | |
| embed_dim, | |
| selfattention_nheads, | |
| project_input=project_input, | |
| gated=gated_attention, | |
| downsample=downsample, | |
| ) | |
| if selfattention[i] | |
| else None | |
| ) | |
| in_channels = out_channels | |
| self.fc2 = Linear(in_channels, out_embed_dim) | |
| self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) | |
| # model fusion | |
| if self.pretrained: | |
| # independent gates are learned from the concatenated input | |
| self.gate1 = nn.Sequential( | |
| Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid() | |
| ) | |
| self.gate2 = nn.Sequential( | |
| Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid() | |
| ) | |
| # pretrained and trained models are joined | |
| self.joining = nn.Sequential( | |
| Linear(out_embed_dim * 2, out_embed_dim * 2), | |
| LayerNorm(out_embed_dim * 2), | |
| nn.GLU(), | |
| Linear(out_embed_dim, out_embed_dim * 2), | |
| LayerNorm(out_embed_dim * 2), | |
| nn.GLU(), | |
| Linear(out_embed_dim, out_embed_dim), | |
| LayerNorm(out_embed_dim), | |
| ) | |
| # pretrained model contains an output layer that is nhid -> vocab size | |
| # but the models are combined in their hidden state | |
| # the hook stores the output of the pretrained model forward | |
| self.pretrained_outputs = {} | |
| def save_output(): | |
| def hook(a, b, output): | |
| self.pretrained_outputs["out"] = output | |
| return hook | |
| self.pretrained_decoder.fc2.register_forward_hook(save_output()) | |
| def forward(self, prev_output_tokens, encoder_out): | |
| trained_encoder_out = encoder_out["pretrained"] if self.pretrained else None | |
| encoder_out = encoder_out["encoder"]["encoder_out"] | |
| encoder_a, encoder_b = self._split_encoder_out(encoder_out) | |
| # embed positions | |
| positions = self.embed_positions(prev_output_tokens) | |
| # embed tokens and positions | |
| x = self.embed_tokens(prev_output_tokens) + positions | |
| x = self.dropout_module(x) | |
| target_embedding = x.transpose(0, 1) | |
| # project to size of convolution | |
| x = self.fc1(x) | |
| # B x T x C -> T x B x C | |
| x = x.transpose(0, 1) | |
| # temporal convolutions | |
| avg_attn_scores = None | |
| for proj, conv, attention, selfattention, attproj in zip( | |
| self.projections, | |
| self.convolutions, | |
| self.attention, | |
| self.selfattention, | |
| self.attproj, | |
| ): | |
| residual = x if proj is None else proj(x) | |
| x = self.dropout_module(x) | |
| x = conv(x) | |
| x = F.glu(x, dim=2) | |
| # attention | |
| if attention is not None: | |
| r = x | |
| x, attn_scores = attention( | |
| attproj(x) + target_embedding, encoder_a, encoder_b | |
| ) | |
| x = x + r | |
| if not self.training and self.need_attn: | |
| if avg_attn_scores is None: | |
| avg_attn_scores = attn_scores | |
| else: | |
| avg_attn_scores.add_(attn_scores) | |
| if selfattention is not None: | |
| x = selfattention(x) | |
| x = (x + residual) * math.sqrt(0.5) | |
| # T x B x C -> B x T x C | |
| x = x.transpose(0, 1) | |
| # project back to size of vocabulary | |
| x = self.fc2(x) | |
| x = self.dropout_module(x) | |
| if not self.pretrained: | |
| x = self.fc3(x) | |
| # fusion gating | |
| if self.pretrained: | |
| trained_x, _ = self.pretrained_decoder.forward( | |
| prev_output_tokens, trained_encoder_out | |
| ) | |
| y = torch.cat([x, self.pretrained_outputs["out"]], dim=-1) | |
| gate1 = self.gate1(y) | |
| gate2 = self.gate2(y) | |
| gated_x1 = gate1 * x | |
| gated_x2 = gate2 * self.pretrained_outputs["out"] | |
| fusion = torch.cat([gated_x1, gated_x2], dim=-1) | |
| fusion = self.joining(fusion) | |
| fusion_output = self.fc3(fusion) | |
| return fusion_output, avg_attn_scores | |
| else: | |
| return x, avg_attn_scores | |
| def max_positions(self): | |
| """Maximum output length supported by the decoder.""" | |
| return self.embed_positions.max_positions | |
| def make_generation_fast_(self, need_attn=False, **kwargs): | |
| self.need_attn = need_attn | |
| def _split_encoder_out(self, encoder_out): | |
| """Split and transpose encoder outputs.""" | |
| # transpose only once to speed up attention layers | |
| encoder_a, encoder_b = encoder_out | |
| encoder_a = encoder_a.transpose(0, 1).contiguous() | |
| encoder_b = encoder_b.transpose(0, 1).contiguous() | |
| result = (encoder_a, encoder_b) | |
| return result | |
| class SelfAttention(nn.Module): | |
| def __init__( | |
| self, | |
| out_channels, | |
| embed_dim, | |
| num_heads, | |
| project_input=False, | |
| gated=False, | |
| downsample=False, | |
| ): | |
| super().__init__() | |
| self.attention = DownsampledMultiHeadAttention( | |
| out_channels, | |
| embed_dim, | |
| num_heads, | |
| dropout=0, | |
| bias=True, | |
| project_input=project_input, | |
| gated=gated, | |
| downsample=downsample, | |
| ) | |
| self.in_proj_q = Linear(out_channels, embed_dim) | |
| self.in_proj_k = Linear(out_channels, embed_dim) | |
| self.in_proj_v = Linear(out_channels, embed_dim) | |
| self.ln = LayerNorm(out_channels) | |
| def forward(self, x): | |
| residual = x | |
| query = self.in_proj_q(x) | |
| key = self.in_proj_k(x) | |
| value = self.in_proj_v(x) | |
| x, _ = self.attention( | |
| query, key, value, mask_future_timesteps=True, use_scalar_bias=True | |
| ) | |
| return self.ln(x + residual) | |
| def Embedding(num_embeddings, embedding_dim, padding_idx): | |
| m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) | |
| m.weight.data.normal_(0, 0.1) | |
| return m | |
| def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx): | |
| m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) | |
| m.weight.data.normal_(0, 0.1) | |
| return m | |
| def Linear(in_features, out_features, dropout=0.0): | |
| """Weight-normalized Linear layer (input: N x T x C)""" | |
| m = nn.Linear(in_features, out_features) | |
| m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features)) | |
| m.bias.data.zero_() | |
| return m | |
| def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs): | |
| """Weight-normalized Conv1d layer optimized for decoding""" | |
| m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs) | |
| std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) | |
| m.weight.data.normal_(mean=0, std=std) | |
| m.bias.data.zero_() | |
| return m | |
| def ConvTBC(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs): | |
| """Weight-normalized Conv1d layer""" | |
| from fairseq.modules import ConvTBC | |
| m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs) | |
| std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) | |
| m.weight.data.normal_(mean=0, std=std) | |
| m.bias.data.zero_() | |
| return m | |
| def base_architecture(args): | |
| args.dropout = getattr(args, "dropout", 0.1) | |
| args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) | |
| args.encoder_layers = getattr(args, "encoder_layers", "[(512, 3)] * 3") | |
| args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) | |
| args.decoder_layers = getattr(args, "decoder_layers", "[(512, 3)] * 8") | |
| args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) | |
| args.decoder_attention = getattr(args, "decoder_attention", "True") | |
| args.self_attention = getattr(args, "self_attention", "False") | |
| args.encoder_attention = getattr(args, "encoder_attention", "False") | |
| args.multihead_attention_nheads = getattr(args, "multihead_attention_nheads", 1) | |
| args.multihead_self_attention_nheads = getattr( | |
| args, "multihead_self_attention_nheads", 1 | |
| ) | |
| args.encoder_attention_nheads = getattr(args, "encoder_attention_nheads", 1) | |
| args.project_input = getattr(args, "project_input", "False") | |
| args.gated_attention = getattr(args, "gated_attention", "False") | |
| args.downsample = getattr(args, "downsample", "False") | |
| args.pretrained_checkpoint = getattr(args, "pretrained_checkpoint", "") | |
| args.pretrained = getattr(args, "pretrained", "False") | |
| def fconv_self_att_wp(args): | |
| args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) | |
| args.encoder_layers = getattr( | |
| args, "encoder_layers", "[(128, 3)] * 2 + [(512,3)] * 1" | |
| ) | |
| args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) | |
| args.decoder_layers = getattr( | |
| args, "decoder_layers", "[(512, 4)] * 4 + [(768, 4)] * 2 + [(1024, 4)] * 1" | |
| ) | |
| args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) | |
| args.self_attention = getattr(args, "self_attention", "True") | |
| args.multihead_self_attention_nheads = getattr( | |
| args, "multihead_self_attention_nheads", 4 | |
| ) | |
| args.project_input = getattr(args, "project_input", "True") | |
| args.gated_attention = getattr(args, "gated_attention", "True") | |
| args.downsample = getattr(args, "downsample", "True") | |
| base_architecture(args) | |