Spaces:
Sleeping
Sleeping
| """Device detection and configuration for training. | |
| Supports: Intel XPU (Arc GPU), NVIDIA CUDA, and CPU fallback. | |
| Intel Arc GPUs use PyTorch's XPU backend, which is API-compatible | |
| with CUDA — same .to(device), same autocast, same amp.GradScaler. | |
| """ | |
| import logging | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| def get_device() -> torch.device: | |
| """Auto-detect the best available device. | |
| Priority: XPU (Intel Arc) > CUDA (NVIDIA) > CPU | |
| """ | |
| if hasattr(torch, "xpu") and torch.xpu.is_available(): | |
| device = torch.device("xpu") | |
| name = torch.xpu.get_device_name(0) | |
| mem = torch.xpu.get_device_properties(0).total_memory / 1024**3 | |
| logger.info(f"Using Intel XPU: {name} ({mem:.1f} GB)") | |
| return device | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| name = torch.cuda.get_device_name(0) | |
| mem = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| logger.info(f"Using NVIDIA CUDA: {name} ({mem:.1f} GB)") | |
| return device | |
| logger.info("Using CPU (no GPU detected)") | |
| return torch.device("cpu") | |
| def get_amp_backend(device: torch.device) -> str: | |
| """Get the appropriate autocast backend string for torch.amp. | |
| XPU and CUDA both support 'xpu'/'cuda' respectively. | |
| CPU uses 'cpu' backend (bf16 on supported CPUs). | |
| """ | |
| if device.type == "xpu": | |
| return "xpu" | |
| elif device.type == "cuda": | |
| return "cuda" | |
| return "cpu" | |
| def supports_mixed_precision(device: torch.device) -> bool: | |
| """Check if the device supports fp16 mixed precision.""" | |
| return device.type in ("xpu", "cuda") | |
| def get_dtype(device: torch.device) -> torch.dtype: | |
| """Get the recommended compute dtype for the device. | |
| Intel Arc supports both fp16 and bf16. | |
| NVIDIA T4 supports fp16 only (no bf16). | |
| """ | |
| if device.type == "xpu": | |
| return torch.float16 # Arc 140T supports fp16 well | |
| elif device.type == "cuda": | |
| # Check for bf16 support (Ampere+) | |
| if torch.cuda.get_device_capability()[0] >= 8: | |
| return torch.bfloat16 | |
| return torch.float16 | |
| return torch.float32 | |