shahidul034's picture
Add files using upload-large-folder tool
ff8fd11 verified

Recipe: Fully Async Policy Trainer

Author: https://github.com/meituan-search

Last updated: 12/25/2025.

This document introduces a fully asynchronous PPO training system that completely decouples the Trainer and Rollouter, supporting asynchronous sample generation and training. Under this system, we achieved a 2.35x-2.67x performance improvement when training the Qwen2.5-7B model with 128 GPUs, without significantly affecting the results.

Introduction

Background

The separated rollout and train architecture, compared to the colocate architecture, can allocate resources more flexibly and design more flexible training logic, thereby addressing issues such as low GPU utilization and training efficiency caused by long-tail problems. The one_step_off_policy alleviates the problem of long rollout times and achieves some gains in training efficiency by designing a separated architecture and performing asynchronous training between rollout and train for one round. However, it forcibly uses data from one round of asynchronous training, which is not flexible enough and cannot completely eliminate the impact of long-tail on training efficiency. In other frameworks such as AReaL, Magistral, StreamRL, and AsyncFlow, asynchronous training and streaming training have been implemented based on the separated architecture and have achieved gains. We borrow from their methods and implemented them in VERL. The fully_async_policy supports asynchronous, streaming, and partial rollout training. By reasonably setting parameters such as resource allocation and parameter synchronization frequency, fully_async_policy can significantly improve training efficiency.

Magistral https://arxiv.org/abs/2506.10910

AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language Reasoning https://arxiv.org/abs/2505.24298

StreamRL: Scalable, Heterogeneous, and Elastic RL for LLMs with Disaggregated Stream Generation https://arxiv.org/abs/2504.15930

AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training https://arxiv.org/abs/2507.01663

Core Contributions

  • Resource Isolation: Unlike using hybrid_engine, Rollouter and Trainer use separate computing resources and need to specify the resources they occupy separately.
  • Parallel Generation and Training: While the Trainer is training, the Rollouter is generating new samples.
  • Multi-step Asynchronous: Compared to one step off policy, it supports asynchronous settings from 0.x steps to multiple steps, making the asynchronous solution more flexible.
  • NCCL Parameter Synchronization: Based on the nccl communication primitive, refer to checkpoint-engine to achieve efficient parameter synchronization between Rollouter and Trainer.
  • Stream Inference and Training: Rollouter generates data sample by sample, and data transmission uses a single sample as the minimum transmission unit.
  • Asynchronous Training and Freshness Control: By setting the parameter async_training.staleness_threshold, it supports training with samples generated by old parameters.
  • PartialRollout: The Rollouter's inference process supports partial rollout logic. During parameter synchronization, by adding sleep() and resume() logic, it saves samples from ongoing rollouts and continues using them in the next rollout, reducing the time spent waiting for ongoing tasks to finish during parameter synchronization.

Currently, the supported usage mode is Megatron/FSDP+vLLM/SGLang. vLLM/SGLang must use the server mode based on AgentLoop.

Design

The overall architecture of fully_async_policy is shown in the figure below. fully_async_policy mainly consists of four parts: Rollouter, MessageQueue, Trainer, and ParameterSynchronizer.

fully_async_policy_structure

  1. Rollouter generates sequences sample by sample and puts the generated samples into the MessageQueue, with the production speed controlled by freshness.
  2. MessageQueue is used to temporarily store samples generated by Rollouter.
  3. Trainer fetches samples from MessageQueue sample by sample. After fetching require_batches*ppo_mini_batch_size samples, it will perform training. After training for async_training.trigger_parameter_sync_step rounds, it triggers a parameter synchronization with Rollouter.
  4. ParameterSynchronizer implements the NCCL synchronous parameter synchronization capability.

The source of benefits compared to the base scheme lies in the fact that in the colocate case, using more resources for rollout cannot solve the idleness caused by long-tail samples. After we perform resource isolation, the time for rollout and train may be longer than before (because fewer resources are used), but the overlap in their time consumption reduces the end-to-end time consumption.

fully_async_policy_revenue

Usage

Parameter Description

