ISDNet-pytorch / isdnet /utils /distributed.py
Antoine1091's picture
Upload folder using huggingface_hub
49d2955 verified
"""
Distributed training utilities
"""
import os
import torch
import torch.distributed as dist
def setup_distributed():
"""Initialize distributed training."""
if 'RANK' in os.environ:
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ['LOCAL_RANK'])
else:
rank = 0
world_size = 1
local_rank = 0
if world_size > 1:
dist.init_process_group('nccl')
torch.cuda.set_device(local_rank)
return rank, world_size, local_rank
def cleanup_distributed():
"""Cleanup distributed training."""
if dist.is_initialized():
dist.destroy_process_group()
def print_rank0(msg, rank=0):
"""Print only from rank 0."""
if rank == 0:
print(msg)
def batch_mm_loop(a, b):
"""
Batch matrix multiply using a loop over the batch dimension.
Avoids CUBLAS strided batched routines which have issues on L40S/CUDA 12.8/PyTorch 2.10.
Args:
a: Tensor of shape (batch, m, k)
b: Tensor of shape (batch, k, n)
Returns:
Tensor of shape (batch, m, n)
"""
batch = a.shape[0]
results = []
for i in range(batch):
results.append(torch.mm(a[i], b[i]))
return torch.stack(results, dim=0)