Spaces:
Runtime error
Runtime error
File size: 1,440 Bytes
0490201 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | import torch
import logging
try:
import onnxruntime as ort
except ImportError:
ort = None
try:
import jax
except ImportError:
jax = None
logger = logging.getLogger(__name__)
class GPUStatus:
def __init__(self, torch_ok=False, jax_ok=False, onnx_ok=False, device_name="Unknown"):
self.torch_ok = torch_ok
self.jax_ok = jax_ok
self.onnx_ok = onnx_ok
self.device_name = device_name
self.ok = torch_ok or jax_ok or onnx_ok
def summary(self):
status = "PASS" if self.ok else "FAIL"
return (f"GPU Validation: {status}\n"
f" - PyTorch ROCm: {'YES' if self.torch_ok else 'NO'}\n"
f" - JAX ROCm: {'YES' if self.jax_ok else 'NO'}\n"
f" - ONNX ROCm: {'YES' if self.onnx_ok else 'NO'}\n"
f" - Device: {self.device_name}")
class GPUValidator:
def validate(self):
torch_ok = torch.cuda.is_available()
device_name = torch.cuda.get_device_name(0) if torch_ok else "Unknown"
jax_ok = False
if jax:
try:
jax_ok = len(jax.devices()) > 0
except Exception:
pass
onnx_ok = False
if ort:
try:
onnx_ok = "ROCmExecutionProvider" in ort.get_available_providers()
except Exception:
pass
return GPUStatus(torch_ok, jax_ok, onnx_ok, device_name)
|