# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for distributed training.""" import ctypes import os import socket from datetime import timedelta import ray import torch.distributed from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device, is_npu_available from verl.utils.net_utils import is_ipv6 def set_numa_affinity(): if is_npu_available: # TODO (FightingZhen) libnuma.so is not available in e2e_ascend CI image, remove this code after image update. return initialized = False try: libnuma = ctypes.CDLL("libnuma.so") if libnuma.numa_available() < 0: return import pynvml pynvml.nvmlInit() initialized = True device_name = "NPU" if is_npu_available else "GPU" local_rank = int(ray.get_runtime_context().get_accelerator_ids()[device_name][0]) handle = pynvml.nvmlDeviceGetHandleByIndex(local_rank) pynvml.nvmlDeviceSetCpuAffinity(handle) except ImportError: print("Warning: pynvml not available, skipping NUMA affinity setup") except Exception as e: print(f"Warning: Failed to set NUMA affinity: {e}") finally: if initialized: pynvml.nvmlShutdown() def initialize_global_process_group(timeout_second=36000): torch.distributed.init_process_group( get_nccl_backend(), timeout=timedelta(seconds=timeout_second), init_method=os.environ.get("DIST_INIT_METHOD", None), ) local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) if torch.distributed.is_initialized(): get_torch_device().set_device(local_rank) return local_rank, rank, world_size def destroy_global_process_group(): if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() def initialize_global_process_group_ray(timeout_second=None, backend=None): # in current ray environment, LOCAL_RANK is always zero. import torch.distributed timeout = timedelta(seconds=timeout_second) if timeout_second is not None else None backend = backend or f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}" if not torch.distributed.is_initialized(): rank = int(os.environ.get("RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) torch.distributed.init_process_group( backend=backend, rank=rank, world_size=world_size, timeout=timeout, init_method=os.environ.get("DIST_INIT_METHOD", None), ) def stateless_init_process_group(master_address, master_port, rank, world_size, device): """ vLLM provides `StatelessProcessGroup` to create a process group without considering the global process group in torch.distributed. It is recommended to create `StatelessProcessGroup`, and then initialize the data-plane communication (NCCL) between external (train processes) and vLLM workers. """ # NOTE: If it is necessary to support weight synchronization with the sglang backend in the future, # the following can be used: # from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator # from sglang.srt.distributed.utils import statelessprocessgroup from torch.distributed import TCPStore from vllm.distributed.utils import StatelessProcessGroup from verl.utils.device import is_npu_available if is_npu_available: from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator else: from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator def create_process_group( host: str, port: int, rank: int, world_size: int, data_expiration_seconds: int = 3600, store_timeout: int = 300, ) -> "StatelessProcessGroup": """ This is copied from vllm/distributed/utils.py:StatelessProcessGroup.create Modified to support ipv6 stateless communication groups.""" launch_server = rank == 0 if launch_server: # listen on the specified interface (instead of 0.0.0.0) if is_ipv6(master_address): listen_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) else: listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) listen_socket.bind((host, port)) listen_socket.listen() listen_fd = listen_socket.fileno() else: listen_socket = None listen_fd = None store = TCPStore( host_name=host, port=port, world_size=world_size, is_master=launch_server, timeout=timedelta(seconds=store_timeout), use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215 master_listen_fd=listen_fd, ) return StatelessProcessGroup( rank=rank, world_size=world_size, store=store, socket=listen_socket, data_expiration_seconds=data_expiration_seconds, ) pg = create_process_group(host=master_address, port=master_port, rank=rank, world_size=world_size) pynccl = PyNcclCommunicator(pg, device=device) return pynccl