Spaces:
Running
Running
File size: 1,645 Bytes
cb6f1ba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | """
Resolve ``torch.device`` for the installed PyTorch build (CPU-only wheels vs CUDA).
"""
from __future__ import annotations
import logging
from typing import Optional, Union
import torch
logger = logging.getLogger(__name__)
def _cuda_tensor_works() -> bool:
"""True only if allocating on CUDA succeeds (catches CPU-only builds / broken drivers)."""
if not torch.cuda.is_available():
return False
try:
torch.zeros(1, device="cuda")
return True
except (AssertionError, RuntimeError):
return False
def resolve_torch_device(device: Optional[Union[str, torch.device]] = None) -> torch.device:
"""
Default to the best available device, or validate an explicit request.
Falls back to CPU when ``cuda`` is requested or auto-selected but not usable
(e.g. PyTorch installed without CUDA support, or lazy CUDA init failure).
"""
if device is None:
if _cuda_tensor_works():
return torch.device("cuda")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
d = torch.device(device)
if d.type == "cuda":
if not _cuda_tensor_works():
logger.warning("CUDA requested or auto-selected but is not usable; using CPU.")
return torch.device("cpu")
return torch.device("cuda")
if d.type == "mps":
if not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()):
logger.warning("MPS requested but not available; using CPU.")
return torch.device("cpu")
return d
|