|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import importlib |
|
|
import os |
|
|
|
|
|
from .fairseq_decoder import FairseqDecoder |
|
|
from .fairseq_encoder import FairseqEncoder |
|
|
from .fairseq_incremental_decoder import FairseqIncrementalDecoder |
|
|
from .fairseq_model import ( |
|
|
BaseFairseqModel, |
|
|
FairseqEncoderModel, |
|
|
FairseqEncoderDecoderModel, |
|
|
FairseqLanguageModel, |
|
|
FairseqModel, |
|
|
FairseqMultiModel, |
|
|
) |
|
|
|
|
|
from .composite_encoder import CompositeEncoder |
|
|
from .distributed_fairseq_model import DistributedFairseqModel |
|
|
|
|
|
|
|
|
MODEL_REGISTRY = {} |
|
|
ARCH_MODEL_REGISTRY = {} |
|
|
ARCH_MODEL_INV_REGISTRY = {} |
|
|
ARCH_CONFIG_REGISTRY = {} |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
'BaseFairseqModel', |
|
|
'CompositeEncoder', |
|
|
'DistributedFairseqModel', |
|
|
'FairseqDecoder', |
|
|
'FairseqEncoder', |
|
|
'FairseqEncoderDecoderModel', |
|
|
'FairseqEncoderModel', |
|
|
'FairseqIncrementalDecoder', |
|
|
'FairseqLanguageModel', |
|
|
'FairseqModel', |
|
|
'FairseqMultiModel', |
|
|
] |
|
|
|
|
|
|
|
|
def build_model(args, task): |
|
|
return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task) |
|
|
|
|
|
|
|
|
def register_model(name): |
|
|
""" |
|
|
New model types can be added to fairseq with the :func:`register_model` |
|
|
function decorator. |
|
|
|
|
|
For example:: |
|
|
|
|
|
@register_model('lstm') |
|
|
class LSTM(FairseqEncoderDecoderModel): |
|
|
(...) |
|
|
|
|
|
.. note:: All models must implement the :class:`BaseFairseqModel` interface. |
|
|
Typically you will extend :class:`FairseqEncoderDecoderModel` for |
|
|
sequence-to-sequence tasks or :class:`FairseqLanguageModel` for |
|
|
language modeling tasks. |
|
|
|
|
|
Args: |
|
|
name (str): the name of the model |
|
|
""" |
|
|
|
|
|
def register_model_cls(cls): |
|
|
if name in MODEL_REGISTRY: |
|
|
raise ValueError('Cannot register duplicate model ({})'.format(name)) |
|
|
if not issubclass(cls, BaseFairseqModel): |
|
|
raise ValueError('Model ({}: {}) must extend BaseFairseqModel'.format(name, cls.__name__)) |
|
|
MODEL_REGISTRY[name] = cls |
|
|
return cls |
|
|
|
|
|
return register_model_cls |
|
|
|
|
|
|
|
|
def register_model_architecture(model_name, arch_name): |
|
|
""" |
|
|
New model architectures can be added to fairseq with the |
|
|
:func:`register_model_architecture` function decorator. After registration, |
|
|
model architectures can be selected with the ``--arch`` command-line |
|
|
argument. |
|
|
|
|
|
For example:: |
|
|
|
|
|
@register_model_architecture('lstm', 'lstm_luong_wmt_en_de') |
|
|
def lstm_luong_wmt_en_de(args): |
|
|
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000) |
|
|
(...) |
|
|
|
|
|
The decorated function should take a single argument *args*, which is a |
|
|
:class:`argparse.Namespace` of arguments parsed from the command-line. The |
|
|
decorated function should modify these arguments in-place to match the |
|
|
desired architecture. |
|
|
|
|
|
Args: |
|
|
model_name (str): the name of the Model (Model must already be |
|
|
registered) |
|
|
arch_name (str): the name of the model architecture (``--arch``) |
|
|
""" |
|
|
|
|
|
def register_model_arch_fn(fn): |
|
|
if model_name not in MODEL_REGISTRY: |
|
|
raise ValueError('Cannot register model architecture for unknown model type ({})'.format(model_name)) |
|
|
if arch_name in ARCH_MODEL_REGISTRY: |
|
|
raise ValueError('Cannot register duplicate model architecture ({})'.format(arch_name)) |
|
|
if not callable(fn): |
|
|
raise ValueError('Model architecture must be callable ({})'.format(arch_name)) |
|
|
ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name] |
|
|
ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name) |
|
|
ARCH_CONFIG_REGISTRY[arch_name] = fn |
|
|
return fn |
|
|
|
|
|
return register_model_arch_fn |
|
|
|
|
|
|
|
|
|
|
|
models_dir = os.path.dirname(__file__) |
|
|
for file in os.listdir(models_dir): |
|
|
path = os.path.join(models_dir, file) |
|
|
if ( |
|
|
not file.startswith('_') |
|
|
and not file.startswith('.') |
|
|
and (file.endswith('.py') or os.path.isdir(path)) |
|
|
): |
|
|
model_name = file[:file.find('.py')] if file.endswith('.py') else file |
|
|
module = importlib.import_module('fairseq.models.' + model_name) |
|
|
|
|
|
|
|
|
if model_name in MODEL_REGISTRY: |
|
|
parser = argparse.ArgumentParser(add_help=False) |
|
|
group_archs = parser.add_argument_group('Named architectures') |
|
|
group_archs.add_argument('--arch', choices=ARCH_MODEL_INV_REGISTRY[model_name]) |
|
|
group_args = parser.add_argument_group('Additional command-line arguments') |
|
|
MODEL_REGISTRY[model_name].add_args(group_args) |
|
|
globals()[model_name + '_parser'] = parser |
|
|
|