""" Distributed training utilities """ import os import torch import torch.distributed as dist def setup_distributed(): """Initialize distributed training.""" if 'RANK' in os.environ: rank = int(os.environ['RANK']) world_size = int(os.environ['WORLD_SIZE']) local_rank = int(os.environ['LOCAL_RANK']) else: rank = 0 world_size = 1 local_rank = 0 if world_size > 1: dist.init_process_group('nccl') torch.cuda.set_device(local_rank) return rank, world_size, local_rank def cleanup_distributed(): """Cleanup distributed training.""" if dist.is_initialized(): dist.destroy_process_group() def print_rank0(msg, rank=0): """Print only from rank 0.""" if rank == 0: print(msg) def batch_mm_loop(a, b): """ Batch matrix multiply using a loop over the batch dimension. Avoids CUBLAS strided batched routines which have issues on L40S/CUDA 12.8/PyTorch 2.10. Args: a: Tensor of shape (batch, m, k) b: Tensor of shape (batch, k, n) Returns: Tensor of shape (batch, m, n) """ batch = a.shape[0] results = [] for i in range(batch): results.append(torch.mm(a[i], b[i])) return torch.stack(results, dim=0)