Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from ..roberta.model_xlmr import XLMRModel | |
| from fairseq.models.xmod.transformer_layer_xmod import XMODTransformerEncoderLayerBase | |
| from ..roberta.model import base_architecture, RobertaEncoder | |
| from fairseq.models.transformer import TransformerEncoder | |
| from fairseq.modules.transformer_sentence_encoder import init_bert_params | |
| from typing import Optional | |
| from fairseq.models.xmod.hub_interface import XMODHubInterface | |
| import torch | |
| from fairseq.distributed import fsdp_wrap | |
| from fairseq.models import ( | |
| register_model, | |
| register_model_architecture, | |
| ) | |
| from fairseq.modules.checkpoint_activations import checkpoint_wrapper | |
| DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) | |
| class XMODModel(XLMRModel): | |
| def hub_models(cls): | |
| return { | |
| "xmod.base": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.81.1M.tar.gz", | |
| "xmod.large.prenorm": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.large.prenorm.81.500k.tar.gz", | |
| "xmod.base.13.125k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.13.125k.tar.gz", | |
| "xmod.base.30.125k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.30.125k.tar.gz", | |
| "xmod.base.30.195k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.30.195k.tar.gz", | |
| "xmod.base.60.125k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.60.125k.tar.gz", | |
| "xmod.base.60.265k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.60.265k.tar.gz", | |
| "xmod.base.75.125k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.75.125k.tar.gz", | |
| "xmod.base.75.269k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.75.269k.tar.gz", | |
| } | |
| def from_pretrained( | |
| cls, | |
| model_name_or_path, | |
| checkpoint_file="model.pt", | |
| data_name_or_path=".", | |
| bpe="sentencepiece", | |
| **kwargs, | |
| ): | |
| from fairseq import hub_utils | |
| x = hub_utils.from_pretrained( | |
| model_name_or_path, | |
| checkpoint_file, | |
| data_name_or_path, | |
| archive_map=cls.hub_models(), | |
| bpe=bpe, | |
| load_checkpoint_heads=True, | |
| **kwargs, | |
| ) | |
| return XMODHubInterface(x["args"], x["task"], x["models"][0]) | |
| def build_model(cls, args, task): | |
| """Build a new model instance.""" | |
| from omegaconf import OmegaConf | |
| if OmegaConf.is_config(args): | |
| OmegaConf.set_struct(args, False) | |
| # make sure all arguments are present | |
| base_architecture(args) | |
| if not hasattr(args, "max_positions"): | |
| if not hasattr(args, "tokens_per_sample"): | |
| args.tokens_per_sample = task.max_positions() | |
| args.max_positions = args.tokens_per_sample | |
| encoder = XMODEncoder(args, task.source_dictionary) | |
| if OmegaConf.is_config(args): | |
| OmegaConf.set_struct(args, True) | |
| return cls(args, encoder) | |
| def forward( | |
| self, | |
| src_tokens, | |
| features_only=False, | |
| return_all_hiddens=False, | |
| classification_head_name=None, | |
| lang_id=None, | |
| **kwargs, | |
| ): | |
| if classification_head_name is not None: | |
| features_only = True | |
| x, extra = self.encoder( | |
| src_tokens, features_only, return_all_hiddens, lang_id=lang_id, **kwargs | |
| ) | |
| if classification_head_name is not None: | |
| x = self.classification_heads[classification_head_name](x) | |
| return x, extra | |
| class XMODEncoder(RobertaEncoder): | |
| """XMOD encoder.""" | |
| def build_encoder(self, args, dictionary, embed_tokens): | |
| encoder = XMODTransformerEncoder(args, dictionary, embed_tokens) | |
| encoder.apply(init_bert_params) | |
| return encoder | |
| def forward( | |
| self, | |
| src_tokens, | |
| features_only=False, | |
| return_all_hiddens=False, | |
| masked_tokens=None, | |
| lang_id=None, | |
| **unused, | |
| ): | |
| """ | |
| Args: | |
| src_tokens (LongTensor): input tokens of shape `(batch, src_len)` | |
| features_only (bool, optional): skip LM head and just return | |
| features. If True, the output will be of shape | |
| `(batch, src_len, embed_dim)`. | |
| return_all_hiddens (bool, optional): also return all of the | |
| intermediate hidden states (default: False). | |
| Returns: | |
| tuple: | |
| - the LM output of shape `(batch, src_len, vocab)` | |
| - a dictionary of additional data, where 'inner_states' | |
| is a list of hidden states. Note that the hidden | |
| states have shape `(src_len, batch, vocab)`. | |
| """ | |
| x, extra = self.extract_features( | |
| src_tokens, return_all_hiddens=return_all_hiddens, lang_id=lang_id | |
| ) | |
| if not features_only: | |
| x = self.output_layer(x, masked_tokens=masked_tokens) | |
| return x, extra | |
| def extract_features( | |
| self, src_tokens, return_all_hiddens=False, lang_id=None, **kwargs | |
| ): | |
| encoder_out = self.sentence_encoder( | |
| src_tokens, | |
| return_all_hiddens=return_all_hiddens, | |
| lang_id=lang_id, | |
| token_embeddings=kwargs.get("token_embeddings", None), | |
| ) | |
| # T x B x C -> B x T x C | |
| features = encoder_out["encoder_out"][0].transpose(0, 1) | |
| inner_states = encoder_out["encoder_states"] if return_all_hiddens else None | |
| return features, {"inner_states": inner_states} | |
| class XMODTransformerEncoder(TransformerEncoder): | |
| def build_encoder_layer(self, cfg): | |
| layer = XMODTransformerEncoderLayerBase(cfg) | |
| checkpoint = cfg.checkpoint_activations | |
| if checkpoint: | |
| offload_to_cpu = cfg.offload_activations | |
| layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) | |
| # if we are checkpointing, enforce that FSDP always wraps the | |
| # checkpointed layer, regardless of layer size | |
| min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 | |
| layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) | |
| return layer | |
| def forward( | |
| self, | |
| src_tokens, | |
| src_lengths: Optional[torch.Tensor] = None, | |
| return_all_hiddens: bool = False, | |
| token_embeddings: Optional[torch.Tensor] = None, | |
| lang_id=None, | |
| ): | |
| """ | |
| Args: | |
| src_tokens (LongTensor): tokens in the source language of shape | |
| `(batch, src_len)` | |
| src_lengths (torch.LongTensor): lengths of each source sentence of | |
| shape `(batch)` | |
| return_all_hiddens (bool, optional): also return all of the | |
| intermediate hidden states (default: False). | |
| token_embeddings (torch.Tensor, optional): precomputed embeddings | |
| default `None` will recompute embeddings | |
| Returns: | |
| dict: | |
| - **encoder_out** (Tensor): the last encoder layer's output of | |
| shape `(src_len, batch, embed_dim)` | |
| - **encoder_padding_mask** (ByteTensor): the positions of | |
| padding elements of shape `(batch, src_len)` | |
| - **encoder_embedding** (Tensor): the (scaled) embedding lookup | |
| of shape `(batch, src_len, embed_dim)` | |
| - **encoder_states** (List[Tensor]): all intermediate | |
| hidden states of shape `(src_len, batch, embed_dim)`. | |
| Only populated if *return_all_hiddens* is True. | |
| """ | |
| return self.forward_scriptable( | |
| src_tokens, | |
| src_lengths, | |
| return_all_hiddens, | |
| token_embeddings, | |
| lang_id=lang_id, | |
| ) | |
| # TorchScript doesn't support super() method so that the scriptable Subclass | |
| # can't access the base class model in Torchscript. | |
| # Current workaround is to add a helper function with different name and | |
| # call the helper function from scriptable Subclass. | |
| def forward_scriptable( | |
| self, | |
| src_tokens, | |
| src_lengths: Optional[torch.Tensor] = None, | |
| return_all_hiddens: bool = False, | |
| token_embeddings: Optional[torch.Tensor] = None, | |
| lang_id=None, | |
| ): | |
| """ | |
| Args: | |
| src_tokens (LongTensor): tokens in the source language of shape | |
| `(batch, src_len)` | |
| src_lengths (torch.LongTensor): lengths of each source sentence of | |
| shape `(batch)` | |
| return_all_hiddens (bool, optional): also return all of the | |
| intermediate hidden states (default: False). | |
| token_embeddings (torch.Tensor, optional): precomputed embeddings | |
| default `None` will recompute embeddings | |
| Returns: | |
| dict: | |
| - **encoder_out** (Tensor): the last encoder layer's output of | |
| shape `(src_len, batch, embed_dim)` | |
| - **encoder_padding_mask** (ByteTensor): the positions of | |
| padding elements of shape `(batch, src_len)` | |
| - **encoder_embedding** (Tensor): the (scaled) embedding lookup | |
| of shape `(batch, src_len, embed_dim)` | |
| - **encoder_states** (List[Tensor]): all intermediate | |
| hidden states of shape `(src_len, batch, embed_dim)`. | |
| Only populated if *return_all_hiddens* is True. | |
| """ | |
| # compute padding mask | |
| encoder_padding_mask = src_tokens.eq(self.padding_idx) | |
| has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any() | |
| x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) | |
| # account for padding while computing the representation | |
| if has_pads: | |
| x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) | |
| # B x T x C -> T x B x C | |
| x = x.transpose(0, 1) | |
| encoder_states = [] | |
| if return_all_hiddens: | |
| encoder_states.append(x) | |
| # encoder layers | |
| for layer in self.layers: | |
| x = layer( | |
| x, | |
| encoder_padding_mask=encoder_padding_mask if has_pads else None, | |
| lang_id=lang_id, | |
| ) | |
| if return_all_hiddens: | |
| assert encoder_states is not None | |
| encoder_states.append(x) | |
| if self.layer_norm is not None: | |
| x = self.layer_norm(x) | |
| # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in | |
| # `forward` so we use a dictionary instead. | |
| # TorchScript does not support mixed values so the values are all lists. | |
| # The empty list is equivalent to None. | |
| src_lengths = ( | |
| src_tokens.ne(self.padding_idx) | |
| .sum(dim=1, dtype=torch.int32) | |
| .reshape(-1, 1) | |
| .contiguous() | |
| ) | |
| return { | |
| "encoder_out": [x], # T x B x C | |
| "encoder_padding_mask": [encoder_padding_mask], # B x T | |
| "encoder_embedding": [encoder_embedding], # B x T x C | |
| "encoder_states": encoder_states, # List[T x B x C] | |
| "src_tokens": [], | |
| "src_lengths": [src_lengths], | |
| } | |
| def roberta_base_architecture(args): | |
| args.ffn_modules = getattr(args, "ffn_modules", False) | |
| args.adapter_modules = getattr(args, "adapter_modules", True) | |
| args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False) | |
| args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True) | |
| args.ln_before_adapter = getattr(args, "ln_before_adapter", True) | |
| args.languages = getattr( | |
| args, | |
| "languages", | |
| [ | |
| "ar_AR", | |
| "en_XX", | |
| "fi_FI", | |
| "fr_XX", | |
| "hi_IN", | |
| "id_ID", | |
| "ka_GE", | |
| "ko_KR", | |
| "ru_RU", | |
| "sw_KE", | |
| "ta_IN", | |
| "th_TH", | |
| "vi_VN", | |
| ], | |
| ) | |
| base_architecture(args) | |
| def roberta_base_architecture(args): | |
| args.ffn_modules = getattr(args, "ffn_modules", False) | |
| args.adapter_modules = getattr(args, "adapter_modules", True) | |
| args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False) | |
| args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True) | |
| args.ln_before_adapter = getattr(args, "ln_before_adapter", True) | |
| args.languages = getattr( | |
| args, | |
| "languages", | |
| [ | |
| "ar_AR", | |
| "cs_CZ", | |
| "en_XX", | |
| "eu_ES", | |
| "fi_FI", | |
| "fr_XX", | |
| "hi_IN", | |
| "hr_HR", | |
| "hu_HU", | |
| "hy_AM", | |
| "id_ID", | |
| "it_IT", | |
| "ka_GE", | |
| "ko_KR", | |
| "lt_LT", | |
| "ml_IN", | |
| "mn_MN", | |
| "ms_MY", | |
| "pl_PL", | |
| "ro_RO", | |
| "ru_RU", | |
| "si_LK", | |
| "sk_SK", | |
| "sq_AL", | |
| "sv_SE", | |
| "sw_KE", | |
| "ta_IN", | |
| "th_TH", | |
| "tl_XX", | |
| "vi_VN", | |
| ], | |
| ) | |
| base_architecture(args) | |
| def roberta_base_architecture(args): | |
| args.ffn_modules = getattr(args, "ffn_modules", False) | |
| args.adapter_modules = getattr(args, "adapter_modules", True) | |
| args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False) | |
| args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True) | |
| args.ln_before_adapter = getattr(args, "ln_before_adapter", True) | |
| args.languages = getattr( | |
| args, | |
| "languages", | |
| [ | |
| "af_ZA", | |
| "am_ET", | |
| "ar_AR", | |
| "be_BY", | |
| "bn_IN", | |
| "ca_ES", | |
| "cs_CZ", | |
| "cy_GB", | |
| "da_DK", | |
| "en_XX", | |
| "eo_EO", | |
| "et_EE", | |
| "eu_ES", | |
| "fa_IR", | |
| "fi_FI", | |
| "fr_XX", | |
| "ga_IE", | |
| "gl_ES", | |
| "gu_IN", | |
| "ha_NG", | |
| "hi_IN", | |
| "hr_HR", | |
| "hu_HU", | |
| "hy_AM", | |
| "id_ID", | |
| "is_IS", | |
| "it_IT", | |
| "ka_GE", | |
| "ko_KR", | |
| "ku_TR", | |
| "la_VA", | |
| "lt_LT", | |
| "lv_LV", | |
| "mk_MK", | |
| "ml_IN", | |
| "mn_MN", | |
| "ms_MY", | |
| "ne_NP", | |
| "nl_XX", | |
| "no_XX", | |
| "pl_PL", | |
| "ps_AF", | |
| "pt_XX", | |
| "ro_RO", | |
| "ru_RU", | |
| "sa_IN", | |
| "sd_PK", | |
| "si_LK", | |
| "sk_SK", | |
| "sl_SI", | |
| "so_SO", | |
| "sq_AL", | |
| "sr_RS", | |
| "sv_SE", | |
| "sw_KE", | |
| "ta_IN", | |
| "te_IN", | |
| "th_TH", | |
| "tl_XX", | |
| "vi_VN", | |
| ], | |
| ) | |
| base_architecture(args) | |
| def roberta_base_architecture(args): | |
| args.ffn_modules = getattr(args, "ffn_modules", False) | |
| args.adapter_modules = getattr(args, "adapter_modules", True) | |
| args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False) | |
| args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True) | |
| args.ln_before_adapter = getattr(args, "ln_before_adapter", True) | |
| args.languages = getattr( | |
| args, | |
| "languages", | |
| [ | |
| "af_ZA", | |
| "am_ET", | |
| "ar_AR", | |
| "as_IN", | |
| "be_BY", | |
| "bn_IN", | |
| "br_FR", | |
| "bs_BA", | |
| "ca_ES", | |
| "cs_CZ", | |
| "cy_GB", | |
| "da_DK", | |
| "en_XX", | |
| "eo_EO", | |
| "et_EE", | |
| "eu_ES", | |
| "fa_IR", | |
| "fi_FI", | |
| "fr_XX", | |
| "fy_NL", | |
| "ga_IE", | |
| "gd_GB", | |
| "gl_ES", | |
| "gu_IN", | |
| "ha_NG", | |
| "hi_IN", | |
| "hr_HR", | |
| "hu_HU", | |
| "hy_AM", | |
| "id_ID", | |
| "is_IS", | |
| "it_IT", | |
| "jv_ID", | |
| "ka_GE", | |
| "kn_IN", | |
| "ko_KR", | |
| "ku_TR", | |
| "la_VA", | |
| "lt_LT", | |
| "lv_LV", | |
| "mg_MG", | |
| "mk_MK", | |
| "ml_IN", | |
| "mn_MN", | |
| "mr_IN", | |
| "ms_MY", | |
| "ne_NP", | |
| "nl_XX", | |
| "no_XX", | |
| "om_KE", | |
| "or_IN", | |
| "pa_IN", | |
| "pl_PL", | |
| "ps_AF", | |
| "pt_XX", | |
| "ro_RO", | |
| "ru_RU", | |
| "sa_IN", | |
| "sd_PK", | |
| "si_LK", | |
| "sk_SK", | |
| "sl_SI", | |
| "so_SO", | |
| "sq_AL", | |
| "sr_RS", | |
| "su_ID", | |
| "sv_SE", | |
| "sw_KE", | |
| "ta_IN", | |
| "te_IN", | |
| "th_TH", | |
| "tl_XX", | |
| "vi_VN", | |
| "xh_ZA", | |
| "yi_DE", | |
| ], | |
| ) | |
| base_architecture(args) | |
| def roberta_base_architecture(args): | |
| args.ffn_modules = getattr(args, "ffn_modules", False) | |
| args.adapter_modules = getattr(args, "adapter_modules", True) | |
| args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False) | |
| args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True) | |
| args.ln_before_adapter = getattr(args, "ln_before_adapter", True) | |
| args.languages = getattr( | |
| args, | |
| "languages", | |
| [ | |
| "en_XX", | |
| "id_ID", | |
| "vi_VN", | |
| "ru_RU", | |
| "fa_IR", | |
| "sv_SE", | |
| "ja_XX", | |
| "fr_XX", | |
| "de_DE", | |
| "ro_RO", | |
| "ko_KR", | |
| "hu_HU", | |
| "es_XX", | |
| "fi_FI", | |
| "uk_UA", | |
| "da_DK", | |
| "pt_XX", | |
| "no_XX", | |
| "th_TH", | |
| "pl_PL", | |
| "bg_BG", | |
| "nl_XX", | |
| "zh_CN", | |
| "he_IL", | |
| "el_GR", | |
| "it_IT", | |
| "sk_SK", | |
| "hr_HR", | |
| "tr_TR", | |
| "ar_AR", | |
| "cs_CZ", | |
| "lt_LT", | |
| "hi_IN", | |
| "zh_TW", | |
| "ca_ES", | |
| "ms_MY", | |
| "sl_SI", | |
| "lv_LV", | |
| "ta_IN", | |
| "bn_IN", | |
| "et_EE", | |
| "az_AZ", | |
| "sq_AL", | |
| "sr_RS", | |
| "kk_KZ", | |
| "ka_GE", | |
| "tl_XX", | |
| "ur_PK", | |
| "is_IS", | |
| "hy_AM", | |
| "ml_IN", | |
| "mk_MK", | |
| "be_BY", | |
| "la_VA", | |
| "te_IN", | |
| "eu_ES", | |
| "gl_ES", | |
| "mn_MN", | |
| "kn_IN", | |
| "ne_NP", | |
| "sw_KE", | |
| "si_LK", | |
| "mr_IN", | |
| "af_ZA", | |
| "gu_IN", | |
| "cy_GB", | |
| "eo_EO", | |
| "km_KH", | |
| "ky_KG", | |
| "uz_UZ", | |
| "ps_AF", | |
| "pa_IN", | |
| "ga_IE", | |
| "ha_NG", | |
| "am_ET", | |
| "lo_LA", | |
| "ku_TR", | |
| "so_SO", | |
| "my_MM", | |
| "or_IN", | |
| "sa_IN", | |
| ], | |
| ) | |
| base_architecture(args) | |
| def roberta_base_architecture(args): | |
| args.ffn_modules = getattr(args, "ffn_modules", False) | |
| args.adapter_modules = getattr(args, "adapter_modules", True) | |
| args.adapter_layer_norm = getattr(args, "adapter_layer_norm", True) | |
| args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", False) | |
| args.ln_before_adapter = getattr(args, "ln_before_adapter", False) | |
| # args.bottleneck = getattr(args, "bottleneck", 8) | |
| args.bottleneck = getattr(args, "bottleneck", 4) | |
| args.languages = getattr( | |
| args, | |
| "languages", | |
| [ | |
| "en_XX", | |
| "id_ID", | |
| "vi_VN", | |
| "ru_RU", | |
| "fa_IR", | |
| "sv_SE", | |
| "ja_XX", | |
| "fr_XX", | |
| "de_DE", | |
| "ro_RO", | |
| "ko_KR", | |
| "hu_HU", | |
| "es_XX", | |
| "fi_FI", | |
| "uk_UA", | |
| "da_DK", | |
| "pt_XX", | |
| "no_XX", | |
| "th_TH", | |
| "pl_PL", | |
| "bg_BG", | |
| "nl_XX", | |
| "zh_CN", | |
| "he_IL", | |
| "el_GR", | |
| "it_IT", | |
| "sk_SK", | |
| "hr_HR", | |
| "tr_TR", | |
| "ar_AR", | |
| "cs_CZ", | |
| "lt_LT", | |
| "hi_IN", | |
| "zh_TW", | |
| "ca_ES", | |
| "ms_MY", | |
| "sl_SI", | |
| "lv_LV", | |
| "ta_IN", | |
| "bn_IN", | |
| "et_EE", | |
| "az_AZ", | |
| "sq_AL", | |
| "sr_RS", | |
| "kk_KZ", | |
| "ka_GE", | |
| "tl_XX", | |
| "ur_PK", | |
| "is_IS", | |
| "hy_AM", | |
| "ml_IN", | |
| "mk_MK", | |
| "be_BY", | |
| "la_VA", | |
| "te_IN", | |
| "eu_ES", | |
| "gl_ES", | |
| "mn_MN", | |
| "kn_IN", | |
| "ne_NP", | |
| "sw_KE", | |
| "si_LK", | |
| "mr_IN", | |
| "af_ZA", | |
| "gu_IN", | |
| "cy_GB", | |
| "eo_EO", | |
| "km_KH", | |
| "ky_KG", | |
| "uz_UZ", | |
| "ps_AF", | |
| "pa_IN", | |
| "ga_IE", | |
| "ha_NG", | |
| "am_ET", | |
| "lo_LA", | |
| "ku_TR", | |
| "so_SO", | |
| "my_MM", | |
| "or_IN", | |
| "sa_IN", | |
| ], | |
| ) | |
| args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) | |
| args.encoder_layers = getattr(args, "encoder_layers", 24) | |
| args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) | |
| args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) | |
| args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) | |
| base_architecture(args) | |