import torch def get_device(): if torch.cuda.is_available(): return "cuda" elif torch.backends.mps.is_available(): return "mps" else: return "cpu" def set_seed(seed): torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed)