|
|
|
|
|
import torch |
|
|
import datetime |
|
|
|
|
|
def get_pytorch_version(): |
|
|
"""Detect PyTorch version for compatibility""" |
|
|
version = torch.__version__ |
|
|
major, minor = map(int, version.split('.')[:2]) |
|
|
return major, minor |
|
|
|
|
|
def setup_timeout(): |
|
|
"""Create timeout compatible with PyTorch version""" |
|
|
major, minor = get_pytorch_version() |
|
|
|
|
|
if major >= 1 and minor >= 10: |
|
|
|
|
|
if hasattr(torch.distributed, 'timedelta'): |
|
|
return torch.distributed.timedelta(seconds=1800) |
|
|
else: |
|
|
|
|
|
return datetime.timedelta(seconds=1800) |
|
|
else: |
|
|
|
|
|
return datetime.timedelta(seconds=1800) |
|
|
|
|
|
def check_environment(): |
|
|
"""Check PyTorch environment and compatibility""" |
|
|
print("=== Environment Check ===") |
|
|
print(f"PyTorch version: {torch.__version__}") |
|
|
print(f"CUDA available: {torch.cuda.is_available()}") |
|
|
if torch.cuda.is_available(): |
|
|
print(f"CUDA version: {torch.version.cuda}") |
|
|
print(f"Number of GPUs: {torch.cuda.device_count()}") |
|
|
|
|
|
if torch.cuda.device_count() > 0: |
|
|
for i in range(torch.cuda.device_count()): |
|
|
print(f"GPU {i}: {torch.cuda.get_device_name(i)}") |
|
|
|
|
|
|
|
|
if hasattr(torch.distributed, 'timedelta'): |
|
|
print("β torch.distributed.timedelta available") |
|
|
else: |
|
|
print("β torch.distributed.timedelta not available - using datetime") |
|
|
|
|
|
|
|
|
critical_attrs = ['init_process_group', 'is_initialized', 'destroy_process_group'] |
|
|
for attr in critical_attrs: |
|
|
if hasattr(torch.distributed, attr): |
|
|
print(f"β torch.distributed.{attr} available") |
|
|
else: |
|
|
print(f"β torch.distributed.{attr} missing!") |
|
|
|
|
|
|
|
|
timeout = setup_timeout() |
|
|
print(f"β Timeout setup: {type(timeout).__name__}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
check_environment() |
|
|
|