arithmetic-grpo / tests /single_controller /test_device_mesh_register.py
LeTue09's picture
initial clean commit
1faccd4
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import ray
import torch
from tensordict import TensorDict
import verl.utils.tensordict_utils as tu
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import make_nd_compute_dataproto_dispatch_fn, register
from verl.utils.device import get_device_name, get_nccl_backend
@ray.remote
class TestActor(Worker):
def __init__(self):
super().__init__()
import torch.distributed
torch.distributed.init_process_group(backend=get_nccl_backend())
self.infer_device_mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=get_device_name(), mesh_shape=[2, 4], mesh_dim_names=["dp", "tp"]
)
self.train_device_mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=get_device_name(), mesh_shape=[2, 2, 2], mesh_dim_names=["pp", "dp", "tp"]
)
self._register_dispatch_collect_info(
"infer",
dp_rank=self.infer_device_mesh["dp"].get_local_rank(),
is_collect=self.infer_device_mesh["tp"].get_local_rank() == 0,
)
self._register_dispatch_collect_info(
"train",
dp_rank=self.train_device_mesh["dp"].get_local_rank(),
is_collect=self.train_device_mesh["tp"].get_local_rank() == 0
and self.train_device_mesh["pp"].get_local_rank() == 1,
)
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="infer"))
def generate_data_proto(self, data: DataProto):
tp_rank = self.infer_device_mesh["tp"].get_local_rank()
dp_rank = self.infer_device_mesh["dp"].get_local_rank()
data.batch["a"] += (tp_rank + 1) * dp_rank
return data
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="infer"))
def generate_tensordict(self, data: TensorDict):
tp_rank = self.infer_device_mesh["tp"].get_local_rank()
dp_rank = self.infer_device_mesh["dp"].get_local_rank()
data["a"] += (tp_rank + 1) * dp_rank
return data
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"))
def train_data_proto(self, data: DataProto):
tp_rank = self.train_device_mesh["tp"].get_local_rank()
dp_rank = self.train_device_mesh["dp"].get_local_rank()
pp_rank = self.train_device_mesh["pp"].get_local_rank()
data.batch["a"] += (tp_rank + 1) * (dp_rank + 2) * (pp_rank + 3)
# tp rank 0, pp rank 1, dp rank 0, output data added: 8 + 3 = 11
# tp rank 0, pp rank 1, dp rank 1, output data added: 12 + 4 = 16
return data
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"))
def train_tensordict(self, data: TensorDict):
tp_rank = self.train_device_mesh["tp"].get_local_rank()
dp_rank = self.train_device_mesh["dp"].get_local_rank()
pp_rank = self.train_device_mesh["pp"].get_local_rank()
data["a"] += (tp_rank + 1) * (dp_rank + 2) * (pp_rank + 3)
# tp rank 0, pp rank 1, dp rank 0, output data added: 8 + 3 = 11
# tp rank 0, pp rank 1, dp rank 1, output data added: 12 + 4 = 16
return data
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="infer"))
def generate_nested_tensor(self, data: TensorDict):
tp_rank = self.infer_device_mesh["tp"].get_local_rank()
dp_rank = self.infer_device_mesh["dp"].get_local_rank()
assert data.shape[0] == 8
data["input_ids"] += tp_rank + dp_rank
print(data)
return data
def test_dist_global_info_wg():
# create a worker group with size 8
# register a infer dist info with tp=4, dp=2
# register a train dist info with tp=2, dp=2, pp=2
# test the correctness of data dispatch and computation
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
ray.init()
ray_cls = RayClassWithInitArgs(TestActor)
resource_pool = RayResourcePool(process_on_nodes=[8])
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls, device_name=get_device_name())
infer_input_data_proto = DataProto.from_single_dict(data={"a": torch.tensor([1, 2])})
infer_output_data_proto = wg.generate_data_proto(infer_input_data_proto)
assert wg._dispatch_info["infer"] == [0, 0, 0, 0, 1, 1, 1, 1]
assert torch.all(torch.eq(infer_output_data_proto.batch["a"], torch.tensor([1, 3])))
infer_input_tensordict = infer_input_data_proto.to_tensordict()
infer_output_tensordict = wg.generate_tensordict(infer_input_tensordict)
assert torch.all(torch.eq(infer_output_tensordict["a"], torch.tensor([1, 3])))
train_input_data_proto = DataProto.from_single_dict(data={"a": torch.tensor([3, 4])})
train_output_data_proto = wg.train_data_proto(train_input_data_proto)
assert wg._dispatch_info["train"] == [0, 0, 1, 1, 0, 0, 1, 1]
assert torch.all(torch.eq(train_output_data_proto.batch["a"], torch.tensor([11, 16])))
train_input_tensordict = train_input_data_proto.to_tensordict()
train_output_tensordict = wg.train_tensordict(train_input_tensordict)
assert torch.all(torch.eq(train_output_tensordict["a"], torch.tensor([11, 16])))
# create a batch size of input_ids
input_ids = [
torch.randint(low=0, high=128, size=(np.random.randint(low=1, high=10, dtype=np.int64),)) for _ in range(16)
]
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
data = tu.get_tensordict(tensor_dict={"input_ids": input_ids})
output = wg.generate_nested_tensor(data)
input_ids_chunked = list(input_ids.chunk(2))
print(input_ids_chunked)
input_ids_chunked[0] += 0
input_ids_chunked[1] += 1
expected = tu.concat_nested_tensors(input_ids_chunked)
assert torch.all(torch.eq(output["input_ids"].values(), expected.values()))
ray.shutdown()
if __name__ == "__main__":
test_dist_global_info_wg()