Spaces:
Running
Running
| """ | |
| Resolve ``torch.device`` for the installed PyTorch build (CPU-only wheels vs CUDA). | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from typing import Optional, Union | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| def _cuda_tensor_works() -> bool: | |
| """True only if allocating on CUDA succeeds (catches CPU-only builds / broken drivers).""" | |
| if not torch.cuda.is_available(): | |
| return False | |
| try: | |
| torch.zeros(1, device="cuda") | |
| return True | |
| except (AssertionError, RuntimeError): | |
| return False | |
| def resolve_torch_device(device: Optional[Union[str, torch.device]] = None) -> torch.device: | |
| """ | |
| Default to the best available device, or validate an explicit request. | |
| Falls back to CPU when ``cuda`` is requested or auto-selected but not usable | |
| (e.g. PyTorch installed without CUDA support, or lazy CUDA init failure). | |
| """ | |
| if device is None: | |
| if _cuda_tensor_works(): | |
| return torch.device("cuda") | |
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| d = torch.device(device) | |
| if d.type == "cuda": | |
| if not _cuda_tensor_works(): | |
| logger.warning("CUDA requested or auto-selected but is not usable; using CPU.") | |
| return torch.device("cpu") | |
| return torch.device("cuda") | |
| if d.type == "mps": | |
| if not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()): | |
| logger.warning("MPS requested but not available; using CPU.") | |
| return torch.device("cpu") | |
| return d | |