Signlink / utils.py
gaurannggg7's picture
Deploy production architecture with single zipped asset matrix
876754e
raw
history blame contribute delete
630 Bytes
import os
import torch
def get_best_device() -> torch.device:
"""
Return the fastest available device in this order:
1) TPU (if on Colab with TPU runtime)
2) CUDA GPU
3) Apple MPS
4) CPU
"""
if 'COLAB_TPU_ADDR' in os.environ:
try:
import torch_xla.core.xla_model as xm
return xm.xla_device()
except ImportError:
pass
if torch.cuda.is_available():
return torch.device('cuda')
if getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available():
return torch.device('mps')
return torch.device('cpu')