Spaces:
Runtime error
Runtime error
| import torch | |
| import logging | |
| try: | |
| import onnxruntime as ort | |
| except ImportError: | |
| ort = None | |
| try: | |
| import jax | |
| except ImportError: | |
| jax = None | |
| logger = logging.getLogger(__name__) | |
| class GPUStatus: | |
| def __init__(self, torch_ok=False, jax_ok=False, onnx_ok=False, device_name="Unknown"): | |
| self.torch_ok = torch_ok | |
| self.jax_ok = jax_ok | |
| self.onnx_ok = onnx_ok | |
| self.device_name = device_name | |
| self.ok = torch_ok or jax_ok or onnx_ok | |
| def summary(self): | |
| status = "PASS" if self.ok else "FAIL" | |
| return (f"GPU Validation: {status}\n" | |
| f" - PyTorch ROCm: {'YES' if self.torch_ok else 'NO'}\n" | |
| f" - JAX ROCm: {'YES' if self.jax_ok else 'NO'}\n" | |
| f" - ONNX ROCm: {'YES' if self.onnx_ok else 'NO'}\n" | |
| f" - Device: {self.device_name}") | |
| class GPUValidator: | |
| def validate(self): | |
| torch_ok = torch.cuda.is_available() | |
| device_name = torch.cuda.get_device_name(0) if torch_ok else "Unknown" | |
| jax_ok = False | |
| if jax: | |
| try: | |
| jax_ok = len(jax.devices()) > 0 | |
| except Exception: | |
| pass | |
| onnx_ok = False | |
| if ort: | |
| try: | |
| onnx_ok = "ROCmExecutionProvider" in ort.get_available_providers() | |
| except Exception: | |
| pass | |
| return GPUStatus(torch_ok, jax_ok, onnx_ok, device_name) | |