Spaces:
Runtime error
Runtime error
File size: 645 Bytes
bf07f10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torch
def get_device():
"""Dynamically selects CUDA, MPS, or falls back to CPU."""
if torch.cuda.is_available():
return torch.device('cuda')
elif getattr(torch.backends, 'mps', None) is not None and torch.backends.mps.is_available():
return torch.device('mps')
else:
return torch.device('cpu')
def require_mps():
"""Enforces MPS device (for Mac-only scripts)."""
if getattr(torch.backends, 'mps', None) is None or not torch.backends.mps.is_available():
raise RuntimeError('MPS (Apple Silicon) is required but not available.')
return torch.device('mps')
DEVICE = get_device()
|