File size: 5,857 Bytes
b6f9c90 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | # -*- coding: utf-8 -*-
"""
设备检测模块 - 自动检测并选择最佳计算设备
支持: CUDA (NVIDIA / AMD ROCm), XPU (Intel Arc via IPEX), DirectML, MPS (Apple), CPU
"""
import torch
def _has_xpu() -> bool:
"""检测 Intel XPU (需要 intel_extension_for_pytorch)"""
try:
import intel_extension_for_pytorch # noqa: F401
return hasattr(torch, "xpu") and torch.xpu.is_available()
except ImportError:
return False
def _has_directml() -> bool:
"""检测 DirectML (AMD/Intel on Windows)"""
try:
import torch_directml # noqa: F401
return True
except ImportError:
return False
def _has_mps() -> bool:
"""检测 Apple MPS"""
if not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available():
return False
try:
torch.zeros(1).to(torch.device("mps"))
return True
except Exception:
return False
def _is_rocm() -> bool:
"""检测当前 PyTorch 是否为 ROCm 构建 (AMD GPU)"""
return hasattr(torch.version, "hip") and torch.version.hip is not None
def get_device(preferred: str = "cuda") -> torch.device:
"""
获取计算设备,按优先级自动回退
Args:
preferred: 首选设备 ("cuda", "xpu", "directml", "mps", "cpu")
Returns:
torch.device: 可用的计算设备
"""
p = preferred.lower().strip()
# 精确匹配请求
if p in ("cuda", "cuda:0") and torch.cuda.is_available():
return torch.device("cuda")
if p in ("xpu", "xpu:0") and _has_xpu():
return torch.device("xpu")
if (p == "directml" or p.startswith("privateuseone")) and _has_directml():
import torch_directml
return torch_directml.device(torch_directml.default_device())
if p == "mps" and _has_mps():
return torch.device("mps")
if p == "cpu":
return torch.device("cpu")
# 自动检测: CUDA (含 ROCm) > XPU > DirectML > MPS > CPU
if torch.cuda.is_available():
return torch.device("cuda")
if _has_xpu():
return torch.device("xpu")
if _has_directml():
import torch_directml
return torch_directml.device(torch_directml.default_device())
if _has_mps():
return torch.device("mps")
return torch.device("cpu")
def supports_fp16(device: torch.device) -> bool:
"""判断设备是否支持 FP16 推理"""
dtype = str(device.type) if hasattr(device, "type") else str(device)
if dtype == "cuda":
return True # CUDA (含 ROCm) 均支持
if dtype == "xpu":
return True
# DirectML / MPS / CPU 不稳定,默认关闭
return False
def empty_device_cache(device: torch.device = None):
"""清理设备显存缓存(设备无关)"""
if device is not None:
dtype = str(device.type) if hasattr(device, "type") else str(device)
else:
dtype = None
if (dtype is None or dtype == "cuda") and torch.cuda.is_available():
torch.cuda.empty_cache()
if (dtype is None or dtype == "xpu") and _has_xpu():
torch.xpu.empty_cache()
if (dtype is None or dtype == "mps") and _has_mps():
if hasattr(torch.mps, "empty_cache"):
torch.mps.empty_cache()
def get_device_info() -> dict:
"""获取设备详细信息"""
info = {
"backends": [],
"current_device": "cpu",
"devices": []
}
# CUDA (NVIDIA 或 AMD ROCm)
if torch.cuda.is_available():
backend = "ROCm (AMD)" if _is_rocm() else "CUDA (NVIDIA)"
info["backends"].append(backend)
info["current_device"] = "cuda"
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
info["devices"].append({
"index": i,
"backend": backend,
"name": props.name,
"total_memory_gb": round(props.total_memory / (1024**3), 2),
})
# Intel XPU
if _has_xpu():
info["backends"].append("XPU (Intel)")
if not info["devices"]:
info["current_device"] = "xpu"
for i in range(torch.xpu.device_count()):
props = torch.xpu.get_device_properties(i)
info["devices"].append({
"index": i,
"backend": "XPU (Intel)",
"name": props.name,
"total_memory_gb": round(props.total_memory / (1024**3), 2),
})
# DirectML
if _has_directml():
import torch_directml
info["backends"].append("DirectML")
if not info["devices"]:
info["current_device"] = "directml"
info["devices"].append({
"index": 0,
"backend": "DirectML",
"name": torch_directml.device_name(0),
"total_memory_gb": None,
})
# MPS
if _has_mps():
info["backends"].append("MPS (Apple)")
if not info["devices"]:
info["current_device"] = "mps"
if not info["backends"]:
info["backends"].append("CPU")
return info
def print_device_info():
"""打印设备信息到控制台"""
info = get_device_info()
print("=" * 50)
print("设备信息")
print("=" * 50)
print(f"可用后端: {', '.join(info['backends'])}")
print(f"当前设备: {info['current_device']}")
for dev in info["devices"]:
mem = f"{dev['total_memory_gb']} GB" if dev.get("total_memory_gb") else "N/A"
print(f" [{dev['index']}] {dev['name']} ({dev['backend']}) - 显存: {mem}")
if not info["devices"]:
print(" 无 GPU 设备,将使用 CPU 进行推理")
print("=" * 50)
|