| import importlib.util |
|
|
| import torch |
| import torch.distributed as dist |
|
|
| try: |
| |
| if importlib.util.find_spec("paifuser") is not None: |
| import paifuser |
| from paifuser.xfuser.core.distributed import ( |
| get_sequence_parallel_rank, get_sequence_parallel_world_size, |
| get_sp_group, get_world_group, init_distributed_environment, |
| initialize_model_parallel) |
| from paifuser.xfuser.core.long_ctx_attention import \ |
| xFuserLongContextAttention |
| print("Import PAI DiT Turbo") |
| else: |
| import xfuser |
| from xfuser.core.distributed import (get_sequence_parallel_rank, |
| get_sequence_parallel_world_size, |
| get_sp_group, get_world_group, |
| init_distributed_environment, |
| initialize_model_parallel) |
| from xfuser.core.long_ctx_attention import xFuserLongContextAttention |
| print("Xfuser import sucessful") |
| except Exception as ex: |
| get_sequence_parallel_world_size = None |
| get_sequence_parallel_rank = None |
| xFuserLongContextAttention = None |
| get_sp_group = None |
| get_world_group = None |
| init_distributed_environment = None |
| initialize_model_parallel = None |
|
|
| def set_multi_gpus_devices(ulysses_degree, ring_degree, classifier_free_guidance_degree=1): |
| if ulysses_degree > 1 or ring_degree > 1 or classifier_free_guidance_degree > 1: |
| if get_sp_group is None: |
| raise RuntimeError("xfuser is not installed.") |
| dist.init_process_group("nccl") |
| print('parallel inference enabled: ulysses_degree=%d ring_degree=%d classifier_free_guidance_degree=% rank=%d world_size=%d' % ( |
| ulysses_degree, ring_degree, classifier_free_guidance_degree, dist.get_rank(), |
| dist.get_world_size())) |
| assert dist.get_world_size() == ring_degree * ulysses_degree * classifier_free_guidance_degree, \ |
| "number of GPUs(%d) should be equal to ring_degree * ulysses_degree * classifier_free_guidance_degree." % dist.get_world_size() |
| init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) |
| initialize_model_parallel(sequence_parallel_degree=ring_degree * ulysses_degree, |
| classifier_free_guidance_degree=classifier_free_guidance_degree, |
| ring_degree=ring_degree, |
| ulysses_degree=ulysses_degree) |
| |
| device = torch.device(f"cuda:{get_world_group().local_rank}") |
| print('rank=%d device=%s' % (get_world_group().rank, str(device))) |
| else: |
| device = "cuda" |
| return device |