| | |
| | |
| | |
| | |
| | """isort:skip_file""" |
| |
|
| | import argparse |
| | import importlib |
| | import os |
| |
|
| | from fairseq.dataclass import FairseqDataclass |
| | from fairseq.dataclass.utils import merge_with_parent, populate_dataclass |
| | from hydra.core.config_store import ConfigStore |
| |
|
| | from .fairseq_task import FairseqTask, LegacyFairseqTask |
| |
|
| |
|
| | |
| | TASK_DATACLASS_REGISTRY = {} |
| | TASK_REGISTRY = {} |
| | TASK_CLASS_NAMES = set() |
| |
|
| |
|
| | def setup_task(cfg: FairseqDataclass, **kwargs): |
| | task = None |
| | task_name = getattr(cfg, "task", None) |
| |
|
| | if isinstance(task_name, str): |
| | |
| | task = TASK_REGISTRY[task_name] |
| | if task_name in TASK_DATACLASS_REGISTRY: |
| | dc = TASK_DATACLASS_REGISTRY[task_name] |
| | cfg = populate_dataclass(dc(), cfg) |
| | else: |
| | task_name = getattr(cfg, "_name", None) |
| |
|
| | if task_name and task_name in TASK_DATACLASS_REGISTRY: |
| | dc = TASK_DATACLASS_REGISTRY[task_name] |
| | cfg = merge_with_parent(dc(), cfg) |
| | task = TASK_REGISTRY[task_name] |
| |
|
| | assert ( |
| | task is not None |
| | ), f"Could not infer task type from {cfg}. Available tasks: {TASK_REGISTRY.keys()}" |
| |
|
| | return task.setup_task(cfg, **kwargs) |
| |
|
| |
|
| | def register_task(name, dataclass=None): |
| | """ |
| | New tasks can be added to fairseq with the |
| | :func:`~fairseq.tasks.register_task` function decorator. |
| | |
| | For example:: |
| | |
| | @register_task('classification') |
| | class ClassificationTask(FairseqTask): |
| | (...) |
| | |
| | .. note:: |
| | |
| | All Tasks must implement the :class:`~fairseq.tasks.FairseqTask` |
| | interface. |
| | |
| | Args: |
| | name (str): the name of the task |
| | """ |
| |
|
| | def register_task_cls(cls): |
| | if name in TASK_REGISTRY: |
| | raise ValueError("Cannot register duplicate task ({})".format(name)) |
| | if not issubclass(cls, FairseqTask): |
| | raise ValueError( |
| | "Task ({}: {}) must extend FairseqTask".format(name, cls.__name__) |
| | ) |
| | if cls.__name__ in TASK_CLASS_NAMES: |
| | raise ValueError( |
| | "Cannot register task with duplicate class name ({})".format( |
| | cls.__name__ |
| | ) |
| | ) |
| | TASK_REGISTRY[name] = cls |
| | TASK_CLASS_NAMES.add(cls.__name__) |
| |
|
| | if dataclass is not None and not issubclass(dataclass, FairseqDataclass): |
| | raise ValueError( |
| | "Dataclass {} must extend FairseqDataclass".format(dataclass) |
| | ) |
| |
|
| | cls.__dataclass = dataclass |
| | if dataclass is not None: |
| | TASK_DATACLASS_REGISTRY[name] = dataclass |
| |
|
| | cs = ConfigStore.instance() |
| | node = dataclass() |
| | node._name = name |
| | cs.store(name=name, group="task", node=node, provider="fairseq") |
| |
|
| | return cls |
| |
|
| | return register_task_cls |
| |
|
| |
|
| | def get_task(name): |
| | return TASK_REGISTRY[name] |
| |
|
| |
|
| | def import_tasks(tasks_dir, namespace): |
| | for file in os.listdir(tasks_dir): |
| | path = os.path.join(tasks_dir, file) |
| | if ( |
| | not file.startswith("_") |
| | and not file.startswith(".") |
| | and (file.endswith(".py") or os.path.isdir(path)) |
| | ): |
| | task_name = file[: file.find(".py")] if file.endswith(".py") else file |
| | importlib.import_module(namespace + "." + task_name) |
| |
|
| | |
| | if task_name in TASK_REGISTRY: |
| | parser = argparse.ArgumentParser(add_help=False) |
| | group_task = parser.add_argument_group("Task name") |
| | |
| | group_task.add_argument('--task', metavar=task_name, |
| | help='Enable this task with: ``--task=' + task_name + '``') |
| | |
| | group_args = parser.add_argument_group( |
| | "Additional command-line arguments" |
| | ) |
| | TASK_REGISTRY[task_name].add_args(group_args) |
| | globals()[task_name + "_parser"] = parser |
| |
|
| |
|
| | |
| | tasks_dir = os.path.dirname(__file__) |
| | import_tasks(tasks_dir, "fairseq.tasks") |
| |
|