| |
| |
| |
| |
| """isort:skip_file""" |
|
|
| import argparse |
| import importlib |
| import os |
|
|
| from contextlib import ExitStack |
|
|
| from fairseq.dataclass import FairseqDataclass |
| from fairseq.dataclass.utils import merge_with_parent |
| from hydra.core.config_store import ConfigStore |
| from omegaconf import open_dict, OmegaConf |
|
|
| from .composite_encoder import CompositeEncoder |
| from .distributed_fairseq_model import DistributedFairseqModel |
| from .fairseq_decoder import FairseqDecoder |
| from .fairseq_encoder import FairseqEncoder |
| from .fairseq_incremental_decoder import FairseqIncrementalDecoder |
| from .fairseq_model import ( |
| BaseFairseqModel, |
| FairseqEncoderDecoderModel, |
| FairseqEncoderModel, |
| FairseqLanguageModel, |
| FairseqModel, |
| FairseqMultiModel, |
| ) |
|
|
|
|
| MODEL_REGISTRY = {} |
| MODEL_DATACLASS_REGISTRY = {} |
| ARCH_MODEL_REGISTRY = {} |
| ARCH_MODEL_NAME_REGISTRY = {} |
| ARCH_MODEL_INV_REGISTRY = {} |
| ARCH_CONFIG_REGISTRY = {} |
|
|
|
|
| __all__ = [ |
| "BaseFairseqModel", |
| "CompositeEncoder", |
| "DistributedFairseqModel", |
| "FairseqDecoder", |
| "FairseqEncoder", |
| "FairseqEncoderDecoderModel", |
| "FairseqEncoderModel", |
| "FairseqIncrementalDecoder", |
| "FairseqLanguageModel", |
| "FairseqModel", |
| "FairseqMultiModel", |
| ] |
|
|
|
|
| def build_model(cfg: FairseqDataclass, task, from_checkpoint=False): |
|
|
| model = None |
| model_type = getattr(cfg, "_name", None) or getattr(cfg, "arch", None) |
|
|
| if not model_type and len(cfg) == 1: |
| |
|
|
| model_type = next(iter(cfg)) |
| if model_type in MODEL_DATACLASS_REGISTRY: |
| cfg = cfg[model_type] |
| else: |
| raise Exception( |
| "Could not infer model type from directory. Please add _name field to indicate model type. " |
| "Available models: " |
| + str(MODEL_DATACLASS_REGISTRY.keys()) |
| + " Requested model type: " |
| + model_type |
| ) |
|
|
| if model_type in ARCH_MODEL_REGISTRY: |
| |
| model = ARCH_MODEL_REGISTRY[model_type] |
| elif model_type in MODEL_DATACLASS_REGISTRY: |
| |
| model = MODEL_REGISTRY[model_type] |
|
|
| if model_type in MODEL_DATACLASS_REGISTRY: |
| |
| dc = MODEL_DATACLASS_REGISTRY[model_type] |
|
|
| if isinstance(cfg, argparse.Namespace): |
| cfg = dc.from_namespace(cfg) |
| else: |
| cfg = merge_with_parent(dc(), cfg, from_checkpoint) |
| else: |
| if model_type in ARCH_CONFIG_REGISTRY: |
| with open_dict(cfg) if OmegaConf.is_config(cfg) else ExitStack(): |
| |
| |
| |
| |
| ARCH_CONFIG_REGISTRY[model_type](cfg) |
|
|
| assert model is not None, ( |
| f"Could not infer model type from {cfg}. " |
| "Available models: {}".format(MODEL_DATACLASS_REGISTRY.keys()) |
| + f" Requested model type: {model_type}" |
| ) |
|
|
| return model.build_model(cfg, task) |
|
|
|
|
| def register_model(name, dataclass=None): |
| """ |
| New model types can be added to fairseq with the :func:`register_model` |
| function decorator. |
| |
| For example:: |
| |
| @register_model('lstm') |
| class LSTM(FairseqEncoderDecoderModel): |
| (...) |
| |
| .. note:: All models must implement the :class:`BaseFairseqModel` interface. |
| Typically you will extend :class:`FairseqEncoderDecoderModel` for |
| sequence-to-sequence tasks or :class:`FairseqLanguageModel` for |
| language modeling tasks. |
| |
| Args: |
| name (str): the name of the model |
| """ |
|
|
| def register_model_cls(cls): |
| if name in MODEL_REGISTRY: |
| raise ValueError("Cannot register duplicate model ({})".format(name)) |
| if not issubclass(cls, BaseFairseqModel): |
| raise ValueError( |
| "Model ({}: {}) must extend BaseFairseqModel".format(name, cls.__name__) |
| ) |
| MODEL_REGISTRY[name] = cls |
| if dataclass is not None and not issubclass(dataclass, FairseqDataclass): |
| raise ValueError( |
| "Dataclass {} must extend FairseqDataclass".format(dataclass) |
| ) |
|
|
| cls.__dataclass = dataclass |
| if dataclass is not None: |
| MODEL_DATACLASS_REGISTRY[name] = dataclass |
|
|
| cs = ConfigStore.instance() |
| node = dataclass() |
| node._name = name |
| cs.store(name=name, group="model", node=node, provider="fairseq") |
|
|
| @register_model_architecture(name, name) |
| def noop(_): |
| pass |
|
|
| return cls |
|
|
| return register_model_cls |
|
|
|
|
| def register_model_architecture(model_name, arch_name): |
| """ |
| New model architectures can be added to fairseq with the |
| :func:`register_model_architecture` function decorator. After registration, |
| model architectures can be selected with the ``--arch`` command-line |
| argument. |
| |
| For example:: |
| |
| @register_model_architecture('lstm', 'lstm_luong_wmt_en_de') |
| def lstm_luong_wmt_en_de(cfg): |
| args.encoder_embed_dim = getattr(cfg.model, 'encoder_embed_dim', 1000) |
| (...) |
| |
| The decorated function should take a single argument *cfg*, which is a |
| :class:`omegaconf.DictConfig`. The decorated function should modify these |
| arguments in-place to match the desired architecture. |
| |
| Args: |
| model_name (str): the name of the Model (Model must already be |
| registered) |
| arch_name (str): the name of the model architecture (``--arch``) |
| """ |
|
|
| def register_model_arch_fn(fn): |
| if model_name not in MODEL_REGISTRY: |
| raise ValueError( |
| "Cannot register model architecture for unknown model type ({})".format( |
| model_name |
| ) |
| ) |
| if arch_name in ARCH_MODEL_REGISTRY: |
| raise ValueError( |
| "Cannot register duplicate model architecture ({})".format(arch_name) |
| ) |
| if not callable(fn): |
| raise ValueError( |
| "Model architecture must be callable ({})".format(arch_name) |
| ) |
| ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name] |
| ARCH_MODEL_NAME_REGISTRY[arch_name] = model_name |
| ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name) |
| ARCH_CONFIG_REGISTRY[arch_name] = fn |
| return fn |
|
|
| return register_model_arch_fn |
|
|
|
|
| def import_models(models_dir, namespace): |
| for file in os.listdir(models_dir): |
| path = os.path.join(models_dir, file) |
| if ( |
| not file.startswith("_") |
| and not file.startswith(".") |
| and (file.endswith(".py") or os.path.isdir(path)) |
| ): |
| model_name = file[: file.find(".py")] if file.endswith(".py") else file |
| importlib.import_module(namespace + "." + model_name) |
|
|
| |
| if model_name in MODEL_REGISTRY: |
| parser = argparse.ArgumentParser(add_help=False) |
| group_archs = parser.add_argument_group("Named architectures") |
| group_archs.add_argument( |
| "--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name] |
| ) |
| group_args = parser.add_argument_group( |
| "Additional command-line arguments" |
| ) |
| MODEL_REGISTRY[model_name].add_args(group_args) |
| globals()[model_name + "_parser"] = parser |
|
|
|
|
| |
| models_dir = os.path.dirname(__file__) |
| import_models(models_dir, "fairseq.models") |
|
|