super params implication
trainer.nnodes Number of nodes for Trainer
trainer.n_gpus_per_node Number of GPUs per node for Trainer
rollout.nnodes Number of nodes for Rollouter
rollout.n_gpus_per_node Number of GPUs per node for Rollouter
data.train_batch_size In the fully async strategy, this value is not effective (default is 0)
data.gen_batch_size In the fully async strategy, uses streaming sample production logic (default is 1)
rollout.total_rollout_steps Total number of rollout samples
rollout.test_freq How many times Rollouter updates parameters before performing a validation
actor_rollout_ref.actor.ppo_mini_batch_size The ppo_mini_batch_size is a global num across all workers/gpus
async_training.require_batches Number of ppo_mini_batch_size that FullyAsyncTrainer fetches at once
async_training.trigger_parameter_sync_step Indicates how many local updates FullyAsyncTrainer performs before a parameter synchronization
async_training.staleness_threshold Freshness control
async_training.partial_rollout Whether to perform partial_rollout
async_training.use_rollout_log_probs Use log_probs generated by rollout
async_training.compute_prox_log_prob Whether to compute log_prob using the training model's parameters during the training phase
async_training.checkpoint_engine.enable Whether to use checkpoint_engine for accelerating, default True
async_training.checkpoint_engine.overlap_broadcast_and_consume When use checkpoint_engine, whether to overlap broadcast and load_weights, default False
async_training.checkpoint_engine.device_buffer_size_M When use checkpoint_engine, the user-specific bucket size (MB), default 4096
async_training.use_trainer_do_validate Whether use trainer node to do validate process, default False

Further Explanation:

  • rollout.total_rollout_steps

    Compared to colocate, the quantity can be aligned by multiplying train_batch_size and step: rollout.total_rollout_steps = data.train_batch_size * step.

  • async_training.trigger_parameter_sync_step

    In the fully async strategy, it indicates how many local updates the Trainer performs (i.e., how many times it fetches require_batches * ppo_mini_batch_size samples) before a parameter synchronization with Rollouter. Between every two parameter synchronizations between Rollouter and Trainer, the Trainer will process trigger_parameter_sync_step* require_batches*ppo_mini_batch_size samples. To fairly compare speed with colocate, trigger_parameter_sync_step should be set to data.train_batch_size / (require_batches * ppo_mini_batch_size).

  • async_training.staleness_threshold

    In the fully async strategy, it indicates the maximum proportion of stale samples allowed to be used.

    • staleness_threshold=0, indicates synchronous training. Rollouter will generate a fixed number of samples between two parameter updates, the sample count is: $$rollout_num = (trigger_parameter_sync_steprequire_batchesppo_mini_batch_size)$$
    • staleness_threshold>0, indicates asynchronous training, can be set to a decimal for more flexible asynchronous calls. Rollouter will generate at most the following number of samples between two parameter updates: $$rollout_num = (1+staleness_threshold)(trigger_parameter_sync_steprequire_batches*ppo_mini_batch_size) - num_staleness_sample $$

    num_staleness_sample represents the number of stale samples generated in excess during the last rollout.

    Since it's a streaming system, rollout continues to generate and trainer continues to consume. If rollouter is slower, trainer will trigger parameter synchronization earlier, and rollouter will not actually produce rollout_num samples. When rollout is fast enough, setting staleness_threshold to 1 is basically equivalent to one_step_off policy. To avoid too many expired samples affecting training accuracy, it is recommended to set this value to less than 1.

  • async_training.partial_rollout

    partial_rollout only actually takes effect when staleness_threshold>0.

  • async_training.use_rollout_log_probs

    In reinforcement learning algorithms, log_probs have implicit correlations with parameter versions and tokens. Due to the settings of algorithms like PPO/GRPO/DAPO, when calculating importance sampling, old_log_prob must use the log_probs corresponding to the rollout parameters and tokens to ensure algorithm correctness. In the fully async strategy, we default to old_log_prob being calculated by rollout rather than by trainer.

  • async_training.require_batches

    In streaming training, require_batches should be set to 1, indicating that training is performed after producing enough ppo_mini_batch_size samples. In actual testing, we found that if fewer samples are issued at once, due to the order of data distribution, it can cause training instability and longer response lengths. Here, we additionally provide require_batches for streaming distribution and control the number of samples participating in training at once.

  • async_training.compute_prox_log_prob (experimental)

    During the training process, we observed that metrics and response lengths may become unstable in the later stages of training. To mitigate this issue, we can use the Rollout Importance Sampling technique for importance sampling. To utilize Rollout Importance Sampling, we need to compute log_prob using the training engine, which requires enabling this switch. Additionally, when compute_prox_log_prob and Rollout Importance Sampling are enabled under mode d (async stream pipeline with partial rollout), our implementation approximates Areal's Decoupled PPO.

  • async_training.checkpoint_engine.enable

    Enabling the checkpoint engine generally reduces synchronization time overhead by more than 60% compared to the original per-tensor parameter synchronization method. However, assembling buckets incurs additional temporary GPU memory overhead.

  • async_training.checkpoint_engine.overlap_broadcast_and_consume

    Enabling pipeline between the broadcast and load_weights parameters will allocate additional GPU memory. Since the main time consumption for parameter synchronization is not in the broadcast and load_weights phases, but in the parameter generation phase (by megatron or FSDP), this option is off by default.

  • async_training.checkpoint_engine.device_buffer_size_M

    It controls the size of the memory buffer used for synchronization when the checkpoint-engine is enabled. The actual bucket_size = max(device_buffer_size_M, maximum parameter tensor size).

    • When enable overlap_broadcast_and_consume, the additional device memory overhead of trainer rank is 3 * bucket_sizeand rollout rank is 2 * bucket_size
    • When disable overlap_broadcast_and_consume, the additional device memory overhead of trainer rank is 2 * bucket_sizeand rollout rank is 1 * bucket_size
  • async_training.use_trainer_do_validate

    It controls whether to use the trainer's do_validate method for validation. If set to True, the trainer will perform validation after each parameter update. It can reduce the validation time overhead and trainer node idle time. If set to False, the trainer will not perform validation.

