Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import subprocess | |
| def get_device(force_cpu=False): | |
| if force_cpu: | |
| return "cpu" | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif torch.backends.mps.is_available(): | |
| torch.mps.empty_cache() | |
| return "mps" | |
| else: | |
| return "cpu" | |
| def get_torch_and_np_dtypes(device, use_bfloat16=False): | |
| if device == "cuda": | |
| torch_dtype = torch.bfloat16 if use_bfloat16 else torch.float16 | |
| np_dtype = np.float16 | |
| elif device == "mps": | |
| torch_dtype = torch.bfloat16 if use_bfloat16 else torch.float16 | |
| np_dtype = np.float16 | |
| else: | |
| torch_dtype = torch.float32 | |
| np_dtype = np.float32 | |
| return torch_dtype, np_dtype | |
| def cuda_version_check(): | |
| if torch.cuda.is_available(): | |
| try: | |
| cuda_runtime = subprocess.check_output(["nvcc", "--version"]).decode() | |
| cuda_version = cuda_runtime.split()[-2] | |
| except Exception: | |
| # Fallback to PyTorch's built-in version if nvcc isn't available | |
| cuda_version = torch.version.cuda | |
| device_name = torch.cuda.get_device_name(0) | |
| return cuda_version, device_name | |
| else: | |
| return None, None |