|
|
|
|
|
import os |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import logging |
|
|
from typing import Tuple, Optional |
|
|
from compatibility_utils import setup_timeout |
|
|
|
|
|
def setup_distributed() -> Tuple[bool, int, int, int, torch.device]: |
|
|
""" |
|
|
First-principles DDP setup with single source of truth for device mapping. |
|
|
Returns: (is_ddp, rank, local_rank, world_size, device) |
|
|
""" |
|
|
ddp = "RANK" in os.environ and "WORLD_SIZE" in os.environ |
|
|
|
|
|
if ddp: |
|
|
|
|
|
if not dist.is_initialized(): |
|
|
|
|
|
timeout = setup_timeout() |
|
|
dist.init_process_group( |
|
|
backend="nccl", |
|
|
timeout=timeout |
|
|
) |
|
|
|
|
|
local_rank = int(os.environ["LOCAL_RANK"]) |
|
|
rank = int(os.environ["RANK"]) |
|
|
world_size = int(os.environ["WORLD_SIZE"]) |
|
|
|
|
|
|
|
|
torch.cuda.set_device(local_rank) |
|
|
device = torch.device(f"cuda:{local_rank}") |
|
|
|
|
|
|
|
|
assert torch.cuda.current_device() == local_rank, \ |
|
|
f"Device mapping error: current={torch.cuda.current_device()}, local_rank={local_rank}" |
|
|
|
|
|
else: |
|
|
local_rank, rank, world_size = 0, 0, 1 |
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
return ddp, rank, local_rank, world_size, device |
|
|
|
|
|
def setup_environment(): |
|
|
"""Set environment variables once at process start""" |
|
|
os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") |
|
|
os.environ.setdefault("NCCL_IB_DISABLE", "1") |
|
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
|
|
|
|
|
|
if "NCCL_ASYNC_ERROR_HANDLING" in os.environ: |
|
|
del os.environ["NCCL_ASYNC_ERROR_HANDLING"] |
|
|
|
|
|
|
|
|
|
|
|
def cleanup_distributed(): |
|
|
"""Clean shutdown of process group""" |
|
|
if dist.is_available() and dist.is_initialized(): |
|
|
try: |
|
|
dist.barrier() |
|
|
dist.destroy_process_group() |
|
|
except Exception as e: |
|
|
logging.warning(f"Cleanup warning: {e}") |
|
|
|
|
|
class RankZeroOnly: |
|
|
"""Context manager for rank-0 only execution""" |
|
|
def __init__(self, is_main: bool): |
|
|
self.is_main = is_main |
|
|
self.original_level = None |
|
|
|
|
|
def __enter__(self): |
|
|
if not self.is_main: |
|
|
|
|
|
self.original_level = logging.getLogger().getEffectiveLevel() |
|
|
logging.getLogger().setLevel(logging.WARNING) |
|
|
return self |
|
|
|
|
|
def __exit__(self, *args): |
|
|
if not self.is_main and self.original_level is not None: |
|
|
logging.getLogger().setLevel(self.original_level) |
|
|
|
|
|
def print(self, *args, **kwargs): |
|
|
if self.is_main: |
|
|
print(*args, **kwargs) |
|
|
|