| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import asyncio |
| import gc |
| import logging |
| import os |
| import time |
| from typing import Any, AsyncGenerator, Generator |
|
|
| import ray |
| import torch |
| from mooncake.engine import TransferEngine |
| from vllm.distributed.utils import StatelessProcessGroup |
|
|
| from verl.checkpoint_engine.base import CheckpointEngine, CheckpointEngineRegistry, TensorMeta |
| from verl.utils.device import get_torch_device |
| from verl.utils.net_utils import get_free_port |
|
|
| logger = logging.getLogger(__name__) |
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) |
|
|
|
|
| @CheckpointEngineRegistry.register("mooncake") |
| class MooncakeCheckpointEngine(CheckpointEngine): |
| """Mooncake checkpoint engine with p2p communication using TransferEngine |
| |
| Args: |
| bucket_size (int): Bucket size in bytes to transfer multiple weights at one time. |
| device (str): The device to use for the checkpoint engine, "cpu" or "cuda". |
| rollout_dtype (torch.dtype): The dtype of the weights received from rollout workers. |
| device_name (str): Mooncake device name filter. |
| """ |
|
|
| def __init__( |
| self, |
| bucket_size: int, |
| device: str = "cuda", |
| rollout_dtype: torch.dtype = torch.bfloat16, |
| device_name: str = "", |
| is_master: bool = False, |
| rebuild_group: bool = False, |
| ): |
| self.bucket_size = bucket_size |
| self.device = device |
| self.rollout_dtype = rollout_dtype |
| self.is_master = is_master |
| self.rebuild_group = rebuild_group |
|
|
| rank = int(os.environ["RANK"]) |
| device_count = get_torch_device().device_count() |
| local_rank = rank % device_count |
| get_torch_device().set_device(local_rank) |
|
|
| self.engine = TransferEngine() |
| hostname = ray.util.get_node_ip_address().strip("[]") |
| ret = self.engine.initialize( |
| hostname, |
| "P2PHANDSHAKE", |
| "ascend_direct" if self.device == "npu" else "rdma", |
| device_name, |
| ) |
| assert ret == 0, f"TransferEngine initialize failed ret={ret}" |
|
|
| rpc_port = self.engine.get_rpc_port() |
| self.session_id = f"{hostname}:{rpc_port}" |
| self.hostname = hostname |
|
|
| self.buf = torch.empty(2 * self.bucket_size, dtype=torch.uint8, device=self.device) |
| self.magic_buf = torch.empty(4 * 1024, dtype=torch.uint8, device=self.device) |
| ret = self.engine.batch_register_memory( |
| [self.buf.data_ptr(), self.magic_buf.data_ptr()], |
| [2 * self.bucket_size, 4 * 1024], |
| ) |
| assert ret == 0, f"batch_register_memory failed ret={ret}" |
| logger.info(f"__init__ session_id={self.session_id}") |
|
|
| def prepare(self) -> dict[str, Any]: |
| """Prepare send and recv buckets""" |
| logger.info( |
| f"prepare ptr={self.buf.data_ptr():#x} len={2 * self.bucket_size} " |
| f"magic_buf_ptr={self.magic_buf.data_ptr():#x}" |
| ) |
| port, _ = get_free_port(self.hostname) |
| return {"addr": self.hostname, "port": port} |
|
|
| @classmethod |
| def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadatas: list[dict]): |
| trainer_kwargs = { |
| "rank": [0] + [-1] * (trainer_world_size - 1), |
| "world_size": [rollout_world_size + 1] * trainer_world_size, |
| "metadata": [metadatas[0]] * trainer_world_size, |
| } |
| rollout_kwargs = { |
| "rank": list(range(1, rollout_world_size + 1)), |
| "world_size": [rollout_world_size + 1] * rollout_world_size, |
| "metadata": [metadatas[0]] * rollout_world_size, |
| } |
| return trainer_kwargs, rollout_kwargs |
|
|
| def init_process_group(self, rank: int, world_size: int, metadata: dict[str, Any]): |
| self.rank = rank |
| self.world_size = world_size |
| if rank < 0: |
| logger.info(f"init_process_group rank={rank}") |
| return |
|
|
| self.store = StatelessProcessGroup.create( |
| host=metadata["addr"], |
| port=metadata["port"], |
| rank=rank, |
| world_size=world_size, |
| ) |
|
|
| info = { |
| "session_id": self.session_id, |
| "ptr": self.buf.data_ptr(), |
| } |
|
|
| info_list = self.store.all_gather_obj(info) |
| self.buffer_info = None if rank == 0 else info_list[rank - 1] |
|
|
| logger.info(f"init_process_group rank={rank} world_size={world_size} buffer_info={self.buffer_info}") |
|
|
| def finalize(self): |
| """Cleanup communication and deregister memory""" |
| self.store = None |
| get_torch_device().empty_cache() |
| gc.collect() |
| logger.info(f"finalize rank={self.rank}") |
|
|
| async def wait_for_complete(self, buf: torch.Tensor): |
| magic = torch.tensor([0xAB, 0xDC, 0xEF, 0x88], dtype=torch.uint8, device=self.device) |
| while True: |
| if torch.equal(buf[:4], magic): |
| break |
| await asyncio.sleep(0) |
|
|
| @torch.no_grad() |
| async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): |
| """Send weights using Mooncake TransferEngine""" |
| if self.rank < 0: |
| for name, weight in weights: |
| pass |
| logger.info(f"send_weights rank={self.rank}") |
| return |
|
|
| total_bytes = 0 |
| start_time = time.time() |
| bucket_meta: dict[str, TensorMeta] = {} |
| offset = 0 |
| should_wait = False |
| bufs = [self.buf[: self.bucket_size], self.buf[self.bucket_size :]] |
| idx = 0 |
| current = bufs[idx] |
|
|
| for name, weight in weights: |
| weight = weight.to(self.rollout_dtype) |
|
|
| if offset + weight.nbytes > self.bucket_size: |
| total_bytes += offset |
| get_torch_device().synchronize() |
| info = { |
| "bucket_meta": bucket_meta, |
| "ptr": current.data_ptr(), |
| "len": offset, |
| "is_last": False, |
| } |
| |
| self.store.send_obj(info, 1) |
|
|
| idx ^= 1 |
| current = bufs[idx] |
| bucket_meta = {} |
| offset = 0 |
|
|
| if should_wait: |
| await self.wait_for_complete(current) |
| should_wait = True |
|
|
| assert offset + weight.nbytes <= self.bucket_size, ( |
| f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." |
| ) |
|
|
| bucket_meta[name] = { |
| "name": name, |
| "shape": weight.shape, |
| "dtype": weight.dtype, |
| "offset": offset, |
| } |
| current[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) |
| offset += weight.nbytes |
|
|
| get_torch_device().synchronize() |
| info = { |
| "bucket_meta": bucket_meta, |
| "ptr": current.data_ptr(), |
| "len": offset, |
| "is_last": True, |
| } |
| self.store.send_obj(info, 1) |
| await self.wait_for_complete(current) |
|
|
| time_cost = time.time() - start_time |
| bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024) |
| logger.info( |
| f"Rank {self.rank} send weights done, " |
| f"total bytes: {total_bytes} time cost: {time_cost:.2f}s bandwidth: {bandwidth:.2f} GB/s" |
| ) |
|
|
| @torch.no_grad() |
| async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: |
| """Receive weights using Mooncake TransferEngine""" |
| start_time = time.time() |
| total_bytes = 0 |
| bufs = [self.buf[: self.bucket_size], self.buf[self.bucket_size :]] |
| idx = 0 |
| current = bufs[idx] |
| self.magic_buf[:4] = torch.tensor([0xAB, 0xDC, 0xEF, 0x88], dtype=torch.uint8, device=self.device) |
|
|
| while True: |
| |
| info = self.store.recv_obj(self.rank - 1) |
| if idx >= 2 and self.rank < self.world_size - 1: |
| await self.wait_for_complete(current) |
|
|
| ptr = info["ptr"] |
| ret = self.engine.transfer_sync_read( |
| self.buffer_info["session_id"], |
| current.data_ptr(), |
| ptr, |
| info["len"], |
| ) |
| assert ret == 0, f"transfer_sync_read failed {ret}" |
| total_bytes += info["len"] |
|
|
| |
| info["ptr"] = current.data_ptr() |
| if self.rank < self.world_size - 1: |
| self.store.send_obj(info, self.rank + 1) |
|
|
| |
| for name, meta in info["bucket_meta"].items(): |
| dtype, shape = meta["dtype"], meta["shape"] |
| size = dtype.itemsize * shape.numel() |
| tensor = current[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) |
| yield name, tensor |
|
|
| |
| ret = self.engine.transfer_sync_write( |
| self.buffer_info["session_id"], |
| self.magic_buf.data_ptr(), |
| ptr, |
| 4, |
| ) |
| assert ret == 0, f"transfer_sync_write failed {ret}" |
|
|
| |
| idx += 1 |
| current = bufs[idx % 2] |
| get_torch_device().synchronize() |
|
|
| if info["is_last"]: |
| break |
|
|
| time_cost = time.time() - start_time |
| bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024) |
| logger.info( |
| f"Rank {self.rank} receive weights done, time cost: {time_cost:.2f}s, bandwidth: {bandwidth:.2f} GB/s" |
| ) |
|
|