mshahidul
Initial commit of readCtrl code without large models
030876e

verl 0.7 release blog

Author: verl team

Last updated: 01/03/2026.

Overview

verl adopts a Hybrid-Controller architecture (also known as HybridFlow). Sharing design principles with asynchronous sharded dataflow systems like Google Pathways, verl models Reinforcement Learning (RL) algorithms, such as PPO, GRPO, DAPO, and others, as a multi-stage, multi-model and parallelizable dataflow graph.

To balance flexibility with performance, verl unifies two distinct programming models:

High-Level Single-Controller (MPMD): At the orchestration level, a single process RLTrainer manages the global computation graph. It handles macro-tasks such as scheduling rollout generation, triggering reward scoring, and dispatching distributed training jobs.

Internal Multi-Controller (SPMD): Internally, the Model Engine operates in standard distributed training mode. Workers execute identical programs, via trainer backends like FSDP, Megatron, or VeOmni, or rollout executors (not rollout server) like vLLM/SGLang/TensorRT-LLM, to perform heavy distributed computation, synchronizing via collective communication.

hybridflow.png

This hybrid approach offers significant advantages:

Flexible Orchestration: The single-controller design allows verl to dynamically manage complex constraints within the computation graph, including flexible data dependencies, diverse resource allocation and model placement, and fine-grained asynchronous staleness control.

Abstraction of Complexity: We encapsulate complex parallel strategies—such as 5D parallelism (DP, TP, CP, PP, and EP)—strictly within the Model Engine. This allows users to focus entirely on RL algorithm implementation without getting bogged down by the details of distributed training.

Furthermore, leveraging Ray placement groups, verl provides ResourcePool and WorkerGroup abstractions. These enable flexible GPU sharing among the various roles in the RL process—such as actor, critic, reward, and rollout—allowing components to share resources efficiently while remaining isolated.

As illustrated in the diagram below, the overall architecture of verl is divided into two layers:

  • verl-core: provides four components required for the RL pipeline: model engine, rollout engine, checkpoint engine, and transfer queue. Each component exposes abstract interfaces, making them both extensible and pluggable.
  • verl-trainer: builds upon these components, construct various RL pipelines—such as on-policy, one-step-off-policy, and fully asynchronous—tailored to meet the demands of diverse scenarios.
verl-arch.png

verl-core

Model Engine

The Model Engine serves as verl's core training engine, defining a set of abstract interfaces that support pluggable backends. It operates in SPMD mode:

  • SFT: Workers are launched via torchrun.
  • RL: Workers are executed via the WorkerGroup API, invoked by the single-controller.

The abstract interfaces include methods like initialize, forward, optimizer_step, and load/offload. Integrating a new training engine simply requires inheriting and implementing these interfaces. Crucially, because all backends adhere to this unified abstraction, adding a new Model Engine requires absolutely no code modification on the caller side. The RLTrainer remains completely agnostic to the backend's specific parallel strategy when calling these interfaces, while the WorkerGroup automatically handles data dispatch and collection based on the underlying parallelism.

Currently, the Model Engine supports the following backends (more backend maybe supported in future, e.g torchtitan):

Backend Parallelism Performance Support Model New Model Support Time
FSDP FSDP+SP Dense medium/MoE low all transformer models Day 0
MCore DP+TP+PP+EP+CP High see Megatron-Bridge support model list few weeks or month
VeOmni FSDP+SP+EP Medium see VeOmni support model list ~1 week
class BaseEngine:
    def initialize(self):
        """Instantiate or load the model, optimizer, and learning rate scheduler."""
        raise NotImplementedError

    def optimizer_zero_grad(self):
        """Zero the gradients of the optimizer."""
        raise NotImplementedError

    def optimizer_step(self):
        """Perform an optimization step using the optimizer."""
        raise NotImplementedError

    def lr_scheduler_step(self):
        """Advance the learning rate scheduler by one step."""
        raise NotImplementedError

    def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forward_only=False) -> Any:
        """Perform a forward pass and optionally a backward pass on a batch of data."""
        raise NotImplementedError

    def get_per_tensor_param(self) -> tuple[Generator[tuple[str, torch.Tensor], None, None], Optional[dict]]:
        """Get a generator that yields per-tensor parameters and optional peft config."""
        raise NotImplementedError

    def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True):
        """Move model parameters, optimizer states, or both to the specified device."""
        raise NotImplementedError

Rollout Engine

