| """Minimal DDP helpers driven entirely by torchrun environment variables. |
| |
| Launch with: torchrun --nproc_per_node=<N> framework/train.py ... |
| Single-process (no torchrun) also works: world_size falls back to 1. |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import random |
| from typing import List, Any |
|
|
| import numpy as np |
| import torch |
| import torch.distributed as dist |
|
|
|
|
| def is_dist() -> bool: |
| return dist.is_available() and dist.is_initialized() |
|
|
|
|
| def get_rank() -> int: |
| return dist.get_rank() if is_dist() else 0 |
|
|
|
|
| def get_world_size() -> int: |
| return dist.get_world_size() if is_dist() else 1 |
|
|
|
|
| def is_main() -> bool: |
| return get_rank() == 0 |
|
|
|
|
| def setup_distributed() -> int: |
| """Init the process group if launched under torchrun. Returns local_rank.""" |
| if "RANK" in os.environ and "WORLD_SIZE" in os.environ: |
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| torch.cuda.set_device(local_rank) |
| |
| |
| try: |
| dist.init_process_group(backend="nccl", init_method="env://", |
| device_id=torch.device("cuda", local_rank)) |
| except TypeError: |
| dist.init_process_group(backend="nccl", init_method="env://") |
| dist.barrier(device_ids=[local_rank]) |
| return local_rank |
| |
| if torch.cuda.is_available(): |
| torch.cuda.set_device(0) |
| return 0 |
|
|
|
|
| def cleanup_distributed() -> None: |
| if is_dist(): |
| dist.barrier() |
| dist.destroy_process_group() |
|
|
|
|
| def all_gather_object(obj: Any) -> List[Any]: |
| """Gather arbitrary picklable objects from all ranks into a flat list.""" |
| if not is_dist(): |
| return [obj] |
| out: List[Any] = [None for _ in range(get_world_size())] |
| dist.all_gather_object(out, obj) |
| return out |
|
|
|
|
| def set_seed(seed: int, rank: int = 0, deterministic: bool = False) -> None: |
| """Seed all RNGs. Each rank gets a distinct stream (seed + rank) so DDP |
| workers don't draw identical augmentation noise, while staying reproducible.""" |
| s = seed + rank |
| random.seed(s) |
| np.random.seed(s) |
| torch.manual_seed(s) |
| torch.cuda.manual_seed_all(s) |
| if deterministic: |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
| else: |
| torch.backends.cudnn.benchmark = True |
|
|
|
|
| def print_main(*args, **kwargs) -> None: |
| if is_main(): |
| print(*args, **kwargs, flush=True) |
|
|