File size: 344 Bytes
7344bef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | """MPS compatibility helpers and local test scripts."""
def is_mps_available():
import torch
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
def mps_device():
import torch
return torch.device("mps") if is_mps_available() else None
def mps_device_or(default):
return mps_device() or default
|