As LLM reinforcement learning evolves from single-turn, static tasks to multi-turn, dynamic, and interactive agentic tasks, the legacy SPMD rollout mode previously used by verl has become insufficient. Consequently, in verl v0.7, we have removed the SPMD rollout mode and switched to rollout server mode by default.

rollout_engine.png

In the server mode, the LLM server operates as online serving rather than the traditional offline batch inference. Clients send per-sample requests to the server, enabling the engine to utilize dynamic batching. This significantly enhances throughput efficiency for multi-turn conversation. Furthermore, the server-based approach eliminates the need for intrusive modifications to the LLM inference engine, allowing for the seamless integration of modern inference backends such as vLLM, SGLang, and TensorRT-LLM.

On the client side, verl introduces an extensible AgentLoop abstraction designed to define custom agentic task loops. This abstraction manages the cycle of requesting responses from the LLM server and interacting with external environments to obtain feedback. We provide two default implementations:

  • SingleTurnAgentLoop: Designed for standard single-turn tasks.
  • ToolAgentLoop: Designed for classic ReAct architectures involving multi-turn tool invocation.

Users can implement custom AgentLoop logic tailored to their specific needs, such as SWEAgentLoop or GUIAgentLoop.

class AgentLoopBase(ABC):
    @abstractmethod
    async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
        """Run agent loop to interact with LLM server and environment.

        Args:
            sampling_params (Dict[str, Any]): LLM sampling params.
            **kwargs: dataset fields from `verl.utils.dataset.RLHFDataset`.

        Returns:
            AgentLoopOutput: Agent loop output.
        """
        raise NotImplementedError

TransferQueue

As mentioned, verl uses a global single-controller RLTrainer to orchestrate the computation graph. A major limitation in the current implementation is that the RLTrainer handles both control and data flow, creating a bottleneck when dispatching data between components. This issue is amplified by the massive data volumes in multimodal training (images, video, audio) and complex algorithms like router replay, which requires transmitting large tensors per sample. Our earlier attempt to solve this using the Ray object store yielded poor performance due to the lack of tensor optimization and fine-grained column access.

transfer_queue.png

In v0.7, we experimentally introduced TransferQueue to decouple control flow from data flow. The RLTrainer now only dispatch instructions and metadata, while TransferQueue handles data transmission via reference passing. TransferQueue is specifically optimized for PyTorch tensors (supporting zero-copy and RDMA) and allows for backend extensions like ZeroMQ, NIXL, and Ray RDT. We plan to make this the default transmission method in v0.8.

# In PPOTrainer
def fit(self):
    batch = next(dataloader)
    gen_batch: BatchMeta = self.rollout_manager.generate_sequences(batch)
    output: BatchMeta = self.actor_rollout_wg.compute_log_prob(gen_batch)
    gen_batch = gen_batch.union(output)
    output = self.actor_rollout_wg.update_actor(gen_batch)

# In Worker
def compute_log_prob(self, batch: BatchMeta) -> BatchMeta:
    data = tq.get(batch)
    output = self.actor.infer_batch(data=data)
    return tq.put(output)

Checkpoint Engine

With the increase in LLM context lengths and the evolution of agentic tasks, the "long-tail" problem in rollout has become prominent, limiting the overall efficiency of RL training.

To mitigate this, a viable strategy is moving from on-policy synchronous training to off-policy asynchronous training, e.g Laminar, Areal, StreamRL, LlamaRL, PipelineRL. This involves separating the rollout and model engines onto different nodes (a disaggregated architecture, as opposed to colocated), with data transmitted via queues. This separation alleviates the rollout long-tail issue and enables rollout elastic scaling, fault tolerance, and heterogeneous hardware. However, it introduces a new challenge: efficient cross-node parameter synchronization.

checkpoint_engine.png

To address this, we introduce the Checkpoint Engine: a unified abstraction layer designed to synchronize weights between various training and inference backends.

  • It provides three unified APIs to implement the streaming transmission of parameters.
  • Users can extend the Transport Layer implementation based on their specific infrastructure requirements (device, network, local cache, etc.).

Currently, we provide two transport backends: NCCL (for broadcast collective communication) and NIXL (for P2P point-to-point communication).

class CheckpointEngine(ABC):
    @abstractmethod
    async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]):
        """Send the weights of the model.

        Args:
            weights: A generator that yields the name of the weight tensor and the tensor itself.
        """
        raise NotImplementedError

    @abstractmethod
    async def receive_weights(self) -> Generator[tuple[str, torch.Tensor], None, None]:
        """Receive the weights of the model.

        Yields:
            A tuple of the name of the weight tensor and the tensor itself.
        """
        raise NotImplementedError

