| """ | |
| Distributed training utilities | |
| """ | |
| import os | |
| import torch | |
| import torch.distributed as dist | |
| def setup_distributed(): | |
| """Initialize distributed training.""" | |
| if 'RANK' in os.environ: | |
| rank = int(os.environ['RANK']) | |
| world_size = int(os.environ['WORLD_SIZE']) | |
| local_rank = int(os.environ['LOCAL_RANK']) | |
| else: | |
| rank = 0 | |
| world_size = 1 | |
| local_rank = 0 | |
| if world_size > 1: | |
| dist.init_process_group('nccl') | |
| torch.cuda.set_device(local_rank) | |
| return rank, world_size, local_rank | |
| def cleanup_distributed(): | |
| """Cleanup distributed training.""" | |
| if dist.is_initialized(): | |
| dist.destroy_process_group() | |
| def print_rank0(msg, rank=0): | |
| """Print only from rank 0.""" | |
| if rank == 0: | |
| print(msg) | |
| def batch_mm_loop(a, b): | |
| """ | |
| Batch matrix multiply using a loop over the batch dimension. | |
| Avoids CUBLAS strided batched routines which have issues on L40S/CUDA 12.8/PyTorch 2.10. | |
| Args: | |
| a: Tensor of shape (batch, m, k) | |
| b: Tensor of shape (batch, k, n) | |
| Returns: | |
| Tensor of shape (batch, m, n) | |
| """ | |
| batch = a.shape[0] | |
| results = [] | |
| for i in range(batch): | |
| results.append(torch.mm(a[i], b[i])) | |
| return torch.stack(results, dim=0) | |