metalmind / src /utils /distributed.py
IELTS8's picture
Upload folder using huggingface_hub
ada3f28 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os
import torch
import torch.distributed as dist
from logging import getLogger
logger = getLogger()
# dist.init_process_group(backend='nccl')
def init_distributed(port=37123, rank_and_world_size=(None, None)):
if dist.is_available() and dist.is_initialized():
return dist.get_world_size(), dist.get_rank()
# logger.info(f'Rank: {rank}')
rank, world_size = rank_and_world_size
os.environ['MASTER_ADDR'] = 'localhost'
if (rank is None) or (world_size is None):
try:
world_size = int(os.environ['SLURM_NTASKS'])
rank = int(os.environ['SLURM_PROCID'])
os.environ['MASTER_ADDR'] = os.environ['HOSTNAME']
except Exception:
logger.info('SLURM vars not set (distributed training not available)')
world_size, rank = 1, 0
return world_size, rank
try:
os.environ['MASTER_PORT'] = str(port)
torch.distributed.init_process_group(
backend='nccl',
world_size=world_size,
rank=rank
)
except Exception as e:
world_size, rank = 1, 0
logger.info(f'Rank: {rank}. Distributed training not available {e}')
return world_size, rank
class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
if (
dist.is_available()
and dist.is_initialized()
and (dist.get_world_size() > 1)
):
x = x.contiguous()
outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
dist.all_gather(outputs, x)
return torch.cat(outputs, 0)
return x
@staticmethod
def backward(ctx, grads):
if (
dist.is_available()
and dist.is_initialized()
and (dist.get_world_size() > 1)
):
s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank()
e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1)
grads = grads.contiguous()
dist.all_reduce(grads)
return grads[s:e]
return grads
class AllReduceSum(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
if (
dist.is_available()
and dist.is_initialized()
and (dist.get_world_size() > 1)
):
x = x.contiguous()
dist.all_reduce(x)
return x
@staticmethod
def backward(ctx, grads):
return grads
class AllReduce(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
if (
dist.is_available()
and dist.is_initialized()
and (dist.get_world_size() > 1)
):
x = x.contiguous() / dist.get_world_size()
dist.all_reduce(x)
return x
@staticmethod
def backward(ctx, grads):
return grads