File size: 1,218 Bytes
8bf6610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import math
from loguru import logger
import datetime
import torch
import torch.distributed as dist

def get_parallel_degree(world_size, num_heads):
    # ulysses_degree is faster, and must be a divisor of num_heads
    ulysses_degree = math.gcd(world_size, num_heads)
    ring_degree = world_size // ulysses_degree
    return ulysses_degree, ring_degree

def get_device(ulysses_degree, ring_degree):
    if ulysses_degree > 1 or ring_degree > 1:
        from xfuser.core.distributed import (
            init_distributed_environment,
            initialize_model_parallel,
            get_world_group,
        )

        dist.init_process_group("nccl", timeout=datetime.timedelta(hours=24*7))
        init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
        initialize_model_parallel(
            sequence_parallel_degree=dist.get_world_size(),
            ring_degree=ring_degree, 
            ulysses_degree=ulysses_degree
        )

        device = torch.device(f"cuda:{get_world_group().rank}")
        torch.cuda.set_device(get_world_group().rank)

        logger.info(f'rank={get_world_group().rank} device={str(device)}')
    else:
        device = "cuda"
    return device