| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import inspect |
| from functools import partial, wraps |
| from types import FunctionType |
|
|
| from verl.protocol import DataProtoFuture, _padding_size_key |
| from verl.utils.py_functional import DynamicEnum |
|
|
| |
| MAGIC_ATTR = "attrs_3141562937" |
|
|
|
|
| class Dispatch(DynamicEnum): |
| """Enum class defining different dispatch modes for distributed computation. |
| |
| Each mode represents a specific strategy for distributing data across |
| different ranks in a distributed system. The modes are used to control |
| how data is partitioned and processed across different worker groups. |
| """ |
|
|
| _registry = {} |
| _next_value = 0 |
|
|
|
|
| def init_predefined_dispatch_mode(): |
| Dispatch.register("RANK_ZERO") |
| Dispatch.register("ONE_TO_ALL") |
| Dispatch.register("ALL_TO_ALL") |
| Dispatch.register("DP_COMPUTE") |
| Dispatch.register("DP_COMPUTE_PROTO") |
| Dispatch.register("DP_COMPUTE_PROTO_WITH_FUNC") |
| Dispatch.register("DP_COMPUTE_METRIC") |
| |
| Dispatch.register("DIRECT_ROLLOUT_METHOD") |
|
|
|
|
| class Execute(DynamicEnum): |
| """Enum class defining different execution modes for distributed computation. |
| |
| These modes control how a function should be executed across different ranks |
| in a distributed system. |
| """ |
|
|
| _registry = {} |
| _next_value = 0 |
|
|
|
|
| def init_predefined_execute_mode(): |
| Execute.register("ALL") |
| Execute.register("RANK_ZERO") |
|
|
|
|
| |
| init_predefined_dispatch_mode() |
| init_predefined_execute_mode() |
|
|
|
|
| def _split_args_kwargs_data_proto(chunks, *args, **kwargs): |
| from verl.protocol import BatchData |
|
|
| splitted_args = [] |
| for arg in args: |
| assert BatchData(arg).is_chunkable(), f"arg of type {type(arg)} is not chunkable" |
| chunked_arg = BatchData(arg).chunk(chunks=chunks) |
| assert len(chunked_arg) == chunks |
| splitted_args.append(chunked_arg) |
|
|
| splitted_kwargs = {} |
| for key, val in kwargs.items(): |
| assert BatchData(val).is_chunkable(), f"kwarg '{key}' of type {type(val)} is not chunkable" |
| chunked_kwarg = BatchData(val).chunk(chunks=chunks) |
| assert len(chunked_kwarg) == chunks |
| splitted_kwargs[key] = chunked_kwarg |
|
|
| return splitted_args, splitted_kwargs |
|
|
|
|
| def _split_args_kwargs_data_proto_with_auto_padding(chunks, *args, **kwargs): |
| from verl.protocol import DataProto, DataProtoFuture |
|
|
| data_proto_len = None |
| padding_size = None |
|
|
| def _padding_and_split_data(obj, chunks): |
| nonlocal data_proto_len, padding_size |
| assert isinstance(obj, DataProto | DataProtoFuture) |
| if isinstance(obj, DataProto) and obj.is_padding_enabled(): |
| |
| if data_proto_len is None: |
| data_proto_len = len(obj) |
| padding_size = (chunks - (data_proto_len % chunks)) if (data_proto_len % chunks > 0) else 0 |
| else: |
| assert data_proto_len == len(obj), ( |
| f"expecting all arg share same length of {data_proto_len}, but got {len(obj)}" |
| ) |
| obj.padding(padding_size=padding_size) |
| return obj.chunk(chunks=chunks) |
|
|
| splitted_args = [_padding_and_split_data(arg, chunks) for arg in args] |
| splitted_kwargs = {key: _padding_and_split_data(val, chunks) for key, val in kwargs.items()} |
| if padding_size is not None: |
| splitted_kwargs[_padding_size_key] = padding_size |
|
|
| return splitted_args, splitted_kwargs |
|
|
|
|
| def dispatch_one_to_all(worker_group, *args, **kwargs): |
| args = tuple([arg] * worker_group.world_size for arg in args) |
| kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()} |
| return args, kwargs |
|
|
|
|
| def dummy_direct_rollout_call(worker_group, *args, **kwargs): |
| raise NotImplementedError("Direct rollout call is forbidden.") |
|
|
|
|
| def dispatch_all_to_all(worker_group, *args, **kwargs): |
| return args, kwargs |
|
|
|
|
| def collect_all_to_all(worker_group, output): |
| return output |
|
|
|
|
| def _concat_data_proto_or_future(output: list): |
| from verl.protocol import BatchData |
|
|
| |
| for o in output: |
| assert type(o) is type(output[0]) |
|
|
| return BatchData(output).concat() |
|
|
|
|
| def dispatch_dp_compute(worker_group, *args, **kwargs): |
| from verl.single_controller.base.worker_group import WorkerGroup |
|
|
| assert isinstance(worker_group, WorkerGroup) |
| for arg in args: |
| assert isinstance(arg, tuple | list) and len(arg) == worker_group.world_size |
| for k, v in kwargs.items(): |
| assert isinstance(v, tuple | list) and len(v) == worker_group.world_size |
| return args, kwargs |
|
|
|
|
| def collect_dp_compute(worker_group, output): |
| from verl.single_controller.base.worker_group import WorkerGroup |
|
|
| assert isinstance(worker_group, WorkerGroup) |
| assert len(output) == worker_group.world_size |
| return output |
|
|
|
|
| def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs): |
| from verl.single_controller.base.worker_group import WorkerGroup |
|
|
| assert isinstance(worker_group, WorkerGroup) |
| |
| splitted_args, splitted_kwargs = _split_args_kwargs_data_proto_with_auto_padding( |
| worker_group.world_size, |
| *args, |
| **kwargs, |
| ) |
| return splitted_args, splitted_kwargs |
|
|
|
|
| def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs): |
| from verl.single_controller.base.worker_group import WorkerGroup |
|
|
| assert isinstance(worker_group, WorkerGroup) |
| assert isinstance(args[0], FunctionType) |
|
|
| splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs) |
| splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args |
| return splitted_args_with_func, splitted_kwargs |
|
|
|
|
| def collect_dp_compute_data_proto(worker_group, output): |
| from verl.protocol import BatchData |
|
|
| assert BatchData(output).is_concatable(), ( |
| f"expecting concatable output, but got element type {type(output[0]) if output else 'empty'}" |
| ) |
|
|
| output = collect_dp_compute(worker_group, output) |
| return _concat_data_proto_or_future(output) |
|
|
|
|
| def dispatch_nd_compute(dp_rank_mapping: list[int], dp_size, worker_group, *args, **kwargs): |
| import os |
|
|
| from verl.single_controller.base.worker_group import WorkerGroup |
| from verl.utils.ray_utils import parallel_put |
|
|
| assert isinstance(worker_group, WorkerGroup) |
|
|
| max_workers = max(1, min(len(args[0]), os.cpu_count())) |
|
|
| args = [parallel_put(arg, max_workers=max_workers) for arg in args] |
| kwargs = {k: parallel_put(v, max_workers=max_workers) for k, v in kwargs.items()} |
|
|
| all_args = [] |
| for arg in args: |
| assert isinstance(arg, tuple | list) and len(arg) == dp_size |
| transformed_args = [] |
| for i in range(worker_group.world_size): |
| local_dp_rank = dp_rank_mapping[i] |
| transformed_args.append(arg[local_dp_rank]) |
| all_args.append(transformed_args) |
| all_args = tuple(all_args) |
|
|
| all_kwargs = {} |
| for k, v in kwargs.items(): |
| assert isinstance(v, tuple | list) and len(v) == dp_size |
| transformed_v = [] |
| for i in range(worker_group.world_size): |
| local_dp_rank = dp_rank_mapping[i] |
| transformed_v.append(v[local_dp_rank]) |
| all_kwargs[k] = transformed_v |
| return all_args, all_kwargs |
|
|
|
|
| def collect_nd_compute(collect_mask: list[bool], worker_group, output): |
| from verl.single_controller.base.worker_group import WorkerGroup |
|
|
| assert isinstance(worker_group, WorkerGroup) |
| assert len(output) == worker_group.world_size |
|
|
| output_in_dp = [] |
| for global_rank in range(worker_group.world_size): |
| collect_dp_rank = collect_mask[global_rank] |
| if collect_dp_rank: |
| output_in_dp.append(output[global_rank]) |
| return output_in_dp |
|
|
|
|
| def dispatch_nd_compute_dataproto(dp_rank_mapping: list[int], dp_size, worker_group, *args, **kwargs): |
| splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(dp_size, *args, **kwargs) |
| return dispatch_nd_compute(dp_rank_mapping, dp_size, worker_group, *splitted_args, **splitted_kwargs) |
|
|
|
|
| def collect_nd_compute_dataproto(collect_mask: list[bool], worker_group, output): |
| output = collect_nd_compute(collect_mask, worker_group, output) |
|
|
| from verl.protocol import BatchData |
|
|
| assert BatchData(output).is_concatable(), ( |
| f"expecting concatable output, but got element type {type(output[0]) if output else 'empty'}" |
| ) |
| return _concat_data_proto_or_future(output) |
|
|
|
|
| def dispatch_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs): |
| from verl.single_controller.base.worker_group import WorkerGroup |
|
|
| assert isinstance(worker_group, WorkerGroup) |
|
|
| |
| if mesh_name not in worker_group._dispatch_info: |
| worker_group._dispatch_info[mesh_name] = worker_group._query_dispatch_info(mesh_name) |
| assert len(worker_group._dispatch_info[mesh_name]) == worker_group.world_size |
|
|
| dp_rank_mapping = worker_group._dispatch_info[mesh_name] |
| |
| dp_size = max(dp_rank_mapping) + 1 |
| return dispatch_nd_compute_dataproto(dp_rank_mapping, dp_size, worker_group, *args, **kwargs) |
|
|
|
|
| def collect_lazy_compute_data_proto(mesh_name, worker_group, *args, **kwargs): |
| from verl.single_controller.base.worker_group import WorkerGroup |
|
|
| assert isinstance(worker_group, WorkerGroup) |
|
|
| |
| assert mesh_name in worker_group._dispatch_info |
|
|
| if mesh_name not in worker_group._collect_info: |
| worker_group._collect_info[mesh_name] = worker_group._query_collect_info(mesh_name) |
| assert len(worker_group._collect_info[mesh_name]) == worker_group.world_size |
|
|
| |
| collect_mask = worker_group._collect_info[mesh_name] |
| |
| return collect_nd_compute_dataproto(collect_mask, worker_group, *args, **kwargs) |
|
|
|
|
| def make_nd_compute_dataproto_dispatch_fn(mesh_name): |
| return { |
| "dispatch_fn": partial(dispatch_lazy_compute_data_proto, mesh_name), |
| "collect_fn": partial(collect_lazy_compute_data_proto, mesh_name), |
| } |
|
|
|
|
| |
| DISPATCH_MODE_FN_REGISTRY = { |
| Dispatch.ONE_TO_ALL: { |
| "dispatch_fn": dispatch_one_to_all, |
| "collect_fn": collect_all_to_all, |
| }, |
| Dispatch.ALL_TO_ALL: { |
| "dispatch_fn": dispatch_all_to_all, |
| "collect_fn": collect_all_to_all, |
| }, |
| Dispatch.DP_COMPUTE: {"dispatch_fn": dispatch_dp_compute, "collect_fn": collect_dp_compute}, |
| Dispatch.DP_COMPUTE_PROTO: { |
| "dispatch_fn": dispatch_dp_compute_data_proto, |
| "collect_fn": collect_dp_compute_data_proto, |
| }, |
| Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: { |
| "dispatch_fn": dispatch_dp_compute_data_proto_with_func, |
| "collect_fn": collect_dp_compute_data_proto, |
| }, |
| Dispatch.DP_COMPUTE_METRIC: {"dispatch_fn": dispatch_dp_compute_data_proto, "collect_fn": collect_dp_compute}, |
| Dispatch.DIRECT_ROLLOUT_METHOD: { |
| "dispatch_fn": dummy_direct_rollout_call, |
| "collect_fn": dummy_direct_rollout_call, |
| }, |
| } |
|
|
|
|
| def get_predefined_dispatch_fn(dispatch_mode): |
| return DISPATCH_MODE_FN_REGISTRY[dispatch_mode] |
|
|
|
|
| def register_dispatch_mode(dispatch_mode_name, dispatch_fn, collect_fn): |
| """ |
| Register a new dispatch mode. |
| """ |
| dispatch_mode = Dispatch.register(dispatch_mode_name) |
| _check_dispatch_mode(dispatch_mode) |
| assert dispatch_mode not in DISPATCH_MODE_FN_REGISTRY, f"dispatch_mode_name {dispatch_mode_name} already exists" |
| DISPATCH_MODE_FN_REGISTRY[dispatch_mode] = {"dispatch_fn": dispatch_fn, "collect_fn": collect_fn} |
|
|
|
|
| def update_dispatch_mode(dispatch_mode, dispatch_fn, collect_fn): |
| """ |
| Update the dispatch mode. |
| """ |
| _check_dispatch_mode(dispatch_mode) |
| assert dispatch_mode in DISPATCH_MODE_FN_REGISTRY, f"dispatch_mode {dispatch_mode} not found" |
| DISPATCH_MODE_FN_REGISTRY[dispatch_mode] = {"dispatch_fn": dispatch_fn, "collect_fn": collect_fn} |
|
|
|
|
| def get_predefined_execute_fn(execute_mode): |
| """ |
| Note that here we only asks execute_all and execute_rank_zero to be implemented |
| Leave the choice of how these two functions handle argument 'blocking' to users |
| """ |
| predefined_execute_mode_fn = { |
| Execute.ALL: {"execute_fn_name": "execute_all"}, |
| Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"}, |
| } |
| return predefined_execute_mode_fn[execute_mode] |
|
|
|
|
| def _check_dispatch_mode(dispatch_mode): |
| assert isinstance(dispatch_mode, Dispatch | dict), ( |
| f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}" |
| ) |
| if isinstance(dispatch_mode, dict): |
| necessary_keys = ["dispatch_fn", "collect_fn"] |
| for key in necessary_keys: |
| assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary" |
|
|
|
|
| def _check_execute_mode(execute_mode): |
| assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}" |
|
|
|
|
| def _materialize_futures(*args, **kwargs): |
| new_args = [] |
| for arg in args: |
| if isinstance(arg, DataProtoFuture): |
| arg = arg.get() |
| |
| new_args.append(arg) |
| for k, v in kwargs.items(): |
| if isinstance(v, DataProtoFuture): |
| kwargs[k] = v.get() |
|
|
| new_args = tuple(new_args) |
| return new_args, kwargs |
|
|
|
|
| def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True): |
| """Register a function with distributed execution configuration. |
| |
| This decorator registers a function with specific dispatch and execution modes |
| for distributed computation. It handles both synchronous and asynchronous |
| functions, and optionally materializes futures before execution. |
| |
| Args: |
| dispatch_mode: |
| Dispatch mode for computation distribution. Default: Dispatch.ALL_TO_ALL. |
| execute_mode: |
| Execute mode for computation distribution. Default: Execute.ALL. |
| blocking: |
| Whether the execution should be blocking. Defaults to True. |
| materialize_futures: |
| Whether to materialize the data before dispatching. Defaults to True. |
| |
| Returns: |
| A decorator that wraps the original function with distributed execution |
| configuration. |
| """ |
|
|
| _check_dispatch_mode(dispatch_mode=dispatch_mode) |
| _check_execute_mode(execute_mode=execute_mode) |
|
|
| def decorator(func): |
| @wraps(func) |
| def inner(*args, **kwargs): |
| if materialize_futures: |
| args, kwargs = _materialize_futures(*args, **kwargs) |
| return func(*args, **kwargs) |
|
|
| @wraps(func) |
| async def async_inner(*args, **kwargs): |
| if materialize_futures: |
| args, kwargs = _materialize_futures(*args, **kwargs) |
| return await func(*args, **kwargs) |
|
|
| wrapper = async_inner if inspect.iscoroutinefunction(func) else inner |
| attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking} |
| setattr(wrapper, MAGIC_ATTR, attrs) |
| return wrapper |
|
|
| return decorator |
|
|