| import torch |
| import logging |
| from packaging import version |
| import torch.backends |
| import torch.backends.mps |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def check_for_mps() -> bool: |
| if version.parse(torch.__version__) <= version.parse("2.0.1"): |
| if not getattr(torch, "has_mps", False): |
| return False |
| try: |
| torch.zeros(1).to(torch.device("mps")) |
| return True |
| except Exception: |
| return False |
| else: |
| try: |
| return torch.backends.mps.is_available() and torch.backends.mps.is_built() |
| except: |
| logger.warning("MPS garbage collection failed", exc_info=True) |
| return False |
|
|
|
|
| has_mps = check_for_mps() |
|
|
|
|
| def torch_mps_gc() -> None: |
| try: |
| from torch.mps import empty_cache |
|
|
| empty_cache() |
| except Exception: |
| logger.warning("MPS garbage collection failed", exc_info=True) |
|
|
|
|
| if __name__ == "__main__": |
| print(torch.__version__) |
| print(has_mps) |
| torch_mps_gc() |
|
|