GenSeg-Baselines / code /framework /engine /distributed.py
MaybeRichard's picture
Upload folder using huggingface_hub
b8fae22 verified
Raw
History Blame Contribute Delete
2.62 kB
"""Minimal DDP helpers driven entirely by torchrun environment variables.
Launch with: torchrun --nproc_per_node=<N> framework/train.py ...
Single-process (no torchrun) also works: world_size falls back to 1.
"""
from __future__ import annotations
import os
import random
from typing import List, Any
import numpy as np
import torch
import torch.distributed as dist
def is_dist() -> bool:
return dist.is_available() and dist.is_initialized()
def get_rank() -> int:
return dist.get_rank() if is_dist() else 0
def get_world_size() -> int:
return dist.get_world_size() if is_dist() else 1
def is_main() -> bool:
return get_rank() == 0
def setup_distributed() -> int:
"""Init the process group if launched under torchrun. Returns local_rank."""
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
# bind this rank's device to the PG so collectives/barrier don't guess
# GPU 0 (avoids the NCCL "devices unknown" warning + potential hang)
try:
dist.init_process_group(backend="nccl", init_method="env://",
device_id=torch.device("cuda", local_rank))
except TypeError: # older torch without device_id kwarg
dist.init_process_group(backend="nccl", init_method="env://")
dist.barrier(device_ids=[local_rank])
return local_rank
# single GPU / CPU fallback
if torch.cuda.is_available():
torch.cuda.set_device(0)
return 0
def cleanup_distributed() -> None:
if is_dist():
dist.barrier()
dist.destroy_process_group()
def all_gather_object(obj: Any) -> List[Any]:
"""Gather arbitrary picklable objects from all ranks into a flat list."""
if not is_dist():
return [obj]
out: List[Any] = [None for _ in range(get_world_size())]
dist.all_gather_object(out, obj)
return out
def set_seed(seed: int, rank: int = 0, deterministic: bool = False) -> None:
"""Seed all RNGs. Each rank gets a distinct stream (seed + rank) so DDP
workers don't draw identical augmentation noise, while staying reproducible."""
s = seed + rank
random.seed(s)
np.random.seed(s)
torch.manual_seed(s)
torch.cuda.manual_seed_all(s)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.backends.cudnn.benchmark = True
def print_main(*args, **kwargs) -> None:
if is_main():
print(*args, **kwargs, flush=True)