File size: 6,176 Bytes
a402b9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | import ctypes
import multiprocessing as mp
import random
import socket
import unittest
from typing import Any, List, Optional
import sgl_kernel.allreduce as custom_ops
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = dist.group.WORLD
try:
device = torch.device(f"cuda:{rank}")
max_size = 8192 * 1024
meta_ptrs = TestCustomAllReduce.create_shared_buffer(
custom_ops.meta_size() + max_size, group=group
)
rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device)
buffer_ptrs = TestCustomAllReduce.create_shared_buffer(max_size, group=group)
custom_ptr = custom_ops.init_custom_ar(meta_ptrs, rank_data, rank, True)
custom_ops.register_buffer(custom_ptr, buffer_ptrs)
test_loop = 10
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
for _ in range(test_loop):
inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
inp1_ref = inp1.clone()
out1 = torch.empty_like(inp1)
custom_ops.all_reduce(
custom_ptr, inp1, out1, buffer_ptrs[rank], max_size
)
dist.all_reduce(inp1_ref, group=group)
torch.testing.assert_close(out1, inp1_ref)
finally:
dist.barrier(group=group)
if custom_ptr is not None:
custom_ops.dispose(custom_ptr)
if buffer_ptrs:
TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group)
if meta_ptrs:
TestCustomAllReduce.free_shared_buffer(meta_ptrs, group)
dist.destroy_process_group(group=group)
def get_open_port() -> int:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
except OSError:
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("::1", 0))
return s.getsockname()[1]
def multi_process_parallel(
world_size: int, test_target: Any, target_args: tuple = ()
) -> None:
mp.set_start_method("spawn", force=True)
procs = []
distributed_init_port = get_open_port()
for i in range(world_size):
proc_args = (world_size, i, distributed_init_port) + target_args
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
proc.start()
procs.append(proc)
for i in range(world_size):
procs[i].join()
assert (
procs[i].exitcode == 0
), f"Process {i} failed with exit code {procs[i].exitcode}"
class TestCustomAllReduce(unittest.TestCase):
test_sizes = [
512,
2560,
4096,
5120,
7680,
32768,
262144,
524288,
1048576,
2097152,
]
world_sizes = [2, 4, 8]
@staticmethod
def create_shared_buffer(
size_in_bytes: int, group: Optional[ProcessGroup] = None
) -> List[int]:
lib = CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer)
if group is None:
group = dist.group.WORLD
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handle_bytes = ctypes.string_at(ctypes.addressof(handle), ctypes.sizeof(handle))
input_tensor = torch.ByteTensor(list(handle_bytes)).to(f"cuda:{rank}")
gathered_tensors = [torch.empty_like(input_tensor) for _ in range(world_size)]
dist.all_gather(gathered_tensors, input_tensor, group=group)
handles = []
handle_type = type(handle)
for tensor in gathered_tensors:
bytes_list = tensor.cpu().tolist()
bytes_data = bytes(bytes_list)
handle_obj = handle_type()
ctypes.memmove(ctypes.addressof(handle_obj), bytes_data, len(bytes_data))
handles.append(handle_obj)
pointers: List[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer.value)
else:
try:
opened_ptr = lib.cudaIpcOpenMemHandle(h)
pointers.append(opened_ptr.value)
except Exception as e:
print(f"Rank {rank}: Failed to open IPC handle from rank {i}: {e}")
raise
dist.barrier(group=group)
return pointers
@staticmethod
def free_shared_buffer(
pointers: List[int], group: Optional[ProcessGroup] = None
) -> None:
if group is None:
group = dist.group.WORLD
rank = dist.get_rank(group=group)
lib = CudaRTLibrary()
if pointers and len(pointers) > rank and pointers[rank] is not None:
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
dist.barrier(group=group)
def test_correctness(self):
for world_size in self.world_sizes:
available_gpus = torch.cuda.device_count()
if world_size > available_gpus:
print(
f"Skipping world_size={world_size}, requires {world_size} GPUs, found {available_gpus}"
)
continue
print(f"Running test for world_size={world_size}")
multi_process_parallel(
world_size, _run_correctness_worker, target_args=(self.test_sizes,)
)
print(f"custom allreduce tp = {world_size}: OK")
if __name__ == "__main__":
unittest.main()
|