Supported Modes

  1. on policy pipeline:

    1. trigger_parameter_sync_step=1, staleness_threshold=0
    2. Rollouter produces require_batches*ppo_mini_batch_size samples at once, Trainer fetches these samples for training, and after training completes, Trainer and Rollouter perform a parameter synchronization;
    3. During the rollout phase, if there are long-tail samples but few rollout samples, shorter samples cannot fill idle resources, causing some resource waste.
    4. As shown in figure a;
  2. stream off policy pipeline:

    1. trigger_parameter_sync_step>1, staleness_threshold=0
    2. Synchronous streaming training will be performed. Rollouter produces require_batches*ppo_mini_batch_size*trigger_parameter_sync_step samples at once, Trainer performs a local training every time it fetches require_batches*ppo_mini_batch_size samples, and after training trigger_parameter_sync_step times, Trainer and Rollouter perform a parameter synchronization;
    3. Compared to a, since more samples are generated at once, resource idleness will be lower.
    4. In one step training, there will be two periods of resource idleness: when fetching the first batch of samples, train waits for require_batches*ppo_mini_batch_size samples to be produced, and during the last parameter update, rollout waits for training to complete.
    5. As shown in figure b;
  3. async stream pipeline with stale samples:

    1. trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=False
    2. After each parameter update, Rollouter will plan to produce at most rollout_num samples (in practice, the number of samples generated may be less than this value depending on rollout speed).
    3. If the rollout process is relatively fast, Rollouter will generate some additional samples num_stale_samples before parameter synchronization for immediate use by Trainer after synchronization. When triggering parameter synchronization, if Rollouter has ongoing tasks, it will wait for the tasks to complete and not add new tasks;
    4. Compared to b, except for the first step training, subsequent training will not have the time to wait for the first batch rollout to finish, but will have the time to wait for active tasks to finish.
    5. As shown in figure c;
  4. async stream pipeline with partial rollout:

    1. trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=True
    2. Compared to c, when triggering parameter synchronization, if Rollouter has samples being produced, it will interrupt the rollout process and perform parameter synchronization. The interrupted samples will continue to be generated after synchronization. This reduces the time to wait for active tasks to finish.
    3. As shown in figure d;

fully_async_policy_mode

Key Metrics

