| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| from contextlib import contextmanager |
|
|
| import torch |
|
|
|
|
| def init_distributed(cuda): |
| """ |
| Initializes distributed backend. |
| :param cuda: (bool) if True initializes nccl backend, if False initializes |
| gloo backend |
| """ |
| world_size = int(os.environ.get('WORLD_SIZE', 1)) |
| distributed = (world_size > 1) |
| if distributed: |
| backend = 'nccl' if cuda else 'gloo' |
| torch.distributed.init_process_group(backend=backend, |
| init_method='env://') |
| assert torch.distributed.is_initialized() |
| return distributed |
|
|
|
|
| def barrier(): |
| """ |
| Call torch.distributed.barrier() if distritubed is in use |
| """ |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): |
| torch.distributed.barrier() |
|
|
|
|
| def get_rank(): |
| """ |
| Gets distributed rank or returns zero if distributed is not initialized. |
| """ |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): |
| rank = torch.distributed.get_rank() |
| else: |
| rank = 0 |
| return rank |
|
|
|
|
| def get_world_size(): |
| """ |
| Gets total number of distributed workers or returns one if distributed is |
| not initialized. |
| """ |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): |
| world_size = torch.distributed.get_world_size() |
| else: |
| world_size = 1 |
| return world_size |
|
|
|
|
| def all_reduce_item(value, op='sum'): |
| """ |
| All-reduces single scalar value if distributed is in use |
| """ |
| if torch.distributed.is_available() and torch.distributed.is_initialized(): |
| if op == 'sum' or op == 'mean': |
| dop = torch.distributed.ReduceOp.SUM |
| elif op == 'min': |
| dop = torch.distributed.ReduceOp.MIN |
| elif op == 'max': |
| dop = torch.distributed.ReduceOp.MAX |
| elif op == 'product': |
| dop = torch.distributed.ReduceOp.PRODUCT |
| else: |
| raise RuntimeError('Unsupported reduce op') |
|
|
| backend = torch.distributed.get_backend() |
| if backend == torch.distributed.Backend.NCCL: |
| device = torch.device('cuda') |
| elif backend == torch.distributed.Backend.GLOO: |
| device = torch.device('cpu') |
| else: |
| raise RuntimeError('Unsupported distributed backend') |
|
|
| tensor = torch.tensor(value, device=device) |
| torch.distributed.all_reduce(tensor, dop) |
| if op == 'mean': |
| tensor /= get_world_size() |
| ret = tensor.item() |
| else: |
| ret = value |
| return ret |
|
|
|
|
| @contextmanager |
| def sync_workers(): |
| """ |
| Yields distributed rank and synchronizes all workers on exit. |
| """ |
| rank = get_rank() |
| yield rank |
| barrier() |
|
|