File size: 190 Bytes
846c883
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
import torch

if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")