fea-surrogate / src /utils /device.py
WolfDavid's picture
Upload folder using huggingface_hub
8e5ba9e verified
"""Device detection and configuration for training.
Supports: Intel XPU (Arc GPU), NVIDIA CUDA, and CPU fallback.
Intel Arc GPUs use PyTorch's XPU backend, which is API-compatible
with CUDA — same .to(device), same autocast, same amp.GradScaler.
"""
import logging
import torch
logger = logging.getLogger(__name__)
def get_device() -> torch.device:
"""Auto-detect the best available device.
Priority: XPU (Intel Arc) > CUDA (NVIDIA) > CPU
"""
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
name = torch.xpu.get_device_name(0)
mem = torch.xpu.get_device_properties(0).total_memory / 1024**3
logger.info(f"Using Intel XPU: {name} ({mem:.1f} GB)")
return device
if torch.cuda.is_available():
device = torch.device("cuda")
name = torch.cuda.get_device_name(0)
mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
logger.info(f"Using NVIDIA CUDA: {name} ({mem:.1f} GB)")
return device
logger.info("Using CPU (no GPU detected)")
return torch.device("cpu")
def get_amp_backend(device: torch.device) -> str:
"""Get the appropriate autocast backend string for torch.amp.
XPU and CUDA both support 'xpu'/'cuda' respectively.
CPU uses 'cpu' backend (bf16 on supported CPUs).
"""
if device.type == "xpu":
return "xpu"
elif device.type == "cuda":
return "cuda"
return "cpu"
def supports_mixed_precision(device: torch.device) -> bool:
"""Check if the device supports fp16 mixed precision."""
return device.type in ("xpu", "cuda")
def get_dtype(device: torch.device) -> torch.dtype:
"""Get the recommended compute dtype for the device.
Intel Arc supports both fp16 and bf16.
NVIDIA T4 supports fp16 only (no bf16).
"""
if device.type == "xpu":
return torch.float16 # Arc 140T supports fp16 well
elif device.type == "cuda":
# Check for bf16 support (Ampere+)
if torch.cuda.get_device_capability()[0] >= 8:
return torch.bfloat16
return torch.float16
return torch.float32