| | import os |
| | import io |
| | from contextlib import contextmanager |
| | import torch |
| | import torch.distributed as dist |
| | from torch.nn.parallel import DistributedDataParallel as DDP |
| |
|
| |
|
| | def setup_dist(rank, local_rank, world_size, master_addr, master_port): |
| | os.environ['MASTER_ADDR'] = master_addr |
| | os.environ['MASTER_PORT'] = master_port |
| | os.environ['WORLD_SIZE'] = str(world_size) |
| | os.environ['RANK'] = str(rank) |
| | os.environ['LOCAL_RANK'] = str(local_rank) |
| | torch.cuda.set_device(local_rank) |
| | dist.init_process_group('nccl', rank=rank, world_size=world_size) |
| | |
| |
|
| | def read_file_dist(path): |
| | """ |
| | Read the binary file distributedly. |
| | File is only read once by the rank 0 process and broadcasted to other processes. |
| | |
| | Returns: |
| | data (io.BytesIO): The binary data read from the file. |
| | """ |
| | if dist.is_initialized() and dist.get_world_size() > 1: |
| | |
| | size = torch.LongTensor(1).cuda() |
| | if dist.get_rank() == 0: |
| | with open(path, 'rb') as f: |
| | data = f.read() |
| | data = torch.ByteTensor( |
| | torch.UntypedStorage.from_buffer(data, dtype=torch.uint8) |
| | ).cuda() |
| | size[0] = data.shape[0] |
| | |
| | dist.broadcast(size, src=0) |
| | if dist.get_rank() != 0: |
| | data = torch.ByteTensor(size[0].item()).cuda() |
| | |
| | dist.broadcast(data, src=0) |
| | |
| | data = data.cpu().numpy().tobytes() |
| | data = io.BytesIO(data) |
| | return data |
| | else: |
| | with open(path, 'rb') as f: |
| | data = f.read() |
| | data = io.BytesIO(data) |
| | return data |
| | |
| |
|
| | def unwrap_dist(model): |
| | """ |
| | Unwrap the model from distributed training. |
| | """ |
| | if isinstance(model, DDP): |
| | return model.module |
| | return model |
| |
|
| |
|
| | @contextmanager |
| | def master_first(): |
| | """ |
| | A context manager that ensures master process executes first. |
| | """ |
| | if not dist.is_initialized(): |
| | yield |
| | else: |
| | if dist.get_rank() == 0: |
| | yield |
| | dist.barrier() |
| | else: |
| | dist.barrier() |
| | yield |
| | |
| |
|
| | @contextmanager |
| | def local_master_first(): |
| | """ |
| | A context manager that ensures local master process executes first. |
| | """ |
| | if not dist.is_initialized(): |
| | yield |
| | else: |
| | if dist.get_rank() % torch.cuda.device_count() == 0: |
| | yield |
| | dist.barrier() |
| | else: |
| | dist.barrier() |
| | yield |
| | |