|
|
import torch
|
|
|
|
|
|
|
|
|
def show_device_list(backend: str) -> int:
|
|
|
"""
|
|
|
Displays a list of all detected devices for a given PyTorch backend.
|
|
|
|
|
|
Args:
|
|
|
backend: The name of the device backend module (e.g., "cuda", "xpu").
|
|
|
|
|
|
Returns:
|
|
|
The number of devices found if the backend is usable, otherwise 0.
|
|
|
"""
|
|
|
|
|
|
backend_upper = backend.upper()
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
|
backend_module = getattr(torch, backend)
|
|
|
|
|
|
|
|
|
brand_name = backend_upper
|
|
|
if backend == "cuda":
|
|
|
|
|
|
|
|
|
|
|
|
brand_name = "NVIDIA CUDA / AMD ROCm"
|
|
|
elif backend == "xpu":
|
|
|
brand_name = "Intel XPU"
|
|
|
elif backend == "mps":
|
|
|
brand_name = "Apple MPS"
|
|
|
|
|
|
if not backend_module.is_available():
|
|
|
print(f"PyTorch: No devices found for {brand_name} backend.")
|
|
|
return 0
|
|
|
|
|
|
print(f"PyTorch: {brand_name} is available!")
|
|
|
|
|
|
|
|
|
device_count = backend_module.device_count()
|
|
|
print(f" * Number of {backend_upper} devices found: {device_count}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if backend != "mps":
|
|
|
for i in range(device_count):
|
|
|
device_name = backend_module.get_device_name(i)
|
|
|
print(f' * Device {i}: "{device_name}"')
|
|
|
|
|
|
return device_count
|
|
|
|
|
|
except AttributeError:
|
|
|
print(
|
|
|
f'Error: The PyTorch backend "{backend}" does not exist, or is missing the necessary APIs (is_available, device_count, get_device_name).'
|
|
|
)
|
|
|
except Exception as e:
|
|
|
print(f"Error: {e}")
|
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
def check_torch_devices() -> None:
|
|
|
"""
|
|
|
Checks for the availability of various PyTorch hardware acceleration
|
|
|
platforms and prints information about the discovered devices.
|
|
|
"""
|
|
|
|
|
|
print("Scanning for PyTorch hardware acceleration devices...\n")
|
|
|
|
|
|
device_count = 0
|
|
|
|
|
|
device_count += show_device_list("cuda")
|
|
|
device_count += show_device_list("xpu")
|
|
|
device_count += show_device_list("mps")
|
|
|
|
|
|
if device_count > 0:
|
|
|
print("\nHardware acceleration detected. Your system is ready!")
|
|
|
else:
|
|
|
print("\nNo hardware acceleration detected. Running in CPU mode.")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
check_torch_devices()
|
|
|
|