| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import logging |
| import os |
| import time |
| from dataclasses import dataclass |
| from typing import AsyncGenerator, Generator |
|
|
| import ray |
| import torch |
| import zmq |
| from vllm.distributed.utils import StatelessProcessGroup |
|
|
| from verl.checkpoint_engine.base import CheckpointEngine, CheckpointEngineRegistry, TensorMeta |
| from verl.utils.device import is_torch_npu_available |
| from verl.utils.distributed import stateless_init_process_group |
| from verl.utils.net_utils import get_free_port, is_valid_ipv6_address |
|
|
| if not is_torch_npu_available(check_device=False): |
| raise ImportError("HCCLCheckpointEngine is unavailable because the torch.npu module is not available.") |
|
|
| logger = logging.getLogger(__name__) |
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) |
|
|
|
|
| @dataclass |
| class MasterMetadata: |
| zmq_ip: str |
| zmq_port: int |
| dist_ip: str |
| dist_port: int |
|
|
|
|
| class BroadcastOperation: |
| """Async broadcast operation with HCCL in separate thread. |
| |
| Args: |
| rank (int): The rank of the current process. |
| group_name (str): The name of the HCCL process group. |
| bucket (torch.Tensor): The tensor to broadcast. |
| metadata (dict[str, TensorMeta]): The metadata of the tensor. |
| socket (zmq.Socket): The zeromq socket to communicate with master. |
| topic (str): The topic to subscribe. |
| """ |
|
|
| def __init__( |
| self, |
| rank: int, |
| process_group: StatelessProcessGroup | str, |
| bucket: torch.Tensor, |
| metadata: dict[str, TensorMeta], |
| socket: zmq.Socket, |
| topic: str, |
| ) -> None: |
| self.rank = rank |
| self.pyhccl = process_group |
| self.bucket = bucket |
| self.metadata = metadata |
| self.socket = socket |
| self.topic = topic |
|
|
| self._run() |
|
|
| def _run(self): |
| |
| if self.rank == 0: |
| self.socket.send_string(self.topic, flags=zmq.SNDMORE) |
| self.socket.send_pyobj(self.metadata) |
| else: |
| self.socket.recv_string() |
| self.metadata = self.socket.recv_pyobj() |
|
|
| |
| self.pyhccl.broadcast(self.bucket, src=0) |
|
|
| async def wait_for_complete(self) -> dict[str, TensorMeta]: |
| """Wait for the broadcast operation to complete. |
| |
| Returns: |
| dict[str, TensorMeta]: The bucket meta after broadcast. |
| """ |
| return self.metadata |
|
|
|
|
| @CheckpointEngineRegistry.register("nccl") |
| class HCCLCheckpointEngine(CheckpointEngine): |
| """HCCL checkpoint engine with collective communication. |
| |
| Args: |
| bucket_size (int): Bucket size in bytes to transfer multiple weights at one time. Note that we use |
| two buffer to send and recv weights at same time, so the device memory overhead is 2 * bucket_size. |
| group_name (str): The name of the HCCL process group. Defaults to "default". |
| rebuild_group (bool): Whether to rebuild the HCCL process group in each update. Defaults to False. |
| is_master (bool): Whether the current process is the master process. Defaults to False. |
| rollout_dtype (torch.dtype): The dtype of the weights received from rollout workers. Defaults to torch.bfloat16. |
| """ |
|
|
| def __init__( |
| self, |
| bucket_size: int, |
| group_name: str = "default", |
| rebuild_group: bool = False, |
| is_master: bool = False, |
| rollout_dtype: torch.dtype = torch.bfloat16, |
| ) -> None: |
| self.bucket_size = bucket_size |
| self.group_name = group_name |
| self.rebuild_group = rebuild_group |
| self.rollout_dtype = rollout_dtype |
| self.pyhccl = None |
| self.device = torch.npu.current_device() |
|
|
| |
| self.is_master = is_master |
| self.topic = "bucket_metadata" |
| if self.is_master: |
| self._start_zmq_server() |
| self.dist_port, _ = get_free_port(self.ip) |
|
|
| def prepare(self) -> MasterMetadata: |
| self.send_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="npu") |
| self.recv_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="npu") |
|
|
| return ( |
| MasterMetadata(zmq_ip=self.ip, zmq_port=self.zmq_port, dist_ip=self.ip, dist_port=self.dist_port) |
| if self.is_master |
| else None |
| ) |
|
|
| def finalize(self): |
| """Destroy the HCCL process group if rebuild_group is True.""" |
| if self.rebuild_group: |
| if self.rank >= 0: |
| self.pyhccl.destroyComm(self.pyhccl.comm) |
| self.pyhccl = None |
| self.rank = None |
| self.world_size = None |
|
|
| self.send_buf = None |
| self.recv_buf = None |
| torch.npu.empty_cache() |
|
|
| @classmethod |
| def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]): |
| trainer_kwargs = { |
| "rank": [0] + [-1] * (trainer_world_size - 1), |
| "world_size": [rollout_world_size + 1] * trainer_world_size, |
| "master_metadata": [metadata[0]] * trainer_world_size, |
| } |
| rollout_kwargs = { |
| "rank": list(range(1, rollout_world_size + 1)), |
| "world_size": [rollout_world_size + 1] * rollout_world_size, |
| "master_metadata": [metadata[0]] * rollout_world_size, |
| } |
| return trainer_kwargs, rollout_kwargs |
|
|
| def _start_zmq_server(self): |
| self.ip = ray.util.get_node_ip_address().strip("[]") |
| self.zmq_port, _ = get_free_port(self.ip) |
|
|
| context = zmq.Context() |
| self.socket = context.socket(zmq.PUB) |
| if is_valid_ipv6_address(self.ip): |
| address = f"tcp://[{self.ip}]:{self.zmq_port}" |
| self.socket.setsockopt(zmq.IPV6, 1) |
| else: |
| address = f"tcp://{self.ip}:{self.zmq_port}" |
|
|
| self.socket.bind(address) |
|
|
| def _connect_zmq_client(self, metadata: MasterMetadata): |
| assert not self.is_master, "Master process should not connect to other processes." |
| context = zmq.Context() |
| self.socket = context.socket(zmq.SUB) |
| if is_valid_ipv6_address(metadata.zmq_ip): |
| address = f"tcp://[{metadata.zmq_ip}]:{metadata.zmq_port}" |
| self.socket.setsockopt(zmq.IPV6, 1) |
| else: |
| address = f"tcp://{metadata.zmq_ip}:{metadata.zmq_port}" |
|
|
| self.socket.connect(address) |
| self.socket.setsockopt_string(zmq.SUBSCRIBE, self.topic) |
|
|
| def init_process_group(self, rank: int, world_size: int, master_metadata: MasterMetadata): |
| """Initialize the HCCL process group. |
| |
| Args: |
| rank (int): The rank of the current process. |
| world_size (int): The total number of processes. |
| """ |
| |
| if rank < 0: |
| self.rank = rank |
| self.world_size = world_size |
| return |
|
|
| if self.rebuild_group or self.pyhccl is None: |
| self.pyhccl = stateless_init_process_group( |
| master_metadata.dist_ip, master_metadata.dist_port, rank, world_size, self.device |
| ) |
| self.rank = rank |
| self.world_size = world_size |
| else: |
| assert self.rank == rank, f"rank {rank} is not equal to self.rank {self.rank}" |
| assert self.world_size == world_size, ( |
| f"world_size {world_size} is not equal to self.world_size {self.world_size}" |
| ) |
|
|
| if self.rank > 0: |
| self._connect_zmq_client(master_metadata) |
|
|
| |
| signal = torch.tensor([1], dtype=torch.int8, device=torch.npu.current_device()) |
| self.pyhccl.all_reduce(signal) |
|
|
| logger.info(f"init_process_group rank: {self.rank}, world_size: {self.world_size}") |
|
|
| @torch.no_grad() |
| async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): |
| """Send the weights of the model. |
| |
| Args: |
| weights: A generator that yields the name of the weight tensor and the tensor itself. |
| """ |
| assert self.rank <= 0, "Trainer workers other than rank 0 should not send weights." |
|
|
| |
| if self.rank < 0: |
| for name, weight in weights: |
| pass |
| return |
|
|
| send_buf, recv_buf = self.send_buf, self.recv_buf |
| broadcast_op = None |
|
|
| start_time = time.time() |
| bucket_meta: dict[str, TensorMeta] = {} |
| offset = 0 |
| for name, weight in weights: |
| |
| if offset + weight.nbytes > self.bucket_size: |
| torch.npu.synchronize() |
|
|
| |
| if broadcast_op is not None: |
| await broadcast_op.wait_for_complete() |
|
|
| broadcast_op = BroadcastOperation( |
| rank=self.rank, |
| process_group=self.pyhccl, |
| bucket=send_buf, |
| metadata={"bucket_meta": bucket_meta, "is_last": False}, |
| socket=self.socket, |
| topic=self.topic, |
| ) |
|
|
| |
| send_buf, recv_buf = recv_buf, send_buf |
| bucket_meta = {} |
| offset = 0 |
|
|
| assert offset + weight.nbytes <= self.bucket_size, ( |
| f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." |
| ) |
|
|
| bucket_meta[name] = { |
| "name": name, |
| "shape": weight.shape, |
| "dtype": weight.dtype, |
| "offset": offset, |
| } |
| send_buf[offset : offset + weight.nbytes] = weight.view(-1).view(torch.uint8) |
| offset += weight.nbytes |
|
|
| |
| torch.npu.synchronize() |
| if broadcast_op is not None: |
| await broadcast_op.wait_for_complete() |
|
|
| broadcast_op = BroadcastOperation( |
| rank=self.rank, |
| process_group=self.pyhccl, |
| bucket=send_buf, |
| metadata={"bucket_meta": bucket_meta, "is_last": True}, |
| socket=self.socket, |
| topic=self.topic, |
| ) |
| await broadcast_op.wait_for_complete() |
| logger.info(f"Rank {self.rank} send weights done, time cost: {time.time() - start_time:.2f}s") |
|
|
| @torch.no_grad() |
| async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: |
| """Receive the weights of the model. |
| |
| Yields: |
| A tuple of the name of the weight tensor and the tensor itself. |
| """ |
| assert self.rank > 0, "Rank 0 should not receive weights." |
| send_buf, recv_buf = self.send_buf, self.recv_buf |
| total_bytes, total_params = 0, 0 |
|
|
| |
| start_time = time.time() |
| broadcast_op = BroadcastOperation( |
| rank=self.rank, |
| process_group=self.pyhccl, |
| bucket=recv_buf, |
| metadata=None, |
| socket=self.socket, |
| topic=self.topic, |
| ) |
| metadata = await broadcast_op.wait_for_complete() |
| total_bytes += self.bucket_size |
| total_params += len(metadata["bucket_meta"]) |
|
|
| |
| send_buf, recv_buf = recv_buf, send_buf |
| while not metadata["is_last"]: |
| |
| broadcast_op = BroadcastOperation( |
| rank=self.rank, |
| process_group=self.pyhccl, |
| bucket=recv_buf, |
| metadata=None, |
| socket=self.socket, |
| topic=self.topic, |
| ) |
|
|
| |
| for name, meta in metadata["bucket_meta"].items(): |
| dtype, shape = meta["dtype"], meta["shape"] |
| size = dtype.itemsize * shape.numel() |
| tensor = send_buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) |
| yield name, tensor |
|
|
| |
| metadata = await broadcast_op.wait_for_complete() |
| total_bytes += self.bucket_size |
| total_params += len(metadata["bucket_meta"]) |
|
|
| |
| torch.npu.synchronize() |
| send_buf, recv_buf = recv_buf, send_buf |
|
|
| |
| for name, meta in metadata["bucket_meta"].items(): |
| dtype, shape = meta["dtype"], meta["shape"] |
| size = dtype.itemsize * shape.numel() |
| tensor = send_buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) |
| yield name, tensor |
|
|
| time_cost = time.time() - start_time |
| bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024) |
| logger.info( |
| f"Rank {self.rank} receive weights done, total_params: {total_params}, " |
| f"time cost: {time_cost:.2f}s, bandwidth: {bandwidth:.2f} GB/s" |
| ) |
|
|