File size: 5,857 Bytes
b6f9c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# -*- coding: utf-8 -*-
"""

设备检测模块 - 自动检测并选择最佳计算设备

支持: CUDA (NVIDIA / AMD ROCm), XPU (Intel Arc via IPEX), DirectML, MPS (Apple), CPU

"""
import torch


def _has_xpu() -> bool:
    """检测 Intel XPU (需要 intel_extension_for_pytorch)"""
    try:
        import intel_extension_for_pytorch  # noqa: F401
        return hasattr(torch, "xpu") and torch.xpu.is_available()
    except ImportError:
        return False


def _has_directml() -> bool:
    """检测 DirectML (AMD/Intel on Windows)"""
    try:
        import torch_directml  # noqa: F401
        return True
    except ImportError:
        return False


def _has_mps() -> bool:
    """检测 Apple MPS"""
    if not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available():
        return False
    try:
        torch.zeros(1).to(torch.device("mps"))
        return True
    except Exception:
        return False


def _is_rocm() -> bool:
    """检测当前 PyTorch 是否为 ROCm 构建 (AMD GPU)"""
    return hasattr(torch.version, "hip") and torch.version.hip is not None


def get_device(preferred: str = "cuda") -> torch.device:
    """

    获取计算设备,按优先级自动回退



    Args:

        preferred: 首选设备 ("cuda", "xpu", "directml", "mps", "cpu")



    Returns:

        torch.device: 可用的计算设备

    """
    p = preferred.lower().strip()

    # 精确匹配请求
    if p in ("cuda", "cuda:0") and torch.cuda.is_available():
        return torch.device("cuda")
    if p in ("xpu", "xpu:0") and _has_xpu():
        return torch.device("xpu")
    if (p == "directml" or p.startswith("privateuseone")) and _has_directml():
        import torch_directml
        return torch_directml.device(torch_directml.default_device())
    if p == "mps" and _has_mps():
        return torch.device("mps")
    if p == "cpu":
        return torch.device("cpu")

    # 自动检测: CUDA (含 ROCm) > XPU > DirectML > MPS > CPU
    if torch.cuda.is_available():
        return torch.device("cuda")
    if _has_xpu():
        return torch.device("xpu")
    if _has_directml():
        import torch_directml
        return torch_directml.device(torch_directml.default_device())
    if _has_mps():
        return torch.device("mps")
    return torch.device("cpu")


def supports_fp16(device: torch.device) -> bool:
    """判断设备是否支持 FP16 推理"""
    dtype = str(device.type) if hasattr(device, "type") else str(device)
    if dtype == "cuda":
        return True  # CUDA (含 ROCm) 均支持
    if dtype == "xpu":
        return True
    # DirectML / MPS / CPU 不稳定,默认关闭
    return False


def empty_device_cache(device: torch.device = None):
    """清理设备显存缓存(设备无关)"""
    if device is not None:
        dtype = str(device.type) if hasattr(device, "type") else str(device)
    else:
        dtype = None

    if (dtype is None or dtype == "cuda") and torch.cuda.is_available():
        torch.cuda.empty_cache()
    if (dtype is None or dtype == "xpu") and _has_xpu():
        torch.xpu.empty_cache()
    if (dtype is None or dtype == "mps") and _has_mps():
        if hasattr(torch.mps, "empty_cache"):
            torch.mps.empty_cache()


def get_device_info() -> dict:
    """获取设备详细信息"""
    info = {
        "backends": [],
        "current_device": "cpu",
        "devices": []
    }

    # CUDA (NVIDIA 或 AMD ROCm)
    if torch.cuda.is_available():
        backend = "ROCm (AMD)" if _is_rocm() else "CUDA (NVIDIA)"
        info["backends"].append(backend)
        info["current_device"] = "cuda"
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            info["devices"].append({
                "index": i,
                "backend": backend,
                "name": props.name,
                "total_memory_gb": round(props.total_memory / (1024**3), 2),
            })

    # Intel XPU
    if _has_xpu():
        info["backends"].append("XPU (Intel)")
        if not info["devices"]:
            info["current_device"] = "xpu"
        for i in range(torch.xpu.device_count()):
            props = torch.xpu.get_device_properties(i)
            info["devices"].append({
                "index": i,
                "backend": "XPU (Intel)",
                "name": props.name,
                "total_memory_gb": round(props.total_memory / (1024**3), 2),
            })

    # DirectML
    if _has_directml():
        import torch_directml
        info["backends"].append("DirectML")
        if not info["devices"]:
            info["current_device"] = "directml"
        info["devices"].append({
            "index": 0,
            "backend": "DirectML",
            "name": torch_directml.device_name(0),
            "total_memory_gb": None,
        })

    # MPS
    if _has_mps():
        info["backends"].append("MPS (Apple)")
        if not info["devices"]:
            info["current_device"] = "mps"

    if not info["backends"]:
        info["backends"].append("CPU")

    return info


def print_device_info():
    """打印设备信息到控制台"""
    info = get_device_info()

    print("=" * 50)
    print("设备信息")
    print("=" * 50)
    print(f"可用后端: {', '.join(info['backends'])}")
    print(f"当前设备: {info['current_device']}")

    for dev in info["devices"]:
        mem = f"{dev['total_memory_gb']} GB" if dev.get("total_memory_gb") else "N/A"
        print(f"  [{dev['index']}] {dev['name']} ({dev['backend']}) - 显存: {mem}")

    if not info["devices"]:
        print("  无 GPU 设备,将使用 CPU 进行推理")

    print("=" * 50)