ricl / verify_gpu.py
doanh25032004's picture
Add files using upload-large-folder tool
b71de11 verified
import jax
import sys
print(f"JAX version: {jax.__version__}")
try:
local_devices = jax.local_devices()
print(f"Local devices: {local_devices}")
if any(d.platform == 'gpu' for d in local_devices):
print("\nSUCCESS: GPU detected!")
# Print details about the GPU(s)
for d in local_devices:
if d.platform == 'gpu':
print(f" - {d}")
else:
print("\nFAILURE: No GPU detected. JAX is running on CPU.")
print("Please verify CUDA installation/environment.")
sys.exit(1)
except Exception as e:
print(f"\nError checking devices: {e}")
sys.exit(1)