metrics implication
trainer/idle_ratio Trainer idle rate
rollouter/idle_ratio Rollouter idle rate
fully_async/count/stale_samples_processed Total number of old samples used in training
fully_async/count/stale_trajectory_processed Total number of old trajectories used in training (one sample produces rollout.n trajectories)
fully_async/partial/total_partial_num Number of partial samples processed by Trainer between two trigger_parameter_sync_step
fully_async/partial/partial_ratio Ratio of partial samples processed by Trainer between two trigger_parameter_sync_step
fully_async/partial/max_partial_span Maximum parameter span of partial samples processed by Trainer between two trigger_parameter_sync_step

Parameter Tuning Recommendations

  • Resource Allocation and Adjustment:

    • Reasonable resource allocation is the prerequisite for achieving good training efficiency. The ideal resource allocation should make the rollout time and train time close, thereby minimizing pipeline bubbles in the entire training process, avoiding resource idleness, and ensuring Trainer does not use old samples. In real training scenarios, resource allocation can be adjusted based on the idle time of rollout and train during actual training, which can be obtained from rollouter/idle_ratio and trainer/idle_ratio. If rollouter/idle_ratio is high and trainer/idle_ratio is low, Trainer resources should be increased and Rollouter resources should be reduced, and vice versa.
  • Key Parameters:

    • staleness_threshold: Setting it too high will cause more old samples to be used, affecting model performance. It is recommended to set it to less than 1.
    • require_batches: The closer to 1, the closer to a pure streaming process, the smaller the training bubbles, and the faster the acceleration effect that can be achieved in terms of speed, but it will affect the order of sample processing;
    • trigger_parameter_sync_step: The smaller the setting, the closer to on policy, but it will cause frequent parameter synchronization. Long-tail samples waste resources that cannot be filled by short samples, resulting in low resource utilization. The larger the setting, the higher the computational efficiency, but the accuracy will be affected by off policy.
    • rollout.test_freq: It will occupy Rollouter resources and is not recommended to be set too small.
  • Mode Selection: By adjusting different parameters, the Fully Async architecture supports optimization acceleration at different levels, suitable for tasks in different scenarios.

    • For small-scale tasks that need to ensure training stability and on-policy nature, and have low speed requirements, the on policy pipeline mode (Mode 1) can be tried.
    • For scenarios that need to improve training throughput but are sensitive to staleness, the stream off policy pipeline mode can be tried. That is, by setting trigger_parameter_sync_step>1 to improve training efficiency, but still maintaining the synchronization mechanism (staleness_threshold=0) (Mode 2).
    • For large-scale tasks with high training speed requirements and can tolerate a certain degree of off-policy and staleness, setting staleness_threshold> 0 and partial_rollout=True can improve training efficiency, using the async stream pipeline mode (Mode 3 or 4).

Quick Start

rollout_mode="async"
rollout_name="vllm" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
    export VLLM_USE_V1=1
    return_raw_chat="True"
fi

train_prompt_bsz=0
gen_prompt_bsz=1
n_resp_per_prompt=16
train_prompt_mini_bsz=32
total_rollout_steps=$(((512*400)))
test_freq=10
staleness_threshold=0
trigger_parameter_sync_step=16
partial_rollout=False


python -m verl.experimental.fully_async_policy.fully_async_main \
    train_batch_size=${train_prompt_bsz} \
    data.gen_batch_size=${gen_prompt_bsz} \
    data.return_raw_chat=${return_raw_chat} \
    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
    actor_rollout_ref.actor.strategy=fsdp2 \
    critic.strategy=fsdp2 \
    actor_rollout_ref.hybrid_engine=False \
    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.rollout.name=${rollout_name} \
    actor_rollout_ref.rollout.mode=${rollout_mode} \
    actor_rollout_ref.rollout.calculate_log_probs=True \
    trainer.nnodes="${NNODES_TRAIN}" \
    trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
    rollout.nnodes="${NNODES_ROLLOUT}" \
    rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
    rollout.total_rollout_steps="${total_rollout_steps}" \
    rollout.test_freq="${test_freq}" \
    async_training.staleness_threshold="${staleness_threshold}" \
    async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
    async_training.partial_rollout="${partial_rollout}"

Experiments

Asynchronous Training on 7B Model

