File size: 4,614 Bytes
722bda8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | #!/usr/bin/env python3
"""Shared RunPod GPU runtime preflight."""
from __future__ import annotations
import argparse
import subprocess
import sys
from typing import Any
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Validate RunPod GPU runtime.")
parser.add_argument(
"--context",
default="RunPod",
help="Short label printed in error messages.",
)
return parser.parse_args(argv)
def detect_gpu_visibility() -> bool:
try:
result = subprocess.run(
["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
capture_output=True,
text=True,
check=False,
timeout=10,
)
except Exception:
return False
return result.returncode == 0 and bool(result.stdout.strip())
def probe_torch(torch_module: Any | None = None) -> dict[str, Any]:
torch = torch_module
if torch is None:
import torch as torch # type: ignore[no-redef]
cuda_available = bool(torch.cuda.is_available())
device_count = int(torch.cuda.device_count())
probe: dict[str, Any] = {
"torch_version": str(torch.__version__),
"cuda_version": getattr(torch.version, "cuda", None),
"cuda_available": cuda_available,
"device_count": device_count,
"device_name": None,
"total_memory_gb": None,
"capability_tag": None,
"supported_arches": [],
"smoke_error": None,
}
if not cuda_available or device_count <= 0:
return probe
major, minor = torch.cuda.get_device_capability(0)
props = torch.cuda.get_device_properties(0)
supported_arches = []
if hasattr(torch.cuda, "get_arch_list"):
try:
supported_arches = list(torch.cuda.get_arch_list())
except Exception:
supported_arches = []
probe.update(
{
"device_name": torch.cuda.get_device_name(0),
"total_memory_gb": round(props.total_memory / 1e9, 1),
"capability_tag": f"sm_{major}{minor}",
"supported_arches": supported_arches,
}
)
try:
sample = torch.tensor([1.0], device="cuda")
sample = sample + 1
_ = float(sample.sum().item())
torch.cuda.synchronize()
except Exception as exc: # pragma: no cover - exercised via runtime
probe["smoke_error"] = str(exc)
return probe
def evaluate_runtime(*, gpu_visible: bool, probe: dict[str, Any]) -> tuple[bool, str | None]:
if gpu_visible and not probe["cuda_available"]:
return False, "GPU is visible to nvidia-smi but PyTorch CUDA is unavailable"
if probe["cuda_available"] and probe["capability_tag"] and probe["supported_arches"]:
if probe["capability_tag"] not in probe["supported_arches"]:
supported = " ".join(probe["supported_arches"])
return (
False,
f"GPU capability {probe['capability_tag']} is not supported by this PyTorch build "
f"(supported: {supported})",
)
if probe["smoke_error"]:
return False, f"CUDA smoke test failed: {probe['smoke_error']}"
return True, None
def print_probe(*, gpu_visible: bool, probe: dict[str, Any]) -> None:
print(f" torch: {probe['torch_version']}")
print(f" torch.version.cuda: {probe['cuda_version']}")
print(f" CUDA available: {probe['cuda_available']}")
print(f" device_count: {probe['device_count']}")
print(f" nvidia-smi GPU visible: {gpu_visible}")
if probe["device_name"]:
print(
" GPU: "
f"{probe['device_name']}, VRAM: {probe['total_memory_gb']} GB, capability: {probe['capability_tag']}"
)
if probe["supported_arches"]:
print(f" PyTorch CUDA arch list: {' '.join(probe['supported_arches'])}")
if probe["smoke_error"]:
print(f" CUDA smoke error: {probe['smoke_error']}")
def main(argv: list[str] | None = None) -> int:
args = parse_args(argv)
gpu_visible = detect_gpu_visibility()
probe = probe_torch()
print_probe(gpu_visible=gpu_visible, probe=probe)
ok, message = evaluate_runtime(gpu_visible=gpu_visible, probe=probe)
if not ok:
print(f" ERROR: {message}")
print(
f" {args.context} GPU runtime is not usable with the current PyTorch/CUDA stack. "
"Use a supported NVIDIA architecture or a newer compatible template."
)
return 1
return 0
if __name__ == "__main__":
raise SystemExit(main())
|