| | |
| | |
| | |
| | |
| |
|
| | from datetime import timedelta |
| | from typing import Any, Optional, Union |
| |
|
| | import torch |
| | import torch.distributed |
| | from torch.distributed.distributed_c10d import ( |
| | Backend, |
| | PrefixStore, |
| | Store, |
| | _new_process_group_helper, |
| | _world, |
| | default_pg_timeout, |
| | rendezvous, |
| | ) |
| |
|
| |
|
| | def torch_dist_barrier_and_cuda_sync(): |
| | """Synchronize distributed training and CUDA operations. |
| | This function ensures that: |
| | 1. All distributed processes reach this point (barrier) |
| | 2. All CUDA operations are completed (synchronize) |
| | """ |
| | torch.distributed.barrier() |
| | torch.cuda.synchronize() |
| |
|
| |
|
| | |
| | |
| | def init_process_group( |
| | backend: Union[str, Backend] = None, |
| | init_method: Optional[str] = None, |
| | timeout: Optional[timedelta] = None, |
| | world_size: int = -1, |
| | rank: int = -1, |
| | store: Optional[Store] = None, |
| | group_name: str = None, |
| | pg_options: Optional[Any] = None, |
| | ): |
| | assert (store is None) or (init_method is None), "Cannot specify both init_method and store." |
| |
|
| | if store is not None: |
| | assert world_size > 0, "world_size must be positive if using store" |
| | assert rank >= 0, "rank must be non-negative if using store" |
| | elif init_method is None: |
| | init_method = "env://" |
| |
|
| | if backend: |
| | backend = Backend(backend) |
| | else: |
| | backend = Backend("undefined") |
| |
|
| | if timeout is None: |
| | timeout = default_pg_timeout |
| |
|
| | |
| | if store is None: |
| | rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) |
| | store, rank, world_size = next(rendezvous_iterator) |
| | store.set_timeout(timeout) |
| |
|
| | |
| | |
| | store = PrefixStore(group_name, store) |
| |
|
| | |
| | |
| | |
| | pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" |
| | pg, _ = _new_process_group_helper( |
| | world_size, |
| | rank, |
| | [], |
| | backend, |
| | store, |
| | group_name=group_name, |
| | **{pg_options_param_name: pg_options}, |
| | timeout=timeout, |
| | ) |
| |
|
| | _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} |
| |
|
| | return pg |
| |
|