File size: 1,292 Bytes
a585f5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""
Distributed training utilities
"""

import os
import torch
import torch.distributed as dist


def setup_distributed():
    """Initialize distributed training."""
    if 'RANK' in os.environ:
        rank = int(os.environ['RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
        local_rank = int(os.environ['LOCAL_RANK'])
    else:
        rank = 0
        world_size = 1
        local_rank = 0

    if world_size > 1:
        dist.init_process_group('nccl')
        torch.cuda.set_device(local_rank)

    return rank, world_size, local_rank


def cleanup_distributed():
    """Cleanup distributed training."""
    if dist.is_initialized():
        dist.destroy_process_group()


def print_rank0(msg, rank=0):
    """Print only from rank 0."""
    if rank == 0:
        print(msg)


def batch_mm_loop(a, b):
    """
    Batch matrix multiply using a loop over the batch dimension.
    Avoids CUBLAS strided batched routines which have issues on L40S/CUDA 12.8/PyTorch 2.10.

    Args:
        a: Tensor of shape (batch, m, k)
        b: Tensor of shape (batch, k, n)

    Returns:
        Tensor of shape (batch, m, n)
    """
    batch = a.shape[0]
    results = []
    for i in range(batch):
        results.append(torch.mm(a[i], b[i]))
    return torch.stack(results, dim=0)