Spaces:
Running
Running
| from __future__ import annotations | |
| import torch | |
| def resolve_torch_device(requested_device: str | None = None) -> torch.device: | |
| if requested_device: | |
| device = torch.device(requested_device) | |
| if device.type == "cuda" and not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA was requested but is not available in this PyTorch build.") | |
| if device.type == "mps" and not mps_is_available(): | |
| raise RuntimeError("MPS was requested but is not available in this PyTorch build.") | |
| return device | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| if mps_is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| def mps_is_available() -> bool: | |
| return bool( | |
| hasattr(torch.backends, "mps") | |
| and torch.backends.mps.is_available() | |
| and torch.backends.mps.is_built() | |
| ) | |
| def configure_torch_runtime(device: torch.device) -> None: | |
| if device.type == "cuda": | |
| torch.backends.cudnn.benchmark = True | |
| if hasattr(torch.backends.cuda.matmul, "allow_tf32"): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| if hasattr(torch.backends.cudnn, "allow_tf32"): | |
| torch.backends.cudnn.allow_tf32 = True | |