import torch import logging try: import onnxruntime as ort except ImportError: ort = None try: import jax except ImportError: jax = None logger = logging.getLogger(__name__) class GPUStatus: def __init__(self, torch_ok=False, jax_ok=False, onnx_ok=False, device_name="Unknown"): self.torch_ok = torch_ok self.jax_ok = jax_ok self.onnx_ok = onnx_ok self.device_name = device_name self.ok = torch_ok or jax_ok or onnx_ok def summary(self): status = "PASS" if self.ok else "FAIL" return (f"GPU Validation: {status}\n" f" - PyTorch ROCm: {'YES' if self.torch_ok else 'NO'}\n" f" - JAX ROCm: {'YES' if self.jax_ok else 'NO'}\n" f" - ONNX ROCm: {'YES' if self.onnx_ok else 'NO'}\n" f" - Device: {self.device_name}") class GPUValidator: def validate(self): torch_ok = torch.cuda.is_available() device_name = torch.cuda.get_device_name(0) if torch_ok else "Unknown" jax_ok = False if jax: try: jax_ok = len(jax.devices()) > 0 except Exception: pass onnx_ok = False if ort: try: onnx_ok = "ROCmExecutionProvider" in ort.get_available_providers() except Exception: pass return GPUStatus(torch_ok, jax_ok, onnx_ok, device_name)