MiloMusic / YuE /finetune /core /arguments.py
futurespyhi
1.add YuE 2.modify .gitignore 3.modify requirements.txt
15389e6
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Megatron arguments."""
import argparse
import dataclasses
import json
import os
import torch
import types
# import torch.nn.functional as F
# from megatron.global_vars import set_retro_args, get_retro_args
# from tools.retro.utils import get_args_path as get_retro_args_path
# from megatron.core.models.retro import RetroConfig
# from megatron.core.transformer import TransformerConfig
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='YuE Finetune Arguments',
allow_abbrev=False)
# Standard arguments.
parser = _add_initialization_args(parser)
parser = _add_data_args(parser)
parser = _add_checkpointing_args(parser)
parser = _add_training_args1(parser)
parser = _add_validation_args(parser)
parser = _add_retro_args(parser)
parser = _add_logging_args(parser)
parser = _add_finetune_args(parser)
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)
# Parse.
if ignore_unknown_args:
args, _ = parser.parse_known_args()
else:
args = parser.parse_args()
# Args from environment
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
return args
def validate_args(args, defaults={}):
# Tensor model parallel size.
args.tensor_model_parallel_size = min(
args.tensor_model_parallel_size, args.world_size)
assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
' ({}) is not divisible by tensor model parallel size ({})'.format(
args.world_size, args.tensor_model_parallel_size)
# Pipeline model parallel size.
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size))
args.transformer_pipeline_model_parallel_size = (
args.pipeline_model_parallel_size - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_size
)
# Checks.
model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size
assert args.world_size % (model_parallel_size * args.context_parallel_size) == 0, \
'world size ({}) is not divisible by tensor parallel size ({}) times ' \
'pipeline parallel size ({}) times context parallel size ({})'.format(
args.world_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size, args.context_parallel_size)
args.data_parallel_size = args.world_size // (model_parallel_size * args.context_parallel_size)
if args.rank == 0:
print('using world size: {}, data-parallel size: {}, '
'context-parallel size: {} '
'tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {} '.format(
args.world_size, args.data_parallel_size,
args.context_parallel_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True)
if args.pipeline_model_parallel_size > 1:
if args.pipeline_model_parallel_split_rank is not None:
assert args.pipeline_model_parallel_split_rank < \
args.pipeline_model_parallel_size, 'split rank needs'\
' to be less than pipeline model parallel size ({})'.format(
args.pipeline_model_parallel_size)
if args.tp_comm_overlap:
assert args.sequence_parallel == True, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled'
# Deprecated arguments
assert args.batch_size is None, '--batch-size argument is no longer ' \
'valid, use --micro-batch-size instead'
del args.batch_size
assert args.warmup is None, '--warmup argument is no longer valid, use ' \
'--lr-warmup-fraction instead'
del args.warmup
assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size
if args.checkpoint_activations:
if args.rank == 0:
print('--checkpoint-activations is no longer valid, use --recompute-activations, '
'or, for more control, --recompute-granularity and --recompute-method.')
exit()
del args.checkpoint_activations
if args.recompute_activations:
args.recompute_granularity = 'selective'
del args.recompute_activations
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key, None) is not None:
if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)),
flush=True)
else:
setattr(args, key, defaults[key])
# Batch size.
assert args.micro_batch_size is not None
assert args.micro_batch_size > 0
if args.global_batch_size is None:
args.global_batch_size = args.micro_batch_size * args.data_parallel_size
if args.rank == 0:
print('setting global batch size to {}'.format(
args.global_batch_size), flush=True)
assert args.global_batch_size > 0
if args.num_layers_per_virtual_pipeline_stage is not None:
assert args.pipeline_model_parallel_size > 2, \
'pipeline-model-parallel size should be greater than 2 with ' \
'interleaved schedule'
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'number of layers should be divisible by the pipeline parallel size'
num_layers_per_pipeline_stage = args.num_layers // args.transformer_pipeline_model_parallel_size
assert num_layers_per_pipeline_stage % args.num_layers_per_virtual_pipeline_stage == 0, \
'number of layers per pipeline stage must be divisible number of layers per virtual pipeline stage'
args.virtual_pipeline_model_parallel_size = num_layers_per_pipeline_stage // \
args.num_layers_per_virtual_pipeline_stage
else:
args.virtual_pipeline_model_parallel_size = None
# Overlap P2P communication is disabled if not using the interleaved schedule.
args.overlap_p2p_comm = False
if args.rank == 0:
print('WARNING: Setting args.overlap_p2p_comm to False since non-interleaved '
'schedule does not support overlapping p2p communication')
if args.overlap_param_gather:
assert args.use_distributed_optimizer, \
'--overlap-param-gather only supported with distributed optimizer'
assert args.overlap_grad_reduce, \
'--overlap-grad-reduce should be turned on when using --overlap-param-gather'
# Parameters dtype.
args.params_dtype = torch.float
if args.fp16:
assert not args.bf16
args.params_dtype = torch.half
if args.bf16:
assert not args.fp16
args.params_dtype = torch.bfloat16
# bfloat16 requires gradient accumulation and all-reduce to
# be done in fp32.
if not args.accumulate_allreduce_grads_in_fp32:
args.accumulate_allreduce_grads_in_fp32 = True
if args.rank == 0:
print('accumulate and all-reduce gradients in fp32 for '
'bfloat16 data type.', flush=True)
if args.rank == 0:
print('using {} for parameters ...'.format(args.params_dtype),
flush=True)
if args.dataloader_type is None:
args.dataloader_type = 'single'
# Consumed tokens.
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
# Support for variable sequence lengths across batches/microbatches.
# set it if the dataloader supports generation of variable sequence lengths
# across batches/microbatches. Due to additional communication overhead
# during pipeline parallelism, it should not be set if sequence length
# is constant during training.
args.variable_seq_lengths = False
# Iteration-based training.
if args.train_iters:
# If we use iteration-based training, make sure the
# sample-based options are off.
assert args.train_samples is None, \
'expected iteration-based training'
assert args.lr_decay_samples is None, \
'expected iteration-based learning rate decay'
assert args.lr_warmup_samples == 0, \
'expected iteration-based learning rate warmup'
assert args.rampup_batch_size is None, \
'expected no batch-size rampup for iteration-based training'
if args.lr_warmup_fraction is not None:
assert args.lr_warmup_iters == 0, \
'can only specify one of lr-warmup-fraction and lr-warmup-iters'
# Sample-based training.
if args.train_samples:
# If we use sample-based training, make sure the
# iteration-based options are off.
assert args.train_iters is None, \
'expected sample-based training'
assert args.lr_decay_iters is None, \
'expected sample-based learning rate decay'
assert args.lr_warmup_iters == 0, \
'expected sample-based learnig rate warmup'
if args.lr_warmup_fraction is not None:
assert args.lr_warmup_samples == 0, \
'can only specify one of lr-warmup-fraction ' \
'and lr-warmup-samples'
if args.num_layers is not None:
assert args.encoder_num_layers is None, \
'cannot have both num-layers and encoder-num-layers specified'
args.encoder_num_layers = args.num_layers
else:
assert args.encoder_num_layers is not None, \
'either num-layers or encoder-num-layers should be specified'
args.num_layers = args.encoder_num_layers
# Check required arguments.
required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_position_embeddings']
for req_arg in required_args:
_check_arg_is_not_none(args, req_arg)
# Checks.
if args.ffn_hidden_size is None:
if args.swiglu:
# reduce the dimnesion for MLP since projections happens on
# two linear layers. this keeps the number of paramters in
# the same ballpark as the counterpart with 4*h size
# we keep it a multiple of 64, which means the actual tensor size
# will be a multiple of 64 / tp_size
args.ffn_hidden_size = int((4 * args.hidden_size * 2 / 3) / 64) * 64
else:
args.ffn_hidden_size = 4 * args.hidden_size
if args.kv_channels is None:
assert args.hidden_size % args.num_attention_heads == 0
args.kv_channels = args.hidden_size // args.num_attention_heads
if args.seq_length is not None:
assert args.encoder_seq_length is None
args.encoder_seq_length = args.seq_length
else:
assert args.encoder_seq_length is not None
args.seq_length = args.encoder_seq_length
if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length
if args.decoder_seq_length is not None:
assert args.max_position_embeddings >= args.decoder_seq_length
if args.lr is not None:
assert args.min_lr <= args.lr
if args.save is not None:
assert args.save_interval is not None
# Mixed precision checks.
if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
if args.moe_grouped_gemm:
assert args.bf16, 'Currently GroupedGEMM for MoE only supports bf16 dtype.'
dc = torch.cuda.get_device_capability()
assert dc[0] >= 8, "Unsupported compute capability for GroupedGEMM kernels."
if args.weight_decay_incr_style == 'constant':
assert args.start_weight_decay is None
assert args.end_weight_decay is None
args.start_weight_decay = args.weight_decay
args.end_weight_decay = args.weight_decay
else:
assert args.start_weight_decay is not None
assert args.end_weight_decay is not None
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
# Persistent fused layer norm.
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
args.no_persist_layer_norm = True
if args.rank == 0:
print('Persistent fused layer norm kernel is supported from '
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True')
# Activation recomputing.
if args.distribute_saved_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'recomputed activations only across tensor model ' \
'parallel groups'
assert args.recompute_granularity == 'full', \
'distributed recompute activations is only '\
'application to full recompute granularity'
assert args.recompute_method is not None, \
'for distributed recompute activations to work you '\
'need to use a recompute method '
assert (TORCH_MAJOR, TORCH_MINOR) >= (1, 10), \
'distributed recompute activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
if args.recompute_granularity == 'selective':
assert args.recompute_method is None, \
'recompute method is not yet supported for ' \
'selective recomputing granularity'
# disable sequence parallelism when tp=1
# to avoid change in numerics when
# sequence_parallelism is enabled.
if args.tensor_model_parallel_size == 1:
args.sequence_parallel = False
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel:
raise RuntimeError(
"Using sequence parallelism requires setting the environment variable "
"CUDA_DEVICE_MAX_CONNECTIONS to 1")
if args.async_tensor_model_parallel_allreduce:
raise RuntimeError(
"Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
# Disable bias gelu fusion if we are disabling bias altogether
if not args.add_bias_linear:
args.bias_gelu_fusion = False
# Retro checks.
if args.retro_add_retriever:
# Sequence parallelism unsupported.
assert not args.sequence_parallel, \
"retro currently does not support sequence parallelism."
# Pipeline parallelism unsupported.
assert args.pipeline_model_parallel_size == 1, \
"retro currently does not support pipeline parallelism."
# Load retro args (used by both Retro & GPT).
# if args.retro_workdir:
# retro_args_path = get_retro_args_path(args.retro_workdir)
# assert os.path.exists(retro_args_path), "retro workdir missing args.json"
# with open(retro_args_path) as f:
# retro_args = types.SimpleNamespace(**json.load(f))
# retro_args.retro_return_doc_ids = args.retro_return_doc_ids
# retro_args.retro_gpt_retrieved_length = \
# args.retro_num_retrieved_chunks * \
# retro_args.retro_gpt_chunk_length
# set_retro_args(retro_args)
# Legacy RoPE arguments
if args.use_rotary_position_embeddings:
args.position_embedding_type = 'rope'
# Would just need to add 'NoPE' as a position_embedding_type to support this, but for now
# don't allow it to keep things simple
if not args.add_position_embedding and args.position_embedding_type != 'rope':
raise RuntimeError('--no-position-embedding is deprecated, use --position-embedding-type')
# MoE Spec check
if args.num_experts is not None:
assert args.spec is None, "Model Spec must be None when using MoEs"
# Expert parallelism check
if args.expert_model_parallel_size > 1:
assert args.num_experts is not None, "num_experts must be non None to use expert model parallelism"
assert args.num_experts % args.expert_model_parallel_size == 0, \
"Number of experts should be a multiple of expert model parallel_size."
assert not args.use_distributed_optimizer, \
"Expert parallelism is not suppored with distributed optimizer."
assert not args.fp16, \
"Expert parallelism is not supported with fp16 training."
if args.tensor_model_parallel_size > 1:
assert args.sequence_parallel, \
"When using expert parallelism and tensor parallelism, sequence parallelism must be used."
# Print arguments.
_print_args("arguments", args)
# retro_args = get_retro_args()
# if retro_args and args != retro_args:
# _print_args("retro arguments", types.SimpleNamespace(**{k:v for k,v in vars(retro_args).items() if k.startswith("retro")}, rank=args.rank))
return args
def _print_args(title, args):
"""Print arguments."""
if args.rank == 0:
print(f'------------------------ {title} ------------------------',
flush=True)
str_list = []
for arg in vars(args):
dots = '.' * (48 - len(arg))
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()):
print(arg, flush=True)
print(f'-------------------- end of {title} ---------------------',
flush=True)
def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
# def core_transformer_config_from_args(args):
# # Translate args to core transformer configuration
# kw_args = {}
# for f in dataclasses.fields(TransformerConfig):
# if hasattr(args, f.name):
# kw_args[f.name] = getattr(args, f.name)
# kw_args['persist_layer_norm'] = not args.no_persist_layer_norm
# kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p
# kw_args['layernorm_epsilon'] = args.norm_epsilon
# kw_args['deallocate_pipeline_outputs'] = True
# kw_args['pipeline_dtype'] = args.params_dtype
# kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm
# kw_args['num_moe_experts'] = args.num_experts
# if args.swiglu:
# kw_args['activation_func'] = F.silu
# kw_args['gated_linear_unit'] = True
# kw_args['bias_gelu_fusion'] = False
# if args.squared_relu:
# assert not args.swiglu
# def squared_relu(x):
# return torch.pow(F.relu(x), 2)
# kw_args['activation_func'] = squared_relu
# if args.init_method_xavier_uniform:
# kw_args['init_method'] = torch.nn.init.xavier_uniform_
# kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_
# if args.group_query_attention:
# kw_args['num_query_groups'] = args.num_query_groups
# else:
# kw_args['num_query_groups'] = None
# # If using Retro, return Retro config.
# retro_args = get_retro_args()
# if retro_args:
# kw_args['retro_preprocess'] = retro_args
# return RetroConfig(**kw_args)
# # Return Transformer config.
# return TransformerConfig(**kw_args)
def _add_transformer_engine_args(parser):
group = parser.add_argument_group(title='Transformer-Engine')
group.add_argument('--fp8-format', default=None,
choices=['e4m3', 'hybrid'],
help='Which fp8 format scheme to use for FP8 tensors in the forward and backward pass',
dest='fp8')
group.add_argument('--fp8-margin', type=int, default=0,
help='Scaling margin for fp8',
dest='fp8_margin')
group.add_argument('--fp8-interval', type=int, default=1,
help='Scaling update interval for fp8',
dest='fp8_interval')
group.add_argument('--fp8-amax-history-len', type=int, default=1,
help='Number of steps for which amax history is recorded per tensor',
dest='fp8_amax_history_len')
group.add_argument('--fp8-amax-compute-algo', default='most_recent',
choices=['most_recent', 'max'],
help='Algorithm for computing amax from history',
dest='fp8_amax_compute_algo')
group.add_argument('--no-fp8-wgrad', action='store_false',
help='Execute wgrad in higher precision even for FP8 runs',
dest='fp8_wgrad')
group.add_argument('--transformer-impl', default='local',
choices=['local', 'transformer_engine'],
help='Which Transformer implementation to use.')
return parser
def _add_inference_args(parser):
group = parser.add_argument_group(title='inference')
group.add_argument('--inference-batch-times-seqlen-threshold',
type=int, default=512,
help='During inference, if batch-size times '
'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.')
group.add_argument('--max-tokens-to-oom',
type=int, default=12000,
help='Maximum number of tokens during inference'
'tokens here is # in prompt + # to generate'
'Allows us to throw an error before OOM crashes server')
group.add_argument('--output-bert-embeddings', action='store_true',
help='Output Bert embeddings (via mean pooling) from '
'model, rather than its binary head output or entire '
'hidden batch.')
group.add_argument('--bert-embedder-type', default="megatron",
choices=["megatron", "huggingface"],
help='Select either Megatron or Huggingface as the '
'Bert embedder.')
return parser
def _add_retro_args(parser):
group = parser.add_argument_group(title='retro')
group.add_argument('--retro-workdir', default=None,
help='Retro working directory, which contains the '
'preprocessed data for for pretraining. This directory '
'is built during preprocessing (see '
'tools/retro/README.md), and contains subdirectories '
'for the chunk database and pretraining neighbors.')
group.add_argument('--retro-add-retriever',
action='store_true', default=False,
help='Add a retriever to the transformer, for use in '
'pretraining a Retro model.')
group.add_argument('--retro-cyclic-train-iters', type=int, default=None,
help='Set number of training iterations for cyclic '
'Retro training.')
group.add_argument('--retro-encoder-layers', type=int, default=2,
help='Number of layers to use for the retrieval '
'encoder.')
group.add_argument('--retro-encoder-hidden-dropout',
type=float, default=0.1, help='Hidden dropout for '
'retrieval encoder.')
group.add_argument('--retro-encoder-attention-dropout',
type=float, default=0.1, help='Attention dropout for '
'retrieval encoder.')
group.add_argument("--retro-num-neighbors", type=int, default=2,
help='Number of neighbors to retrieve during '
'pretraining.')
group.add_argument("--retro-num-retrieved-chunks", type=int, default=2,
help='Number of chunks to retrieve from the retrieval '
'database.')
group.add_argument("--retro-return-doc-ids", action="store_true",
help="Turn this on when preprocessing retro data.")
group.add_argument("--retro-attention-gate", type=float, default=1,
help="Gated cross attention.")
group.add_argument("--retro-no-verify-neighbor-count", action="store_false",
dest="retro_verify_neighbor_count",
help="Skip verifying that len(GPT dataset) == len(saved "
"neighbors).")
# Enforce argument naming convention.
for action in group._group_actions:
prefix = action.dest.split("_")[0]
assert prefix == "retro", \
"Retro args must be prefixed with '--retro-*', for consistent " \
"styling. Please fix '%s'." % ", ".join(action.option_strings)
return parser
def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size')
group.add_argument('--num-layers', type=int, default=None,
help='Number of transformer layers.')
group.add_argument('--encoder-num-layers', type=int, default=None,
help='Number of encoder transformer layers.')
group.add_argument('--decoder-num-layers', type=int, default=None,
help='Number of decoder transformer layers.')
group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.')
group.add_argument('--ffn-hidden-size', type=int, default=None,
help='Transformer Feed-Forward Network hidden size. '
'This is set to 4*hidden-size if not provided')
group.add_argument('--num-attention-heads', type=int, default=None,
help='Number of transformer attention heads.')
group.add_argument('--kv-channels', type=int, default=None,
help='Projection weights dimension in multi-head '
'attention. This is set to '
' args.hidden_size // args.num_attention_heads '
'if not provided.')
group.add_argument('--group-query-attention', action='store_true',
help='Use group-query attention.')
group.add_argument('--num-query-groups', type=int, default=1)
group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. '
'This is the size of position embedding.')
group.add_argument('--position-embedding-type', type=str, default='learned_absolute',
choices=['learned_absolute', 'rope'],
help='Position embedding type.')
group.add_argument('--use-rotary-position-embeddings', action='store_true',
help='Use rotary positional embeddings or not. '
'Deprecated: use --position-embedding-type')
group.add_argument('--rotary-percent', type=float, default=1.0,
help='Percent of rotary dimension to use, default 100%%')
group.add_argument('--rotary-seq-len-interpolation-factor', type=int, default=None,
help='Sequence length interpolation factor for rotary embeddings.')
group.add_argument('--rotary-base', type=int, default=10000,
help='Base period for rotary position embeddings. Ignored unless position_embedding_type is \'rope\'. Defaults to 10000.')
group.add_argument('--no-position-embedding',
action='store_false',
help='Disable position embedding. Deprecated: use --position-embedding-type',
dest='add_position_embedding')
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.')
group.add_argument('--normalization', default='LayerNorm',
choices=['LayerNorm', 'RMSNorm'],
help='Which normalization technique to use.')
group.add_argument('--norm-epsilon', type=float, default=1e-5,
help='Epsilon for layer norm and RMS norm.')
group.add_argument('--apply-layernorm-1p', action='store_true',
help='Adjust LayerNorm weights such that they are centered '
'around zero. This improves numerical stability.')
group.add_argument('--apply-residual-connection-post-layernorm',
action='store_true',
help='If set, use original BERT residula connection '
'ordering.')
group.add_argument('--openai-gelu', action='store_true',
help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility'
'reasons.')
group.add_argument('--squared-relu', action='store_true',
help='Use squared relu activation instead of default gelu')
group.add_argument('--swiglu', action='store_true',
help='Use gated linear units and SiLU activation instead of default gelu')
group.add_argument('--onnx-safe', type=bool, required=False,
help='Use workarounds for known problems with '
'Torch ONNX exporter')
group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.',
dest='bert_binary_head')
group.add_argument('--num-experts', type=int, default=None,
help='Number of Experts in Switch Transformer (None means no Switch)')
group.add_argument('--moe-grouped-gemm', action='store_true',
help='When there are multiple experts per rank, compress '
'multiple local (potentially small) gemms in a single kernel '
'launch to improve the utilization and performance by '
'leveraging the Grouped GEMM feature introduced since '
'CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).')
group.add_argument('--untie-embeddings-and-output-weights', action='store_true',
help='Untie embeddings and output weights.'),
return parser
def _add_logging_args(parser):
group = parser.add_argument_group(title='logging')
group.add_argument('--log-params-norm', action='store_true',
help='If set, calculate and log parameters norm.')
group.add_argument('--log-num-zeros-in-grad', action='store_true',
help='If set, calculate and log the number of zeros in gradient.')
group.add_argument('--log-throughput', action='store_true',
help='If set, calculate and log throughput per GPU.')
group.add_argument('--timing-log-level', type=int,
default=0, choices=range(0,3),
help='Granularity level to measure and report timing. '
' 0: report only iteration time and make sure timing '
' does not introduce extra overhead.'
' 1: report timing for operations that are executed '
' very limited times (basically once) during '
' each iteration (such as gradient all-reduce) '
' 2: report timing for operations that migh be '
' executed numerous times during each iteration. '
'Note that setting the level to 1 or 2 might '
'cause increase in iteration time.')
group.add_argument('--no-barrier-with-level-1-timing', action='store_false',
help='If not set, use barrier with level 1 time '
'measurements. Note that this is up to the user '
'to make sure calling barrier with their timers '
'will not result in hangs. This can happen if for '
'example the user adds a level 1 timer that is not '
'called by all ranks.',
dest='barrier_with_L1_time')
group.add_argument('--timing-log-option', type=str, default='minmax',
choices=['max', 'minmax', 'all'],
help='Options for logging timing:'
' max: report the max timing across all ranks'
' minmax: report min and max timings across all ranks'
' all: report timings of all ranks.')
group.add_argument('--tensorboard-log-interval', type=int, default=1,
help='Report to tensorboard interval.')
group.add_argument('--tensorboard-queue-size', type=int, default=1000,
help='Size of the tensorboard queue for pending events '
'and summaries before one of the ‘add’ calls forces a '
'flush to disk.')
group.add_argument('--log-timers-to-tensorboard', action='store_true',
help='If set, write timers to tensorboard.')
group.add_argument('--log-batch-size-to-tensorboard', action='store_true',
help='If set, write batch-size to tensorboard.')
group.add_argument('--no-log-learnig-rate-to-tensorboard',
action='store_false',
help='Disable learning rate logging to tensorboard.',
dest='log_learning_rate_to_tensorboard')
group.add_argument('--no-log-loss-scale-to-tensorboard',
action='store_false',
help='Disable loss-scale logging to tensorboard.',
dest='log_loss_scale_to_tensorboard')
group.add_argument('--log-validation-ppl-to-tensorboard',
action='store_true',
help='If set, write validation perplexity to '
'tensorboard.')
group.add_argument('--log-memory-to-tensorboard',
action='store_true',
help='Enable memory logging to tensorboard.')
group.add_argument('--log-world-size-to-tensorboard',
action='store_true',
help='Enable world size logging to tensorboard.')
group.add_argument('--wandb-project', type=str, default='',
help='The wandb project name. Ignore wandb by default.')
group.add_argument('--wandb-exp-name', type=str, default='',
help='The wandb experiment name.')
group.add_argument('--wandb-save-dir', type=str, default='',
help='Path to save the wandb results locally.')
return parser
def _add_regularization_args(parser):
group = parser.add_argument_group(title='regularization')
group.add_argument('--attention-dropout', type=float, default=0.1,
help='Post attention dropout probability.')
group.add_argument('--hidden-dropout', type=float, default=0.1,
help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.')
group.add_argument('--start-weight-decay', type=float,
help='Initial weight decay coefficient for L2 regularization.')
group.add_argument('--end-weight-decay', type=float,
help='End of run weight decay coefficient for L2 regularization.')
group.add_argument('--weight-decay-incr-style', type=str, default='constant',
choices=['constant', 'linear', 'cosine'],
help='Weight decay increment function.')
group.add_argument('--clip-grad', type=float, default=1.0,
help='Gradient clipping based on global L2 norm.')
group.add_argument('--adam-beta1', type=float, default=0.9,
help='First coefficient for computing running averages '
'of gradient and its square')
group.add_argument('--adam-beta2', type=float, default=0.999,
help='Second coefficient for computing running averages '
'of gradient and its square')
group.add_argument('--adam-eps', type=float, default=1e-08,
help='Term added to the denominator to improve'
'numerical stability')
group.add_argument('--sgd-momentum', type=float, default=0.9,
help='Momentum factor for sgd')
return parser
def _add_finetune_args(parser):
group = parser.add_argument_group(title='finetune')
group.add_argument('--model-name-or-path', type=str, default=None,
help='Path to the model to finetune.')
group.add_argument('--cache-dir', type=str, default=None,
help='Cache directory for the model.')
group.add_argument('--optim', type=str, default="adamw_torch_fused",
help='Optimizer to use.')
group.add_argument('--model-max-length', type=int, default=2048,
help='Maximum sequence length.')
group.add_argument('--logging-steps', type=int, default=100,
help='Log every X updates.')
group.add_argument('--report-to', type=str, default=None,
help='The integration to report the results and logs to.')
group.add_argument('--run-name', type=str, default=None,
help='The name of the run for logging.')
group.add_argument('--gradient-checkpointing', action='store_true',
help='Enable gradient checkpointing.')
group.add_argument('--lr-scheduler-type', type=str, default="cosine",
help='The learning rate scheduler to use.')
group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode.')
group.add_argument('--bf16', action='store_true',
help='Run model in bfloat16 mode.')
group.add_argument('--num-train-epochs', type=int, default=200,
help='Total number of training epochs.')
group.add_argument('--per-device-train-batch-size', type=int, default=1,
help='Batch size per device during training.')
group.add_argument('--per-device-eval-batch-size', type=int, default=1,
help='Batch size per device during evaluation.')
group.add_argument('--gradient-accumulation-steps', type=int, default=1,
help='Number of updates steps to accumulate before performing a backward/update pass.')
group.add_argument('--evaluation-strategy', type=str, default="steps",
help='The evaluation strategy to use.')
group.add_argument('--eval-steps', type=int, default=5000,
help='Number of update steps between two evaluations.')
group.add_argument('--save-strategy', type=str, default="steps",
help='The checkpoint save strategy to use.')
group.add_argument('--save-steps', type=int, default=100,
help='Number of updates steps before two checkpoint saves.')
group.add_argument('--save-total-limit', type=int, default=100,
help='Limit the total amount of checkpoints.')
group.add_argument('--learning-rate', type=float, default=0.0005,
help='The initial learning rate for training.')
group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.')
group.add_argument('--adam-beta2', type=float, default=0.95,
help='Beta2 for Adam optimizer.')
group.add_argument('--warmup-ratio', type=float, default=0.03,
help='Linear warmup over warmup_ratio fraction of total steps.')
group.add_argument('--dataloader-num-workers', type=int, default=4,
help='Number of subprocesses to use for data loading.')
group.add_argument('--dataloader-prefetch-factor', type=int, default=4,
help='Number of batches loaded in advance by each worker.')
group.add_argument('--deepspeed', type=str, default="ds_config_zero2.json",
help='Path to deepspeed config file.')
group.add_argument('--output-dir', type=str, default=None,
help='Path to save the finetuned model.')
# LoRA parameters
group.add_argument('--lora-r', type=int, default=64,
help='Rank of the LoRA update matrices.')
group.add_argument('--lora-alpha', type=int, default=32,
help='Scaling factor for the LoRA update.')
group.add_argument('--lora-target-modules', nargs='+', default=["q_proj", "k_proj", "v_proj", "o_proj"],
help='List of module names to apply LoRA to.')
group.add_argument('--lora-dropout', type=float, default=0.1,
help='Dropout probability for LoRA layers.')
return parser
def _add_training_args1(parser):
group = parser.add_argument_group(title='training')
group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--train-samples', type=int, default=None,
help='Total number of samples to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--log-interval', type=int, default=100,
help='Report loss and timing interval.')
group.add_argument('--micro-batch-size', type=int, default=None,
help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'parallel size times number of micro batches.')
group.add_argument('--global-batch-size', type=int, default=None,
help='Training batch size. If set, it should be a '
'multiple of micro-batch-size times data-parallel-size. '
'If this value is None, then '
'use micro-batch-size * data-parallel-size as the '
'global batch size. This choice will result in 1 for '
'number of micro-batches.')
group.add_argument('--rampup-batch-size', nargs='*', default=None,
help='Batch size ramp up with the following values:'
' --rampup-batch-size <start batch size> '
' <batch size incerement> '
' <ramp-up samples> '
'For example:'
' --rampup-batch-size 16 8 300000 '
' --global-batch-size 1024'
'will start with global batch size 16 and over '
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--lr-scheduler-type ', type=str, default="cosine",
help='The learning rate scheduler to use.')
return parser
def _add_training_args(parser):
group = parser.add_argument_group(title='training')
group.add_argument('--micro-batch-size', type=int, default=None,
help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'parallel size times number of micro batches.')
group.add_argument('--batch-size', type=int, default=None,
help='Old batch size parameter, do not use. '
'Use --micro-batch-size instead')
group.add_argument('--global-batch-size', type=int, default=None,
help='Training batch size. If set, it should be a '
'multiple of micro-batch-size times data-parallel-size. '
'If this value is None, then '
'use micro-batch-size * data-parallel-size as the '
'global batch size. This choice will result in 1 for '
'number of micro-batches.')
group.add_argument('--rampup-batch-size', nargs='*', default=None,
help='Batch size ramp up with the following values:'
' --rampup-batch-size <start batch size> '
' <batch size incerement> '
' <ramp-up samples> '
'For example:'
' --rampup-batch-size 16 8 300000 \ '
' --global-batch-size 1024'
'will start with global batch size 16 and over '
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--recompute-activations', action='store_true',
help='recompute activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--recompute-granularity', type=str, default=None,
choices=['full', 'selective'],
help='Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer '
'layer is recomputed.')
group.add_argument('--no-check-for-nan-in-loss-and-grad', action='store_false',
help='Check for NaNs in loss and grad',
dest='check_for_nan_in_loss_and_grad')
group.add_argument('--distribute-saved-activations',
action='store_true',
help='If set, distribute recomputed activations '
'across model parallel group.')
group.add_argument('--recompute-method', type=str, default=None,
choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and recompute the input activation of '
'each divided chunk at specified granularity, '
'2) recompute the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any recomputing at specified granularity'
'default) do not apply activations recompute to any layers')
group.add_argument('--recompute-num-layers', type=int, default=None,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided recompute unit, '
'2) block: the number of individual Transformer layers '
'to recompute within each pipeline stage.')
group.add_argument('--no-clone-scatter-output-in-embedding', action='store_false',
help='If not set, clone the output of the scatter in embedding layer to GC original tensor.',
dest='clone_scatter_output_in_embedding')
group.add_argument('--profile', action='store_true',
help='Enable nsys profiling. When using this option, nsys '
'options should be specified in commandline. An example '
'nsys commandline is `nsys profile -s none -t nvtx,cuda '
'-o <path/to/output_file> --force-overwrite true '
'--capture-range=cudaProfilerApi '
'--capture-range-end=stop`.')
group.add_argument('--profile-step-start', type=int, default=10,
help='Global step to start profiling.')
group.add_argument('--profile-step-end', type=int, default=12,
help='Global step to stop profiling.')
group.add_argument('--profile-ranks', nargs='+', type=int, default=[0],
help='Global ranks to profile.')
group.add_argument('--tp-comm-overlap', action='store_true', help = 'Enables the '
' overlap of Tensor parallel communication and GEMM kernels.')
group.add_argument('--tp-comm-overlap-cfg', type=str, default=None,
help = 'Config file when tp_comm_overlap is enabled.')
group.add_argument('--disable-tp-comm-split-ag', action='store_false',
help = 'Disables the All-Gather overlap with fprop GEMM.',
dest='tp_comm_split_ag')
group.add_argument('--disable-tp-comm-split-rs', action='store_false',
help = 'Disables the Reduce-Scatter overlap with fprop GEMM.',
dest='tp_comm_split_rs')
group.add_argument('--disable-tp-comm-bulk-dgrad', action='store_false',
help = 'Disables the All-Gather overlap with bprop activation gradient GEMM.',
dest='tp_comm_bulk_dgrad')
group.add_argument('--disable-tp-comm-bulk-wgrad', action='store_false',
help = 'Disables the Reduce-Scatter overlap with bprop weight gradient GEMM.',
dest='tp_comm_bulk_wgrad')
# deprecated
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--train-samples', type=int, default=None,
help='Total number of samples to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--log-interval', type=int, default=100,
help='Report loss and timing interval.')
group.add_argument('--exit-interval', type=int, default=None,
help='Exit the program after the iteration is divisible '
'by this value.')
group.add_argument('--exit-duration-in-mins', type=int, default=None,
help='Exit the program after this many minutes.')
group.add_argument('--exit-signal-handler', action='store_true',
help='Dynamically save the checkpoint and shutdown the '
'training if SIGTERM is received')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--no-masked-softmax-fusion',
action='store_false',
help='Disable fusion of query_key_value scaling, '
'masking, and softmax.',
dest='masked_softmax_fusion')
group.add_argument('--no-bias-gelu-fusion', action='store_false',
help='Disable bias and gelu fusion.',
dest='bias_gelu_fusion')
group.add_argument('--no-bias-dropout-fusion', action='store_false',
help='Disable bias and dropout fusion.',
dest='bias_dropout_fusion')
group.add_argument('--use-flash-attn', action='store_true',
help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135')
group.add_argument('--disable-bias-linear', action='store_false',
help='Disable bias in the linear layers',
dest='add_bias_linear')
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'],
help='Optimizer function')
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader')
group.add_argument('--no-async-tensor-model-parallel-allreduce',
action='store_false',
help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.',
dest='async_tensor_model_parallel_allreduce')
group.add_argument('--no-persist-layer-norm', action='store_true',
help='Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.')
group.add_argument('--sequence-parallel', action='store_true',
help='Enable sequence parallel optimization.')
group.add_argument('--no-gradient-accumulation-fusion',
action='store_false',
help='Disable fusing gradient accumulation to weight '
'gradient computation of linear layers',
dest='gradient_accumulation_fusion')
group.add_argument('--use-mcore-models', action='store_true',
help='Use the implementation from megatron core')
group.add_argument('--manual-gc', action='store_true',
help='Disable the threshold-based default garbage '
'collector and trigger the garbage collection manually. '
'Manual garbage collection helps to align the timing of '
'the collection across ranks which mitigates the impact '
'of CPU-associated jitters. When the manual gc is enabled, '
'garbage collection is performed only at the start and the '
'end of the validation routine by default.')
group.add_argument('--manual-gc-interval', type=int, default=0,
help='Training step interval to trigger manual garbage '
'collection. When the value is set to 0, garbage '
'collection is not triggered between training steps.')
group.add_argument('--no-manual-gc-eval', action='store_false',
help='When using manual garbage collection, disable '
'garbage collection at the start and the end of each '
'evaluation run.', dest='manual_gc_eval')
return parser
def _add_initialization_args(parser):
group = parser.add_argument_group(title='initialization')
group.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy, '
'pytorch, and cuda.')
group.add_argument('--data-parallel-random-init', action='store_true',
help='Enable random initialization of params '
'across data parallel ranks')
group.add_argument('--init-method-std', type=float, default=0.02,
help='Standard deviation of the zero mean normal '
'distribution used for weight initialization.')
group.add_argument('--init-method-xavier-uniform', action='store_true',
help='Enable Xavier uniform parameter initialization')
return parser
def _add_learning_rate_args(parser):
group = parser.add_argument_group(title='learning rate')
group.add_argument('--lr', type=float, default=None,
help='Initial learning rate. Depending on decay style '
'and initial warmup, the learing rate at each '
'iteration would be different.')
group.add_argument('--lr-decay-style', type=str, default='linear',
choices=['constant', 'linear', 'cosine', 'inverse-square-root', 'wsd'],
help='Learning rate decay function.')
group.add_argument('--lr-decay-iters', type=int, default=None,
help='number of iterations to decay learning rate over,'
' If None defaults to `--train-iters`')
group.add_argument('--lr-decay-samples', type=int, default=None,
help='number of samples to decay learning rate over,'
' If None defaults to `--train-samples`')
group.add_argument('--lr-warmup-fraction', type=float, default=None,
help='fraction of lr-warmup-(iters/samples) to use '
'for warmup (as a float)')
group.add_argument('--lr-warmup-iters', type=int, default=0,
help='number of iterations to linearly warmup '
'learning rate over.')
group.add_argument('--lr-warmup-samples', type=int, default=0,
help='number of samples to linearly warmup '
'learning rate over.')
group.add_argument('--lr-warmup-init', type=float, default=0.0,
help='Initial value for learning rate warmup. The '
'scheduler starts warmup from this value.')
group.add_argument('--warmup', type=int, default=None,
help='Old lr warmup argument, do not use. Use one of the'
'--lr-warmup-* arguments above')
group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.')
group.add_argument('--override-opt_param-scheduler', action='store_true',
help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.')
group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true',
help='Use checkpoint to set the values of the scheduler '
'(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style '
'from checkpoint and ignore input arguments.')
group.add_argument('--wsd_decay_ratio', type=float, default=0.1,
help='used in wsd, usewsd_decay_ratio only works when lr_stable_steps > 0')
group.add_argument('--wsd_half_life', type=int, default=-1,
help='used in wsd, wsd_half_life == -1: use the default value(0.5 * (lr_decay_steps - lr_stable_steps) + 1)')
group.add_argument('--lr_stable_steps', type=int, default=-1,
help='used in wsd, lr_stable_steps == -1: use the default value(lr_decay_steps / (1 + wsd_decay_ratio)))'
'lr_stable_steps == -2: use the lr_decay_steps value, no decay stage')
return parser
def _add_checkpointing_args(parser):
group = parser.add_argument_group(title='checkpointing')
group.add_argument('--save', type=str, default=None,
help='Output directory to save checkpoints to.')
group.add_argument('--save-interval', type=int, default=None,
help='Number of iterations between checkpoint saves.')
group.add_argument('--no-save-optim', action='store_true', default=None,
help='Do not save current optimizer.')
group.add_argument('--no-save-rng', action='store_true', default=None,
help='Do not save current rng state.')
group.add_argument('--load', type=str, default=None,
help='Directory containing a model checkpoint.')
group.add_argument('--no-load-optim', action='store_true', default=None,
help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true', default=None,
help='Do not load rng state when loading checkpoint.')
group.add_argument('--finetune', action='store_true',
help='Load model for finetuning. Do not load optimizer '
'or rng state from checkpoint and set iteration to 0. '
'Assumed when loading a release checkpoint.')
group.add_argument('--no-initialization', action='store_false',
help='Do not perform initialization when building model, '
'can reduce startup time when definitely loading from a '
'checkpoint',
dest='perform_initialization')
group.add_argument('--use-checkpoint-args', action='store_true',
help='Override any command line arguments with arguments '
'from the checkpoint')
group.add_argument('--exit-on-missing-checkpoint', action='store_true',
help="If '--load' is set, but checkpoint is not found "
"(e.g., path typo), then exit instead of random "
"initialization.")
group.add_argument('--overwrite-iteration', type=str, default=None,
help='overwrite the iteration number to load, None means load from latest.')
return parser
def _add_mixed_precision_args(parser):
group = parser.add_argument_group(title='mixed precision')
group.add_argument('--loss-scale', type=float, default=None,
help='Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic'
'loss scaling is used.')
group.add_argument('--initial-loss-scale', type=float, default=2**32,
help='Initial loss-scale for dynamic loss scaling.')
group.add_argument('--min-loss-scale', type=float, default=1.0,
help='Minimum loss scale for dynamic loss scale.')
group.add_argument('--loss-scale-window', type=float, default=1000,
help='Window over which to raise/lower dynamic scale.')
group.add_argument('--hysteresis', type=int, default=2,
help='hysteresis for dynamic loss scaling')
group.add_argument('--fp32-residual-connection', action='store_true',
help='Move residual connections to fp32.')
group.add_argument('--apply-query-key-layer-scaling', action='store_true',
help='Scale Q * K^T by 1 / layer-number. '
'Useful for fp16 training.')
group.add_argument('--attention-softmax-in-fp32', action='store_true',
help='Run attention masking and softmax in fp32. '
'This flag is ignored unless '
'--no-query-key-layer-scaling is specified.')
group.add_argument('--accumulate-allreduce-grads-in-fp32',
action='store_true',
help='Gradient accumulation and all-reduce in fp32.')
group.add_argument('--fp16-lm-cross-entropy', action='store_true',
help='Move the cross entropy unreduced loss calculation'
'for lm head to fp16.')
return parser
def _add_distributed_args(parser):
group = parser.add_argument_group(title='distributed')
group.add_argument('--tensor-model-parallel-size', type=int, default=1,
help='Degree of tensor model parallelism.')
group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.')
group.add_argument('--pipeline-model-parallel-split-rank',
type=int, default=None,
help='Rank where encoder and decoder should be split.')
group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.')
group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
help='Number of layers per virtual pipeline stage')
group.add_argument('--no-overlap-p2p-communication', action='store_false',
help='overlap pipeline parallel communication with forward and backward chunks',
dest='overlap_p2p_comm')
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.')
group.add_argument('--distributed-timeout-minutes', type=int, default=10,
help='Timeout minutes for torch.distributed.')
group.add_argument('--overlap-grad-reduce', action='store_true',
default=False, help='If set, overlap DDP grad reduce.')
group.add_argument('--no-delay-grad-reduce', action='store_false',
help='If not set, delay / synchronize grad reductions in all but first PP stage.',
dest='delay_grad_reduce')
group.add_argument('--overlap-param-gather', action='store_true',
default=False, help='If set, overlap param all-gather in distributed optimizer.')
group.add_argument('--delay-param-gather', action='store_true',
default=False, help='If set, delay / synchronize param all-gathers in all but first PP stage.')
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
help='If not set, use scatter/gather to optimize communication of tensors in pipeline.',
dest='scatter_gather_tensors_in_pipeline')
group.add_argument('--use-ring-exchange-p2p', action='store_true',
default=False, help='If set, use custom-built ring exchange '
'for p2p communications. Note that this option will require '
'a custom built image that support ring-exchange p2p.')
group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher.')
group.add_argument('--lazy-mpu-init', type=bool, required=False,
help='If set to True, initialize_megatron() '
'skips DDP initialization and returns function to '
'complete it instead.Also turns on '
'--use-cpu-initialization flag. This is for '
'external DDP manager.' )
group.add_argument('--use-cpu-initialization', action='store_true',
default=None, help='If set, affine parallel weights '
'initialization uses CPU' )
group.add_argument('--empty-unused-memory-level', default=0, type=int,
choices=[0, 1, 2],
help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.')
group.add_argument('--standalone-embedding-stage', action='store_true',
default=False, help='If set, *input* embedding layer '
'is placed on its own pipeline stage, without any '
'transformer layers. (For T5, this flag currently only '
'affects the encoder embedding.)')
group.add_argument('--use-distributed-optimizer', action='store_true',
help='Use distributed optimizer.')
group.add_argument('--expert-model-parallel-size', type=int, default=1,
help='Degree of expert model parallelism.')
group.add_argument('--context-parallel-size', type=int, default=1,
help='Degree of context parallelism.')
group.add_argument('--nccl-communicator-config-path', type=str, default=None,
help='Path to the yaml file with NCCL communicator '
'configurations. The number of min/max thread groups and thread '
'group cluster size of each communicator can be configured by '
'setting `min_ctas`, `max_ctas`, and `cga_cluster_size`.')
return parser
def _add_validation_args(parser):
group = parser.add_argument_group(title='validation')
group.add_argument('--eval-iters', type=int, default=100,
help='Number of iterations to run for evaluation'
'validation/test for.')
group.add_argument('--eval-interval', type=int, default=1000,
help='Interval between running evaluation on '
'validation set.')
group.add_argument('--skip-train', action='store_true',
default=False, help='If set, bypass the training loop, '
'optionally do evaluation for validation/test, and exit.')
return parser
def _add_data_args(parser):
group = parser.add_argument_group(title='data and dataloader')
group.add_argument('--data-path', nargs='*', default=None,
help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ... It is used with --split when a '
'single dataset used for all three: train, valid '
'and test. It is exclusive to the other '
'--*-data-path args')
group.add_argument('--split', type=str, default='960, 30, 10',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
'`90,5,5` will use 90%% of data for training, 5%% for '
'validation and 5%% for test.')
group.add_argument('--train-data-path', nargs='*', default=None,
help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--valid-data-path', nargs='*', default=None,
help='Path to the validation dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--test-data-path', nargs='*', default=None,
help='Path to the test dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--data-cache-path', default=None,
help='Path to a directory to hold cached index files.')
group.add_argument('--vocab-size', type=int, default=None,
help='Size of vocab before EOD or padding.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file.')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file.')
group.add_argument('--vocab-extra-ids', type=int, default=0,
help='Number of additional vocabulary tokens. '
'They are used for span masking in the T5 model')
group.add_argument('--seq-length', type=int, default=None,
help='Maximum sequence length to process.')
group.add_argument('--encoder-seq-length', type=int, default=None,
help='Maximum encoder sequence length to process.'
'This should be exclusive of --seq-length')
group.add_argument('--decoder-seq-length', type=int, default=None,
help="Maximum decoder sequence length to process.")
group.add_argument('--retriever-seq-length', type=int, default=256,
help='Maximum sequence length for the biencoder model '
'for retriever')
group.add_argument('--sample-rate', type=float, default=1.0,
help='sample rate for training data. Supposed to be 0 '
' < sample_rate < 1')
group.add_argument('--mask-prob', type=float, default=0.15,
help='Probability of replacing a token with mask.')
group.add_argument('--short-seq-prob', type=float, default=0.1,
help='Probability of producing a short sequence.')
group.add_argument('--num-workers', type=int, default=2,
help="Dataloader number of workers.")
group.add_argument('--tokenizer-type', type=str,
default=None,
choices=['BertWordPieceLowerCase',
'BertWordPieceCase',
'GPT2BPETokenizer',
'SentencePieceTokenizer',
'GPTSentencePieceTokenizer',
'Llama2Tokenizer',
'NullTokenizer',
'MMSentencePieceTokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='Sentencepiece tokenizer model.')
group.add_argument('--reset-position-ids', action='store_true',
help='Reset posistion ids after end-of-document token.')
group.add_argument('--reset-attention-mask', action='store_true',
help='Reset self attention maske after '
'end-of-document token.')
group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens.')
group.add_argument('--enable-shuffle', action='store_true',
help='Enable shuffle of the data')
return parser
def _add_autoresume_args(parser):
group = parser.add_argument_group(title='autoresume')
group.add_argument('--adlr-autoresume', action='store_true',
help='Enable autoresume on adlr cluster.')
group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
help='Intervals over which check for autoresume'
'termination signal')
return parser
def _add_biencoder_args(parser):
group = parser.add_argument_group(title='biencoder')
# network size
group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and '
'REALM (paper default: 128)')
group.add_argument('--biencoder-projection-dim', type=int, default=0,
help='Size of projection head used in biencoder (paper'
' default: 128)')
group.add_argument('--biencoder-shared-query-context-model', action='store_true',
help='Whether to share the parameters of the query '
'and context models or not')
# checkpointing
group.add_argument('--ict-load', type=str, default=None,
help='Directory containing an ICTBertModel checkpoint')
group.add_argument('--bert-load', type=str, default=None,
help='Directory containing an BertModel checkpoint '
'(needed to start ICT and REALM)')
# data
group.add_argument('--titles-data-path', type=str, default=None,
help='Path to titles dataset used for ICT')
group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for '
'ICT dataset')
group.add_argument('--use-one-sent-docs', action='store_true',
help='Whether to use one sentence documents in ICT')
group.add_argument('--evidence-data-path', type=str, default=None,
help='Path to Wikipedia Evidence frm DPR paper')
# training
group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
default=[], help="Which top-k accuracies to report "
"(e.g. '1 5 20')")
group.add_argument('--retriever-score-scaling', action='store_true',
help='Whether to scale retriever scores by inverse '
'square root of hidden size')
# faiss index
group.add_argument('--block-data-path', type=str, default=None,
help='Where to save/load BlockData to/from')
group.add_argument('--embedding-path', type=str, default=None,
help='Where to save/load Open-Retrieval Embedding'
' data to/from')
# indexer
group.add_argument('--indexer-batch-size', type=int, default=128,
help='How large of batches to use when doing indexing '
'jobs')
group.add_argument('--indexer-log-interval', type=int, default=1000,
help='After how many batches should the indexer '
'report progress')
return parser
def _add_vision_args(parser):
group = parser.add_argument_group(title="vision")
# general vision arguements
group.add_argument('--num-classes', type=int, default=1000,
help='num of classes in vision classificaiton task')
group.add_argument('--img-h', type=int, default=224,
help='Image height for vision classification task')
group.add_argument('--img-w', type=int, default=224,
help='Image height for vision classification task')
group.add_argument('--num-channels', type=int, default=3,
help='Number of channels in input image data')
group.add_argument('--patch-dim', type=int, default=16,
help='patch dimension')
group.add_argument('--classes-fraction', type=float, default=1.0,
help='training with fraction of classes.')
group.add_argument('--data-per-class-fraction', type=float, default=1.0,
help='training with fraction of data per class.')
group.add_argument('--no-data-sharding', action='store_false',
help='Disable data sharding.',
dest='data_sharding')
group.add_argument('--head-lr-mult', type=float, default=1.0,
help='learning rate multiplier for head during finetuning')
# pretraining type and backbone selection`
group.add_argument('--vision-pretraining', action='store_true',
help='flag to indicate vision pretraining')
group.add_argument('--vision-pretraining-type', type=str, default='classify',
choices=['classify', 'inpaint', 'dino'],
help='pretraining objectives')
group.add_argument('--vision-backbone-type', type=str, default='vit',
choices=['vit', 'mit', 'swin'],
help='backbone types types')
group.add_argument('--swin-backbone-type', type=str, default='tiny',
choices=['tiny', 'base', 'h3'],
help='pretraining objectives')
# inpainting arguments
group.add_argument('--mask-type', type=str, default='random',
choices=['random', 'row'],
help='mask types')
group.add_argument('--mask-factor', type=float, default=1.0,
help='mask size scaling parameter')
# dino arguments
group.add_argument('--iter-per-epoch', type=int, default=1250,
help='iterations per epoch')
group.add_argument('--dino-local-img-size', type=int, default=96,
help='Image size for vision classification task')
group.add_argument('--dino-local-crops-number', type=int, default=10,
help='Number of local crops')
group.add_argument('--dino-head-hidden-size', type=int, default=2048,
help='Hidden dimension size in dino head')
group.add_argument('--dino-bottleneck-size', type=int, default=256,
help='Bottle neck dimension in dino head ')
group.add_argument('--dino-freeze-last-layer', type=float, default=1,
help='Freezing last layer weights')
group.add_argument('--dino-norm-last-layer', action='store_true',
help='Disable Norm in last layer.')
group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04,
help='warump teacher temperature')
group.add_argument('--dino-teacher-temp', type=float, default=0.07,
help='teacher temperature')
group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30,
help='warmup teacher temperaure epochs')
return parser
def _add_experimental_args(parser):
group = parser.add_argument_group(title='experimental')
group.add_argument('--spec', type=str, default=None, nargs=2,
help='Specify the <module_location function_name> pair '
'that returns a spec to customize a model, transformer '
'block, or transformer layer, depending on the use case. '
'For more details, see the model class, '
'`transformer_block.py`, or `transformer_layer.py`')
return parser