File size: 1,376 Bytes
8f51ef2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | """
CUDA Utilities
"""
import os
import torch
import warnings
def setup_cuda():
"""Setup CUDA with proper error handling."""
# Suppress CUDA warnings
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
# Check if CUDA is available
if not torch.cuda.is_available():
print("CUDA not available, using CPU")
return False
try:
# Test CUDA availability
torch.cuda.empty_cache()
device_count = torch.cuda.device_count()
if device_count > 0:
print(f"CUDA available with {device_count} device(s)")
return True
else:
print("No CUDA devices found")
return False
except RuntimeError as e:
if "CUDA" in str(e) and ("busy" in str(e) or "unavailable" in str(e)):
print("CUDA is busy/unavailable, falling back to CPU")
return False
else:
print(f"CUDA error: {e}")
return False
def get_best_device():
"""Get the best available device."""
if setup_cuda():
return torch.device("cuda")
else:
return torch.device("cpu")
def suppress_cuda_warnings():
"""Suppress CUDA warnings."""
warnings.filterwarnings("ignore", category=UserWarning, module="torch.cuda")
warnings.filterwarnings("ignore", message=".*CUDA.*")
|