| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | from pathlib import Path |
| | from typing import Callable, List, Optional, Union |
| |
|
| | import torch |
| | from fairseq import utils |
| | from fairseq.data.indexed_dataset import get_available_dataset_impl |
| | from fairseq.dataclass.configs import ( |
| | CheckpointConfig, |
| | CommonConfig, |
| | CommonEvalConfig, |
| | DatasetConfig, |
| | DistributedTrainingConfig, |
| | EvalLMConfig, |
| | GenerationConfig, |
| | InteractiveConfig, |
| | OptimizationConfig, |
| | ) |
| | from fairseq.dataclass.utils import gen_parser_from_dataclass |
| |
|
| | |
| | from fairseq.utils import csv_str_list, eval_bool, eval_str_dict, eval_str_list |
| |
|
| |
|
| | def get_preprocessing_parser(default_task="translation"): |
| | parser = get_parser("Preprocessing", default_task) |
| | add_preprocess_args(parser) |
| | return parser |
| |
|
| |
|
| | def get_training_parser(default_task="translation"): |
| | parser = get_parser("Trainer", default_task) |
| | add_dataset_args(parser, train=True) |
| | add_distributed_training_args(parser) |
| | add_model_args(parser) |
| | add_optimization_args(parser) |
| | add_checkpoint_args(parser) |
| | return parser |
| |
|
| |
|
| | def get_generation_parser(interactive=False, default_task="translation"): |
| | parser = get_parser("Generation", default_task) |
| | add_dataset_args(parser, gen=True) |
| | add_distributed_training_args(parser, default_world_size=1) |
| | add_generation_args(parser) |
| | add_checkpoint_args(parser) |
| | if interactive: |
| | add_interactive_args(parser) |
| | return parser |
| |
|
| |
|
| | def get_interactive_generation_parser(default_task="translation"): |
| | return get_generation_parser(interactive=True, default_task=default_task) |
| |
|
| |
|
| | def get_eval_lm_parser(default_task="language_modeling"): |
| | parser = get_parser("Evaluate Language Model", default_task) |
| | add_dataset_args(parser, gen=True) |
| | add_distributed_training_args(parser, default_world_size=1) |
| | add_eval_lm_args(parser) |
| | return parser |
| |
|
| |
|
| | def get_validation_parser(default_task=None): |
| | parser = get_parser("Validation", default_task) |
| | add_dataset_args(parser, train=True) |
| | add_distributed_training_args(parser, default_world_size=1) |
| | group = parser.add_argument_group("Evaluation") |
| | gen_parser_from_dataclass(group, CommonEvalConfig()) |
| | return parser |
| |
|
| |
|
| | def parse_args_and_arch( |
| | parser: argparse.ArgumentParser, |
| | input_args: List[str] = None, |
| | parse_known: bool = False, |
| | suppress_defaults: bool = False, |
| | modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None, |
| | ): |
| | """ |
| | Args: |
| | parser (ArgumentParser): the parser |
| | input_args (List[str]): strings to parse, defaults to sys.argv |
| | parse_known (bool): only parse known arguments, similar to |
| | `ArgumentParser.parse_known_args` |
| | suppress_defaults (bool): parse while ignoring all default values |
| | modify_parser (Optional[Callable[[ArgumentParser], None]]): |
| | function to modify the parser, e.g., to set default values |
| | """ |
| | if suppress_defaults: |
| | |
| | |
| | |
| | args = parse_args_and_arch( |
| | parser, |
| | input_args=input_args, |
| | parse_known=parse_known, |
| | suppress_defaults=False, |
| | ) |
| | suppressed_parser = argparse.ArgumentParser(add_help=False, parents=[parser]) |
| | suppressed_parser.set_defaults(**{k: None for k, v in vars(args).items()}) |
| | args = suppressed_parser.parse_args(input_args) |
| | return argparse.Namespace( |
| | **{k: v for k, v in vars(args).items() if v is not None} |
| | ) |
| |
|
| | from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY, MODEL_REGISTRY |
| |
|
| | |
| | |
| | usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) |
| | usr_parser.add_argument("--user-dir", default=None) |
| | usr_args, _ = usr_parser.parse_known_args(input_args) |
| | utils.import_user_module(usr_args) |
| |
|
| | if modify_parser is not None: |
| | modify_parser(parser) |
| |
|
| | |
| | |
| | |
| | |
| | args, _ = parser.parse_known_args(input_args) |
| |
|
| | |
| | if hasattr(args, "arch"): |
| | model_specific_group = parser.add_argument_group( |
| | "Model-specific configuration", |
| | |
| | |
| | argument_default=argparse.SUPPRESS, |
| | ) |
| | if args.arch in ARCH_MODEL_REGISTRY: |
| | ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group) |
| | elif args.arch in MODEL_REGISTRY: |
| | MODEL_REGISTRY[args.arch].add_args(model_specific_group) |
| | else: |
| | raise RuntimeError() |
| |
|
| | if hasattr(args, "task"): |
| | from fairseq.tasks import TASK_REGISTRY |
| |
|
| | TASK_REGISTRY[args.task].add_args(parser) |
| | if getattr(args, "use_bmuf", False): |
| | |
| | from fairseq.optim.bmuf import FairseqBMUF |
| |
|
| | FairseqBMUF.add_args(parser) |
| |
|
| | |
| | from fairseq.registry import REGISTRIES |
| |
|
| | for registry_name, REGISTRY in REGISTRIES.items(): |
| | choice = getattr(args, registry_name, None) |
| | if choice is not None: |
| | cls = REGISTRY["registry"][choice] |
| | if hasattr(cls, "add_args"): |
| | cls.add_args(parser) |
| | elif hasattr(cls, "__dataclass"): |
| | gen_parser_from_dataclass(parser, cls.__dataclass()) |
| |
|
| | |
| | if modify_parser is not None: |
| | modify_parser(parser) |
| |
|
| | |
| | if parse_known: |
| | args, extra = parser.parse_known_args(input_args) |
| | else: |
| | args = parser.parse_args(input_args) |
| | extra = None |
| | |
| | if ( |
| | hasattr(args, "batch_size_valid") and args.batch_size_valid is None |
| | ) or not hasattr(args, "batch_size_valid"): |
| | args.batch_size_valid = args.batch_size |
| | if hasattr(args, "max_tokens_valid") and args.max_tokens_valid is None: |
| | args.max_tokens_valid = args.max_tokens |
| | if getattr(args, "memory_efficient_fp16", False): |
| | args.fp16 = True |
| | if getattr(args, "memory_efficient_bf16", False): |
| | args.bf16 = True |
| | args.tpu = getattr(args, "tpu", False) |
| | args.bf16 = getattr(args, "bf16", False) |
| | if args.bf16: |
| | args.tpu = True |
| | if args.tpu and args.fp16: |
| | raise ValueError("Cannot combine --fp16 and --tpu, use --bf16 on TPUs") |
| |
|
| | if getattr(args, "seed", None) is None: |
| | args.seed = 1 |
| | args.no_seed_provided = True |
| | else: |
| | args.no_seed_provided = False |
| |
|
| | |
| | if hasattr(args, "arch") and args.arch in ARCH_CONFIG_REGISTRY: |
| | ARCH_CONFIG_REGISTRY[args.arch](args) |
| |
|
| | if parse_known: |
| | return args, extra |
| | else: |
| | return args |
| |
|
| |
|
| | def get_parser(desc, default_task="translation"): |
| | |
| | |
| | usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) |
| | usr_parser.add_argument("--user-dir", default=None) |
| | usr_args, _ = usr_parser.parse_known_args() |
| | utils.import_user_module(usr_args) |
| |
|
| | parser = argparse.ArgumentParser(allow_abbrev=False) |
| | gen_parser_from_dataclass(parser, CommonConfig()) |
| |
|
| | from fairseq.registry import REGISTRIES |
| |
|
| | for registry_name, REGISTRY in REGISTRIES.items(): |
| | parser.add_argument( |
| | "--" + registry_name.replace("_", "-"), |
| | default=REGISTRY["default"], |
| | choices=REGISTRY["registry"].keys(), |
| | ) |
| |
|
| | |
| | from fairseq.tasks import TASK_REGISTRY |
| |
|
| | parser.add_argument( |
| | "--task", |
| | metavar="TASK", |
| | default=default_task, |
| | choices=TASK_REGISTRY.keys(), |
| | help="task", |
| | ) |
| | |
| | return parser |
| |
|
| |
|
| | def add_preprocess_args(parser): |
| | group = parser.add_argument_group("Preprocessing") |
| | |
| | group.add_argument("-s", "--source-lang", default=None, metavar="SRC", |
| | help="source language") |
| | group.add_argument("-t", "--target-lang", default=None, metavar="TARGET", |
| | help="target language") |
| | group.add_argument("--trainpref", metavar="FP", default=None, |
| | help="train file prefix (also used to build dictionaries)") |
| | group.add_argument("--validpref", metavar="FP", default=None, |
| | help="comma separated, valid file prefixes " |
| | "(words missing from train set are replaced with <unk>)") |
| | group.add_argument("--testpref", metavar="FP", default=None, |
| | help="comma separated, test file prefixes " |
| | "(words missing from train set are replaced with <unk>)") |
| | group.add_argument("--align-suffix", metavar="FP", default=None, |
| | help="alignment file suffix") |
| | group.add_argument("--destdir", metavar="DIR", default="data-bin", |
| | help="destination dir") |
| | group.add_argument("--thresholdtgt", metavar="N", default=0, type=int, |
| | help="map words appearing less than threshold times to unknown") |
| | group.add_argument("--thresholdsrc", metavar="N", default=0, type=int, |
| | help="map words appearing less than threshold times to unknown") |
| | group.add_argument("--tgtdict", metavar="FP", |
| | help="reuse given target dictionary") |
| | group.add_argument("--srcdict", metavar="FP", |
| | help="reuse given source dictionary") |
| | group.add_argument("--nwordstgt", metavar="N", default=-1, type=int, |
| | help="number of target words to retain") |
| | group.add_argument("--nwordssrc", metavar="N", default=-1, type=int, |
| | help="number of source words to retain") |
| | group.add_argument("--alignfile", metavar="ALIGN", default=None, |
| | help="an alignment file (optional)") |
| | parser.add_argument('--dataset-impl', metavar='FORMAT', default='mmap', |
| | choices=get_available_dataset_impl(), |
| | help='output dataset implementation') |
| | group.add_argument("--joined-dictionary", action="store_true", |
| | help="Generate joined dictionary") |
| | group.add_argument("--only-source", action="store_true", |
| | help="Only process the source language") |
| | group.add_argument("--padding-factor", metavar="N", default=8, type=int, |
| | help="Pad dictionary size to be multiple of N") |
| | group.add_argument("--workers", metavar="N", default=1, type=int, |
| | help="number of parallel workers") |
| | group.add_argument("--dict-only", action='store_true', |
| | help="if true, only builds a dictionary and then exits") |
| | |
| | return parser |
| |
|
| |
|
| | def add_dataset_args(parser, train=False, gen=False): |
| | group = parser.add_argument_group("dataset_data_loading") |
| | gen_parser_from_dataclass(group, DatasetConfig()) |
| | |
| | return group |
| |
|
| |
|
| | def add_distributed_training_args(parser, default_world_size=None): |
| | group = parser.add_argument_group("distributed_training") |
| | if default_world_size is None: |
| | default_world_size = max(1, torch.cuda.device_count()) |
| | gen_parser_from_dataclass( |
| | group, DistributedTrainingConfig(distributed_world_size=default_world_size) |
| | ) |
| | return group |
| |
|
| |
|
| | def add_optimization_args(parser): |
| | group = parser.add_argument_group("optimization") |
| | |
| | gen_parser_from_dataclass(group, OptimizationConfig()) |
| | |
| | return group |
| |
|
| |
|
| | def add_checkpoint_args(parser): |
| | group = parser.add_argument_group("checkpoint") |
| | |
| | gen_parser_from_dataclass(group, CheckpointConfig()) |
| | |
| | return group |
| |
|
| |
|
| | def add_common_eval_args(group): |
| | gen_parser_from_dataclass(group, CommonEvalConfig()) |
| |
|
| |
|
| | def add_eval_lm_args(parser): |
| | group = parser.add_argument_group("LM Evaluation") |
| | add_common_eval_args(group) |
| | gen_parser_from_dataclass(group, EvalLMConfig()) |
| |
|
| |
|
| | def add_generation_args(parser): |
| | group = parser.add_argument_group("Generation") |
| | add_common_eval_args(group) |
| | gen_parser_from_dataclass(group, GenerationConfig()) |
| | return group |
| |
|
| |
|
| | def add_interactive_args(parser): |
| | group = parser.add_argument_group("Interactive") |
| | gen_parser_from_dataclass(group, InteractiveConfig()) |
| |
|
| |
|
| | def add_model_args(parser): |
| | group = parser.add_argument_group("Model configuration") |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from fairseq.models import ARCH_MODEL_REGISTRY |
| | group.add_argument('--arch', '-a', metavar='ARCH', |
| | choices=ARCH_MODEL_REGISTRY.keys(), |
| | help='model architecture') |
| | |
| | return group |
| |
|
| |
|
| | def get_args( |
| | data: Union[str, Path], |
| | task: str = "translation", |
| | arch: str = "transformer", |
| | **overrides |
| | ): |
| | parser = get_training_parser(task) |
| | args = parse_args_and_arch(parser, [str(data), "--task", task, "--arch", arch]) |
| |
|
| | for k, v in overrides.items(): |
| | setattr(args, k, v) |
| |
|
| | return args |
| |
|