File size: 3,048 Bytes
36ac84e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# distributed_utils.py
import os
import torch
import torch.distributed as dist
import logging
from typing import Tuple, Optional
from compatibility_utils import setup_timeout

def setup_distributed() -> Tuple[bool, int, int, int, torch.device]:
    """
    First-principles DDP setup with single source of truth for device mapping.
    Returns: (is_ddp, rank, local_rank, world_size, device)
    """
    ddp = "RANK" in os.environ and "WORLD_SIZE" in os.environ
    
    if ddp:
        # Initialize process group with robust timeout
        if not dist.is_initialized():
            # Use compatibility-aware timeout
            timeout = setup_timeout()
            dist.init_process_group(
                backend="nccl",
                timeout=timeout
            )
        
        local_rank = int(os.environ["LOCAL_RANK"])
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        
        # Critical: Set device BEFORE any CUDA operations
        torch.cuda.set_device(local_rank)
        device = torch.device(f"cuda:{local_rank}")
        
        # Verify device mapping
        assert torch.cuda.current_device() == local_rank, \
            f"Device mapping error: current={torch.cuda.current_device()}, local_rank={local_rank}"
        
    else:
        local_rank, rank, world_size = 0, 0, 1
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    return ddp, rank, local_rank, world_size, device

def setup_environment():
    """Set environment variables once at process start"""
    os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1")  # Modern replacement
    os.environ.setdefault("NCCL_IB_DISABLE", "1")  # Disable InfiniBand on single node
    os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")  # Prevent tokenizer conflicts
    
    # Remove deprecated variables
    if "NCCL_ASYNC_ERROR_HANDLING" in os.environ:
        del os.environ["NCCL_ASYNC_ERROR_HANDLING"]
    
    # Do NOT set NCCL_P2P_DISABLE - allow peer-to-peer on single node

def cleanup_distributed():
    """Clean shutdown of process group"""
    if dist.is_available() and dist.is_initialized():
        try:
            dist.barrier()
            dist.destroy_process_group()
        except Exception as e:
            logging.warning(f"Cleanup warning: {e}")

class RankZeroOnly:
    """Context manager for rank-0 only execution"""
    def __init__(self, is_main: bool):
        self.is_main = is_main
        self.original_level = None
    
    def __enter__(self):
        if not self.is_main:
            # Suppress logging for non-main ranks
            self.original_level = logging.getLogger().getEffectiveLevel()
            logging.getLogger().setLevel(logging.WARNING)
        return self
    
    def __exit__(self, *args):
        if not self.is_main and self.original_level is not None:
            logging.getLogger().setLevel(self.original_level)
    
    def print(self, *args, **kwargs):
        if self.is_main:
            print(*args, **kwargs)