| import jax | |
| import sys | |
| print(f"JAX version: {jax.__version__}") | |
| try: | |
| local_devices = jax.local_devices() | |
| print(f"Local devices: {local_devices}") | |
| if any(d.platform == 'gpu' for d in local_devices): | |
| print("\nSUCCESS: GPU detected!") | |
| # Print details about the GPU(s) | |
| for d in local_devices: | |
| if d.platform == 'gpu': | |
| print(f" - {d}") | |
| else: | |
| print("\nFAILURE: No GPU detected. JAX is running on CPU.") | |
| print("Please verify CUDA installation/environment.") | |
| sys.exit(1) | |
| except Exception as e: | |
| print(f"\nError checking devices: {e}") | |
| sys.exit(1) | |