| | import torch |
| | import torch.multiprocessing as mp |
| | import torch.distributed as dist |
| | import os |
| |
|
| |
|
| | def find_free_port(): |
| | """ https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number """ |
| | import socket |
| | from contextlib import closing |
| |
|
| | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: |
| | s.bind(('', 0)) |
| | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| | return str(s.getsockname()[1]) |
| |
|
| |
|
| | def setup_process(rank, master_addr, master_port, world_size, backend='nccl'): |
| | print(f'setting up {rank=} {world_size=} {backend=}') |
| |
|
| | |
| | os.environ['MASTER_ADDR'] = master_addr |
| | os.environ['MASTER_PORT'] = master_port |
| | print(f"{master_addr=} {master_port=}") |
| |
|
| | |
| | dist.init_process_group(backend, rank=rank, world_size=world_size) |
| | print(f"{rank=} init complete") |
| | dist.destroy_process_group() |
| | print(f"{rank=} destroy complete") |
| |
|
| |
|
| | if __name__ == '__main__': |
| | world_size = 2 |
| | master_addr = '127.0.0.1' |
| | master_port = find_free_port() |
| | mp.spawn(setup_process, |
| | args=( |
| | master_addr, |
| | master_port, |
| | world_size, |
| | ), |
| | nprocs=world_size) |
| |
|