File size: 344 Bytes
7344bef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"""MPS compatibility helpers and local test scripts."""


def is_mps_available():
    import torch
    return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()


def mps_device():
    import torch
    return torch.device("mps") if is_mps_available() else None


def mps_device_or(default):
    return mps_device() or default