| | |
| | |
| | |
| | |
| |
|
| | import sys |
| | from dataclasses import _MISSING_TYPE, dataclass, field |
| | from typing import Any, List, Optional |
| |
|
| | import torch |
| |
|
| | from fairseq.dataclass.constants import ( |
| | DATASET_IMPL_CHOICES, |
| | DDP_BACKEND_CHOICES, |
| | DDP_COMM_HOOK_CHOICES, |
| | GENERATION_CONSTRAINTS_CHOICES, |
| | GENERATION_DECODING_FORMAT_CHOICES, |
| | LOG_FORMAT_CHOICES, |
| | PIPELINE_CHECKPOINT_CHOICES, |
| | PRINT_ALIGNMENT_CHOICES, |
| | ZERO_SHARDING_CHOICES, |
| | ) |
| |
|
| | from omegaconf import II, MISSING |
| |
|
| |
|
| | @dataclass |
| | class FairseqDataclass: |
| | """fairseq base dataclass that supported fetching attributes and metas""" |
| |
|
| | _name: Optional[str] = None |
| |
|
| | @staticmethod |
| | def name(): |
| | return None |
| |
|
| | def _get_all_attributes(self) -> List[str]: |
| | return [k for k in self.__dataclass_fields__.keys()] |
| |
|
| | def _get_meta( |
| | self, attribute_name: str, meta: str, default: Optional[Any] = None |
| | ) -> Any: |
| | return self.__dataclass_fields__[attribute_name].metadata.get(meta, default) |
| |
|
| | def _get_name(self, attribute_name: str) -> str: |
| | return self.__dataclass_fields__[attribute_name].name |
| |
|
| | def _get_default(self, attribute_name: str) -> Any: |
| | if hasattr(self, attribute_name): |
| | if str(getattr(self, attribute_name)).startswith("${"): |
| | return str(getattr(self, attribute_name)) |
| | elif str(self.__dataclass_fields__[attribute_name].default).startswith( |
| | "${" |
| | ): |
| | return str(self.__dataclass_fields__[attribute_name].default) |
| | elif ( |
| | getattr(self, attribute_name) |
| | != self.__dataclass_fields__[attribute_name].default |
| | ): |
| | return getattr(self, attribute_name) |
| |
|
| | f = self.__dataclass_fields__[attribute_name] |
| | if not isinstance(f.default_factory, _MISSING_TYPE): |
| | return f.default_factory() |
| | return f.default |
| |
|
| | def _get_type(self, attribute_name: str) -> Any: |
| | return self.__dataclass_fields__[attribute_name].type |
| |
|
| | def _get_help(self, attribute_name: str) -> Any: |
| | return self._get_meta(attribute_name, "help") |
| |
|
| | def _get_argparse_const(self, attribute_name: str) -> Any: |
| | return self._get_meta(attribute_name, "argparse_const") |
| |
|
| | def _get_argparse_alias(self, attribute_name: str) -> Any: |
| | return self._get_meta(attribute_name, "argparse_alias") |
| |
|
| | def _get_choices(self, attribute_name: str) -> Any: |
| | return self._get_meta(attribute_name, "choices") |
| |
|
| |
|
| | @dataclass |
| | class CommonConfig(FairseqDataclass): |
| | |
| | |
| | no_progress_bar: bool = field( |
| | default=False, metadata={"help": "disable progress bar"} |
| | ) |
| | log_interval: int = field( |
| | default=100, |
| | metadata={ |
| | "help": "log progress every N batches (when progress bar is disabled)" |
| | }, |
| | ) |
| | log_format: Optional[LOG_FORMAT_CHOICES] = field( |
| | default=None, metadata={"help": "log format to use"} |
| | ) |
| | log_file: Optional[str] = field( |
| | default=None, metadata={"help": "log file to copy metrics to."} |
| | ) |
| | tensorboard_logdir: Optional[str] = field( |
| | default=None, |
| | metadata={ |
| | "help": "path to save logs for tensorboard, should match --logdir " |
| | "of running tensorboard (default: no tensorboard logging)" |
| | }, |
| | ) |
| | wandb_project: Optional[str] = field( |
| | default=None, |
| | metadata={"help": "Weights and Biases project name to use for logging"}, |
| | ) |
| | azureml_logging: Optional[bool] = field( |
| | default=False, metadata={"help": "Log scalars to AzureML context"}, |
| | ) |
| | seed: int = field( |
| | default=1, metadata={"help": "pseudo random number generator seed"} |
| | ) |
| | cpu: bool = field(default=False, metadata={"help": "use CPU instead of CUDA"}) |
| | tpu: bool = field(default=False, metadata={"help": "use TPU instead of CUDA"}) |
| | bf16: bool = field(default=False, metadata={"help": "use bfloat16; implies --tpu"}) |
| | memory_efficient_bf16: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "use a memory-efficient version of BF16 training; implies --bf16" |
| | }, |
| | ) |
| | fp16: bool = field(default=False, metadata={"help": "use FP16"}) |
| | memory_efficient_fp16: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "use a memory-efficient version of FP16 training; implies --fp16" |
| | }, |
| | ) |
| | fp16_no_flatten_grads: bool = field( |
| | default=False, metadata={"help": "don't flatten FP16 grads tensor"} |
| | ) |
| | fp16_init_scale: int = field( |
| | default=2 ** 7, metadata={"help": "default FP16 loss scale"} |
| | ) |
| | fp16_scale_window: Optional[int] = field( |
| | default=None, |
| | metadata={"help": "number of updates before increasing loss scale"}, |
| | ) |
| | fp16_scale_tolerance: float = field( |
| | default=0.0, |
| | metadata={ |
| | "help": "pct of updates that can overflow before decreasing the loss scale" |
| | }, |
| | ) |
| | on_cpu_convert_precision: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "if set, the floating point conversion to fp16/bf16 runs on CPU. " |
| | "This reduces bus transfer time and GPU memory usage." |
| | } |
| | ) |
| | min_loss_scale: float = field( |
| | default=1e-4, |
| | metadata={"help": "minimum FP16/AMP loss scale, after which training is stopped"}, |
| | ) |
| | threshold_loss_scale: Optional[float] = field( |
| | default=None, metadata={"help": "threshold FP16 loss scale from below"} |
| | ) |
| | amp: bool = field(default=False, metadata={"help": "use automatic mixed precision"}) |
| | amp_batch_retries: int = field( |
| | default=2, |
| | metadata={"help": "number of retries of same batch after reducing loss scale with AMP"}, |
| | ) |
| | amp_init_scale: int = field( |
| | default=2 ** 7, metadata={"help": "default AMP loss scale"} |
| | ) |
| | amp_scale_window: Optional[int] = field( |
| | default=None, |
| | metadata={"help": "number of updates before increasing AMP loss scale"}, |
| | ) |
| | user_dir: Optional[str] = field( |
| | default=None, |
| | metadata={ |
| | "help": "path to a python module containing custom extensions (tasks and/or architectures)" |
| | }, |
| | ) |
| | empty_cache_freq: int = field( |
| | default=0, |
| | metadata={"help": "how often to clear the PyTorch CUDA cache (0 to disable)"}, |
| | ) |
| | all_gather_list_size: int = field( |
| | default=16384, |
| | metadata={"help": "number of bytes reserved for gathering stats from workers"}, |
| | ) |
| | model_parallel_size: int = field( |
| | default=1, metadata={"help": "total number of GPUs to parallelize model over"} |
| | ) |
| | quantization_config_path: Optional[str] = field( |
| | default=None, metadata={"help": "path to quantization config file"} |
| | ) |
| | profile: bool = field( |
| | default=False, metadata={"help": "enable autograd profiler emit_nvtx"} |
| | ) |
| | reset_logging: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "when using Hydra, reset the logging at the beginning of training" |
| | }, |
| | ) |
| | suppress_crashes: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "suppress crashes when training with the hydra_train entry point so that the " |
| | "main method can return a value (useful for sweeps)" |
| | }, |
| | ) |
| | use_plasma_view: bool = field( |
| | default=False, metadata={"help": "Store indices and sizes in shared memory"} |
| | ) |
| | plasma_path: Optional[str] = field( |
| | default="/tmp/plasma", |
| | metadata={ |
| | "help": "path to run plasma_store, defaults to /tmp/plasma. Paths outside /tmp tend to fail." |
| | }, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class DistributedTrainingConfig(FairseqDataclass): |
| | distributed_world_size: int = field( |
| | default=max(1, torch.cuda.device_count()), |
| | metadata={ |
| | "help": "total number of GPUs across all nodes (default: all visible GPUs)" |
| | }, |
| | ) |
| | distributed_num_procs: Optional[int] = field( |
| | default=max(1, torch.cuda.device_count()), |
| | metadata={ |
| | "help": "total number of processes to fork (default: all visible GPUs)" |
| | }, |
| | ) |
| | distributed_rank: Optional[int] = field( |
| | default=0, metadata={"help": "rank of the current worker"} |
| | ) |
| | distributed_backend: str = field( |
| | default="nccl", metadata={"help": "distributed backend"} |
| | ) |
| | distributed_init_method: Optional[str] = field( |
| | default=None, |
| | metadata={ |
| | "help": "typically tcp://hostname:port that will be used to " |
| | "establish initial connetion" |
| | }, |
| | ) |
| | distributed_port: int = field( |
| | default=-1, |
| | metadata={ |
| | "help": "port number (not required if using --distributed-init-method)" |
| | }, |
| | ) |
| | device_id: int = field( |
| | default=0, |
| | metadata={"help": "which GPU to use (usually configured automatically)"}, |
| | ) |
| | local_rank: int = field( |
| | default=0, |
| | metadata={"help": "which GPU to use (usually configured automatically)"}, |
| | ) |
| | distributed_no_spawn: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "do not spawn multiple processes even if multiple GPUs are visible" |
| | }, |
| | ) |
| | ddp_backend: DDP_BACKEND_CHOICES = field( |
| | default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"} |
| | ) |
| | ddp_comm_hook: DDP_COMM_HOOK_CHOICES = field( |
| | default="none", metadata={"help": "communication hook"} |
| | ) |
| | bucket_cap_mb: int = field( |
| | default=25, metadata={"help": "bucket size for reduction"} |
| | ) |
| | fix_batches_to_gpus: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "don't shuffle batches between GPUs; this reduces overall " |
| | "randomness and may affect precision but avoids the cost of re-reading the data" |
| | }, |
| | ) |
| | find_unused_parameters: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "disable unused parameter detection (not applicable to " |
| | "--ddp-backend=legacy_ddp)" |
| | }, |
| | ) |
| | fast_stat_sync: bool = field( |
| | default=False, |
| | metadata={"help": "[deprecated] this is now defined per Criterion"}, |
| | ) |
| | heartbeat_timeout: int = field( |
| | default=-1, |
| | metadata={ |
| | "help": "kill the job if no progress is made in N seconds; " |
| | "set to -1 to disable" |
| | }, |
| | ) |
| | broadcast_buffers: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "Copy non-trainable parameters between GPUs, such as " |
| | "batchnorm population statistics" |
| | }, |
| | ) |
| | slowmo_momentum: Optional[float] = field( |
| | default=None, |
| | metadata={ |
| | "help": "SlowMo momentum term; by default use 0.0 for 16 GPUs, " |
| | "0.2 for 32 GPUs; 0.5 for 64 GPUs, 0.6 for > 64 GPUs" |
| | }, |
| | ) |
| | slowmo_algorithm: str = field( |
| | default="LocalSGD", metadata={"help": "whether to use LocalSGD or SGP"} |
| | ) |
| | localsgd_frequency: int = field( |
| | default=3, metadata={"help": "Local SGD allreduce frequency"} |
| | ) |
| | nprocs_per_node: int = field( |
| | default=max(1, torch.cuda.device_count()), |
| | metadata={ |
| | "help": "number of GPUs in each node. An allreduce operation across GPUs in " |
| | "a node is very fast. Hence, we do allreduce across GPUs in a node, " |
| | "and gossip across different nodes" |
| | }, |
| | ) |
| | pipeline_model_parallel: bool = field( |
| | default=False, |
| | metadata={"help": "if set, use pipeline model parallelism across GPUs"}, |
| | ) |
| | pipeline_balance: Optional[str] = field( |
| | default=None, |
| | metadata={ |
| | "help": "partition the model into N_K pieces, where each piece " |
| | "contains N_i layers. The sum(args.pipeline_balance) " |
| | "should equal the total number of layers in the model" |
| | }, |
| | ) |
| | pipeline_devices: Optional[str] = field( |
| | default=None, |
| | metadata={ |
| | "help": "a list of device indices indicating which device to place " |
| | "each of the N_K partitions. The length of this list should " |
| | "equal the length of the --pipeline-balance argument" |
| | }, |
| | ) |
| | pipeline_chunks: Optional[int] = field( |
| | default=0, metadata={"help": "microbatch count for pipeline model parallelism"} |
| | ) |
| | pipeline_encoder_balance: Optional[str] = field( |
| | default=None, |
| | metadata={ |
| | "help": "partition the pipeline parallel encoder into N_K pieces, where each piece " |
| | "contains N_i layers. The sum(args.pipeline_encoder_balance) " |
| | "should equal the total number of encoder layers in the model" |
| | }, |
| | ) |
| | pipeline_encoder_devices: Optional[str] = field( |
| | default=None, |
| | metadata={ |
| | "help": "a list of device indices indicating which device to place " |
| | "each of the N_K partitions. The length of this list should " |
| | "equal the length of the --pipeline-encoder-balance argument" |
| | }, |
| | ) |
| | pipeline_decoder_balance: Optional[str] = field( |
| | default=None, |
| | metadata={ |
| | "help": "partition the pipeline parallel decoder into N_K pieces, where each piece " |
| | "contains N_i layers. The sum(args.pipeline_decoder_balance) " |
| | "should equal the total number of decoder layers in the model" |
| | }, |
| | ) |
| | pipeline_decoder_devices: Optional[str] = field( |
| | default=None, |
| | metadata={ |
| | "help": "a list of device indices indicating which device to place " |
| | "each of the N_K partitions. The length of this list should " |
| | "equal the length of the --pipeline-decoder-balance argument" |
| | }, |
| | ) |
| | pipeline_checkpoint: PIPELINE_CHECKPOINT_CHOICES = field( |
| | default="never", |
| | metadata={"help": "checkpointing mode for pipeline model parallelism"}, |
| | ) |
| | zero_sharding: ZERO_SHARDING_CHOICES = field( |
| | default="none", metadata={"help": "ZeRO sharding"} |
| | ) |
| | fp16: bool = II("common.fp16") |
| | memory_efficient_fp16: bool = II("common.memory_efficient_fp16") |
| | tpu: bool = II("common.tpu") |
| | |
| | no_reshard_after_forward: bool = field( |
| | default=False, metadata={"help": "don't reshard parameters after forward pass"}, |
| | ) |
| | fp32_reduce_scatter: bool = field( |
| | default=False, metadata={"help": "reduce-scatter grads in FP32"}, |
| | ) |
| | cpu_offload: bool = field( |
| | default=False, metadata={"help": "offload FP32 params to CPU"} |
| | ) |
| | use_sharded_state: bool = field( |
| | default=False, metadata={"help": "use sharded checkpoint files"}, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class DatasetConfig(FairseqDataclass): |
| | num_workers: int = field( |
| | default=1, metadata={"help": "how many subprocesses to use for data loading"} |
| | ) |
| | skip_invalid_size_inputs_valid_test: bool = field( |
| | default=False, |
| | metadata={"help": "ignore too long or too short lines in valid and test set"}, |
| | ) |
| | max_tokens: Optional[int] = field( |
| | default=None, metadata={"help": "maximum number of tokens in a batch"} |
| | ) |
| | batch_size: Optional[int] = field( |
| | default=None, |
| | metadata={ |
| | "help": "number of examples in a batch", |
| | "argparse_alias": "--max-sentences", |
| | }, |
| | ) |
| | required_batch_size_multiple: int = field( |
| | default=8, metadata={"help": "batch size will be a multiplier of this value"} |
| | ) |
| | required_seq_len_multiple: int = field( |
| | default=1, |
| | metadata={ |
| | "help": "maximum sequence length in batch will be a multiplier of this value" |
| | }, |
| | ) |
| | dataset_impl: Optional[DATASET_IMPL_CHOICES] = field( |
| | default=None, metadata={"help": "output dataset implementation"} |
| | ) |
| | data_buffer_size: int = field( |
| | default=10, metadata={"help": "Number of batches to preload"} |
| | ) |
| | train_subset: str = field( |
| | default="train", |
| | metadata={"help": "data subset to use for training (e.g. train, valid, test)"}, |
| | ) |
| | valid_subset: str = field( |
| | default="valid", |
| | metadata={ |
| | "help": "comma separated list of data subsets to use for validation" |
| | " (e.g. train, valid, test)" |
| | }, |
| | ) |
| | combine_valid_subsets: Optional[bool] = field( |
| | default=None, |
| | metadata={ |
| | "help": "comma separated list of data subsets to use for validation" |
| | " (e.g. train, valid, test)", |
| | "argparse_alias": "--combine-val", |
| | }, |
| | ) |
| | ignore_unused_valid_subsets: Optional[bool] = field( |
| | default=False, |
| | metadata={"help": "do not raise error if valid subsets are ignored"}, |
| | ) |
| |
|
| |
|
| | |
| | validate_interval: int = field( |
| | default=1, metadata={"help": "validate every N epochs"} |
| | ) |
| | validate_interval_updates: int = field( |
| | default=0, metadata={"help": "validate every N updates"} |
| | ) |
| | validate_after_updates: int = field( |
| | default=0, metadata={"help": "dont validate until reaching this many updates"} |
| | ) |
| | fixed_validation_seed: Optional[int] = field( |
| | default=None, metadata={"help": "specified random seed for validation"} |
| | ) |
| | disable_validation: bool = field( |
| | default=False, metadata={"help": "disable validation"} |
| | ) |
| | max_tokens_valid: Optional[int] = field( |
| | default=II("dataset.max_tokens"), |
| | metadata={ |
| | "help": "maximum number of tokens in a validation batch" |
| | " (defaults to --max-tokens)" |
| | }, |
| | ) |
| | batch_size_valid: Optional[int] = field( |
| | default=II("dataset.batch_size"), |
| | metadata={ |
| | "help": "batch size of the validation batch (defaults to --batch-size)", |
| | "argparse_alias": "--max-sentences-valid", |
| | }, |
| | ) |
| | curriculum: int = field( |
| | default=0, metadata={"help": "don't shuffle batches for first N epochs"} |
| | ) |
| | gen_subset: str = field( |
| | default="test", |
| | metadata={"help": "data subset to generate (train, valid, test)"}, |
| | ) |
| | num_shards: int = field( |
| | default=1, metadata={"help": "shard generation over N shards"} |
| | ) |
| | shard_id: int = field( |
| | default=0, metadata={"help": "id of the shard to generate (id < num_shards)"} |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class OptimizationConfig(FairseqDataclass): |
| | max_epoch: int = field( |
| | default=0, metadata={"help": "force stop training at specified epoch"} |
| | ) |
| | max_update: int = field( |
| | default=0, metadata={"help": "force stop training at specified update"} |
| | ) |
| | stop_time_hours: float = field( |
| | default=0, |
| | metadata={ |
| | "help": "force stop training after specified cumulative time (if >0)" |
| | }, |
| | ) |
| | clip_norm: float = field( |
| | default=25.0, metadata={"help": "clip threshold of gradients"} |
| | ) |
| | sentence_avg: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "normalize gradients by the number of sentences in a batch" |
| | " (default is to normalize by number of tokens)" |
| | }, |
| | ) |
| | update_freq: List[int] = field( |
| | default_factory=lambda: [1], |
| | metadata={"help": "update parameters every N_i batches, when in epoch i"}, |
| | ) |
| | lr: List[float] = field( |
| | default_factory=lambda: [0.25], |
| | metadata={ |
| | "help": "learning rate for the first N epochs; all epochs >N using LR_N" |
| | " (note: this may be interpreted differently depending on --lr-scheduler)" |
| | }, |
| | ) |
| | stop_min_lr: float = field( |
| | default=-1.0, |
| | metadata={"help": "stop training when the learning rate reaches this minimum"}, |
| | ) |
| | use_bmuf: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "specify global optimizer for syncing models on different GPUs/shards" |
| | }, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class CheckpointConfig(FairseqDataclass): |
| | save_dir: str = field( |
| | default="checkpoints", metadata={"help": "path to save checkpoints"} |
| | ) |
| | restore_file: str = field( |
| | default="checkpoint_last.pt", |
| | metadata={ |
| | "help": "filename from which to load checkpoint " |
| | "(default: <save-dir>/checkpoint_last.pt" |
| | }, |
| | ) |
| | finetune_from_model: Optional[str] = field( |
| | default=None, |
| | metadata={ |
| | "help": "finetune from a pretrained model; note that meters and lr scheduler will be reset" |
| | }, |
| | ) |
| | reset_dataloader: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "if set, does not reload dataloader state from the checkpoint" |
| | }, |
| | ) |
| | reset_lr_scheduler: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "if set, does not load lr scheduler state from the checkpoint" |
| | }, |
| | ) |
| | reset_meters: bool = field( |
| | default=False, |
| | metadata={"help": "if set, does not load meters from the checkpoint"}, |
| | ) |
| | reset_optimizer: bool = field( |
| | default=False, |
| | metadata={"help": "if set, does not load optimizer state from the checkpoint"}, |
| | ) |
| | optimizer_overrides: str = field( |
| | default="{}", |
| | metadata={ |
| | "help": "a dictionary used to override optimizer args when loading a checkpoint" |
| | }, |
| | ) |
| | save_interval: int = field( |
| | default=1, metadata={"help": "save a checkpoint every N epochs"} |
| | ) |
| | save_interval_updates: int = field( |
| | default=0, metadata={"help": "save a checkpoint (and validate) every N updates"} |
| | ) |
| | keep_interval_updates: int = field( |
| | default=-1, |
| | metadata={ |
| | "help": "keep the last N checkpoints saved with --save-interval-updates" |
| | }, |
| | ) |
| | keep_interval_updates_pattern: int = field( |
| | default=-1, |
| | metadata={ |
| | "help": "when used with --keep-interval-updates, skips deleting " |
| | "any checkpoints with update X where " |
| | "X %% keep_interval_updates_pattern == 0" |
| | }, |
| | ) |
| | keep_last_epochs: int = field( |
| | default=-1, metadata={"help": "keep last N epoch checkpoints"} |
| | ) |
| | keep_best_checkpoints: int = field( |
| | default=-1, metadata={"help": "keep best N checkpoints based on scores"} |
| | ) |
| | no_save: bool = field( |
| | default=False, metadata={"help": "don't save models or checkpoints"} |
| | ) |
| | no_epoch_checkpoints: bool = field( |
| | default=False, metadata={"help": "only store last and best checkpoints"} |
| | ) |
| | no_last_checkpoints: bool = field( |
| | default=False, metadata={"help": "don't store last checkpoints"} |
| | ) |
| | no_save_optimizer_state: bool = field( |
| | default=False, |
| | metadata={"help": "don't save optimizer-state as part of checkpoint"}, |
| | ) |
| | best_checkpoint_metric: str = field( |
| | default="loss", metadata={"help": 'metric to use for saving "best" checkpoints'} |
| | ) |
| | maximize_best_checkpoint_metric: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": 'select the largest metric value for saving "best" checkpoints' |
| | }, |
| | ) |
| | patience: int = field( |
| | default=-1, |
| | metadata={ |
| | "help": ( |
| | "early stop training if valid performance doesn't " |
| | "improve for N consecutive validation runs; note " |
| | "that this is influenced by --validate-interval" |
| | ) |
| | }, |
| | ) |
| | checkpoint_suffix: str = field( |
| | default="", metadata={"help": "suffix to add to the checkpoint file name"} |
| | ) |
| | checkpoint_shard_count: int = field( |
| | default=1, |
| | metadata={ |
| | "help": "Number of shards containing the checkpoint - " |
| | "if the checkpoint is over 300GB, it is preferable " |
| | "to split it into shards to prevent OOM on CPU while loading " |
| | "the checkpoint" |
| | }, |
| | ) |
| | load_checkpoint_on_all_dp_ranks: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "load checkpoints on all data parallel devices " |
| | "(default: only load on rank 0 and broadcast to other devices)" |
| | }, |
| | ) |
| | write_checkpoints_asynchronously: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": ( |
| | "Write checkpoints asynchronously in a separate " |
| | "thread. NOTE: This feature is currently being tested." |
| | ), |
| | "argparse_alias": "--save-async", |
| | }, |
| | ) |
| | model_parallel_size: int = II("common.model_parallel_size") |
| |
|
| |
|
| | @dataclass |
| | class FairseqBMUFConfig(FairseqDataclass): |
| | block_lr: float = field( |
| | default=1, metadata={"help": "block learning rate for bmuf"} |
| | ) |
| | block_momentum: float = field( |
| | default=0.875, metadata={"help": "block momentum for bmuf"} |
| | ) |
| | global_sync_iter: int = field( |
| | default=50, metadata={"help": "Iteration for syncing global model"} |
| | ) |
| | warmup_iterations: int = field( |
| | default=500, metadata={"help": "warmup iterations for model to broadcast"} |
| | ) |
| | use_nbm: bool = field( |
| | default=False, |
| | metadata={"help": "Specify whether you want to use classical BM / Nesterov BM"}, |
| | ) |
| | average_sync: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "Specify whether you want to average the local momentum after each sync" |
| | }, |
| | ) |
| | distributed_world_size: int = II("distributed_training.distributed_world_size") |
| |
|
| |
|
| | @dataclass |
| | class GenerationConfig(FairseqDataclass): |
| | beam: int = field( |
| | default=5, metadata={"help": "beam size"}, |
| | ) |
| | nbest: int = field( |
| | default=1, metadata={"help": "number of hypotheses to output"}, |
| | ) |
| | max_len_a: float = field( |
| | default=0, |
| | metadata={ |
| | "help": "generate sequences of maximum length ax + b, where x is the source length" |
| | }, |
| | ) |
| | max_len_b: int = field( |
| | default=200, |
| | metadata={ |
| | "help": "generate sequences of maximum length ax + b, where x is the source length" |
| | }, |
| | ) |
| | min_len: int = field( |
| | default=1, metadata={"help": "minimum generation length"}, |
| | ) |
| | match_source_len: bool = field( |
| | default=False, metadata={"help": "generations should match the source length"}, |
| | ) |
| | unnormalized: bool = field( |
| | default=False, metadata={"help": "compare unnormalized hypothesis scores"}, |
| | ) |
| | no_early_stop: bool = field( |
| | default=False, metadata={"help": "deprecated"}, |
| | ) |
| | no_beamable_mm: bool = field( |
| | default=False, metadata={"help": "don't use BeamableMM in attention layers"}, |
| | ) |
| | lenpen: float = field( |
| | default=1, |
| | metadata={ |
| | "help": "length penalty: <1.0 favors shorter, >1.0 favors longer sentences" |
| | }, |
| | ) |
| | unkpen: float = field( |
| | default=0, |
| | metadata={ |
| | "help": "unknown word penalty: <0 produces more unks, >0 produces fewer" |
| | }, |
| | ) |
| | replace_unk: Optional[str] = field( |
| | default=None, |
| | metadata={ |
| | "help": "perform unknown replacement (optionally with alignment dictionary)", |
| | "argparse_const": "@@ ", |
| | }, |
| | ) |
| | sacrebleu: bool = field( |
| | default=False, metadata={"help": "score with sacrebleu"}, |
| | ) |
| | score_reference: bool = field( |
| | default=False, metadata={"help": "just score the reference translation"}, |
| | ) |
| | prefix_size: int = field( |
| | default=0, |
| | metadata={"help": "initialize generation by target prefix of given length"}, |
| | ) |
| | no_repeat_ngram_size: int = field( |
| | default=0, |
| | metadata={ |
| | "help": "ngram blocking such that this size ngram cannot be repeated in the generation" |
| | }, |
| | ) |
| | sampling: bool = field( |
| | default=False, |
| | metadata={"help": "sample hypotheses instead of using beam search"}, |
| | ) |
| | sampling_topk: int = field( |
| | default=-1, |
| | metadata={"help": "sample from top K likely next words instead of all words"}, |
| | ) |
| | sampling_topp: float = field( |
| | default=-1.0, |
| | metadata={ |
| | "help": "sample from the smallest set whose cumulative probability mass exceeds p for next words" |
| | }, |
| | ) |
| | constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field( |
| | default=None, |
| | metadata={ |
| | "help": "enables lexically constrained decoding", |
| | "argparse_const": "ordered", |
| | }, |
| | ) |
| | temperature: float = field( |
| | default=1.0, metadata={"help": "temperature for generation"}, |
| | ) |
| | diverse_beam_groups: int = field( |
| | default=-1, metadata={"help": "number of groups for Diverse Beam Search"}, |
| | ) |
| | diverse_beam_strength: float = field( |
| | default=0.5, |
| | metadata={"help": "strength of diversity penalty for Diverse Beam Search"}, |
| | ) |
| | diversity_rate: float = field( |
| | default=-1.0, |
| | metadata={"help": "strength of diversity penalty for Diverse Siblings Search"}, |
| | ) |
| | print_alignment: Optional[PRINT_ALIGNMENT_CHOICES] = field( |
| | default=None, |
| | metadata={ |
| | "help": "if set, uses attention feedback to compute and print alignment to source tokens " |
| | "(valid options are: hard, soft, otherwise treated as hard alignment)", |
| | "argparse_const": "hard", |
| | }, |
| | ) |
| | print_step: bool = field( |
| | default=False, metadata={"help": "print steps"}, |
| | ) |
| | lm_path: Optional[str] = field( |
| | default=None, metadata={"help": "path to lm checkpoint for lm fusion"}, |
| | ) |
| | lm_weight: float = field( |
| | default=0.0, metadata={"help": "weight for lm probs for lm fusion"}, |
| | ) |
| |
|
| | |
| | iter_decode_eos_penalty: float = field( |
| | default=0.0, |
| | metadata={"help": "if > 0.0, it penalized early-stopping in decoding."}, |
| | ) |
| | iter_decode_max_iter: int = field( |
| | default=10, metadata={"help": "maximum iterations for iterative refinement."}, |
| | ) |
| | iter_decode_force_max_iter: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "if set, run exact the maximum number of iterations without early stop" |
| | }, |
| | ) |
| | iter_decode_with_beam: int = field( |
| | default=1, |
| | metadata={ |
| | "help": "if > 1, model will generate translations varying by the lengths." |
| | }, |
| | ) |
| | iter_decode_with_external_reranker: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "if set, the last checkpoint are assumed to be a reranker to rescore the translations" |
| | }, |
| | ) |
| | retain_iter_history: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "if set, decoding returns the whole history of iterative refinement" |
| | }, |
| | ) |
| | retain_dropout: bool = field( |
| | default=False, metadata={"help": "Use dropout at inference time"}, |
| | ) |
| | |
| | |
| | retain_dropout_modules: Any = field( |
| | default=None, |
| | metadata={ |
| | "help": "if set, only retain dropout for the specified modules; " |
| | "if not set, then dropout will be retained for all modules" |
| | }, |
| | ) |
| | |
| | decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field( |
| | default=None, |
| | metadata={"help": "special decoding format for advanced decoding."}, |
| | ) |
| | no_seed_provided: bool = field( |
| | default=False, |
| | metadata={"help": "if set, dont use seed for initializing random generators"}, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class CommonEvalConfig(FairseqDataclass): |
| | path: Optional[str] = field( |
| | default=None, metadata={"help": "path(s) to model file(s), colon separated"}, |
| | ) |
| | post_process: Optional[str] = field( |
| | default=None, |
| | metadata={ |
| | "help": ( |
| | "post-process text by removing BPE, letter segmentation, etc. " |
| | "Valid options can be found in fairseq.data.utils.post_process." |
| | ), |
| | "argparse_const": "subword_nmt", |
| | "argparse_alias": "--remove-bpe", |
| | }, |
| | ) |
| | quiet: bool = field(default=False, metadata={"help": "only print final scores"}) |
| | model_overrides: str = field( |
| | default="{}", |
| | metadata={ |
| | "help": "a dictionary used to override model args at generation that were used during model training" |
| | }, |
| | ) |
| | results_path: Optional[str] = field( |
| | default=None, metadata={"help": "path to save eval results (optional)"} |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class EvalLMConfig(FairseqDataclass): |
| | output_word_probs: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "if set, outputs words and their predicted log probabilities to standard output" |
| | }, |
| | ) |
| | output_word_stats: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "if set, outputs word statistics such as word count, average probability, etc" |
| | }, |
| | ) |
| | context_window: int = field( |
| | default=0, |
| | metadata={ |
| | "help": "ensures that every evaluated token has access to a context of at least this size, if possible" |
| | }, |
| | ) |
| | softmax_batch: int = field( |
| | default=sys.maxsize, |
| | metadata={ |
| | "help": "if BxT is more than this, will batch the softmax over vocab to this amount of tokens, in order to fit into GPU memory" |
| | }, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class InteractiveConfig(FairseqDataclass): |
| | buffer_size: int = field( |
| | default=0, |
| | metadata={ |
| | "help": "read this many sentences into a buffer before processing them" |
| | }, |
| | ) |
| | input: str = field( |
| | default="-", metadata={"help": "file to read from; use - for stdin"}, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class FairseqConfig(FairseqDataclass): |
| | common: CommonConfig = CommonConfig() |
| | common_eval: CommonEvalConfig = CommonEvalConfig() |
| | distributed_training: DistributedTrainingConfig = DistributedTrainingConfig() |
| | dataset: DatasetConfig = DatasetConfig() |
| | optimization: OptimizationConfig = OptimizationConfig() |
| | checkpoint: CheckpointConfig = CheckpointConfig() |
| | bmuf: FairseqBMUFConfig = FairseqBMUFConfig() |
| | generation: GenerationConfig = GenerationConfig() |
| | eval_lm: EvalLMConfig = EvalLMConfig() |
| | interactive: InteractiveConfig = InteractiveConfig() |
| | model: Any = MISSING |
| | task: Any = None |
| | criterion: Any = None |
| | optimizer: Any = None |
| | lr_scheduler: Any = None |
| | scoring: Any = None |
| | bpe: Any = None |
| | tokenizer: Any = None |
| |
|