Synesthesia / runtime /gpu_validator.py
Ashiedu's picture
Sync unified workbench
0490201 verified
import torch
import logging
try:
import onnxruntime as ort
except ImportError:
ort = None
try:
import jax
except ImportError:
jax = None
logger = logging.getLogger(__name__)
class GPUStatus:
def __init__(self, torch_ok=False, jax_ok=False, onnx_ok=False, device_name="Unknown"):
self.torch_ok = torch_ok
self.jax_ok = jax_ok
self.onnx_ok = onnx_ok
self.device_name = device_name
self.ok = torch_ok or jax_ok or onnx_ok
def summary(self):
status = "PASS" if self.ok else "FAIL"
return (f"GPU Validation: {status}\n"
f" - PyTorch ROCm: {'YES' if self.torch_ok else 'NO'}\n"
f" - JAX ROCm: {'YES' if self.jax_ok else 'NO'}\n"
f" - ONNX ROCm: {'YES' if self.onnx_ok else 'NO'}\n"
f" - Device: {self.device_name}")
class GPUValidator:
def validate(self):
torch_ok = torch.cuda.is_available()
device_name = torch.cuda.get_device_name(0) if torch_ok else "Unknown"
jax_ok = False
if jax:
try:
jax_ok = len(jax.devices()) > 0
except Exception:
pass
onnx_ok = False
if ort:
try:
onnx_ok = "ROCmExecutionProvider" in ort.get_available_providers()
except Exception:
pass
return GPUStatus(torch_ok, jax_ok, onnx_ok, device_name)