We used Qwen2.5-Math-7B to verify the benefits of the fully async strategy under long candidates and multiple resources. Using the async stream pipeline with stale samples strategy, we achieved about 2x performance improvement on 32 cards, 64 cards, and 128 cards without significantly affecting experimental results.

  • Machine: H20

  • Model: Qwen2.5-Math-7B

  • Rollout length: max_response_length FSDP2: 28K tokens;

  • Algorithm: DAPO

  • Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet

  • Engine: vLLM + FSDP2

  • rollout.n: 16

  • ppo_mini_batch_size: 32

  • test_freq: 20

  • colocate sync:

    • step: 400
    • train_batch_size: 512
  • fully_async_policy

    • total_rollout_steps: 512*400
    • require_batches: 4
    • trigger_parameter_sync_step: 4
    • staleness_threshold: 0.5
    • partial_rollout: True
training mode resource allocation step gen old_log_prob update_actor total time
100 step
total time
200 step
total time
300 step
total time
400 step
acc/mean@1
colocate sync 32 790.10 357.41 107.71 269.80 13h 44m 1d 3h 43m 2d 9h 22m 3d 17h 5m max: 0.3313
last: 0.2448
fully_async_policy 16:16 294.77 21.26 \ 313.81 7h 58m
(1.72x)
16h 21m
(1.70x)
1d 0h 53m
(2.31x)
1d 9h 26m
(2.66x)
max: 0.3302
last: 0.2333
colocate sync 64 365.28 150.72 70.26 133.41 10h 22m 20h 45m 1d 7h 6m 1d 17h 32m max: 0.3365
last: 0.2333
fully_async_policy 32:32 189.26 28.46 \ 156.98 4h 57m
(2.09x)
10h 14m
(2.03x)
16h 58m
(1.83x)
21h 40m
(1.92x)
max: 0.3677
last: 0.3406
colocate sync 128 356.30 177.85 53.92 113.81 8h 36m 17h 56m 1d 5h 6m 1d 16h 48m max: 0.3573
last: 0.2958
fully_async_policy 64:64 150.63 33.14 \ 113.16 3h 13m
(2.67x)
6h 46m
(2.65x)
10h 53m
(2.67x)
17h 22m
(2.35x)
max: 0.3521
last: 0.3094

source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-colocate_async?nw=nwuserhouzg

128-card 7B Asynchronous Mode Experiment

We used Qwen2.5-Math-7B to verify the effects of various modes supported by fully async. We can see that the benefit brought by streaming is approximately 1.6x, and after combining staleness and partial_rollout, the benefit reaches 2.35x.

mode step gen old_log_prob update_actor total time
100 step
total time
200 step
total time
300 step
total time
400 step
acc/mean@1
colocate sync 356.30 177.85 53.92 113.81 8h 36m 17h 56m 1d 5h 6m 1d 16h 48m max: 0.3573
last: 0.2958
stream off policy pipeline
(+fully async: trigger_parameter_sync_step= 4,
require_batches= 4)
231.34 128.47 \ 98.77 4h 25m 9h 41m 15h 2m 1d 1h 53m max: 0.2844
last: 0.2604
async stream pipeline with stale samples
(+staleness_threshold=0.5)
async stream pipeline with partial rollout
(+partial_rollout=True)
150.63 33.14 \ 113.16 3h 13m 6h 46m 10h 53m 17h 22m max: 0.3521
last: 0.3094

source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg

128-card Stale Ablation Experiment

Under the async stream pipeline with partial rollout mode, we verified the impact of staleness settings on training efficiency. We found that the larger the staleness, the more obvious the final gains. We also noticed that the times for staleness values of 0.3 and 0.5 are quite close, because as the training steps increase, the response length changes significantly, causing training instability. Further analysis and optimization are needed for this issue.

staleness_threshold step gen old_log_prob update_actor total time
100 step
total time
200 step
total time
300 step
total time
400 step
acc/mean@1
0 231.34 128.47 \ 98.77 4h 25m 9h 41m 15h 2m 1d 1h 53m max: 0.2844
last: 0.2604
0.1 171.30 58.17 \ 109.12 3h 53m 8h 37m 14h 25m 19h 59m max: 0.3542
last: 0.2979
0.3 146.11 38.88 \ 103.22 3h 18m 6h 49m 11h 40m 17h 20m max: 0.3469
last: 0.2865
0.5 150.63 33.14 \ 113.16 3h 13m 6h 46m 10h 53m 17h 22m max: 0.3521
last: 0.3094

