File size: 2,155 Bytes
8e5ba9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Device detection and configuration for training.

Supports: Intel XPU (Arc GPU), NVIDIA CUDA, and CPU fallback.
Intel Arc GPUs use PyTorch's XPU backend, which is API-compatible
with CUDA — same .to(device), same autocast, same amp.GradScaler.
"""

import logging

import torch

logger = logging.getLogger(__name__)


def get_device() -> torch.device:
    """Auto-detect the best available device.

    Priority: XPU (Intel Arc) > CUDA (NVIDIA) > CPU
    """
    if hasattr(torch, "xpu") and torch.xpu.is_available():
        device = torch.device("xpu")
        name = torch.xpu.get_device_name(0)
        mem = torch.xpu.get_device_properties(0).total_memory / 1024**3
        logger.info(f"Using Intel XPU: {name} ({mem:.1f} GB)")
        return device

    if torch.cuda.is_available():
        device = torch.device("cuda")
        name = torch.cuda.get_device_name(0)
        mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
        logger.info(f"Using NVIDIA CUDA: {name} ({mem:.1f} GB)")
        return device

    logger.info("Using CPU (no GPU detected)")
    return torch.device("cpu")


def get_amp_backend(device: torch.device) -> str:
    """Get the appropriate autocast backend string for torch.amp.

    XPU and CUDA both support 'xpu'/'cuda' respectively.
    CPU uses 'cpu' backend (bf16 on supported CPUs).
    """
    if device.type == "xpu":
        return "xpu"
    elif device.type == "cuda":
        return "cuda"
    return "cpu"


def supports_mixed_precision(device: torch.device) -> bool:
    """Check if the device supports fp16 mixed precision."""
    return device.type in ("xpu", "cuda")


def get_dtype(device: torch.device) -> torch.dtype:
    """Get the recommended compute dtype for the device.

    Intel Arc supports both fp16 and bf16.
    NVIDIA T4 supports fp16 only (no bf16).
    """
    if device.type == "xpu":
        return torch.float16  # Arc 140T supports fp16 well
    elif device.type == "cuda":
        # Check for bf16 support (Ampere+)
        if torch.cuda.get_device_capability()[0] >= 8:
            return torch.bfloat16
        return torch.float16
    return torch.float32