|
|
import os
|
|
|
|
|
|
import torch
|
|
|
import socket
|
|
|
|
|
|
try:
|
|
|
import horovod.torch as hvd
|
|
|
except ImportError:
|
|
|
hvd = None
|
|
|
|
|
|
|
|
|
def is_global_master(args):
|
|
|
return args.rank == 0
|
|
|
|
|
|
|
|
|
def is_local_master(args):
|
|
|
return args.local_rank == 0
|
|
|
|
|
|
|
|
|
def is_master(args, local=False):
|
|
|
return is_local_master(args) if local else is_global_master(args)
|
|
|
|
|
|
|
|
|
def is_using_horovod():
|
|
|
|
|
|
|
|
|
ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
|
|
|
pmi_vars = ["PMI_RANK", "PMI_SIZE"]
|
|
|
if all([var in os.environ for var in ompi_vars]) or all(
|
|
|
[var in os.environ for var in pmi_vars]
|
|
|
):
|
|
|
return True
|
|
|
else:
|
|
|
return False
|
|
|
|
|
|
|
|
|
def is_using_distributed():
|
|
|
if "WORLD_SIZE" in os.environ:
|
|
|
return int(os.environ["WORLD_SIZE"]) > 1
|
|
|
if "SLURM_NTASKS" in os.environ:
|
|
|
return int(os.environ["SLURM_NTASKS"]) > 1
|
|
|
return False
|
|
|
|
|
|
|
|
|
def world_info_from_env():
|
|
|
local_rank = 0
|
|
|
for v in (
|
|
|
"SLURM_LOCALID",
|
|
|
"MPI_LOCALRANKID",
|
|
|
"OMPI_COMM_WORLD_LOCAL_RANK",
|
|
|
"LOCAL_RANK",
|
|
|
):
|
|
|
if v in os.environ:
|
|
|
local_rank = int(os.environ[v])
|
|
|
break
|
|
|
global_rank = 0
|
|
|
for v in ("SLURM_PROCID", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "RANK"):
|
|
|
if v in os.environ:
|
|
|
global_rank = int(os.environ[v])
|
|
|
break
|
|
|
world_size = 1
|
|
|
for v in ("SLURM_NTASKS", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "WORLD_SIZE"):
|
|
|
if v in os.environ:
|
|
|
world_size = int(os.environ[v])
|
|
|
break
|
|
|
|
|
|
return local_rank, global_rank, world_size
|
|
|
|
|
|
|
|
|
def init_distributed_device(args):
|
|
|
|
|
|
|
|
|
args.distributed = False
|
|
|
args.world_size = 1
|
|
|
args.rank = 0
|
|
|
args.local_rank = 0
|
|
|
if args.horovod:
|
|
|
assert hvd is not None, "Horovod is not installed"
|
|
|
hvd.init()
|
|
|
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
|
|
|
world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
|
|
|
local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
|
|
|
args.local_rank = local_rank
|
|
|
args.rank = world_rank
|
|
|
args.world_size = world_size
|
|
|
|
|
|
|
|
|
|
|
|
args.distributed = True
|
|
|
os.environ["LOCAL_RANK"] = str(args.local_rank)
|
|
|
os.environ["RANK"] = str(args.rank)
|
|
|
os.environ["WORLD_SIZE"] = str(args.world_size)
|
|
|
print(
|
|
|
f"Distributed training: local_rank={args.local_rank}, "
|
|
|
f"rank={args.rank}, world_size={args.world_size}, "
|
|
|
f"hostname={socket.gethostname()}, pid={os.getpid()}"
|
|
|
)
|
|
|
elif is_using_distributed():
|
|
|
if "SLURM_PROCID" in os.environ:
|
|
|
|
|
|
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
|
|
|
|
|
os.environ["LOCAL_RANK"] = str(args.local_rank)
|
|
|
os.environ["RANK"] = str(args.rank)
|
|
|
os.environ["WORLD_SIZE"] = str(args.world_size)
|
|
|
torch.distributed.init_process_group(
|
|
|
backend=args.dist_backend,
|
|
|
init_method=args.dist_url,
|
|
|
world_size=args.world_size,
|
|
|
rank=args.rank,
|
|
|
)
|
|
|
elif "OMPI_COMM_WORLD_SIZE" in os.environ:
|
|
|
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
|
|
|
world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
|
|
|
local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
|
|
|
args.local_rank = local_rank
|
|
|
args.rank = world_rank
|
|
|
args.world_size = world_size
|
|
|
torch.distributed.init_process_group(
|
|
|
backend=args.dist_backend,
|
|
|
init_method=args.dist_url,
|
|
|
world_size=args.world_size,
|
|
|
rank=args.rank,
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
args.local_rank, _, _ = world_info_from_env()
|
|
|
torch.distributed.init_process_group(
|
|
|
backend=args.dist_backend, init_method=args.dist_url
|
|
|
)
|
|
|
args.world_size = torch.distributed.get_world_size()
|
|
|
args.rank = torch.distributed.get_rank()
|
|
|
args.distributed = True
|
|
|
print(
|
|
|
f"Distributed training: local_rank={args.local_rank}, "
|
|
|
f"rank={args.rank}, world_size={args.world_size}, "
|
|
|
f"hostname={socket.gethostname()}, pid={os.getpid()}"
|
|
|
)
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
if args.distributed and not args.no_set_device_rank:
|
|
|
device = "cuda:%d" % args.local_rank
|
|
|
else:
|
|
|
device = "cuda:0"
|
|
|
torch.cuda.set_device(device)
|
|
|
else:
|
|
|
device = "cpu"
|
|
|
args.device = device
|
|
|
device = torch.device(device)
|
|
|
return device
|
|
|
|