| """ |
| Helpers for distributed training. |
| """ |
|
|
| import socket |
|
|
| import torch as th |
| import torch.distributed as dist |
|
|
| |
| |
| GPUS_PER_NODE = 8 |
|
|
| SETUP_RETRY_COUNT = 3 |
|
|
| used_device = 0 |
|
|
| def setup_dist(device=0): |
| """ |
| Setup a distributed process group. |
| """ |
| global used_device |
| used_device = device |
| if dist.is_initialized(): |
| return |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
|
|
| def dev(): |
| """ |
| Get the device to use for torch.distributed. |
| """ |
| global used_device |
| if th.cuda.is_available() and used_device>=0: |
| return th.device(f"cuda:{used_device}") |
| return th.device("cpu") |
|
|
|
|
| def load_state_dict(path, **kwargs): |
| """ |
| Load a PyTorch file without redundant fetches across MPI ranks. |
| """ |
| return th.load(path, **kwargs) |
|
|
|
|
| def sync_params(params): |
| """ |
| Synchronize a sequence of Tensors across ranks from rank 0. |
| """ |
| for p in params: |
| with th.no_grad(): |
| dist.broadcast(p, 0) |
|
|
|
|
| def _find_free_port(): |
| try: |
| s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| s.bind(("", 0)) |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| return s.getsockname()[1] |
| finally: |
| s.close() |
|
|