|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from fairseq.dataclass.utils import gen_parser_from_dataclass |
|
|
from fairseq.models import ( |
|
|
register_model, |
|
|
register_model_architecture, |
|
|
) |
|
|
from fairseq.models.transformer.transformer_config import ( |
|
|
TransformerConfig, |
|
|
DEFAULT_MAX_SOURCE_POSITIONS, |
|
|
DEFAULT_MAX_TARGET_POSITIONS, |
|
|
DEFAULT_MIN_PARAMS_TO_WRAP, |
|
|
) |
|
|
from fairseq.models.transformer.transformer_base import ( |
|
|
TransformerModelBase, |
|
|
) |
|
|
|
|
|
|
|
|
@register_model("transformer") |
|
|
class TransformerModel(TransformerModelBase): |
|
|
""" |
|
|
This is the legacy implementation of the transformer model that |
|
|
uses argparse for configuration. |
|
|
""" |
|
|
|
|
|
@classmethod |
|
|
def hub_models(cls): |
|
|
|
|
|
|
|
|
def moses_subword(path): |
|
|
return { |
|
|
'path': path, |
|
|
'tokenizer': 'moses', |
|
|
'bpe': 'subword_nmt', |
|
|
} |
|
|
|
|
|
def moses_fastbpe(path): |
|
|
return { |
|
|
'path': path, |
|
|
'tokenizer': 'moses', |
|
|
'bpe': 'fastbpe', |
|
|
} |
|
|
|
|
|
def spm(path): |
|
|
return { |
|
|
'path': path, |
|
|
'bpe': 'sentencepiece', |
|
|
'tokenizer': 'space', |
|
|
} |
|
|
|
|
|
return { |
|
|
'transformer.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2'), |
|
|
'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2', |
|
|
'transformer.wmt18.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz'), |
|
|
'transformer.wmt19.en-de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz'), |
|
|
'transformer.wmt19.en-ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz'), |
|
|
'transformer.wmt19.de-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz'), |
|
|
'transformer.wmt19.ru-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz'), |
|
|
'transformer.wmt19.en-de.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz'), |
|
|
'transformer.wmt19.en-ru.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz'), |
|
|
'transformer.wmt19.de-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz'), |
|
|
'transformer.wmt19.ru-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz'), |
|
|
'transformer.wmt20.en-ta': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-ta.single.tar.gz'), |
|
|
'transformer.wmt20.en-iu.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.news.single.tar.gz'), |
|
|
'transformer.wmt20.en-iu.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.nh.single.tar.gz'), |
|
|
'transformer.wmt20.ta-en': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta-en.single.tar.gz'), |
|
|
'transformer.wmt20.iu-en.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.news.single.tar.gz'), |
|
|
'transformer.wmt20.iu-en.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.nh.single.tar.gz'), |
|
|
'transformer.flores101.mm100.615M': spm('https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_615M.tar.gz'), |
|
|
'transformer.flores101.mm100.175M': spm('https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_175M.tar.gz'), |
|
|
} |
|
|
|
|
|
|
|
|
def __init__(self, args, encoder, decoder): |
|
|
cfg = TransformerConfig.from_namespace(args) |
|
|
super().__init__(cfg, encoder, decoder) |
|
|
self.args = args |
|
|
|
|
|
@classmethod |
|
|
def add_args(cls, parser): |
|
|
"""Add model-specific arguments to the parser.""" |
|
|
|
|
|
|
|
|
gen_parser_from_dataclass( |
|
|
parser, TransformerConfig(), delete_default=True, with_prefix="" |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def build_model(cls, args, task): |
|
|
"""Build a new model instance.""" |
|
|
|
|
|
|
|
|
base_architecture(args) |
|
|
|
|
|
if args.encoder_layers_to_keep: |
|
|
args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) |
|
|
if args.decoder_layers_to_keep: |
|
|
args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) |
|
|
|
|
|
if getattr(args, "max_source_positions", None) is None: |
|
|
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS |
|
|
if getattr(args, "max_target_positions", None) is None: |
|
|
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS |
|
|
|
|
|
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary |
|
|
|
|
|
if args.share_all_embeddings: |
|
|
if src_dict != tgt_dict: |
|
|
raise ValueError("--share-all-embeddings requires a joined dictionary") |
|
|
if args.encoder_embed_dim != args.decoder_embed_dim: |
|
|
raise ValueError( |
|
|
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" |
|
|
) |
|
|
if args.decoder_embed_path and ( |
|
|
args.decoder_embed_path != args.encoder_embed_path |
|
|
): |
|
|
raise ValueError( |
|
|
"--share-all-embeddings not compatible with --decoder-embed-path" |
|
|
) |
|
|
args.share_decoder_input_output_embed = True |
|
|
|
|
|
if getattr(args, "offload_activations", False): |
|
|
args.checkpoint_activations = True |
|
|
|
|
|
if not args.share_all_embeddings: |
|
|
args.min_params_to_wrap = getattr( |
|
|
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP |
|
|
) |
|
|
cfg = TransformerConfig.from_namespace(args) |
|
|
return super().build_model(cfg, task) |
|
|
|
|
|
@classmethod |
|
|
def build_embedding(cls, args, dictionary, embed_dim, path=None): |
|
|
return super().build_embedding( |
|
|
TransformerConfig.from_namespace(args), dictionary, embed_dim, path |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def build_encoder(cls, args, src_dict, embed_tokens): |
|
|
return super().build_encoder( |
|
|
TransformerConfig.from_namespace(args), src_dict, embed_tokens |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def build_decoder(cls, args, tgt_dict, embed_tokens): |
|
|
return super().build_decoder( |
|
|
TransformerConfig.from_namespace(args), tgt_dict, embed_tokens |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model_architecture("transformer", "transformer_tiny") |
|
|
def tiny_architecture(args): |
|
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 64) |
|
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 64) |
|
|
args.encoder_layers = getattr(args, "encoder_layers", 2) |
|
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 2) |
|
|
args.decoder_layers = getattr(args, "decoder_layers", 2) |
|
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 2) |
|
|
return base_architecture(args) |
|
|
|
|
|
|
|
|
@register_model_architecture("transformer", "transformer") |
|
|
def base_architecture(args): |
|
|
args.encoder_embed_path = getattr(args, "encoder_embed_path", None) |
|
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) |
|
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) |
|
|
args.encoder_layers = getattr(args, "encoder_layers", 6) |
|
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) |
|
|
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) |
|
|
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) |
|
|
|
|
|
args.decoder_embed_path = getattr(args, "decoder_embed_path", None) |
|
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) |
|
|
args.decoder_ffn_embed_dim = getattr( |
|
|
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim |
|
|
) |
|
|
args.decoder_layers = getattr(args, "decoder_layers", 6) |
|
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) |
|
|
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) |
|
|
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) |
|
|
args.attention_dropout = getattr(args, "attention_dropout", 0.0) |
|
|
args.activation_dropout = getattr(args, "activation_dropout", 0.0) |
|
|
args.activation_fn = getattr(args, "activation_fn", "relu") |
|
|
args.dropout = getattr(args, "dropout", 0.1) |
|
|
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) |
|
|
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) |
|
|
args.share_decoder_input_output_embed = getattr( |
|
|
args, "share_decoder_input_output_embed", False |
|
|
) |
|
|
args.share_all_embeddings = getattr(args, "share_all_embeddings", False) |
|
|
args.merge_src_tgt_embed = getattr(args, "merge_src_tgt_embed", False) |
|
|
args.no_token_positional_embeddings = getattr( |
|
|
args, "no_token_positional_embeddings", False |
|
|
) |
|
|
args.adaptive_input = getattr(args, "adaptive_input", False) |
|
|
args.no_cross_attention = getattr(args, "no_cross_attention", False) |
|
|
args.cross_self_attention = getattr(args, "cross_self_attention", False) |
|
|
|
|
|
args.decoder_output_dim = getattr( |
|
|
args, "decoder_output_dim", args.decoder_embed_dim |
|
|
) |
|
|
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) |
|
|
|
|
|
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) |
|
|
args.layernorm_embedding = getattr(args, "layernorm_embedding", False) |
|
|
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) |
|
|
args.checkpoint_activations = getattr(args, "checkpoint_activations", False) |
|
|
args.offload_activations = getattr(args, "offload_activations", False) |
|
|
if args.offload_activations: |
|
|
args.checkpoint_activations = True |
|
|
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) |
|
|
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) |
|
|
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) |
|
|
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) |
|
|
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) |
|
|
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) |
|
|
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) |
|
|
|
|
|
|
|
|
@register_model_architecture("transformer", "transformer_iwslt_de_en") |
|
|
def transformer_iwslt_de_en(args): |
|
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) |
|
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) |
|
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) |
|
|
args.encoder_layers = getattr(args, "encoder_layers", 6) |
|
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) |
|
|
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) |
|
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) |
|
|
args.decoder_layers = getattr(args, "decoder_layers", 6) |
|
|
base_architecture(args) |
|
|
|
|
|
|
|
|
@register_model_architecture("transformer", "transformer_wmt_en_de") |
|
|
def transformer_wmt_en_de(args): |
|
|
base_architecture(args) |
|
|
|
|
|
|
|
|
|
|
|
@register_model_architecture("transformer", "transformer_vaswani_wmt_en_de_big") |
|
|
def transformer_vaswani_wmt_en_de_big(args): |
|
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) |
|
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) |
|
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) |
|
|
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) |
|
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) |
|
|
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) |
|
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) |
|
|
args.dropout = getattr(args, "dropout", 0.3) |
|
|
base_architecture(args) |
|
|
|
|
|
|
|
|
@register_model_architecture("transformer", "transformer_vaswani_wmt_en_fr_big") |
|
|
def transformer_vaswani_wmt_en_fr_big(args): |
|
|
args.dropout = getattr(args, "dropout", 0.1) |
|
|
transformer_vaswani_wmt_en_de_big(args) |
|
|
|
|
|
|
|
|
@register_model_architecture("transformer", "transformer_wmt_en_de_big") |
|
|
def transformer_wmt_en_de_big(args): |
|
|
args.attention_dropout = getattr(args, "attention_dropout", 0.1) |
|
|
transformer_vaswani_wmt_en_de_big(args) |
|
|
|
|
|
|
|
|
|
|
|
@register_model_architecture("transformer", "transformer_wmt_en_de_big_t2t") |
|
|
def transformer_wmt_en_de_big_t2t(args): |
|
|
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) |
|
|
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) |
|
|
args.attention_dropout = getattr(args, "attention_dropout", 0.1) |
|
|
args.activation_dropout = getattr(args, "activation_dropout", 0.1) |
|
|
transformer_vaswani_wmt_en_de_big(args) |
|
|
|