Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import pytest | |
| import torch | |
| import torch.distributed as dist | |
| from xformers.ops import init_ipc | |
| from .multiprocessing_utils import launch_subprocesses | |
| compute_capability = (0, 0) | |
| if torch.cuda.is_available(): | |
| compute_capability = torch.cuda.get_device_capability("cuda") | |
| cuda_sm70_only = pytest.mark.skipif( | |
| compute_capability < (7, 0), reason="requires sm70+" | |
| ) | |
| def inner_test_ipc() -> None: | |
| my_rank = torch.distributed.get_rank() | |
| world_size = torch.distributed.get_world_size() | |
| subgroup = torch.distributed.new_group() | |
| ipcs = init_ipc(subgroup) | |
| send_bufs = [ | |
| torch.full([1], my_rank, device="cuda", dtype=torch.int32) | |
| for _ in range(world_size) | |
| ] | |
| recv_bufs = send_bufs.copy() | |
| for other_rank, conn in enumerate(ipcs): | |
| if conn is None: | |
| continue | |
| conn.send(send_bufs[other_rank]) | |
| for other_rank, conn in enumerate(ipcs): | |
| if conn is None: | |
| continue | |
| recv_bufs[other_rank] = conn.recv() | |
| torch.cuda.synchronize() | |
| dist.barrier(subgroup) | |
| # Use the buffer to send data | |
| for other_rank, buf in enumerate(recv_bufs): | |
| assert buf[0].item() == other_rank | |
| buf.fill_(my_rank) | |
| torch.cuda.synchronize() | |
| dist.barrier(subgroup) | |
| # Verify we've received the data correctly | |
| for other_rank, buf in enumerate(send_bufs): | |
| assert ( | |
| buf[0].item() == other_rank | |
| ), f"[#{my_rank}] {other_rank=} != {buf[0].item()=}" | |
| def test_ipc() -> None: | |
| world_size = 2 | |
| launch_subprocesses( | |
| world_size=world_size, | |
| fn=inner_test_ipc, | |
| ) | |
| # We had an issue where the second rendezvous in a single process would use the | |
| # same store keys as the first one, thus retrieve a stale address to connect to, | |
| # and fail. | |
| def inner_test_ipc_twice() -> None: | |
| subgroup = torch.distributed.new_group() | |
| init_ipc(subgroup) | |
| init_ipc(subgroup) | |
| def test_ipc_twice() -> None: | |
| world_size = 2 | |
| launch_subprocesses( | |
| world_size=world_size, | |
| fn=inner_test_ipc_twice, | |
| ) | |