import argparse import json import logging import os from typing import Any import yaml from sglang_router.launch_router import RouterArgs from transformers import AutoConfig from slime.backends.sglang_utils.arguments import add_sglang_arguments from slime.backends.sglang_utils.arguments import validate_args as sglang_validate_args from slime.utils.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list from slime.utils.logging_utils import configure_logger logger = logging.getLogger(__name__) def reset_arg(parser, name, **kwargs): """ Reset the default value of a Megatron argument. :param parser: The argument parser. :param name: The name of the argument to reset. :param default: The new default value. """ for action in parser._actions: if name in action.option_strings: if "default" in kwargs: action.default = kwargs["default"] break else: parser.add_argument(name, **kwargs) def get_slime_extra_args_provider(add_custom_arguments=None): def add_slime_arguments(parser): # Ray def add_cluster_arguments(parser): parser.add_argument("--actor-num-nodes", type=int, default=1, help="Number of nodes for training actor") parser.add_argument( "--actor-num-gpus-per-node", type=int, default=8, help="Number of gpus per node for training actor" ) parser.add_argument( "--critic-num-nodes", type=int, default=None, help="Number of nodes for training actor" ) parser.add_argument( "--critic-num-gpus-per-node", type=int, default=None, help="Number of gpus per node for training actor" ) parser.add_argument( "--rollout-num-gpus", type=int, default=None, help=( "Number of GPUs for inference. Note that when using --colocate, " "i.e. the training and the inference engines are on the same gpus, this param will be ignored and will be set as " "actor_num_gpus_per_node * actor_num_nodes." ), ) parser.add_argument( "--rollout-num-gpus-per-engine", type=int, default=1, help="Number of GPUs per inference engine, just like the tp_size in sglang.", ) parser.add_argument( "--num-gpus-per-node", type=int, default=8, help=( "Number of gpus per node for rollout." "Notice: If you are going to use less than 8 gpus per node under colocate mode, you should set this number." ), ) parser.add_argument( "--colocate", action="store_true", default=False, help=( "Whether to colocate the inference engines and the actor. " "Turning this on will also set --offload to true." ), ) parser.add_argument( "--offload", action="store_true", default=False, help=("Equivalent to --offload-train + --offload-rollout. "), ) parser.add_argument( "--offload-train", action=argparse.BooleanOptionalAction, help=( "Whether to offload the training actor to CPU during training. " "This will always be true when --colocate is set." ), ) parser.add_argument( "--offload-rollout", action=argparse.BooleanOptionalAction, help=( "Whether to offload the rollout generator to CPU during training. " "This will always be true when --colocate is set." ), ) reset_arg(parser, "--distributed-backend", type=str, default="nccl") reset_arg(parser, "--distributed-timeout-minutes", type=int, default=10) return parser def add_train_arguments(parser): parser.add_argument( "--train-backend", type=str, choices=["megatron", "fsdp"], default="megatron", help="The backend for training.", ) parser.add_argument( "--qkv-format", type=str, choices=["thd", "bshd"], default="thd", help="The qkv layout for Megatron backend.", ) parser.add_argument( "--true-on-policy-mode", action="store_true", default=False, help="Whether to enable true-on-policy mode.", ) parser.add_argument( "--train-env-vars", type=json.loads, default="{}", help="Extra environment variables for training process, e.g. PyTorch memory management ones.", ) parser.add_argument( "--train-memory-margin-bytes", type=int, default=1024**3, help="Add margin for train memory allocation. By default we will reserve 1GB as margin.", ) parser.add_argument( "--disable-weights-backuper", action="store_false", dest="enable_weights_backuper", help="Whether to disable weights backuper to save host memory.", ) parser.add_argument( "--megatron-to-hf-mode", choices=["raw", "bridge"], default="raw", help="The method to convert megatron weights to hugging face weights for SGLang.", ) parser.add_argument( "--custom-model-provider-path", type=str, default=None, help=( "Path to a custom model provider function. " "If set, we will use this function instead of the default model provider. " "The function should have the signature " "`def custom_model_provider(pre_process: bool, post_process: bool, vp_stage: int | None = None) -> GPTModel`. " "Example: 'my_module.my_model_provider'." ), ) parser.add_argument( "--recompute-loss-function", action="store_true", help="Whether to disable recompute loss function to save memory during training.", ) parser.add_argument( "--log-probs-chunk-size", type=int, default=-1, help="Chunk size to compute log probs to save memory" ) parser.add_argument( "--only-train-params-name-list", type=str, nargs="*", default=None, help="""List of regex patterns of parameter names to TRAIN. All other parameters will be FROZEN. Supports Python regex syntax (re.search). Examples: 1. Train ONLY MoE experts: --only-train-params-name-list experts 2. Train ONLY Indexer parameters: --only-train-params-name-list self_attention.wq_b self_attention.wk self_attention.k_norm self_attention.weights_proj 3. Train ONLY Layer 20 to 23: --only-train-params-name-list layers\.2[0-3]\. """, ) parser.add_argument( "--freeze-params-name-list", type=str, nargs="*", default=None, help="""List of regex patterns of parameter names to FREEZE. Other parameters will remain trainable. Supports Python regex syntax (re.search). Examples: 1. Freeze Embeddings and Output Layer (common for fine-tuning): --freeze-params-name-list embedding output_layer 2. Freeze Indexer parameters: --freeze-params-name-list self_attention.wq_b self_attention.wk self_attention.k_norm self_attention.weights_proj 3. Freeze specific projection layers (e.g., all Gate/Up projections): --freeze-params-name-list linear_fc1 """, ) return parser # rollout def add_rollout_arguments(parser): parser.add_argument( "--hf-checkpoint", type=str, default=None, help=( "The huggingface checkpoint of the trained model. " "This is used to initialize sglang and also provide the tokenizer. " "Note that, we will always update the parameters in sglang with that of megatron before training, " "so you only need to provide a huggingface checkpoint that has the same architecture as the model you want to train. " "It doesn't necessary need to contain the most up-to-date parameters." ), ) parser.add_argument( "--model-name", type=str, default=None, help=( "The name of the model, this is used to convert the megatron weights into huggingface format. " "If not set, we will use `type(AutoConfig.from_pretrained(args.hf_checkpoint)).__name__.lower()` as model_name. " "Also, sometimes this will help alleviate the bug that transformers cannot find certain model." ), ) parser.add_argument( "--rollout-function-path", type=str, default="slime.rollout.sglang_rollout.generate_rollout", help=( "Path to the rollout generation function." "You should use this model to create your own custom rollout function, " "and then set this to the path of your custom rollout function. " "The signature of the function should be " "`def generate_rollout(args, rollout_id, *, evaluation=False) -> list[list[Sample]]`" "and within the output sample, you should at least set `tokens`, `response_length`, `reward` " "and `truncated`." ), ) parser.add_argument( "--rollout-temperature", type=float, default=1.0, help="the temperature for the inference engine during rollout.", ) parser.add_argument( "--rollout-top-p", type=float, default=1.0, help="the top-p for the inference engine during rollout." ) parser.add_argument( "--rollout-top-k", type=int, default=-1, help="the top-k for the inference engine during rollout." ) parser.add_argument( "--rollout-max-context-len", type=int, default=None, help=( "The maximum context size for the inference engine during rollout." "It should no exceed the `max_position_embeddinds` in Huggingface model's `config.json`" ), ) parser.add_argument( "--rollout-max-prompt-len", type=int, default=None, help=( "The maximum length of the prompt for the inference engine during rollout. " "If set, we will filter out the long prompts during initialization of the global dataset. " "This is not recommended if the dataset is large." ), ) parser.add_argument( "--rollout-max-response-len", type=int, default=None, help=( "The maximum length of the response for the inference engine during rollout. " "It is basically `max_tokens` in sglang." ), ) parser.add_argument( "--rollout-skip-special-tokens", action="store_true", default=False, help=( "Whether to skip special tokens in the response during rollout. " "This is useful when you want to use the response as a prompt for the next rollout." ), ) parser.add_argument( "--rollout-stop", type=str, nargs="+", default=None, help=( "The stop words for the inference engine during rollout. " "It can be a list of strings or a single string. " "It may be hard to pass special tokens in command line, in that case rollout_stop_token_ids can be used." ), ) parser.add_argument( "--rollout-stop-token-ids", type=int, nargs="+", default=None, help=( "The stop token ids for the inference engine during rollout. " "It can be a list of integers or a single integer." ), ) parser.add_argument( "--rollout-shuffle", action="store_true", default=False, help=("Whether to shuffle the prompts during rollout."), ) parser.add_argument( "--rollout-seed", type=int, default=42, help=( "The seed for the random number generator during rollout. " "This is used to shuffle the prompts and also for the random sampling of the prompts." ), ) # sampling parser.add_argument( "--over-sampling-batch-size", type=int, default=None, help=( "This defines the granularity of the sampling batch in the rollout function. " "When the number of available samples falls below the target, a sampling " "operation of size over_sampling_batch_size will be triggered." "Regardless of whether partial rollout is used or filters are applied, " "the sampling granularity is always determined by this value. " "If this value is None, rollout_batch_size will be used as the default over_sampling_batch_size." ), ) parser.add_argument( "--dynamic-sampling-filter-path", type=str, default=None, help=( "This is the filter function for dynamic sampling. " "It should be able to judge whether the result of a prompt should be selected or not." "We will do dynamic filter for sampling as in DAPO. e.g. not all correct or all wrong samples." "You could use `slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std` as an example." ), ) # partial rollout parser.add_argument( "--partial-rollout", action="store_true", default=False, help=( "Whether to use partial rollout. " "If set, the unfinished samples during dynamic sampling will be recycled back to data buffer. " "This is useful for long responses." ), ) parser.add_argument( "--mask-offpolicy-in-partial-rollout", action="store_true", default=False, help=( "Whether to mask previous generation in partial rollout. " "If set, only on-policy generated tokens will be used in training" ), ) parser.add_argument( "--custom-generate-function-path", type=str, default=None, help=( "Only substitue the `def generate(args, sample, sampling_params)` function within the example rollout function. " "This should be useful if you need to implement some special rollout logic, e.g. multi-turn, function calling." ), ) parser.add_argument( "--custom-rollout-log-function-path", type=str, default=None, help=( "The custom function for logging rollout data. The signature of the functions is: " "def log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_time) -> bool. " "The return value indicates whether to skip the default logging. " ), ) parser.add_argument( "--custom-eval-rollout-log-function-path", type=str, default=None, help=( "The custom function for logging eval rollout data. " "def log_eval_rollout_data(rollout_id, args, data, extra_metrics) -> bool. " "The return value indicates whether to skip the default logging. " ), ) parser.add_argument( "--buffer-filter-path", type=str, default=None, help=( "Path to the buffer filter function. " "It should be able to select the samples in the buffer. " "The function should take list[list[Sample]] and return list[list[Sample]]." ), ) # update weight parser.add_argument( "--update-weight-buffer-size", type=int, default=512 * 1024**2, help=( "buffer size for update weight, in bytes. " "This is used for updating weights by chunk and should be useful for MoE models." ), ) parser.add_argument( "--update-weights-interval", type=int, default=1, help="Interval for updating the weights", ) parser.add_argument( "--keep-old-actor", action="store_true", help="Whether to keep the rollout model on training process", ) parser.add_argument( "--rollout-data-postprocess-path", type=str, default=None, help=( "The called after we have all the rollout data including log_probs. " "It may be helpful for updating loss mask." ), ) parser.add_argument( "--rollout-external", action="store_true", default=False, help="Use external SGLang instances instead of launching them inside the framework.", ) parser.add_argument( "--rollout-external-engine-addrs", type=str, default=None, nargs="+", help="Address and ports of the external engines.", ) return parser def add_fault_tolerance_arguments(parser): parser.add_argument( "--use-fault-tolerance", action="store_true", default=False, help="Whether to enable the fault tolerance function during rollout.", ) parser.add_argument( "--rollout-health-check-interval", type=float, default=30.0, help="Interval in seconds between rollout engine /health_generate checks during generate/eval.", ) parser.add_argument( "--rollout-health-check-timeout", type=float, default=30.0, help="Timeout in seconds to wait for a rollout engine /health_generate response before killing it.", ) parser.add_argument( "--rollout-health-check-first-wait", type=float, default=0, help="Initial grace period (in seconds) before starting health checks. This allows time for model compilation and initialization. Increase this value significantly when using deepgemm.", ) return parser # data def add_data_arguments(parser): # dataset # TODO: maybe add an num_epoch and calculate the num_rollout from buffer parser.add_argument( "--num-rollout", type=int, default=None, help="Number of rollout steps. If not set, we will calculate the number of rollout steps from the dataset size.", ) parser.add_argument( "--num-epoch", type=int, default=None, help=( "Number of epochs for the training. " "This is used to calculate the number of rollout steps from the dataset size. " "If set, we will calculate the number of rollout steps as `num_rollout = num_epoch * dataset_size // rollout_batch_size`." "If both `--num-epoch` and `--num-rollout` are set, `--num-epoch` will be ignored." ), ) parser.add_argument( "--disable-rollout-global-dataset", action="store_false", dest="rollout_global_dataset", help=( "Whether to use a global dataset for rollout. " "If set, the rollout will use the `--prompt-data` as the prompt dataset, " "and the prompts for rollout will be sampled from the dataset. " "If not set, you need to manage the data by your self." ), ) parser.add_argument( "--data-source-path", type=str, default="slime.rollout.data_source.RolloutDataSourceWithBuffer", help="The data source class for rollout data.", ) parser.add_argument( "--prompt-data", type=str, default=None, help=( "The path to the prompt data. " "Currently we only support jsonl format, and each line should contains --input-key and --label-key, " "which will be used as the prompt and the label respectively. " "If you want to use a custom template, you can set --apply-chat-template to true, in that case, " "the input should be the same structure as an openai message, e.g. [{'role': 'user', 'content': 'blabla'}]. " ), ) parser.add_argument("--apply-chat-template", action="store_true", default=False) # Temporarily be JSON-serialized str, will be a real dict after using Omegaconf parser.add_argument("--apply-chat-template-kwargs", type=json.loads, default="{}") parser.add_argument("--input-key", type=str, default="input", help="JSON dataset key") parser.add_argument("--label-key", type=str, default=None, help="JSON dataset key") parser.add_argument( "--multimodal-keys", type=json.loads, default=None, help=( 'JSON string for multimodal data mapping media types to data keys. Example: \'{"image": "image_file"}\'' ), ) parser.add_argument("--metadata-key", type=str, default="metadata", help="JSON dataset key") parser.add_argument( "--tool-key", type=str, default="tools", help=( "When need to add tools during apply_chat_template, you should provide the key for the tools in the prompt dataset." ), ) parser.add_argument( "--start-rollout-id", type=int, default=None, help=( "The starting rollout step, if not set, will try to load the step from --load when doing continue training, " "otherwise will be set to 0, meaning training from start." ), ) # batch sizes parser.add_argument( "--rollout-batch-size", type=int, required=True, help=( "The number of prompts in each rollout step. " "The total data returned should be rollout_batch_size * n_samples_per_prompt. " ), ) parser.add_argument( "--n-samples-per-prompt", type=int, default=1, help="Number of responses for each prompt in generation" ) # gbs of the training, note that the gbs is of sample, not of prompts, # so if you hope to train 1 step for each rollout, the global_bach_size should be set as # `rollout_batch_size * n_samples_per_prompt`. reset_arg(parser, "--global-batch-size", type=int, default=None) parser.add_argument( "--num-steps-per-rollout", type=int, default=None, help=( "Number of steps per rollout, e.g. It is equivalent to setting gbs as " "`rollout_batch_size * n_samples_per_prompt // num_steps_per_rollout`." ), ) # mbs for the training, will be ignored if `use_dynamic_batch_size` is set. reset_arg(parser, "--micro-batch-size", type=int, default=1) parser.add_argument( "--balance-data", action="store_true", default=False, help=( "Balance the number of tokens between data parallel ranks with `karmarkar_karp` for verl. " "Note that this may allocate the different response of the same prompt into different training steps." ), ) parser.add_argument( "--use-dynamic-batch-size", action="store_true", default=False, help=( "Because the sample length varies, to maximize the GPU utilization, " "we will use the dynamic batch size to adjust the micro batch size according to the maximum number of tokens each gpu can run. " "For example, if we have 3 samples, with the length of 100, 200, and 300, and the max_tokens_per_gpu is 300, when enabling " "dynamic batch size, slime will make 2 micro batches, i.e. [100, 200], [300]." ), ) parser.add_argument( "--max-tokens-per-gpu", type=int, default=None, help=( "The maximum number of tokens per GPU for dynamic batch size. " "Note that when enabling context parallel (CP), the max tokens per gpu should be around " "`max_response_len // cp_size` instead of `max_response_len`." ), ) parser.add_argument( "--log-probs-max-tokens-per-gpu", type=int, default=None, help=( "The maximum number of tokens per GPU for calculating log probs. " "This is used to calculate the log probs of the responses during rollout, " "and should be set to a larger value than `max_tokens_per_gpu` if you want better performance. " ), ) return parser def add_eval_arguments(parser): parser.add_argument( "--eval-function-path", type=str, default=None, help=( "Path to the eval generation function." "If not set, we will use rollout_function_path as the default. " ), ) # change the default value of eval_interval from Megatron to None reset_arg(parser, "--eval-interval", type=int, default=None) parser.add_argument( "--eval-prompt-data", type=str, default=None, nargs="+", help=( "Path to the evaluation prompt data, " "should first input the name of the eval dataset and then the path, e.g. " "aime /path/to/aime.jsonl" ), ) parser.add_argument( "--eval-config", type=str, default=None, help=( "Path to an OmegaConf YAML/JSON file describing evaluation datasets. " "When provided, this overrides --eval-prompt-data." ), ) parser.add_argument( "--skip-eval-before-train", action="store_true", default=False, help="Whether to skip evaluation before training.", ) # The following keys are used to override the rollout version during eval. parser.add_argument("--eval-input-key", type=str, default=None, help="JSON dataset key") parser.add_argument("--eval-label-key", type=str, default=None, help="JSON dataset key") parser.add_argument("--eval-tool-key", type=str, default=None, help="JSON dataset key") parser.add_argument( "--n-samples-per-eval-prompt", type=int, default=1, help="number of responses for each prompt in generation", ) parser.add_argument("--eval-temperature", type=float, default=None) parser.add_argument("--eval-top-p", type=float, default=None) parser.add_argument("--eval-top-k", type=int, default=None) parser.add_argument("--eval-max-response-len", type=int, default=None) parser.add_argument("--eval-max-prompt-len", type=int, default=None) parser.add_argument("--eval-min-new-tokens", type=int, default=None) parser.add_argument("--eval-max-context-len", type=int, default=None) return parser def add_algo_arguments(parser): parser.add_argument( "--ref-load", type=str, default=None, help=( "The checkpoint for reference model. " "When --load is not set, this will be used as the initial checkpoint for training. " ), ) parser.add_argument( "--ref-ckpt-step", type=int, default=None, help="The checkpoint step for reference model. " ) reset_arg(parser, "--load", type=str, default=None) reset_arg(parser, "--save", type=str, default=None) reset_arg(parser, "--save-interval", type=int, default=None) reset_arg(parser, "--async-save", action="store_true") reset_arg( parser, "--no-save-optim", action="store_true", default=False, help=( "If set, do not save the optimizer state when saving checkpoints. " "This reduces checkpoint size but disables training resumption from the saved checkpoint." ), ) parser.add_argument( "--save-hf", type=str, default=None, help=( "Path to save the model in HuggingFace format when using Megatron backend. " "The model will be saved to `save_hf.format(rollout_id)`. " ), ) reset_arg(parser, "--seed", type=int, default=1234) reset_arg(parser, "--clip-grad", type=float, default=1.0) reset_arg(parser, "--calculate-per-token-loss", action="store_true") reset_arg(parser, "--lr", type=float, default=1e-6) parser.add_argument("--num-critic-only-steps", type=int, default=0, help="Number of critic only steps") parser.add_argument("--critic-load", type=str, default=None, help="The checkpoint for critic model.") parser.add_argument("--critic-save", type=str, default=None, help="The checkpoint for critic model.") parser.add_argument("--critic-lr", type=float, default=None, help="The lr for critic model") parser.add_argument( "--critic-lr-warmup-iters", type=int, default=0, help="number of iterations to linearly warmup for critic model.", ) parser.add_argument("--eps-clip", type=float, default=0.2, help="PPO clip range") parser.add_argument("--eps-clip-high", type=float, default=None, help="PPO clip upper range") parser.add_argument( "--eps-clip-c", type=float, default=None, help="lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729", ) parser.add_argument("--value-clip", type=float, default=0.2, help="the clip for value loss") parser.add_argument( "--kl-coef", type=float, default=0.00, help="KL penalty coefficient for reward shaping. This is applied to the reward signal before advantage calculation.", ) parser.add_argument( "--loss-type", type=str, choices=["policy_loss", "sft_loss", "custom_loss"], default="policy_loss", help=( "Choose loss type, currently support ppo policy_loss or sft_loss, " "if custom_loss is set, we will use the function path from `--custom-loss-function-path`." ), ) parser.add_argument( "--custom-loss-function-path", type=str, default=None, help=( "Path to the custom loss function, if the loss_type is `custom_loss`, " "we will use this function to calculate the loss. " ), ) parser.add_argument( "--kl-loss-type", type=str, choices=["k1", "k2", "k3", "low_var_kl"], default="k1", help="Choose KL loss type: kl, k2, k3, low_var_kl", ) parser.add_argument( "--advantage-estimator", type=str, choices=[ "grpo", "gspo", "reinforce_plus_plus", "reinforce_plus_plus_baseline", "ppo", ], default="grpo", help=( "Advantage estimator to use. Note: on-policy distillation (OPD) is now orthogonal " "to the advantage estimator. Use --opd-kl-coef > 0 to enable OPD on top of any estimator." ), ) parser.add_argument( "--disable-compute-advantages-and-returns", action="store_false", dest="compute_advantages_and_returns", help=( "Whether to disable computing advantages and returns. " "If set, we will not compute the advantages and returns, " "This is useful for sft or custom loss function." ), ) parser.add_argument( "--use-kl-loss", action="store_true", default=False, help="whether to use KL loss from GRPO" ) parser.add_argument( "--kl-loss-coef", type=float, default=0.0, help="KL penalty coefficient for the loss function. This is added to the final PPO loss.", ) parser.add_argument( "--use-unbiased-kl", action="store_true", default=False, help="Whether to enable unbiased KL estimation.", ) parser.add_argument( "--ref-update-interval", type=int, default=None, help="Interval (in rollout steps) to update ref model from actor. If None, ref model is not updated.", ) parser.add_argument("--entropy-coef", type=float, default=0.0, help="Entropy loss coef") parser.add_argument("--gamma", type=float, default=1.0, help="PPO GAE gamma") parser.add_argument("--lambd", type=float, default=1.0, help="PPO GAE lambd") parser.add_argument("--normalize-advantages", action="store_true", default=False) parser.add_argument( "--disable-grpo-std-normalization", action="store_false", dest="grpo_std_normalization", help="from Dr.GRPO https://arxiv.org/pdf/2503.20783", ) parser.add_argument( "--disable-rewards-normalization", action="store_false", dest="rewards_normalization", help="Disable rewards normalization", ) parser.add_argument( "--use-rollout-entropy", action="store_true", default=False, help=( "Whether to calculate the entropy when calculating the logprobs from actor and reference model. " "This is useful for doing special loss mask." ), ) parser.add_argument( "--get-mismatch-metrics", action="store_true", default=False, help="Whether to calculate the mismatch metrics.", ) parser.add_argument( "--reset-optimizer-states", action="store_true", default=False, help=( "Whether to reset optimizer states after each rollout. " "If enabled, the optimizer's history will be cleared at the end of each rollout, which can sometimes help with training stability or fulfill specific experiment requirements." ), ) parser.add_argument( "--use-rollout-logprobs", action="store_true", default=False, help=( "Whether to use the rollout logprobs when calculating the importance sampling ratios. " "If not set, we will use the logprobs from the actor model." ), ) # Off-Policy Correction using Importance Sampling: https://fengyao.notion.site/off-policy-rl parser.add_argument( "--use-tis", action="store_true", default=False, help="Enable TIS from https://fengyao.notion.site/off-policy-rl for off-policy importance sampling.", ) parser.add_argument( "--tis-clip", type=float, default=2.0, help="Clipping threshold C for importance sampling ratios to control variance.", ) parser.add_argument( "--tis-clip-low", type=float, default=0, help="Lower bound clipping threshold C for importance sampling ratios to control variance.", ) parser.add_argument( "--custom-tis-function-path", type=str, default=None, help="Path to the custom TIS/RS function (e.g., examples/train_infer_mismatch_helper/mis.py:compute_mis_weights_with_cp).", ) parser.add_argument( "--custom-pg-loss-reducer-function-path", type=str, default=None, help="Path to a custom reducer function for pg_loss only. When set, pg_loss will use this custom reducer while other metrics (pg_clipfrac, ppo_kl, entropy_loss, etc.) still use the default sum_of_sample_mean. (e.g., examples/Dr.GRPO/custom_reducer.py:get_pg_loss_reducer).", ) parser.add_argument( "--use-routing-replay", action="store_true", default=False, help="The routing replay technique from https://arxiv.org/abs/2507.18071", ) parser.add_argument( "--use-rollout-routing-replay", action="store_true", default=False, help="The rollout routing replay technique from https://arxiv.org/abs/2510.11370", ) parser.add_argument( "--use-opsm", action="store_true", default=False, help="Whether to enable Off-Policy Sequence Masking (OPSM).", ) parser.add_argument( "--opsm-delta", type=float, default=1e-4, help="The threshold for Off-Policy Sequence Masking (OPSM).", ) return parser def add_on_policy_distillation_arguments(parser): """Add on-policy distillation (OPD) related arguments. OPD is orthogonal to advantage estimators and can be applied on top of any estimator (GRPO, PPO, etc.) by adding a KL penalty to advantages. """ parser.add_argument( "--use-opd", action="store_true", default=False, help="Enable on-policy distillation (OPD). Must specify --opd-type when enabled.", ) parser.add_argument( "--opd-type", type=str, choices=["sglang", "megatron"], default=None, help=( "Type of on-policy distillation. " "'sglang': Teacher log-probs are obtained from external SGLang server during rollout. " "'megatron': Teacher model is loaded via --opd-teacher-load and forwarded during training." ), ) parser.add_argument( "--opd-kl-coef", type=float, default=1.0, help="On-policy distillation KL penalty coefficient. Default is 1.0.", ) parser.add_argument( "--opd-teacher-load", type=str, default=None, help=( "The checkpoint for OPD teacher model. Required when --opd-type=megatron. " "The teacher model should have the same architecture as policy/ref model." ), ) parser.add_argument( "--opd-teacher-ckpt-step", type=int, default=None, help="The checkpoint step for OPD teacher model." ) return parser def add_router_arguments(parser): parser.add_argument( "--use-slime-router", action="store_true", default=False, help="Whether to use SlimeRouter for text-based routing instead of SGLang token-based routing", ) parser.add_argument( "--slime-router-middleware-paths", type=str, nargs="+", default="", ) parser.add_argument( "--slime-router-timeout", type=float, default=None, help="Timeout for SlimeRouter HTTP requests in seconds.", ) parser.add_argument( "--slime-router-max-connections", type=int, default=None, help="Max connections for SlimeRouter HTTP client.", ) parser.add_argument( "--slime-router-health-check-failure-threshold", type=int, default=3, help="Number of consecutive failures before marking a worker as unhealthy.", ) RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True) return parser # wandb def add_wandb_arguments(parser): # wandb parameters parser.add_argument("--use-wandb", action="store_true", default=False) parser.add_argument( "--wandb-mode", type=str, default=None, choices=["online", "offline", "disabled"], help="W&B mode: online (default), offline (local only), or disabled. Overrides WANDB_MODE env var.", ) parser.add_argument( "--wandb-dir", type=str, default=None, help="Directory to store wandb logs. Default is ./wandb in current directory.", ) parser.add_argument("--wandb-key", type=str, default=None) parser.add_argument("--wandb-host", type=str, default=None) parser.add_argument("--wandb-team", type=str, default=None) parser.add_argument("--wandb-group", type=str, default=None) reset_arg(parser, "--wandb-project", type=str, default=None) parser.add_argument( "--disable-wandb-random-suffix", action="store_false", dest="wandb_random_suffix", default=True, help=( "Whether to add a random suffix to the wandb run name. " "By default, we will add a random 6 length string with characters to the run name." ), ) parser.add_argument( "--wandb-always-use-train-step", action="store_true", default=False, help=( "Whether to always use train step as the step metric in wandb. " "If set, we will always use the train steps for wandb logging, " "otherwise, will use rollout step for most info other than train/*. " ), ) parser.add_argument( "--log-multi-turn", action="store_true", default=False, help="Whether to log information for multi-turn rollout.", ) parser.add_argument( "--log-passrate", action="store_true", default=False, help="Whether to turn on passrate logging, which will log the pass@n of the responses in the rollout.", ) parser.add_argument( "--log-reward-category", type=str, default=None, help=( "Log statistics of the category of reward, such as why the reward function considers it as failed. " "Specify the key in the reward dict using this argument.", ), ) parser.add_argument( "--log-correct-samples", action="store_true", default=False, help="Whether to turn on passrate logging, which will log the pass@n of the responses in the rollout.", ) parser.add_argument("--wandb-run-id", type=str, default=None) return parser # tensorboard def add_tensorboard_arguments(parser): # tb_project_name, tb_experiment_name parser.add_argument("--use-tensorboard", action="store_true", default=False) parser.add_argument( "--tb-project-name", type=str, default=None, help="Directory to store tensorboard logs. Default is os.environ.get('TENSORBOARD_DIR') directory.", ) parser.add_argument("--tb-experiment-name", type=str, default=None) return parser # debug def add_debug_arguments(parser): parser.add_argument( "--save-debug-rollout-data", type=str, default=None, help=( "Save the rollout data to this path for debugging. " "The file will be saved to `save_debug_rollout_data.format(rollout_id)`." ), ) parser.add_argument( "--load-debug-rollout-data", type=str, default=None, help=( "Load the rollout data from this path for debugging. " "The file will be loaded from `load_debug_rollout_data.format(rollout_id)`. " "When this is enabled, slime will not instantiate sglang servers." ), ) parser.add_argument( "--load-debug-rollout-data-subsample", type=float, default=None, help="Subsample a portion of the debug rollout data for faster debugging.", ) parser.add_argument( "--debug-rollout-only", action="store_true", default=False, help=( "Whether to only run the rollout generation without training. " "This is useful for debugging the rollout generation function." ), ) parser.add_argument( "--debug-train-only", action="store_true", default=False, help=( "Whether to only run the training without sglang servers. " "This is useful for debugging the rollout generation function." ), ) parser.add_argument( "--save-debug-train-data", type=str, default=None, help=( "Save the train data to this path for debugging. " "The file will be saved to `save_debug_train_data.format(rollout_id)`." ), ) parser.add_argument( "--dump-details", type=str, default=None, help=("Dump all details of training for post-hoc analysis and visualization."), ) # use together with --record-memory-history and --memory-snapshot-path (defined in Megatron) parser.add_argument( "--memory-snapshot-dir", type=str, default=".", ) parser.add_argument( "--memory-snapshot-num-steps", type=int, default=None, ) parser.add_argument( "--profile-target", type=str, choices=["train_overall", "train_actor", "train_log_probs"], default=["train_overall"], nargs="+", ) parser.add_argument( "--memory-recorder", type=str, choices=["torch", "memray"], default="torch", ) parser.add_argument("--check-weight-update-equal", action="store_true") return parser def add_network_arguments(parser): parser.add_argument("--http-proxy", type=str, default=None) parser.add_argument("--use-distributed-post", action="store_true", default=False) return parser def add_reward_model_arguments(parser): parser.add_argument( "--rm-type", type=str, default=None, help="Type of the reward model", ) parser.add_argument( "--reward-key", type=str, default=None, help=( "Some reward model may return a dict instead of a value, " "this is the key to extract the reward value from the dict. " ), ) parser.add_argument( "--eval-reward-key", type=str, default=None, help="The eval variant for --reward-key", ) parser.add_argument( "--group-rm", action="store_true", default=False, help="Whether to do rm on a whole group." ) parser.add_argument( "--rm-url", type=str, default=None, help="URL for the reward model service for --rm-type remote_rm, e.g. http://localhost:8000", ) parser.add_argument( "--custom-rm-path", type=str, default=None, help=( "Path to the custom reward model function. " "If set, we will use this function to calculate the reward instead of the default one. " "The function should have the signature `def custom_rm(args, sample) -> float`." ), ) parser.add_argument( "--custom-reward-post-process-path", type=str, default=None, help=( "Path to the custom function that will post process reward, by default it will be the normalization for grpo. " ), ) parser.add_argument( "--custom-convert-samples-to-train-data-path", type=str, default=None, help=( "Path to a custom function that converts samples to training data. " "If set, this function will replace the default _convert_samples_to_train_data. " "The function should have the signature `def convert_samples_to_train_data(args, samples) -> dict`." ), ) return parser def add_rollout_buffer_arguments(parser): parser.add_argument( "--rollout-buffer-url", type=str, default=None, help="URL for the rollout buffer", ) parser.add_argument( "--fetch-trajectory-retry-times", type=int, default=-1, help="Number of times to retry fetching trajectory, -1 means unlimited retry", ) parser.add_argument( "--min-batch-collection-ratio", type=float, default=1, help="Minimum batch collection ratio", ) parser.add_argument( "--rollout-task-type", type=str, default="math", ) parser.add_argument( "--loss-mask-type", type=str, default="qwen", choices=["qwen", "qwen3", "distill_qwen"], help="Loss mask type", ) parser.add_argument( "--data-pad-size-multiplier", type=int, default=128, help="Multiplier for data padding size in data processing.", ) parser.add_argument( "--rollout-sample-filter-path", type=str, default=None, help=( "Path to the rollout sample filter function. " "This function determines whether a sample will participate in loss calculation. " "The function should take args and samples (list[Sample]) as input, and return None. " "Please directly modify the remove_sample attribute of Sample. " "Note: This attribute does not determine whether the sample participates in advantage normalization." ), ) parser.add_argument( "--rollout-all-samples-process-path", type=str, default=None, help=( "Path to the rollout all samples process function that " "can process all samples including filtered ones." ), ) parser.add_argument( "--disable-rollout-trim-samples", action="store_true", default=False, help="disable trim samples in rollout buffer when converting samples to train data", ) parser.add_argument( "--use-dynamic-global-batch-size", action="store_true", default=False, help="enable dynamic global batch size, disable trim samples in rollout buffer when converting samples to train data", ) return parser def add_custom_megatron_plugins_arguments(parser): """ Add custom Megatron plugins arguments. This is a placeholder for any additional arguments that might be needed. """ # Custom arguments can be added here parser.add_argument( "--custom-megatron-init-path", type=str, default=None, ) parser.add_argument( "--custom-megatron-before-log-prob-hook-path", type=str, default=None, ) parser.add_argument( "--custom-megatron-before-train-step-hook-path", type=str, default=None, ) return parser def add_mtp_training_arguments(parser): """Add MTP training specific arguments.""" reset_arg(parser, "--mtp-num-layers", type=int, default=None) reset_arg(parser, "--mtp-loss-scaling-factor", type=float, default=0.2) parser.add_argument( "--enable-mtp-training", action="store_true", default=False, help="Enable MTP layer parameter updates during training", ) return parser def add_prefill_decode_disaggregation_arguments(parser): parser.add_argument( "--prefill-num-servers", type=int, default=None, help="Number of prefill servers for disaggregation.", ) return parser def add_ci_arguments(parser): parser.add_argument( "--ci-test", action="store_true", ) parser.add_argument( "--ci-disable-kl-checker", action="store_true", ) parser.add_argument( "--ci-metric-checker-key", type=str, default=None, ) parser.add_argument( "--ci-metric-checker-threshold", type=float, default=None, ) parser.add_argument( "--ci-save-grad-norm", type=str, default=None, ) parser.add_argument( "--ci-load-grad-norm", type=str, default=None, ) return parser def add_sglang_tp_size(): temp_parser = argparse.ArgumentParser(add_help=False) temp_parser.add_argument("--rollout-num-gpus-per-engine", type=int, default=1) temp_parser.add_argument("--sglang-pp-size", type=int, default=1) temp_parser.add_argument("--sglang-pipeline-parallel-size", type=int, default=1) temp_args, _ = temp_parser.parse_known_args() # Use sglang_pp_size if set (non-default), otherwise use sglang_pipeline_parallel_size pp_size = ( temp_args.sglang_pp_size if temp_args.sglang_pp_size != 1 else temp_args.sglang_pipeline_parallel_size ) sglang_tp_size = temp_args.rollout_num_gpus_per_engine // pp_size return sglang_tp_size def add_custom_arguments(parser): parser.add_argument( "--evolving-gym", action="store_true", default=False, help="Use SingleTaskEvolvingGym as the rollout environment (mutually exclusive with global dataset).", ) parser.add_argument( "--evolving-gym-initial-program", type=str, default=None, help="Path to initial program file for EvolvingGym.", ) parser.add_argument( "--evolving-gym-evaluator-file", type=str, default=None, help="Path to evaluator file for EvolvingGym.", ) parser.add_argument( "--evolving-gym-config-path", type=str, default=None, help="Path to config yaml for the gym.", ) parser.add_argument( "--evolving-gym-max-concurrent-evals", type=int, default=8, help="Max concurrent evaluations inside EvolvingGym.", ) parser.add_argument( "--evolving-gym-log-prompts", action="store_true", default=True, help="Whether to log prompts in EvolvingGym.", ) parser.add_argument( "--evolving-gym-record", action="store_true", default=False, help="Enable EvolvingGym recorder and save per-rollout snapshots." ) parser.add_argument( "--evolving-gym-record-dir", type=str, default="./gym_output", help="Directory to store EvolvingGym recorder outputs." ) parser.add_argument( "--evolving-gym-lazy-output-penalty-level", type=int, default=2, help="Lazy output penalty level: 0=no penalty, 1=check parent only, 2=check parent+database." ) parser.add_argument( "--evolving-gym-seed", type=int, default=1234, help="Random seed for evolving gym (database sampling, prompt selection). Default: 1234." ) parser.add_argument( "--evolving-gym-database-reinit-ratio", type=float, default=0.0, help="Database reinitialization ratio: when (max_score - min_score) / abs(max_score) < ratio, clear database and reinitialize. 0.0 disables reinitialization." ) parser.add_argument( "--evolving-gym-smallest-restart-step", type=int, default=10000000, help="Minimum steps between database reinitializations" ) parser.add_argument( "--evolving-gym-largest-restart-step", type=int, default=10000000, help="Maximum steps between database reinitializations (force restart)" ) parser.add_argument( "--evolving-gym-add-historical-programs", type=int, default=0, help="Number of historical best programs to reload after reinitialization" ) parser.add_argument( "--evolving-gym-reward-process-type", type=str, default="original_reward", choices=["original_reward", "rl_normalized_reward", "format_reward", "validation_reward", "improve_reward"], help=( "Reward processing type:\n" " - original_reward: Use raw combined_score without normalization (default)\n" " - rl_normalized_reward: Use RL-normalized reward from metrics\n" " - format_reward: Keep format errors negative, others get 1.0\n" " - validation_reward: Keep all errors negative, only success gets 1.0\n" " - improve_reward: Binary reward (1.0 if child > parent, else 0.0)" ) ) return parser # Add custom arguments in front to prevent overwritten some slime arguments. if add_custom_arguments is not None: parser = add_custom_arguments(parser) parser = add_cluster_arguments(parser) parser = add_train_arguments(parser) parser = add_rollout_arguments(parser) parser = add_fault_tolerance_arguments(parser) parser = add_data_arguments(parser) parser = add_eval_arguments(parser) parser = add_algo_arguments(parser) parser = add_on_policy_distillation_arguments(parser) parser = add_wandb_arguments(parser) parser = add_tensorboard_arguments(parser) parser = add_router_arguments(parser) parser = add_debug_arguments(parser) parser = add_sglang_arguments(parser) parser = add_network_arguments(parser) parser = add_reward_model_arguments(parser) parser = add_rollout_buffer_arguments(parser) parser = add_mtp_training_arguments(parser) parser = add_prefill_decode_disaggregation_arguments(parser) parser = add_ci_arguments(parser) parser = add_custom_megatron_plugins_arguments(parser) reset_arg( parser, "--custom-config-path", type=str, default=None, help="Path to the YAML config for custom function arguments.", ) reset_arg(parser, "--padded-vocab-size", type=int, default=None) parser.set_defaults(sglang_tensor_parallel_size=add_sglang_tp_size()) return parser return add_slime_arguments def parse_args(add_custom_arguments=None): # Users may call `parse_args` very early, thus we ensure logger is configured here configure_logger() add_slime_arguments = get_slime_extra_args_provider(add_custom_arguments) backend = parse_args_train_backend() if backend == "megatron": from slime.backends.megatron_utils.arguments import parse_args as megatron_parse_args from slime.backends.megatron_utils.arguments import set_default_megatron_args from slime.backends.megatron_utils.arguments import validate_args as megatron_validate_args args = megatron_parse_args(extra_args_provider=add_slime_arguments) if args.hf_checkpoint and not args.debug_rollout_only: hf_config = AutoConfig.from_pretrained(args.hf_checkpoint, trust_remote_code=True) hf_validate_args(args, hf_config) args.rank = 0 args.world_size = args.actor_num_nodes * args.actor_num_gpus_per_node args = set_default_megatron_args(args) else: logger.warning( "🚧 🚧 🚧 FSDP backend is being rewritten, please use Megatron backend for better stability. 🚧 🚧 🚧" ) from slime.backends.fsdp_utils.arguments import load_fsdp_args args = load_fsdp_args(extra_args_provider=add_slime_arguments) args.rank = 0 # Primary process rank for wandb initialization args.world_size = args.actor_num_nodes * args.actor_num_gpus_per_node slime_validate_args(args) if backend == "megatron": megatron_validate_args(args) # always use varlen args.variable_seq_lengths = True if getattr(args, "moe_token_dispatcher_type", None) == "allgather": logger.info( "--moe-token-dispatcher-type allgather does not support variable sequence length, " "please use alltoall dispatcher instead." ) args.moe_token_dispatcher_type = "alltoall" if args.pipeline_model_parallel_size == 1: assert args.decoder_first_pipeline_num_layers is None and args.decoder_last_pipeline_num_layers is None, ( "decoder_first_pipeline_num_layers and decoder_last_pipeline_num_layers should be None when " "pipeline_model_parallel_size is 1." ) sglang_validate_args(args) return args def parse_args_train_backend(): if os.environ.get("SLIME_BACKEND") is not None: raise Exception("`SLIME_BACKEND` is deprecated, please use --train-backend directly.") parser = argparse.ArgumentParser() get_slime_extra_args_provider()(parser) args_partial, _ = parser.parse_known_args() return args_partial.train_backend def _resolve_eval_datasets(args) -> list[EvalDatasetConfig]: """ Build evaluation dataset configurations from either --eval-config or --eval-prompt-data. """ datasets_config = [] defaults: dict[str, Any] = {} if args.eval_config: from omegaconf import OmegaConf cfg = OmegaConf.load(args.eval_config) cfg_dict = OmegaConf.to_container(cfg, resolve=True) if not isinstance(cfg_dict, dict): raise ValueError("--eval-config must contain a mapping at the root.") eval_cfg = cfg_dict.get("eval", cfg_dict) if not isinstance(eval_cfg, dict): raise ValueError("--eval-config must define an `eval` mapping or be a mapping itself.") defaults = dict(eval_cfg.get("defaults") or {}) datasets_config = ensure_dataset_list(eval_cfg.get("datasets")) if not datasets_config: raise ValueError("--eval-config does not define any datasets under `eval.datasets`.") elif args.eval_prompt_data: values = list(args.eval_prompt_data) if len(values) == 1: logger.info("[legacy] only one eval_prompt_data detected, will assume it is data for aime") values = ["aime", values[0]] if len(values) % 2 != 0: raise ValueError("eval prompt data must be provided as name/path pairs.") datasets_config = [{"name": values[i], "path": values[i + 1]} for i in range(0, len(values), 2)] else: datasets_config = [] eval_datasets = build_eval_dataset_configs(args, datasets_config, defaults) if eval_datasets: args.eval_prompt_data = [item for dataset in eval_datasets for item in (dataset.name, dataset.path)] else: args.eval_prompt_data = None return eval_datasets def slime_validate_args(args): args.eval_datasets = _resolve_eval_datasets(args) if args.kl_coef != 0 or args.use_kl_loss: if not os.path.exists(args.ref_load): raise FileNotFoundError(f"ref_load {args.ref_load} does not exist, please check the path.") if not os.path.exists(os.path.join(args.ref_load, "latest_checkpointed_iteration.txt")): logger.info( f"ref_load {args.ref_load} does not have latest_checkpointed_iteration.txt, " "please make sure it is a valid megatron checkpoint directory." ) # Validate on-policy distillation (OPD) arguments if args.use_opd: if args.opd_type is None: raise ValueError("--opd-type must be specified when --use-opd is enabled. Choose 'sglang' or 'megatron'.") if args.opd_type == "megatron": if args.opd_teacher_load is None: raise ValueError( "--opd-teacher-load is required when --opd-type=megatron. " "Please provide the path to the teacher model checkpoint." ) if not os.path.exists(args.opd_teacher_load): raise FileNotFoundError( f"opd_teacher_load {args.opd_teacher_load} does not exist, please check the path." ) if not os.path.exists(os.path.join(args.opd_teacher_load, "latest_checkpointed_iteration.txt")): logger.info( f"opd_teacher_load {args.opd_teacher_load} does not have latest_checkpointed_iteration.txt, " "please make sure it is a valid megatron checkpoint directory." ) elif args.opd_type == "sglang": if args.opd_teacher_load is not None: raise ValueError( "--opd-teacher-load should not be set when --opd-type=sglang. " "In sglang mode, teacher log-probs are obtained from external server during rollout." ) else: # If OPD is not enabled, opd_teacher_load should not be set if args.opd_teacher_load is not None: raise ValueError("--opd-teacher-load is set but --use-opd is not enabled. Please add --use-opd flag.") if args.megatron_to_hf_mode == "bridge": if ( args.load is not None and os.path.exists(args.load) and os.path.exists(os.path.join(args.load, "latest_checkpointed_iteration.txt")) ): # If is a Megatron checkpoint, won't use bridge to load hf weight. pass else: if args.load is None: args.load = args.ref_load or args.hf_checkpoint # If is a HF checkpoint, set start_rollout_id to 0 here. args.start_rollout_id = 0 else: if ( args.load is None or not os.path.exists(args.load) or not os.path.exists(os.path.join(args.load, "latest_checkpointed_iteration.txt")) ): args.no_load_optim = True args.no_load_rng = True args.finetune = True args.load = args.ref_load if args.ref_ckpt_step is not None: args.ckpt_step = args.ref_ckpt_step args.start_rollout_id = 0 if args.eval_interval is not None: assert args.eval_datasets, "Evaluation datasets must be configured when eval_interval is set." if args.save_interval is not None: assert args.save is not None, "'--save' is required when save_interval is set." assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set" if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]: assert args.normalize_advantages, ( "The 'reinforce_plus_plus' and 'reinforce_plus_plus_baseline' advantage estimators " "require advantage normalization. Please add `--normalize-advantages` to your command." ) if args.use_rollout_logprobs: assert not args.use_tis, "use_rollout_logprobs and use_tis cannot be set at the same time." if args.get_mismatch_metrics: assert ( args.custom_tis_function_path is not None ), "custom_tis_function_path must be set when get_mismatch_metrics is set" if args.use_rollout_logprobs: logger.info( "get_mismatch_metrics is set; For metrics calculation, the log probs will still be recomputed by training engine. One more forward pass will be applied." ) if args.use_dynamic_batch_size: assert args.max_tokens_per_gpu is not None, "max_tokens_per_gpu must be set when use_dynamic_batch_size is set" if args.log_probs_max_tokens_per_gpu is None: args.log_probs_max_tokens_per_gpu = args.max_tokens_per_gpu if args.eps_clip_high is None: args.eps_clip_high = args.eps_clip if args.eval_reward_key is None: args.eval_reward_key = args.reward_key if args.dump_details is not None: args.save_debug_rollout_data = f"{args.dump_details}/rollout_data/{{rollout_id}}.pt" args.save_debug_train_data = f"{args.dump_details}/train_data/{{rollout_id}}_{{rank}}.pt" if args.load_debug_rollout_data is not None: logger.info( f"load_debug_rollout_data {args.load_debug_rollout_data} is set, " "will not instantiate sglang servers and will only run the training process." ) args.debug_train_only = True args.use_critic = args.advantage_estimator == "ppo" if args.critic_num_gpus_per_node is None: args.critic_num_gpus_per_node = args.actor_num_gpus_per_node if args.critic_num_nodes is None: args.critic_num_nodes = args.actor_num_nodes if args.critic_load is None: args.critic_load = args.load if args.critic_lr is None: args.critic_lr = args.lr if args.offload: args.offload_train = True args.offload_rollout = True del args.offload if args.debug_rollout_only: if args.colocate and (not args.rollout_num_gpus): args.rollout_num_gpus = args.actor_num_gpus_per_node * args.actor_num_nodes else: args.actor_num_gpus_per_node = min(8, args.rollout_num_gpus) args.actor_num_nodes = args.rollout_num_gpus // args.actor_num_gpus_per_node args.colocate = False args.offload_train = args.offload_rollout = False if args.train_memory_margin_bytes > 0: logger.warning("Force train_memory_margin_bytes=0 since debug_rollout_only does not support it") args.train_memory_margin_bytes = 0 assert not (args.debug_rollout_only and args.debug_train_only), ( "debug_rollout_only and debug_train_only cannot be set at the same time, " "please set only one of them." ) # always true on offload for colocate at the moment. if args.colocate: if args.offload_train is None: args.offload_train = True if args.offload_rollout is None: args.offload_rollout = True if args.rollout_num_gpus != args.actor_num_gpus_per_node * args.actor_num_nodes: logger.info( f"rollout_num_gpus {args.rollout_num_gpus} != actor_num_gpus_per_node {args.actor_num_gpus_per_node} " f"* actor_num_nodes {args.actor_num_nodes}, overriding rollout_num_gpus to match actor_num_gpus_per_node * actor_num_nodes." ) args.rollout_num_gpus = args.actor_num_gpus_per_node * args.actor_num_nodes if args.use_critic: args.rollout_num_gpus += args.critic_num_gpus_per_node * args.critic_num_nodes if args.offload_train is None: args.offload_train = False if args.offload_rollout is None: args.offload_rollout = False if args.eval_function_path is None: args.eval_function_path = args.rollout_function_path if args.num_steps_per_rollout is not None: global_batch_size = args.rollout_batch_size * args.n_samples_per_prompt // args.num_steps_per_rollout if args.global_batch_size is not None: assert args.global_batch_size == global_batch_size, ( f"global_batch_size {args.global_batch_size} is not equal to " f"rollout_batch_size {args.rollout_batch_size} * n_samples_per_prompt {args.n_samples_per_prompt} " f"// num_steps_per_rollout {args.num_steps_per_rollout}" ) args.global_batch_size = global_batch_size if args.n_samples_per_prompt == 1: args.grpo_std_normalization = False logger.info("n_samples_per_prompt is set to 1, grpo_std_normalization will be set to False.") if args.over_sampling_batch_size is None: args.over_sampling_batch_size = args.rollout_batch_size assert args.over_sampling_batch_size >= args.rollout_batch_size, ( f"over_sampling_batch_size {args.over_sampling_batch_size} should be greater than or equal to " f"rollout_batch_size {args.rollout_batch_size}" ) if args.num_epoch is not None: if args.num_rollout is not None: logger.info("Both num_epoch and num_rollout are set, num_epoch will be ignored.") else: assert args.rollout_global_dataset, ( "num_epoch is set, but rollout_global_dataset is not set, " "please remove --disable-rollout-global-dataset to use num_epoch" ) else: # if num_epoch is not set, we should set num_rollout assert args.num_rollout is not None, ( "num_epoch is not set, but num_rollout is not set, " "please set --num-rollout or --num-epoch" ) if args.enable_mtp_training: assert args.mtp_num_layers, "mtp_num_layers must be set when enable_mtp_training is set" if args.use_rollout_routing_replay: args.use_routing_replay = True if args.custom_config_path: with open(args.custom_config_path) as f: data = yaml.safe_load(f) or {} for k, v in data.items(): if hasattr(args, k): logger.info(f"Warning: Argument {k} is already set to {getattr(args, k)}, will override with {v}.") setattr(args, k, v) if args.eval_max_context_len is None: logger.info( f"args.eval_max_context_len is not set. Use args.rollout_max_context_len {args.rollout_max_context_len} as default value." ) args.eval_max_context_len = args.rollout_max_context_len if args.rollout_max_context_len is not None: if args.rollout_max_prompt_len is None: args.rollout_max_prompt_len = args.rollout_max_context_len - 1 logger.info( f"args.rollout_max_prompt_len is not set. Use args.rollout_max_context_len - 1 ({args.rollout_max_context_len} - 1) as default value so that there is at least one generated token to compute loss." ) assert ( args.rollout_max_prompt_len <= args.rollout_max_context_len - 1 ), f"args.rollout_max_prompt_len ({args.rollout_max_prompt_len}) must be smaller than args.rollout_max_context_len ({args.rollout_max_context_len}) so that there is at least one generated token to compute loss." assert not ( args.prefill_num_servers is not None and args.rollout_external ), "prefill_num_servers cannot be set when rollout_external is set." if args.qkv_format == "bshd": assert args.train_backend == "megatron", "bshd format is only supported for megatron backend." assert ( args.use_dynamic_batch_size is False ), "Dynamic batch size is not supported for bshd format. Please specify --micro-batch-size instead." if args.only_train_params_name_list and args.freeze_params_name_list: raise ValueError("You can only specify ONE of: --only-train-params-name-list, or --freeze-params-name-list.") def hf_validate_args(args, hf_config): def equal(x, y): return x == y errors = [] # multimodal models have different config structure if hasattr(hf_config, "text_config"): hf_config = hf_config.text_config for hf_config_name, megatron_config_name, compare_fn in [ ("hidden_size", "hidden_size", equal), ("num_attention_heads", "num_attention_heads", equal), ("num_hidden_layers", "num_layers", equal), ("intermediate_size", "ffn_hidden_size", equal), ("tie_word_embeddings", "untie_embeddings_and_output_weights", lambda x, y: not x == y), ("rms_norm_eps", "norm_epsilon", equal), ("rope_theta", "rotary_base", equal), ]: if hasattr(hf_config, hf_config_name): if not compare_fn(getattr(hf_config, hf_config_name), getattr(args, megatron_config_name)): errors.append( f"{hf_config_name} in hf config {getattr(hf_config, hf_config_name)} is not equal to " f"{megatron_config_name} {getattr(args, megatron_config_name)}, please check the config." ) if len(errors) > 0: raise AssertionError("hf_validate_args failed: " + "; ".join(errors))