File size: 2,062 Bytes
36ac84e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# compatibility_utils.py
import torch
import datetime

def get_pytorch_version():
    """Detect PyTorch version for compatibility"""
    version = torch.__version__
    major, minor = map(int, version.split('.')[:2])
    return major, minor

def setup_timeout():
    """Create timeout compatible with PyTorch version"""
    major, minor = get_pytorch_version()
    
    if major >= 1 and minor >= 10:
        # Use modern timedelta for newer PyTorch
        if hasattr(torch.distributed, 'timedelta'):
            return torch.distributed.timedelta(seconds=1800)
        else:
            # Fallback to datetime
            return datetime.timedelta(seconds=1800)
    else:
        # Use datetime for older versions
        return datetime.timedelta(seconds=1800)

def check_environment():
    """Check PyTorch environment and compatibility"""
    print("=== Environment Check ===")
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA version: {torch.version.cuda}")
        print(f"Number of GPUs: {torch.cuda.device_count()}")
        
        if torch.cuda.device_count() > 0:
            for i in range(torch.cuda.device_count()):
                print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    
    # Check distributed module
    if hasattr(torch.distributed, 'timedelta'):
        print("✓ torch.distributed.timedelta available")
    else:
        print("✗ torch.distributed.timedelta not available - using datetime")
    
    # Check critical attributes
    critical_attrs = ['init_process_group', 'is_initialized', 'destroy_process_group']
    for attr in critical_attrs:
        if hasattr(torch.distributed, attr):
            print(f"✓ torch.distributed.{attr} available")
        else:
            print(f"✗ torch.distributed.{attr} missing!")
    
    # Check timeout compatibility
    timeout = setup_timeout()
    print(f"✓ Timeout setup: {type(timeout).__name__}")

if __name__ == "__main__":
    check_environment()