| | import os |
| | import torch |
| |
|
| | from datetime import timedelta |
| |
|
| | |
| | RANK = int(os.getenv("RANK", "0")) |
| | WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) |
| |
|
| | |
| | MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0")) |
| |
|
| |
|
| | class FakeBarrier: |
| | def wait(self): |
| | pass |
| |
|
| |
|
| | class FakeGroup: |
| | def __init__(self, rank, size): |
| | self._rank = rank |
| | self._size = size |
| |
|
| | def allreduce(self, *args, **kwargs): |
| | return FakeBarrier() |
| |
|
| | def allgather(self, inputs, local_tensor, **kwargs): |
| | assert ( |
| | len(inputs[0]) == len(local_tensor) == 1 |
| | ), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors" |
| | for input_ in inputs: |
| | input_[0].data = local_tensor[0].data |
| | return FakeBarrier() |
| |
|
| | def barrier(self, *args, **kwargs): |
| | return FakeBarrier() |
| |
|
| | def size(self): |
| | return self._size |
| |
|
| | def rank(self): |
| | return self._rank |
| |
|
| |
|
| | def initialize_torch_distributed(): |
| | if torch.cuda.is_available(): |
| | from torch.distributed import ProcessGroupNCCL |
| |
|
| | |
| | assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu" |
| | device = RANK % torch.cuda.device_count() |
| | torch.cuda.set_device(device) |
| | torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device) |
| | backend = "nccl" |
| | options = ProcessGroupNCCL.Options() |
| | options.is_high_priority_stream = True |
| | options._timeout = timedelta(seconds=60) |
| | else: |
| | backend = "gloo" |
| | options = None |
| |
|
| | if WORLD_SIZE == 1: |
| | return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE |
| | else: |
| | if os.getenv("DEBUG", None) == "1": |
| | return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE |
| |
|
| | if not torch.distributed.is_initialized(): |
| | |
| | torch.distributed.init_process_group( |
| | backend=backend, |
| | world_size=WORLD_SIZE, |
| | rank=RANK, |
| | timeout=timedelta(seconds=60), |
| | pg_options=options, |
| | ) |
| | else: |
| | print("torch.distributed is already initialized.") |
| |
|
| | return torch.distributed.group.WORLD, RANK, WORLD_SIZE |