|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
import torch |
|
|
from torch import distributed as dist |
|
|
|
|
|
|
|
|
def setup_dist( |
|
|
rank=None, world_size=None, master_port=None, use_ddp_launch=False, master_addr=None |
|
|
): |
|
|
""" |
|
|
rank and world_size are used only if use_ddp_launch is False. |
|
|
""" |
|
|
if "MASTER_ADDR" not in os.environ: |
|
|
os.environ["MASTER_ADDR"] = ( |
|
|
"localhost" if master_addr is None else str(master_addr) |
|
|
) |
|
|
|
|
|
if "MASTER_PORT" not in os.environ: |
|
|
os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port) |
|
|
|
|
|
if use_ddp_launch is False: |
|
|
dist.init_process_group("nccl", rank=rank, world_size=world_size) |
|
|
torch.cuda.set_device(rank) |
|
|
else: |
|
|
dist.init_process_group("nccl") |
|
|
|
|
|
|
|
|
def cleanup_dist(): |
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
def get_world_size(): |
|
|
if "WORLD_SIZE" in os.environ: |
|
|
return int(os.environ["WORLD_SIZE"]) |
|
|
if dist.is_available() and dist.is_initialized(): |
|
|
return dist.get_world_size() |
|
|
else: |
|
|
return 1 |
|
|
|
|
|
|
|
|
def get_rank(): |
|
|
if "RANK" in os.environ: |
|
|
return int(os.environ["RANK"]) |
|
|
elif dist.is_available() and dist.is_initialized(): |
|
|
return dist.get_rank() |
|
|
else: |
|
|
return 0 |
|
|
|
|
|
|
|
|
def get_local_rank(): |
|
|
return int(os.environ.get("LOCAL_RANK", 0)) |
|
|
|