import jax import sys print(f"JAX version: {jax.__version__}") try: local_devices = jax.local_devices() print(f"Local devices: {local_devices}") if any(d.platform == 'gpu' for d in local_devices): print("\nSUCCESS: GPU detected!") # Print details about the GPU(s) for d in local_devices: if d.platform == 'gpu': print(f" - {d}") else: print("\nFAILURE: No GPU detected. JAX is running on CPU.") print("Please verify CUDA installation/environment.") sys.exit(1) except Exception as e: print(f"\nError checking devices: {e}") sys.exit(1)