tokev's picture
Add files using upload-large-folder tool
248a619 verified
from __future__ import annotations
import torch
def resolve_torch_device(requested_device: str | None = None) -> torch.device:
if requested_device:
device = torch.device(requested_device)
if device.type == "cuda" and not torch.cuda.is_available():
raise RuntimeError("CUDA was requested but is not available in this PyTorch build.")
if device.type == "mps" and not mps_is_available():
raise RuntimeError("MPS was requested but is not available in this PyTorch build.")
return device
if torch.cuda.is_available():
return torch.device("cuda")
if mps_is_available():
return torch.device("mps")
return torch.device("cpu")
def mps_is_available() -> bool:
return bool(
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and torch.backends.mps.is_built()
)
def configure_torch_runtime(device: torch.device) -> None:
if device.type == "cuda":
torch.backends.cudnn.benchmark = True
if hasattr(torch.backends.cuda.matmul, "allow_tf32"):
torch.backends.cuda.matmul.allow_tf32 = True
if hasattr(torch.backends.cudnn, "allow_tf32"):
torch.backends.cudnn.allow_tf32 = True