File size: 645 Bytes
bf07f10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch

def get_device():
    """Dynamically selects CUDA, MPS, or falls back to CPU."""
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif getattr(torch.backends, 'mps', None) is not None and torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')

def require_mps():
    """Enforces MPS device (for Mac-only scripts)."""
    if getattr(torch.backends, 'mps', None) is None or not torch.backends.mps.is_available():
        raise RuntimeError('MPS (Apple Silicon) is required but not available.')
    return torch.device('mps')

DEVICE = get_device()