source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg

128-card 7B require_batches Ablation Experiment

In multiple tests, we found that the number of samples issued each time in streaming affects the response length during training, which in turn affects training time. We verified the impact on results by modifying async_training.require_batches.

require_batches step gen old_log_prob update_actor total time
100 step
total time
200 step
total time
300 step
acc/mean@1
1 203.47 30.88 \ 181.08 3h 31m 8h 29m 17h 36m max: 0.349
last: 0.326
2 158.72 26.32 \ 128.08 3h 35m 7h 38m 13h 57m max: 0.351
last: 0.3406
4 124.64 25.62 \ 95.06 3h 13m 6h 46m 10h 53m max: 0.3521
last: 0.3521

source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_require_batches?nw=nwuserhouzg

30B Model Mode Experiment

We achieved a 1.7x performance improvement with async stream pipeline with staleness samples strategy on the Qwen3-30B-A3B-Base model compared to the colocate setup. It is worth noting that this is far from the upper limit of performance gains achievable through asynchrony. Firstly, the comparative experiments used a maximum response length of only 8k, which is much shorter than the 20k sequence length in previous experiments, resulting in a less pronounced rollout tail effect. Secondly, we adopted a highly skewed resource allocation, with rollout using 96 GPUs and trainer using 32 GPUs, which is not an optimal configuration. During the experiments, we observed that the current verl implementation imposes certain constraints, such as requiring data to be evenly divisible by the number of GPUs, making resource adjustment less flexible. Additionally, as asynchronous training and deployment accelerate, the performance gap is gradually narrowing. Therefore, enabling more flexible resource allocation and dynamic resource adjustment in the future will be our next focus.

  • Machine: H20

  • Model: Qwen3-30B-A3B-Base

  • Rollout length: max_response_length : 8K tokens;

  • Algorithm: GRPO

  • Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet

  • Engine: vLLM + Megatron

  • rollout.n: 16

  • ppo_mini_batch_size: 128

  • test_freq: 20

  • colocate sync:

    • step:400
    • train_batch_size: 512
  • fully_async_policy

    • total_rollout_steps: 512*400
    • trigger_parameter_sync_step: 512/128 = 4
    • staleness_threshold: 0.5
    • partial_rollout: True
Training Mode Resource Allocation Step Gen Old Log Prob Ref Update Actor Total Time 100 Step Total Time 200 Step Total Time 300 Step Total Time 400 Step Acc/Mean@1
Colocate Sync 128 497.89 348.05 28.73 20.86 86.27 13h 36m 1d 3h 48m 1d 19h 4m 2d 11h 39m max: 0.3500
last: 0.3208
Fully Async Policy 96:32 282.75 22.06 \ 50.05 206.63 6h 45m (2.01x) 14h 48m (1.88x) 1d 0h 9m (1.78x) 1d 10h 41m (1.72x) max: 0.3813
last: 0.3448

source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-30B?nw=nwuserhouzg | | |

checkpoint-engine Ablation Experiment

We tested the single-step parameter synchronization time of the checkpoint-engine on three models: Qwen2.5-Math-7B, Qwen3-30B-A3B, and Qwen3-235B-A22B, using default checkpoint-engine configurations. All experiments were performed on H20 machines, and the Megatron engine was used for training.

model trainer rank rollout rank checkpoint-engine total sync time
Qwen2.5-Math-7B 4 4 False 0.12s
Qwen2.5-Math-7B 4 4 True 0.02s
Qwen3-30B-A3B 16 16 False 15.76s
Qwen3-30B-A3B 16 16 True 4.38s
Qwen3-235B-A22B 64 64 False 58.57s
Qwen3-235B-A22B 64 64 True 23.70s

use_trainer_do_validate Experiment

