MedAI-ACM / src /utils /device_utils.py
Tirath5504's picture
deploy
bf07f10
raw
history blame contribute delete
645 Bytes
import torch
def get_device():
"""Dynamically selects CUDA, MPS, or falls back to CPU."""
if torch.cuda.is_available():
return torch.device('cuda')
elif getattr(torch.backends, 'mps', None) is not None and torch.backends.mps.is_available():
return torch.device('mps')
else:
return torch.device('cpu')
def require_mps():
"""Enforces MPS device (for Mac-only scripts)."""
if getattr(torch.backends, 'mps', None) is None or not torch.backends.mps.is_available():
raise RuntimeError('MPS (Apple Silicon) is required but not available.')
return torch.device('mps')
DEVICE = get_device()