import torch if torch.backends.mps.is_available(): DEVICE = torch.device("mps") elif torch.cuda.is_available(): DEVICE = torch.device("cuda") else: DEVICE = torch.device("cpu")