We tested the effect of setting use_trainer_do_validate=True on the training process. The results show that setting this parameter to True can reduce the validation time overhead and trainer node idle time. We used Qwen2.5-Math-7B to verify the benefits of use_trainer_do_validate=True on the training process, we achieved about 2x performance improvement on validation time, and the trainer node idle time is reduced by about 40%.

  • Machine: H20

  • Model: Qwen2.5-Math-7B

  • Rollout length: max_response_length FSDP2: 10K tokens;

  • Algorithm: DAPO

  • Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet

  • Engine: vllm+FSDP2

  • rollout.n: 16

  • ppo_mini_batch_size: 32

  • test_freq: 10

  • fully_async_policy

    • total_rollout_steps: 512*400
    • require_batches: 4
    • trigger_parameter_sync_step: 4
    • staleness_threshold: 0.5
    • partial_rollout: True
training mode resource allocation step gen old_log_prob update_actor validate time total time
50 step
acc/mean@2
colocate sync 16 484.623 52.939 0 430.263 205.080 7h9m 22.6
fully_async_policy 8:8 489.953 52.622 0 435.874 95.699 7h2m 21.0

Multi-Turn Tool Calling

Referencing recipe/retool and ToolAgentLoop, we implemented AsyncPartialToolAgentLoop, a multi-turn tool-calling loop that supports partial_rollout for fully_async_policy.

Core Design

AsyncPartialToolAgentLoop inherits from ToolAgentLoop and is adapted for the asynchronous training mode of fully_async_policy. When partial_rollout=True, the Rollouter interrupts ongoing generation tasks before synchronizing parameters with the Trainer. AsyncPartialToolAgentLoop is capable of:

  1. Interrupting Tasks: Responding to an interrupt signal to save the current state. Currently, interruptions occur during the GENERATING process or after other states have completed.
  2. Resuming Tasks: Resuming execution from the saved state after parameter synchronization is complete, rather than starting over.

How to Use

RL training with multi-turn tool calling in fully_async_policy is similar to recipe/retool. It is enabled by specifying multi_turn configurations in the config file.

  1. SFT Stage: First, the model should undergo SFT to learn how to follow tool-calling format instructions.
  2. Multi-turn Configuration: In the fully_async_policy training configuration, set the following parameters:
    actor_rollout_ref:
      rollout:
        multi_turn:
          enable: True # AsyncPartialToolAgentLoop will be used by default in fully_async_policy mode
          # Other multi_turn related configurations
    
  3. Async Parameters: To improve efficiency, enable partial_rollout and staleness_threshold when using multi-turn tool calling:
    async_training:
      partial_rollout: True
      staleness_threshold: 0.5
      # Other async parameters
    
  4. Example: See recipe/fully_async_policy/shell/dapo_7b_async_retool.sh.

Experimental Results

To validate the performance of fully_async_policy on multi-turn tool-calling tasks, we compared it with the standard colocate synchronous mode. Key parameter settings are as follows.

  • SFT Model: Based on Qwen2.5-7B-Instruct, trained for 6 epochs on the ReTool-SFT dataset
  • RL Algorithm: DAPO
  • Dataset:
    • Train: DAPO-Math-17k
    • Test: aime_2025
  • Resource and Mode Comparison:
    • colocate sync: 32 H20 gpus
    • fully_async_policy: 16 gpus for Trainer + 16 gpus for Rollouter
  • Key Configurations:
    1. Tool Calling Configuration:
      • multi_turn.enable: True
      • multi_turn.max_user_turns: 16
      • multi_turn.max_assistant_turns: 16
      • multi_turn.tool_config_path: recipe/retool/sandbox_fusion_tool_config.yaml
    2. colocate sync Configuration:
      • ppo_mini_batch_size: 16
      • train_batch_size: 64
    3. fully_async_policy Configuration:
      • ppo_mini_batch_size: 16
      • trigger_parameter_sync_step: 4
      • require_batches: 1
      • staleness_threshold: 1
      • partial_rollout: True
training mode Resource allocation step gen old_log_prob update_actor total time
100 step
total time
200 step
aime_2025
acc/mean@30
colocate 32 375.47 228.03 35.19 111.84 9h 46m 22h 28m start:0.1078
last:0.2056
fully_async_policy 16: 16 221.36 40.59 \ 179.58 6h 19m
(1.55x)
14h 4m
(1.60x)
start:0.11
last:0.2044

source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-multiturn-tool?nw=nwuserhouzg

Future Plans

  • Transfer queue integration
  • Asynchronous parameter synchronization