File size: 653 Bytes
b71de11 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | import jax
import sys
print(f"JAX version: {jax.__version__}")
try:
local_devices = jax.local_devices()
print(f"Local devices: {local_devices}")
if any(d.platform == 'gpu' for d in local_devices):
print("\nSUCCESS: GPU detected!")
# Print details about the GPU(s)
for d in local_devices:
if d.platform == 'gpu':
print(f" - {d}")
else:
print("\nFAILURE: No GPU detected. JAX is running on CPU.")
print("Please verify CUDA installation/environment.")
sys.exit(1)
except Exception as e:
print(f"\nError checking devices: {e}")
sys.exit(1)
|