readctrl / code /RL_model /verl /verl_train /docs /single_controller.rst
shahidul034's picture
Add files using upload-large-folder tool
c7a6fe6 verified
The Design of ``verl.single_controller``
==============================================
Last updated: 05/21/2025.
**Author:**\ `Wang Zhang <https://github.com/zw0610>`__
Preface
-------
We prepared this document for developers of ``verl``, particularly those
interested in understanding or contributing to the
``verl.single_controller`` module. It is not intended for end users, but
for contributors seeking to understand the architectural rationale and
internal mechanics.
--------------
Origin
------
The ``single_controller`` module originated from a request I received —
to adapt a toy single-process RLHF script into a distributed system with
minimal changes, while maintaining ease of debugging.
Common practice — such as using PyTorch’s Distributed Data Parallel
(DDP) — typically involves wrapping ``nn.Module`` and launching multiple
processes that execute the same function under different ranks. However,
this approach presents two main limitations in the context of
distributed RLHF: - Difficulty representing multiple DAGs as required by
PPO; - Difficulty inspecting intermediate tensors during training.
To maintain debuggability, we opted for a different approach — breaking
the training loop into well-defined stages like ``generate_sequences``,
``compute_advantages``, and so on.
We selected `Ray <https://www.ray.io/>`__ as the initial backend for
``verl`` due to its ability to expose Python class methods as RPC
endpoints. However, Ray’s default model only supports **one method call,
one RPC**, while training LLMs typically requires coordination across
multiple processes.
To hide this multi-Ray actors invocation for a single method from users,
we introduced the following components:
- ``WorkerGroup`` – manages a group of remote workers and provides
a unified interface for multi-process distributed computation;
- ``ResourcePool`` – binds computational resources to worker
processes;
- ``ClassWithArgs`` – enables delayed remote instantiation with
specified initialization arguments.
--------------
A Running Example: ``generate_sequences``
-----------------------------------------
To illustrate the design, we walk through how the ``generate_sequences``
method in the ``ActorRolloutRefWorker`` class is registered and invoked
across distributed workers.
--------------
Step 1: Register with a Decorator
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The first step is to define the ``generate_sequences`` and decorate it
with ``@register`` as it will be called in driver script.
**Source:**
`fsdp_workers.py <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/workers/fsdp_workers.py#L528>`__
.. code:: python
class ActorRolloutRefWorker(Worker):
...
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
prompts = prompts.to(torch.cuda.current_device())
...
The ``@register`` decorator adds metadata to the ``generate_sequences``
method. Currently, it doesn’t alter functionality, but attaches
attributes via a magic key (``MAGIC_ATTR``):
**Source:**
`decorator.py <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/decorator.py#L411>`__
.. code:: python
def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True):
...
def decorator(func):
@wraps(func)
def inner(*args, **kwargs):
if materialize_futures:
args, kwargs = _materialize_futures(*args, **kwargs)
return func(*args, **kwargs)
attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking}
setattr(inner, MAGIC_ATTR, attrs)
return inner
return decorator
As the code shows, values of ``dispatch_mode``, ``execute_mode`` and
``blocking`` is attached the ``generate_sequences`` method.
--------------
Step 2: Binding During Initialization
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
These attached attributes are extracted and utilized when
``ActorRolloutRefWorker``, wrapped in a ``RayClassWithArgs``, is passed
into a ``RayWorkerGroup``.
**Source:**
`main_generation.py <https://github.com/volcengine/verl/blob/4ae9a0fdab229f75f080e9478807783ed4c97154/verl/trainer/main_generation.py#L82>`__
.. code:: python
ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout")
resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
During the
`initialization <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/ray/base.py#L184>`__
of ``RayWorkerGroup``, two key steps occur:
1. Worker instances (Ray actors) are created:
`RayWorkerGroup._init_with_resource_pool <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/ray/base.py#L211>`__
2. Methods decorated with ``@register`` are bound to ``RayWorkerGroup``:
`RayWorkerGroup._bind_worker_method <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/ray/base.py#L214>`__
.. figure:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/worker_group_init.png?raw=true
:alt: initialization_and_binding_of_worker_group
initialization_and_binding_of_worker_group
The binding procedure is the heart of ``verl.single_controller``.
**Key function:**
`WorkerGroup._bind_worker_method <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/worker_group.py#L143>`__
.. code:: python
def _bind_worker_method(self, user_defined_cls, func_generator):
...
for method_name in dir(user_defined_cls):
try:
method = getattr(user_defined_cls, method_name)
assert callable(method)
except Exception:
continue # Skip properties
<<<to be continue 1>>>
When a method has the ``MAGIC_ATTR``, the attributes set by
``@register`` are extracted:
.. code:: python
<<<continue 1>>>
if hasattr(method, MAGIC_ATTR):
attribute = getattr(method, MAGIC_ATTR)
dispatch_mode = attribute["dispatch_mode"]
execute_mode = attribute["execute_mode"]
blocking = attribute["blocking"]
<<<to be continue 2>>>
As show in the flow chart above, these attributes are fed into
``func_generator``. However, ``func_generator`` takes ``method_name``,
``dispatch_fn``, ``collect_fn``, ``execute_fn``, ``blocking``. We need
to find the corresponding ``dispatch_fn`` and ``collect_fn`` associated
with the ``dispatch_mode`` (``DP_COMPUTE_PROTO``) from
`DISPATCH_MODE_FN_REGISTRY <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/decorator.py#L387>`__:
.. code:: python3
DISPATCH_MODE_FN_REGISTRY = {
Dispatch.ONE_TO_ALL: {
"dispatch_fn": dispatch_one_to_all,
"collect_fn": collect_all_to_all,
},
...
Dispatch.DP_COMPUTE_PROTO: {
"dispatch_fn": dispatch_dp_compute_data_proto,
"collect_fn": collect_dp_compute_data_proto,
},
...
}
Similarly, the ``execute_fn`` is selected by ``execute_mode`` and
extracted by:
.. code:: python
<<<continue 2>>>
# get execute_fn_name
execute_mode = get_predefined_execute_fn(execute_mode=execute_mode)
wg_execute_fn_name = execute_mode["execute_fn_name"]
# get execute_fn from string
try:
execute_fn = getattr(self, wg_execute_fn_name)
assert callable(execute_fn), "execute_fn must be callable"
except Exception:
print(f"execute_fn {wg_execute_fn_name} is invalid")
raise
<<<to be continue 3>>>
In this ``generate_sequences`` cases: -
``dispatch_mode = Dispatch.DP_COMPUTE_PROTO`` -
``dispatch_fn = dispatch_dp_compute_data_proto`` -
``collect_fn = collect_dp_compute_data_proto`` -
``execute_fn = RayWorkerGroup.execute_all``
ONE_TO_ALL v.s. DP_COMPUTE_PROTO
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
``dispatch_mode`` is associated with a ``dispatch_fn`` and a
``collect_fn``. As the name implies, ``dispatch_fn`` processes the input
arguments in ``WorkerGroup`` and generate a batch (list) of input
arguments, each of which will be fed into a worker attached to the
``WorkerGroup``.
``dispatch_fn`` of ``ONE_TO_ALL`` is
`dispatch_one_to_all <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/decorator.py#L119>`__,
which just duplicates all the input arguments into N replicas, where N
equals the number of Workers attached to the ``worker_group``:
.. code:: python
def dispatch_one_to_all(worker_group, *args, **kwargs):
args = tuple([arg] * worker_group.world_size for arg in args)
kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
return args, kwargs
``dispatch_fn`` of ``DP_COMPUTE_PROTO`` is
`dispatch_dp_compute_data_proto <https://github.com/volcengine/verl/blob/c59ab2f4788f9a910836a9f2f53dcdb62dfa314e/verl/single_controller/base/decorator.py#L350>`__,
which uses ``DataProto.chunk`` to split a large ``DataProto`` into N
smaller ``DataProto``, where N equals the world_size (number of the
workers) of the ``worker_group``:
.. code:: python
def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
from verl.single_controller.base.worker_group import WorkerGroup
assert isinstance(worker_group, WorkerGroup)
# Note: enable auto padding for dp compute DatapProto
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto_with_auto_padding(
worker_group.world_size,
*args,
**kwargs,
)
return splitted_args, splitted_kwargs
The ``collect_fn`` follows the same pattern and process a batch (list)
of returned value from all workers of a ``WorkerGroup`` and merge it
into a list as ``collect_all_to_all`` does or a large ``DataProto`` as
``collect_dp_compute_data_proto`` does.
Finally, a new method is dynamically generated using ``func_generator``
and added to the ``WorkerGroup`` instance:
.. code:: python
<<<continue 3>>>
# bind a new method to the RayWorkerGroup
func = func_generator(
self,
method_name,
dispatch_fn=dispatch_fn,
collect_fn=collect_fn,
execute_fn=execute_fn,
blocking=blocking,
)
try:
setattr(self, method_name, func)
method_names.append(method_name)
except Exception as e:
raise ValueError(f"Fail to set method_name {method_name}") from e
This makes the method invocable via the ``WorkerGroup`` interface.
--------------
Step 3: Call Chain
~~~~~~~~~~~~~~~~~~
All the machinery above ensures that distributed calls feel identical to
single-process ones. In the original single-process script, the code
looks like:
.. code:: python
rollout = Rollout()
rollout.generate_sequences(batch)
With ``verl``, the multiprocess program becomes:
.. code:: python
rollout = RayWorkerGroup(resource_pool=[4], RayClassWithArgs(Rollout))
rollout.generate_sequences(batch)
.. figure:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/call_generate_sequences.png?raw=true
:alt: call_chain_of_generate_sequences
call_chain_of_generate_sequences
Behind this simple call: - ``dispatch_fn`` splits input across workers -
``execute_fn`` performs the actual remote invocation - ``collect_fn``
gathers the results
All of this is abstracted away, enabling developers to write distributed
code with minimal changes to their existing logic.
--------------
Beyond RL Post-Training: Generalizing ``verl.single_controller``
----------------------------------------------------------------
The ``verl.single_controller`` module generalizes well beyond
reinforcement learning. It provides a clean abstraction to batch-process
remote method calls, with automatic input/output handling.
By minimizing the gap between single-process and multi-process scripts,
``verl.single_controller`` opens the door to distributed computing in
broader domains — not limited to RL post-training.
We hope this design inspires more examples and extensions from the
community.