| # Asynchronous GRPO | |
| > [!IMPORTANT] | |
| > This trainer requires `vllm>=0.17.1` and `transformers>=5.2.0`. For distributed training, only FSDP2 is supported (DeepSpeed ZeRO is not). | |
| > | |
| > Currently, `vllm` and `transformers` have conflicting dependency constraints. To work around this, install vLLM first and then force-install transformers: | |
| > | |
| > ```bash | |
| > pip install 'vllm>=0.17.1' | |
| > pip install 'transformers>=5.2.0' --no-deps | |
| > ``` | |
| ## Overview | |
| [`AsyncGRPOTrainer`] implements the same [GRPO](grpo_trainer) algorithm but decouples rollout generation from training. A background worker continuously streams completions from a vLLM server while the training loop consumes them, so generation and gradient updates overlap instead of alternating. The API mirrors [`GRPOTrainer`] — for full details on the GRPO method itself (advantage computation, KL estimation, loss formulation, reward functions, etc.), see the [GRPO Trainer](grpo_trainer) documentation. Not all features from [`GRPOTrainer`] are available; refer to [`AsyncGRPOConfig`] for the supported parameters. | |
| This trainer was contributed by [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Amine Dirhoussi](https://huggingface.co/aminediroHF). | |
| ## How it differs from [`GRPOTrainer`] | |
| In the standard [`GRPOTrainer`], generation and training are sequential: generate a batch, compute the loss, update weights, repeat. Even in [vLLM colocate mode](grpo_trainer#speed-up-training-with-vllm), where generation runs on the same GPUs, one phase must finish before the other begins. | |
| [`AsyncGRPOTrainer`] separates these two concerns: | |
| - **Rollout worker** (background thread) — sends prompts to a vLLM server, scores completions with reward functions, computes advantages, and pushes ready-to-train samples into a queue. | |
| - **Training loop** (main process) — pulls samples from the queue, computes the clipped surrogate loss, and updates the model weights. | |
| After every `weight_sync_steps` training steps, the updated weights are transferred to the vLLM server via NCCL so that subsequent generations reflect the latest policy. | |
| Because generation and training run concurrently, the training samples may have been generated by a slightly older version of the model. The `max_staleness` parameter controls how many weight updates a sample can lag behind before being discarded. | |
| The number of concurrent requests sent to the vLLM server is controlled by `max_inflight_tasks`. By default it is set automatically to `max_staleness × per_device_train_batch_size × gradient_accumulation_steps × num_processes` — the maximum number of samples the trainer can consume before they become stale. Generating more than this is wasteful since the excess samples will be discarded. | |
| ## Quick start | |
| ```python | |
| # train_async_grpo.py | |
| from datasets import load_dataset | |
| from trl.experimental.async_grpo import AsyncGRPOTrainer | |
| from trl.rewards import accuracy_reward | |
| dataset = load_dataset("trl-lib/DeepMath-103K", split="train") | |
| trainer = AsyncGRPOTrainer( | |
| model="Qwen/Qwen3-4B", | |
| reward_funcs=accuracy_reward, | |
| train_dataset=dataset, | |
| ) | |
| trainer.train() | |
| ``` | |
| The vLLM server and the trainer must run on **separate GPUs**. Use `CUDA_VISIBLE_DEVICES` to partition your GPUs. For example, with 2 GPUs, you can run the vLLM server on GPU 0 and the trainer on GPU 1 as follows: | |
| ```bash | |
| # Terminal 1: vLLM server on GPU 0 (dev mode + NCCL weight transfer are required) | |
| CUDA_VISIBLE_DEVICES=0 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-4B \ | |
| --max-model-len 4096 \ | |
| --logprobs-mode processed_logprobs \ | |
| --weight-transfer-config '{"backend":"nccl"}' | |
| ``` | |
| > [!TIP] | |
| > Set `--max-model-len` to the maximum total sequence length (prompt + completion) you expect. A lower value reduces GPU memory usage on the server, freeing more memory for the KV cache and increasing throughput. A good starting point is the prompt length plus `max_completion_length` from your config. | |
| ```bash | |
| # Terminal 2: training on GPU 1 | |
| CUDA_VISIBLE_DEVICES=1 accelerate launch train_async_grpo.py | |
| ``` | |
| ## Design philosophy | |
| This trainer is intentionally kept minimal and is not meant to grow into a general-purpose solution. If you need a feature that is not supported, we recommend cloning the repository and adapting the trainer to your needs directly. New features will only be considered when there is significant community demand. | |
| ## AsyncGRPOConfig | |
| [[autodoc]] trl.experimental.async_grpo.AsyncGRPOConfig | |
| ## AsyncGRPOTrainer | |
| [[autodoc]] trl.experimental.async_grpo.AsyncGRPOTrainer | |