| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| In this test, we instantiate a data parallel worker with 8 GPUs |
| """ |
|
|
| from verl.single_controller.base import Worker |
| from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool |
|
|
| from verl.single_controller.base.decorator import Dispatch, register |
|
|
| import ray |
| import torch |
|
|
| from torch import distributed as dist |
|
|
| from verl import DataProto |
| from verl.utils.ray_utils import parallel_put |
|
|
| from codetiming import Timer |
|
|
| import tensordict |
|
|
|
|
| @ray.remote |
| class DummyWorker(Worker): |
|
|
| def __init__(self): |
| super().__init__() |
| dist.init_process_group() |
|
|
| @register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False) |
| def do_nothing(self, data): |
| for key in data.batch.keys(): |
| data.batch[key] += 1 |
| if tensordict.__version__ >= '0.5.0': |
| data.batch = data.batch.consolidate() |
| return data |
|
|
|
|
| def test_data_transfer(): |
| ray.init() |
| |
| resource_pool = RayResourcePool([8]) |
| cls_with_init = RayClassWithInitArgs(cls=DummyWorker) |
| |
| wg = RayWorkerGroup(resource_pool, cls_with_init) |
|
|
| |
| batch_size = 4096 |
| seqlen = 32768 |
|
|
| data_dict = {} |
|
|
| for i in range(2): |
| data_dict[str(i)] = torch.randint(0, 10000, (batch_size, seqlen)) |
|
|
| data = DataProto.from_dict(tensors=data_dict) |
|
|
| print(data) |
|
|
| |
| data_list = data.chunk(wg.world_size) |
|
|
| for i in range(wg.world_size): |
| |
| if tensordict.__version__ >= '0.5.0': |
| data_list[i].batch = data_list[i].batch.consolidate() |
|
|
| with Timer(name='ray.pickle', initial_text=True): |
| for i in range(wg.world_size): |
| ray.cloudpickle.pickle.dumps(data_list[i]) |
|
|
| with Timer(name='raw.pickle', initial_text=True): |
| import pickle |
| for i in range(wg.world_size): |
| pickle.dumps(data_list[i]) |
|
|
| |
| with Timer(name='put', initial_text=True): |
| |
| data_list_ref = parallel_put(data_list) |
| |
| |
|
|
| with Timer(name='launch', initial_text=True): |
| output_ref = wg.do_nothing(data_list_ref) |
|
|
| with Timer(name='get', initial_text=True): |
| |
| output_lst = ray.get(output_ref) |
|
|
| for input_data, output_data in zip(data_list, output_lst): |
| for key in input_data.batch.keys(): |
| assert torch.all(torch.eq(input_data.batch[key] + 1, |
| output_data.batch[key])), (input_data.batch[key], output_data.batch[key], key) |
|
|
| ray.shutdown() |
|
|