protloc-ai / src /utils /device.py
Tanoj22
Initial commit: ProtLoc-AI project setup and core app
cb6f1ba
"""
Resolve ``torch.device`` for the installed PyTorch build (CPU-only wheels vs CUDA).
"""
from __future__ import annotations
import logging
from typing import Optional, Union
import torch
logger = logging.getLogger(__name__)
def _cuda_tensor_works() -> bool:
"""True only if allocating on CUDA succeeds (catches CPU-only builds / broken drivers)."""
if not torch.cuda.is_available():
return False
try:
torch.zeros(1, device="cuda")
return True
except (AssertionError, RuntimeError):
return False
def resolve_torch_device(device: Optional[Union[str, torch.device]] = None) -> torch.device:
"""
Default to the best available device, or validate an explicit request.
Falls back to CPU when ``cuda`` is requested or auto-selected but not usable
(e.g. PyTorch installed without CUDA support, or lazy CUDA init failure).
"""
if device is None:
if _cuda_tensor_works():
return torch.device("cuda")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
d = torch.device(device)
if d.type == "cuda":
if not _cuda_tensor_works():
logger.warning("CUDA requested or auto-selected but is not usable; using CPU.")
return torch.device("cpu")
return torch.device("cuda")
if d.type == "mps":
if not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()):
logger.warning("MPS requested but not available; using CPU.")
return torch.device("cpu")
return d