File size: 2,062 Bytes
36ac84e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
# compatibility_utils.py
import torch
import datetime
def get_pytorch_version():
"""Detect PyTorch version for compatibility"""
version = torch.__version__
major, minor = map(int, version.split('.')[:2])
return major, minor
def setup_timeout():
"""Create timeout compatible with PyTorch version"""
major, minor = get_pytorch_version()
if major >= 1 and minor >= 10:
# Use modern timedelta for newer PyTorch
if hasattr(torch.distributed, 'timedelta'):
return torch.distributed.timedelta(seconds=1800)
else:
# Fallback to datetime
return datetime.timedelta(seconds=1800)
else:
# Use datetime for older versions
return datetime.timedelta(seconds=1800)
def check_environment():
"""Check PyTorch environment and compatibility"""
print("=== Environment Check ===")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA version: {torch.version.cuda}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
if torch.cuda.device_count() > 0:
for i in range(torch.cuda.device_count()):
print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
# Check distributed module
if hasattr(torch.distributed, 'timedelta'):
print("✓ torch.distributed.timedelta available")
else:
print("✗ torch.distributed.timedelta not available - using datetime")
# Check critical attributes
critical_attrs = ['init_process_group', 'is_initialized', 'destroy_process_group']
for attr in critical_attrs:
if hasattr(torch.distributed, attr):
print(f"✓ torch.distributed.{attr} available")
else:
print(f"✗ torch.distributed.{attr} missing!")
# Check timeout compatibility
timeout = setup_timeout()
print(f"✓ Timeout setup: {type(timeout).__name__}")
if __name__ == "__main__":
check_environment()
|