File size: 6,043 Bytes
1faccd4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | # Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for distributed training."""
import ctypes
import os
import socket
from datetime import timedelta
import ray
import torch.distributed
from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device, is_npu_available
from verl.utils.net_utils import is_ipv6
def set_numa_affinity():
if is_npu_available:
# TODO (FightingZhen) libnuma.so is not available in e2e_ascend CI image, remove this code after image update.
return
initialized = False
try:
libnuma = ctypes.CDLL("libnuma.so")
if libnuma.numa_available() < 0:
return
import pynvml
pynvml.nvmlInit()
initialized = True
device_name = "NPU" if is_npu_available else "GPU"
local_rank = int(ray.get_runtime_context().get_accelerator_ids()[device_name][0])
handle = pynvml.nvmlDeviceGetHandleByIndex(local_rank)
pynvml.nvmlDeviceSetCpuAffinity(handle)
except ImportError:
print("Warning: pynvml not available, skipping NUMA affinity setup")
except Exception as e:
print(f"Warning: Failed to set NUMA affinity: {e}")
finally:
if initialized:
pynvml.nvmlShutdown()
def initialize_global_process_group(timeout_second=36000):
torch.distributed.init_process_group(
get_nccl_backend(),
timeout=timedelta(seconds=timeout_second),
init_method=os.environ.get("DIST_INIT_METHOD", None),
)
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if torch.distributed.is_initialized():
get_torch_device().set_device(local_rank)
return local_rank, rank, world_size
def destroy_global_process_group():
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
def initialize_global_process_group_ray(timeout_second=None, backend=None):
# in current ray environment, LOCAL_RANK is always zero.
import torch.distributed
timeout = timedelta(seconds=timeout_second) if timeout_second is not None else None
backend = backend or f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}"
if not torch.distributed.is_initialized():
rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
torch.distributed.init_process_group(
backend=backend,
rank=rank,
world_size=world_size,
timeout=timeout,
init_method=os.environ.get("DIST_INIT_METHOD", None),
)
def stateless_init_process_group(master_address, master_port, rank, world_size, device):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
# NOTE: If it is necessary to support weight synchronization with the sglang backend in the future,
# the following can be used:
# from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator
# from sglang.srt.distributed.utils import statelessprocessgroup
from torch.distributed import TCPStore
from vllm.distributed.utils import StatelessProcessGroup
from verl.utils.device import is_npu_available
if is_npu_available:
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator
else:
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
def create_process_group(
host: str,
port: int,
rank: int,
world_size: int,
data_expiration_seconds: int = 3600,
store_timeout: int = 300,
) -> "StatelessProcessGroup":
"""
This is copied from vllm/distributed/utils.py:StatelessProcessGroup.create
Modified to support ipv6 stateless communication groups."""
launch_server = rank == 0
if launch_server:
# listen on the specified interface (instead of 0.0.0.0)
if is_ipv6(master_address):
listen_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
else:
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
listen_socket.bind((host, port))
listen_socket.listen()
listen_fd = listen_socket.fileno()
else:
listen_socket = None
listen_fd = None
store = TCPStore(
host_name=host,
port=port,
world_size=world_size,
is_master=launch_server,
timeout=timedelta(seconds=store_timeout),
use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215
master_listen_fd=listen_fd,
)
return StatelessProcessGroup(
rank=rank,
world_size=world_size,
store=store,
socket=listen_socket,
data_expiration_seconds=data_expiration_seconds,
)
pg = create_process_group(host=master_address, port=master_port, rank=rank, world_size=world_size)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl
|