"""Minimal DDP helpers driven entirely by torchrun environment variables. Launch with: torchrun --nproc_per_node= framework/train.py ... Single-process (no torchrun) also works: world_size falls back to 1. """ from __future__ import annotations import os import random from typing import List, Any import numpy as np import torch import torch.distributed as dist def is_dist() -> bool: return dist.is_available() and dist.is_initialized() def get_rank() -> int: return dist.get_rank() if is_dist() else 0 def get_world_size() -> int: return dist.get_world_size() if is_dist() else 1 def is_main() -> bool: return get_rank() == 0 def setup_distributed() -> int: """Init the process group if launched under torchrun. Returns local_rank.""" if "RANK" in os.environ and "WORLD_SIZE" in os.environ: local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) # bind this rank's device to the PG so collectives/barrier don't guess # GPU 0 (avoids the NCCL "devices unknown" warning + potential hang) try: dist.init_process_group(backend="nccl", init_method="env://", device_id=torch.device("cuda", local_rank)) except TypeError: # older torch without device_id kwarg dist.init_process_group(backend="nccl", init_method="env://") dist.barrier(device_ids=[local_rank]) return local_rank # single GPU / CPU fallback if torch.cuda.is_available(): torch.cuda.set_device(0) return 0 def cleanup_distributed() -> None: if is_dist(): dist.barrier() dist.destroy_process_group() def all_gather_object(obj: Any) -> List[Any]: """Gather arbitrary picklable objects from all ranks into a flat list.""" if not is_dist(): return [obj] out: List[Any] = [None for _ in range(get_world_size())] dist.all_gather_object(out, obj) return out def set_seed(seed: int, rank: int = 0, deterministic: bool = False) -> None: """Seed all RNGs. Each rank gets a distinct stream (seed + rank) so DDP workers don't draw identical augmentation noise, while staying reproducible.""" s = seed + rank random.seed(s) np.random.seed(s) torch.manual_seed(s) torch.cuda.manual_seed_all(s) if deterministic: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False else: torch.backends.cudnn.benchmark = True def print_main(*args, **kwargs) -> None: if is_main(): print(*args, **kwargs, flush=True)