| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import asyncio |
| import logging |
| import os |
| import time |
| import uuid |
| from collections import defaultdict, deque |
| from dataclasses import dataclass |
| from typing import AsyncGenerator, Generator |
| from unittest.mock import patch |
|
|
| with patch("importlib.metadata.distributions", return_value=[]): |
| import cupy as cp |
|
|
| import nixl._api as nixl_api |
| import nixl._bindings as nixl_bindings |
| import ray |
| import torch |
| import zmq |
| import zmq.asyncio |
|
|
| from verl.checkpoint_engine.base import CheckpointEngine, CheckpointEngineRegistry, TensorMeta |
| from verl.utils.net_utils import get_free_port, is_valid_ipv6_address |
|
|
| logger = logging.getLogger(__name__) |
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) |
|
|
|
|
| @dataclass |
| class NixlAgentMetadata: |
| agent_name: str |
| agent_metadata: bytes |
| zmq_ip: str |
| zmq_port: int |
|
|
|
|
| class NixlAgent: |
| """This is a wrapper class for nixl_agent, the main purpose is to use ZeroMQ instead of |
| `nixl_agent.send_notif` to send bucket tensor metadata. |
| """ |
|
|
| def __init__(self): |
| self.agent_name = str(uuid.uuid4()) |
| self.agent = nixl_api.nixl_agent(self.agent_name) |
| self.notifications: dict[str, deque[bytes]] = defaultdict(deque) |
|
|
| self.start_zmq_server() |
| self.zmq_clients: dict[str, zmq.Socket] = {} |
| self.messages: dict[str, deque[bytes]] = defaultdict(deque) |
|
|
| def __getattr__(self, name): |
| attr = getattr(self.agent, name) |
|
|
| if callable(attr): |
|
|
| def wrapper(*args, **kwargs): |
| return attr(*args, **kwargs) |
|
|
| return wrapper |
| else: |
| return attr |
|
|
| def get_agent_metadata(self) -> NixlAgentMetadata: |
| return NixlAgentMetadata( |
| agent_name=self.agent_name, |
| agent_metadata=self.agent.get_agent_metadata(), |
| zmq_ip=self.ip, |
| zmq_port=self.listen_port, |
| ) |
|
|
| def start_zmq_server(self): |
| self.ip = ray.util.get_node_ip_address().strip("[]") |
| self.listen_port, _ = get_free_port(self.ip) |
|
|
| context = zmq.asyncio.Context() |
| self.socket = context.socket(zmq.PULL) |
| if is_valid_ipv6_address(self.ip): |
| address = f"tcp://[{self.ip}]:{self.listen_port}" |
| self.socket.setsockopt(zmq.IPV6, 1) |
| else: |
| address = f"tcp://{self.ip}:{self.listen_port}" |
|
|
| self.socket.bind(address) |
|
|
| def add_remote_agent(self, metadata: NixlAgentMetadata) -> str: |
| agent_name = self.agent.add_remote_agent(metadata.agent_metadata).decode("utf-8") |
| assert agent_name == metadata.agent_name, f"Agent name {agent_name} not equal to {metadata.agent_name}" |
|
|
| context = zmq.Context() |
| socket = context.socket(zmq.PUSH) |
| if is_valid_ipv6_address(metadata.zmq_ip): |
| address = f"tcp://[{metadata.zmq_ip}]:{metadata.zmq_port}" |
| socket.setsockopt(zmq.IPV6, 1) |
| else: |
| address = f"tcp://{metadata.zmq_ip}:{metadata.zmq_port}" |
|
|
| socket.connect(address) |
| self.zmq_clients[agent_name] = socket |
| return agent_name |
|
|
| def remove_remote_agent(self, agent_name: str): |
| self.agent.remove_remote_agent(agent_name) |
| socket = self.zmq_clients.pop(agent_name) |
| socket.close() |
|
|
| def send_message(self, agent_name, message: dict): |
| socket = self.zmq_clients[agent_name] |
| socket.send_pyobj((self.agent_name, message), zmq.DONTWAIT) |
|
|
| async def read_message(self, agent_name: str) -> dict: |
| while len(self.messages[agent_name]) == 0: |
| recv_agent_name, message = await self.socket.recv_pyobj() |
| self.messages[recv_agent_name].append(message) |
| return self.messages[agent_name].popleft() |
|
|
| async def get_notification(self, remote_name: str) -> bytes: |
| while len(self.notifications[remote_name]) == 0: |
| notifs = self.agent.get_new_notifs() |
| for remote_name, notif in notifs.items(): |
| self.notifications[remote_name].extend(notif) |
| await asyncio.sleep(0) |
| return self.notifications[remote_name].popleft() |
|
|
|
|
| class ReadableOperation: |
| """Encapsulates a readable operation to remote agent. |
| 1. send metadata to remote agent |
| 2. wait until remote agent read complete. |
| |
| Args: |
| agent (NixlAgent): The Nixl agent. |
| remote_agent (str): The name of the remote agent. |
| local_descs (nixl_bindings.nixlXferDList): The local transfer descriptors. |
| metadata (dict): Metadata for the read operation. |
| bucket_size (int): The size of the bucket in bytes. |
| """ |
|
|
| def __init__( |
| self, |
| agent: NixlAgent, |
| remote_agent: str, |
| local_descs: nixl_bindings.nixlXferDList, |
| metadata: dict, |
| ): |
| self.agent = agent |
| self.remote_agent = remote_agent |
| self.local_descs = local_descs |
| self.notify_key = uuid.uuid4().bytes |
| message = {"notify_key": self.notify_key, "remote_descs": self.local_descs, **metadata} |
| self.agent.send_message(self.remote_agent, message) |
|
|
| async def wait_for_complete(self): |
| """Block until remote agent read complete.""" |
| notification = await self.agent.get_notification(self.remote_agent) |
| assert self.notify_key == notification, f"Notify key {self.notify_key} not equal to {notification}" |
| logger.debug(f"ReadableOperation to {self.remote_agent} complete") |
|
|
|
|
| class ReadOperation: |
| """Encapsulates a read operation from remote agent. |
| 1. read medata from remote agent |
| 2. start read transfer operation |
| 3. wait until read complete |
| |
| Args: |
| agent (NixlAgent): The Nixl agent. |
| remote_agent (str): The name of the remote agent. |
| local_descs (nixl_bindings.nixlXferDList): The local transfer descriptors. |
| bucket_size (int): The size of the bucket in bytes. |
| """ |
|
|
| def __init__(self, agent: NixlAgent, remote_agent: str, local_descs: nixl_bindings.nixlXferDList, bucket_size: int): |
| self.agent = agent |
| self.remote_agent = remote_agent |
| self.local_descs = local_descs |
| self.remote_descs = None |
| self.xfer_handle = None |
| self.notify_key = None |
| self.bucket_size = bucket_size |
| self.start_time = None |
|
|
| async def read_metadata(self) -> dict: |
| """Block until the remote agent sends the metadata. |
| |
| Returns: |
| dict: Metadata from the remote agent. |
| """ |
| metadata = await self.agent.read_message(self.remote_agent) |
| self.remote_descs = metadata.pop("remote_descs") |
| self.notify_key = metadata.pop("notify_key") |
| return metadata |
|
|
| def begin_read(self): |
| """Start the read operation.""" |
| assert self.remote_descs is not None and self.notify_key is not None |
| self.xfer_handle = self.agent.initialize_xfer( |
| "READ", self.local_descs, self.remote_descs, self.remote_agent, self.notify_key |
| ) |
| state = self.agent.transfer(self.xfer_handle) |
| assert state != "ERR", f"Read from {self.remote_agent} got to {state} state." |
| self.start_time = time.time() |
|
|
| async def wait_for_complete(self): |
| """Block until the read operation complete.""" |
| while True: |
| state = self.agent.check_xfer_state(self.xfer_handle) |
| if state == "ERR": |
| logger.error(f"Read from {self.remote_agent} got to {state} state.") |
| exit(-1) |
| elif state == "DONE": |
| break |
| else: |
| await asyncio.sleep(0) |
| self.agent.release_xfer_handle(self.xfer_handle) |
| end_time = time.time() |
| bandwidth = self.bucket_size / (end_time - self.start_time) / (1024 * 1024 * 1024) |
| logger.debug(f"ReadOperation read data from {self.remote_agent} complete, bandwidth: {bandwidth:.2f} GB/s") |
|
|
|
|
| @CheckpointEngineRegistry.register("nixl") |
| class NIXLCheckpointEngine(CheckpointEngine): |
| """NIXL checkpoint engine with p2p communication, support various backends: ucx, uccl, mooncacke, etc. |
| |
| For UCX backend, some environment variables need to be set: UCX_TLS, UCX_IB_GID_INDEX, UCX_IB_DEVICES, etc. |
| Please refer to: https://openucx.readthedocs.io/en/master/faq.html |
| |
| Args: |
| bucket_size (int): Bucket size in bytes to transfer multiple weights at one time. Note that we use |
| two buffer to send and recv weights at same time, so the device memory overhead is 2 * bucket_size. |
| device (str): The device to use for the checkpoint engine, "cpu" or "cuda". |
| rollout_dtype (torch.dtype): The dtype of the weights received from rollout workers. Defaults to torch.bfloat16. |
| """ |
|
|
| def __init__( |
| self, |
| bucket_size: int, |
| device: str = "cuda", |
| rollout_dtype: torch.dtype = torch.bfloat16, |
| is_master: bool = False, |
| ): |
| self.bucket_size = bucket_size |
| self.device = device |
| self.rollout_dtype = rollout_dtype |
| self.agent = NixlAgent() |
| self.is_master = is_master |
|
|
| def prepare(self) -> NixlAgentMetadata: |
| """Prepare send and recv bucket. |
| |
| Returns: |
| NixlAgentMetadata: The metadata of the current nixl agent. |
| """ |
| |
| |
| if self.device == "cuda": |
| send_buf = cp.zeros(self.bucket_size, dtype=cp.uint8) |
| recv_buf = cp.zeros(self.bucket_size, dtype=cp.uint8) |
| self.send_buf = torch.as_tensor(send_buf, dtype=torch.uint8) |
| self.recv_buf = torch.as_tensor(recv_buf, dtype=torch.uint8) |
| else: |
| self.send_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device=self.device, pin_memory=True) |
| self.recv_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device=self.device, pin_memory=True) |
| self.send_reg_descs = self.agent.register_memory(self.send_buf) |
| self.recv_reg_descs = self.agent.register_memory(self.recv_buf) |
| self.send_descs = self.agent.get_xfer_descs(self.send_buf) |
| self.recv_descs = self.agent.get_xfer_descs(self.recv_buf) |
|
|
| return self.agent.get_agent_metadata() |
|
|
| @classmethod |
| def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metadata: list[dict]): |
| trainer_kwargs = { |
| "method": ["init_process_group"] * trainer_world_size, |
| "rank": [0] + [-1] * (trainer_world_size - 1), |
| "world_size": [rollout_world_size + 1] * trainer_world_size, |
| "prev_agent_metadata": [None] * trainer_world_size, |
| "next_agent_metadata": [metadata[-rollout_world_size]] + [None] * (trainer_world_size - 1), |
| } |
|
|
| rollout_kwargs = { |
| "method": ["init_process_group"] * rollout_world_size, |
| "rank": list(range(1, rollout_world_size + 1)), |
| "world_size": [rollout_world_size + 1] * rollout_world_size, |
| "prev_agent_metadata": [metadata[0]] + metadata[-rollout_world_size:-1], |
| "next_agent_metadata": metadata[-rollout_world_size + 1 :] + [None], |
| } |
| return trainer_kwargs, rollout_kwargs |
|
|
| def init_process_group( |
| self, rank: int, world_size: int, prev_agent_metadata: NixlAgentMetadata, next_agent_metadata: NixlAgentMetadata |
| ): |
| """Setup the communication with the previous and next agent. |
| |
| Args: |
| rank (int): The rank of the current process. |
| world_size (int): The total number of processes. |
| prev_agent_metadata (NixlAgentMetadata): The metadata of the previous nixl agent. |
| next_agent_metadata (NixlAgentMetadata): The metadata of the next nixl agent. |
| """ |
| if rank < 0: |
| assert not prev_agent_metadata and not next_agent_metadata, ( |
| f"rank {rank} should not have prev_agent_metadata or next_agent_metadata" |
| ) |
| elif rank == 0: |
| assert not prev_agent_metadata and next_agent_metadata, f"rank {rank} should have next_agent_metadata" |
| elif 0 < rank < world_size - 1: |
| assert prev_agent_metadata and next_agent_metadata, ( |
| f"rank {rank} should have prev_agent_metadata and next_agent_metadata" |
| ) |
| elif rank == world_size - 1: |
| assert prev_agent_metadata and not next_agent_metadata, ( |
| f"rank {rank} should have prev_agent_metadata and not next_agent_metadata" |
| ) |
|
|
| self.rank = rank |
| self.world_size = world_size |
| self.prev_agent = None |
| self.next_agent = None |
|
|
| if prev_agent_metadata is not None: |
| self.prev_agent = self.agent.add_remote_agent(prev_agent_metadata) |
|
|
| if next_agent_metadata is not None: |
| self.next_agent = self.agent.add_remote_agent(next_agent_metadata) |
|
|
| logger.info( |
| f"init_process_group rank: {self.rank}, world_size: {self.world_size}, " |
| f"prev_agent: {self.prev_agent}, next_agent: {self.next_agent}" |
| ) |
|
|
| def finalize(self): |
| """Cleanup communication with the previous and next agent, and deregister the memory.""" |
| if self.prev_agent: |
| self.agent.remove_remote_agent(self.prev_agent) |
| if self.next_agent: |
| self.agent.remove_remote_agent(self.next_agent) |
|
|
| self.agent.deregister_memory(self.send_reg_descs) |
| self.agent.deregister_memory(self.recv_reg_descs) |
| self.send_buf = None |
| self.recv_buf = None |
| self.send_reg_descs = None |
| self.recv_reg_descs = None |
| self.send_descs = None |
| self.recv_descs = None |
|
|
| self.rank = None |
| self.world_size = None |
| self.prev_agent = None |
| self.next_agent = None |
|
|
| @torch.no_grad() |
| async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): |
| """Send the weights of the model. |
| |
| Args: |
| weights: A generator that yields the name of the weight tensor and the tensor itself. |
| """ |
| assert self.rank <= 0, "Trainer workers other than rank 0 should not send weights." |
|
|
| |
| if self.rank < 0: |
| for name, weight in weights: |
| pass |
| return |
|
|
| assert self.next_agent is not None, "Next agent is not set." |
| send_buf, recv_buf = self.send_buf, self.recv_buf |
| send_descs, recv_descs = self.send_descs, self.recv_descs |
| readable_op = None |
|
|
| start_time = time.time() |
| bucket_meta: dict[str, TensorMeta] = {} |
| offset = 0 |
| for name, weight in weights: |
| |
| if offset + weight.nbytes > self.bucket_size: |
| torch.cuda.synchronize() |
|
|
| |
| if readable_op is not None: |
| await readable_op.wait_for_complete() |
|
|
| |
| readable_op = ReadableOperation( |
| self.agent, |
| self.next_agent, |
| send_descs, |
| {"bucket_meta": bucket_meta, "is_last": False}, |
| ) |
|
|
| |
| send_buf, recv_buf = recv_buf, send_buf |
| send_descs, recv_descs = recv_descs, send_descs |
| bucket_meta = {} |
| offset = 0 |
|
|
| assert offset + weight.nbytes <= self.bucket_size, ( |
| f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." |
| ) |
|
|
| bucket_meta[name] = { |
| "name": name, |
| "shape": weight.shape, |
| "dtype": weight.dtype, |
| "offset": offset, |
| } |
| send_buf[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) |
| offset += weight.nbytes |
|
|
| |
| torch.cuda.synchronize() |
| if readable_op is not None: |
| await readable_op.wait_for_complete() |
|
|
| readable_op = ReadableOperation( |
| self.agent, self.next_agent, send_descs, {"bucket_meta": bucket_meta, "is_last": True} |
| ) |
| await readable_op.wait_for_complete() |
| logger.info(f"Rank {self.rank} send weights done, time cost: {time.time() - start_time:.2f}s") |
|
|
| @torch.no_grad() |
| async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: |
| """Receive the weights of the model. |
| |
| Yields: |
| A tuple of the name of the weight tensor and the tensor itself. |
| """ |
| assert self.prev_agent is not None, "Previous agent is not set." |
| send_buf, recv_buf = self.send_buf, self.recv_buf |
| send_descs, recv_descs = self.send_descs, self.recv_descs |
| total_bytes, total_params = 0, 0 |
|
|
| |
| start_time = time.time() |
| read_op = ReadOperation(self.agent, self.prev_agent, recv_descs, self.bucket_size) |
| metadata = await read_op.read_metadata() |
| read_op.begin_read() |
| await read_op.wait_for_complete() |
| total_bytes += self.bucket_size |
| total_params += len(metadata["bucket_meta"]) |
|
|
| |
| send_buf, recv_buf = recv_buf, send_buf |
| send_descs, recv_descs = recv_descs, send_descs |
| while not metadata["is_last"]: |
| |
| readable_op = None |
| if self.next_agent is not None: |
| readable_op = ReadableOperation( |
| self.agent, |
| self.next_agent, |
| send_descs, |
| metadata, |
| ) |
|
|
| |
| read_op = ReadOperation(self.agent, self.prev_agent, recv_descs, self.bucket_size) |
| next_metadata = await read_op.read_metadata() |
| read_op.begin_read() |
|
|
| |
| for name, meta in metadata["bucket_meta"].items(): |
| dtype, shape = meta["dtype"], meta["shape"] |
| size = dtype.itemsize * shape.numel() |
| tensor = send_buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) |
| yield name, tensor |
|
|
| |
| if readable_op is not None: |
| await readable_op.wait_for_complete() |
| await read_op.wait_for_complete() |
| total_bytes += self.bucket_size |
| total_params += len(next_metadata["bucket_meta"]) |
|
|
| |
| torch.cuda.synchronize() |
| metadata = next_metadata |
| send_buf, recv_buf = recv_buf, send_buf |
| send_descs, recv_descs = recv_descs, send_descs |
|
|
| |
| readable_op = None |
| if self.next_agent is not None: |
| readable_op = ReadableOperation( |
| self.agent, |
| self.next_agent, |
| send_descs, |
| metadata, |
| ) |
|
|
| |
| for name, meta in metadata["bucket_meta"].items(): |
| dtype, shape = meta["dtype"], meta["shape"] |
| size = dtype.itemsize * shape.numel() |
| tensor = send_buf[meta["offset"] : meta["offset"] + size].view(dtype=dtype).view(shape) |
| yield name, tensor |
|
|
| |
| if readable_op is not None: |
| await readable_op.wait_for_complete() |
| time_cost = time.time() - start_time |
| bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024) |
| logger.info( |
| f"Rank {self.rank} receive weights done, total_params: {total_params}, " |
| f"time cost: {time_cost:.2f}s, bandwidth: {bandwidth:.2f} GB/s" |
| ) |
|
|