verl-trainer

Building upon the four core components provided by verl-core, verl-trainer constructs several RL training pipelines tailored to specific scenarios. These pipelines are designed to address training efficiency challenges across varying scales and requirements:

On-policy (Synchronous)

  • Main Features: Executes rollout and training serially, typically sharing GPU resources (Colocate). It strictly adheres to standard on-policy algorithm definitions, where training must wait for all samples to be generated.
  • Scenarios: Best for baseline implementations, scenarios where strict algorithmic correctness is prioritized over training throughput.

One-step-off-policy (Async)

  • Main Features: Parallelizes generation and training by overlapping the current training step with the next batch's generation. It employs resource isolation and uses parameters from the previous step for rollout to minimize GPU idle time.
  • Scenarios: Ideal for scenarios requiring moderate efficiency gains (20%–40%) while maintaining training stability very close to strict on-policy methods.

Fully async (Decoupled & Streaming)

  • Main Features: Completely decouples the Trainer and Rollouter onto separate nodes. It utilizes streaming data transfer, staleness control, and partial rollout mechanisms to maximize throughput and mitigate long-tail generation latency.
  • Scenarios: Essential for large-scale training (e.g., 128+ GPUs) or complex reasoning tasks (e.g., long chain-of-thought) where generation latency significantly bottlenecks performance.
fully_async.png

roadmap

v0.7 release

Model Engine

  • Integrate Megatron-Bridge and support LoRA/PEFT, see blog post: How We Build Trillion Parameter Reasoning RL with 10% GPUs
  • Support experimental fp8 training for megatron backend
  • Support new model for megatron backend: GPT-OSS, Qwen3-Next
  • Comprehensive support for new mode engine, FSDP and Megatron engine are production ready.
    • Dispatch tensordict with nested tensor instead of padded DataProto
    • Add TrainingWorker that resembles Tinker-like API
    • Add VLM support for model engine, SFT and RL trainer
    • Add model engine based critic model
    • Implement ActorRolloutRefWorker by TrainingWorker, support different backend in one worker
  • New VeOmni engine added, still in alpha status.

Rollout Engine

  • Remove SPMD rollout mode
  • Support blockwise fp8 rollout for vllm and sglang; support online quant for vllm with torchao
  • Experimental router replay support for vllm
  • Optimize multi-modal data fetch and preprocess, support video input
  • Upgrade to vllm==0.12.0; sglang==0.5.6

Reward

  • Support hybrid reward scenarios, including generative, discriminative, rule-based rewards, and their combinations.
  • Refactor reward models into server mode, supporting both colocated and standalone deployments.
  • Introduce new reward managers to handle more complex scenarios, limited mode for request rate control and remote mode for CPU-intensive tasks.

Algorithm

  • Add CISPO: Clipped IS-weight Policy Optimization
  • Add SAPO: Soft Adaptive Policy Optimization

Recipe

  • [NEW] VLA: add experimental support for VLA model
  • [NEW] rhymerl: History Rhymes: Accelerating LLM Reinforcement Learning with RhymeRL
  • TransferQueue: support multiple data partition and optimize tensor zero-copy serialization
  • One-step-off-policy/Fully async: optimize weight synchronization by checkpoint engine with bucket and pipeline support.

v0.8

Model Engine

  • Deprecate DataProto by Tensordict for zero padding transmission
  • Switch default to new model engine, mark legacy engine (fsdp_workers.py, megatron_workers.py) as deprecated
  • Feature parity between new and legacy model engine: LoRA/PEFT, etc
  • Polish VeOmni engine to production ready status
  • Support MTP RL training
  • Optimize GPU memory for long context: fine-grained activation recompuation/offload
  • New model support: DeepSeek V3.2, etc

Rollout Engine

  • New rollout engine TensorRT-LLM
  • Separate vllm worker from trainer process, update weights by cuda ipc

TransferQueue

  • Merge TransferQueue recipe into main
  • Optimize e2e image/video vlm training pipeline by TransferQueue
  • Optimize router replay transmission by TransferQueue

Checkpoint Engine

  • Add checkpoint engine abstract interface
  • Add NCCL and NIXL transport backend
  • Add more transport backend

v0.9

Trainer

  • Merge Full async into main: refactor with verl-core component

Model Engine

  • Remove legacy model engine (fsdp_workers.py, megatron_workers.py)
  • Support omni-model RL training: Qwen3-Omni, BAGEL, etc

Rollout Engine

  • New rollout engine vllm-omni

More agentic training recipe

  • SWEAgent
  • GUIAgent