File size: 4,798 Bytes
be611b4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | # Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import importlib
import os
from .fairseq_decoder import FairseqDecoder
from .fairseq_encoder import FairseqEncoder
from .fairseq_incremental_decoder import FairseqIncrementalDecoder
from .fairseq_model import (
BaseFairseqModel,
FairseqEncoderModel,
FairseqEncoderDecoderModel,
FairseqLanguageModel,
FairseqModel,
FairseqMultiModel,
)
from .composite_encoder import CompositeEncoder
from .distributed_fairseq_model import DistributedFairseqModel
MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_MODEL_INV_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
__all__ = [
'BaseFairseqModel',
'CompositeEncoder',
'DistributedFairseqModel',
'FairseqDecoder',
'FairseqEncoder',
'FairseqEncoderDecoderModel',
'FairseqEncoderModel',
'FairseqIncrementalDecoder',
'FairseqLanguageModel',
'FairseqModel',
'FairseqMultiModel',
]
def build_model(args, task):
return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task)
def register_model(name):
"""
New model types can be added to fairseq with the :func:`register_model`
function decorator.
For example::
@register_model('lstm')
class LSTM(FairseqEncoderDecoderModel):
(...)
.. note:: All models must implement the :class:`BaseFairseqModel` interface.
Typically you will extend :class:`FairseqEncoderDecoderModel` for
sequence-to-sequence tasks or :class:`FairseqLanguageModel` for
language modeling tasks.
Args:
name (str): the name of the model
"""
def register_model_cls(cls):
if name in MODEL_REGISTRY:
raise ValueError('Cannot register duplicate model ({})'.format(name))
if not issubclass(cls, BaseFairseqModel):
raise ValueError('Model ({}: {}) must extend BaseFairseqModel'.format(name, cls.__name__))
MODEL_REGISTRY[name] = cls
return cls
return register_model_cls
def register_model_architecture(model_name, arch_name):
"""
New model architectures can be added to fairseq with the
:func:`register_model_architecture` function decorator. After registration,
model architectures can be selected with the ``--arch`` command-line
argument.
For example::
@register_model_architecture('lstm', 'lstm_luong_wmt_en_de')
def lstm_luong_wmt_en_de(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000)
(...)
The decorated function should take a single argument *args*, which is a
:class:`argparse.Namespace` of arguments parsed from the command-line. The
decorated function should modify these arguments in-place to match the
desired architecture.
Args:
model_name (str): the name of the Model (Model must already be
registered)
arch_name (str): the name of the model architecture (``--arch``)
"""
def register_model_arch_fn(fn):
if model_name not in MODEL_REGISTRY:
raise ValueError('Cannot register model architecture for unknown model type ({})'.format(model_name))
if arch_name in ARCH_MODEL_REGISTRY:
raise ValueError('Cannot register duplicate model architecture ({})'.format(arch_name))
if not callable(fn):
raise ValueError('Model architecture must be callable ({})'.format(arch_name))
ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name]
ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name)
ARCH_CONFIG_REGISTRY[arch_name] = fn
return fn
return register_model_arch_fn
# automatically import any Python files in the models/ directory
models_dir = os.path.dirname(__file__)
for file in os.listdir(models_dir):
path = os.path.join(models_dir, file)
if (
not file.startswith('_')
and not file.startswith('.')
and (file.endswith('.py') or os.path.isdir(path))
):
model_name = file[:file.find('.py')] if file.endswith('.py') else file
module = importlib.import_module('fairseq.models.' + model_name)
# extra `model_parser` for sphinx
if model_name in MODEL_REGISTRY:
parser = argparse.ArgumentParser(add_help=False)
group_archs = parser.add_argument_group('Named architectures')
group_archs.add_argument('--arch', choices=ARCH_MODEL_INV_REGISTRY[model_name])
group_args = parser.add_argument_group('Additional command-line arguments')
MODEL_REGISTRY[model_name].add_args(group_args)
globals()[model_name + '_parser'] = parser
|