"""Device detection and configuration for training. Supports: Intel XPU (Arc GPU), NVIDIA CUDA, and CPU fallback. Intel Arc GPUs use PyTorch's XPU backend, which is API-compatible with CUDA — same .to(device), same autocast, same amp.GradScaler. """ import logging import torch logger = logging.getLogger(__name__) def get_device() -> torch.device: """Auto-detect the best available device. Priority: XPU (Intel Arc) > CUDA (NVIDIA) > CPU """ if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device("xpu") name = torch.xpu.get_device_name(0) mem = torch.xpu.get_device_properties(0).total_memory / 1024**3 logger.info(f"Using Intel XPU: {name} ({mem:.1f} GB)") return device if torch.cuda.is_available(): device = torch.device("cuda") name = torch.cuda.get_device_name(0) mem = torch.cuda.get_device_properties(0).total_memory / 1024**3 logger.info(f"Using NVIDIA CUDA: {name} ({mem:.1f} GB)") return device logger.info("Using CPU (no GPU detected)") return torch.device("cpu") def get_amp_backend(device: torch.device) -> str: """Get the appropriate autocast backend string for torch.amp. XPU and CUDA both support 'xpu'/'cuda' respectively. CPU uses 'cpu' backend (bf16 on supported CPUs). """ if device.type == "xpu": return "xpu" elif device.type == "cuda": return "cuda" return "cpu" def supports_mixed_precision(device: torch.device) -> bool: """Check if the device supports fp16 mixed precision.""" return device.type in ("xpu", "cuda") def get_dtype(device: torch.device) -> torch.dtype: """Get the recommended compute dtype for the device. Intel Arc supports both fp16 and bf16. NVIDIA T4 supports fp16 only (no bf16). """ if device.type == "xpu": return torch.float16 # Arc 140T supports fp16 well elif device.type == "cuda": # Check for bf16 support (Ampere+) if torch.cuda.get_device_capability()[0] >= 8: return torch.bfloat16 return torch.float16 return torch.float32