|
|
| import os |
| import torch |
| import torch.distributed as dist |
|
|
|
|
| def get_global_rank() -> int: |
| """ |
| Get the global rank, the global index of the GPU. |
| """ |
| return int(os.environ.get("RANK", "0")) |
|
|
|
|
| def get_local_rank() -> int: |
| """ |
| Get the local rank, the local index of the GPU. |
| """ |
| return int(os.environ.get("LOCAL_RANK", "0")) |
|
|
|
|
| def get_world_size() -> int: |
| """ |
| Get the world size, the total amount of GPUs. |
| """ |
| return int(os.environ.get("WORLD_SIZE", "1")) |
|
|
|
|
| def get_device() -> torch.device: |
| """ |
| Get current rank device. |
| """ |
| return torch.device("cuda", get_local_rank()) |
|
|
| def get_sequence_parallel_group(): |
| """Get the sequence parallel group the caller rank belongs to.""" |
| return _SEQUENCE_PARALLEL_GROUP |
|
|
| def initialize_sequence_parallelism(sequence_parallel_size): |
| assert int(get_world_size()) % sequence_parallel_size == 0 |
| sequence_parallel_num_groups = int(get_world_size()) // sequence_parallel_size |
| global _SEQUENCE_PARALLEL_GROUP |
| for i in range(sequence_parallel_num_groups): |
| ranks = range(i * sequence_parallel_size, |
| (i + 1) * sequence_parallel_size) |
| group = torch.distributed.new_group(ranks) |
| if int(get_global_rank()) in ranks: |
| print(f"Rank {get_global_rank()} joined group with ranks {list(ranks)}") |
| _SEQUENCE_PARALLEL_GROUP = group |