File size: 1,292 Bytes
a585f5a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | """
Distributed training utilities
"""
import os
import torch
import torch.distributed as dist
def setup_distributed():
"""Initialize distributed training."""
if 'RANK' in os.environ:
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ['LOCAL_RANK'])
else:
rank = 0
world_size = 1
local_rank = 0
if world_size > 1:
dist.init_process_group('nccl')
torch.cuda.set_device(local_rank)
return rank, world_size, local_rank
def cleanup_distributed():
"""Cleanup distributed training."""
if dist.is_initialized():
dist.destroy_process_group()
def print_rank0(msg, rank=0):
"""Print only from rank 0."""
if rank == 0:
print(msg)
def batch_mm_loop(a, b):
"""
Batch matrix multiply using a loop over the batch dimension.
Avoids CUBLAS strided batched routines which have issues on L40S/CUDA 12.8/PyTorch 2.10.
Args:
a: Tensor of shape (batch, m, k)
b: Tensor of shape (batch, k, n)
Returns:
Tensor of shape (batch, m, n)
"""
batch = a.shape[0]
results = []
for i in range(batch):
results.append(torch.mm(a[i], b[i]))
return torch.stack(results, dim=0)
|