File size: 4,614 Bytes
722bda8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#!/usr/bin/env python3
"""Shared RunPod GPU runtime preflight."""

from __future__ import annotations

import argparse
import subprocess
import sys
from typing import Any


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Validate RunPod GPU runtime.")
    parser.add_argument(
        "--context",
        default="RunPod",
        help="Short label printed in error messages.",
    )
    return parser.parse_args(argv)


def detect_gpu_visibility() -> bool:
    try:
        result = subprocess.run(
            ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
            capture_output=True,
            text=True,
            check=False,
            timeout=10,
        )
    except Exception:
        return False
    return result.returncode == 0 and bool(result.stdout.strip())


def probe_torch(torch_module: Any | None = None) -> dict[str, Any]:
    torch = torch_module
    if torch is None:
        import torch as torch  # type: ignore[no-redef]

    cuda_available = bool(torch.cuda.is_available())
    device_count = int(torch.cuda.device_count())
    probe: dict[str, Any] = {
        "torch_version": str(torch.__version__),
        "cuda_version": getattr(torch.version, "cuda", None),
        "cuda_available": cuda_available,
        "device_count": device_count,
        "device_name": None,
        "total_memory_gb": None,
        "capability_tag": None,
        "supported_arches": [],
        "smoke_error": None,
    }

    if not cuda_available or device_count <= 0:
        return probe

    major, minor = torch.cuda.get_device_capability(0)
    props = torch.cuda.get_device_properties(0)
    supported_arches = []
    if hasattr(torch.cuda, "get_arch_list"):
        try:
            supported_arches = list(torch.cuda.get_arch_list())
        except Exception:
            supported_arches = []

    probe.update(
        {
            "device_name": torch.cuda.get_device_name(0),
            "total_memory_gb": round(props.total_memory / 1e9, 1),
            "capability_tag": f"sm_{major}{minor}",
            "supported_arches": supported_arches,
        }
    )

    try:
        sample = torch.tensor([1.0], device="cuda")
        sample = sample + 1
        _ = float(sample.sum().item())
        torch.cuda.synchronize()
    except Exception as exc:  # pragma: no cover - exercised via runtime
        probe["smoke_error"] = str(exc)

    return probe


def evaluate_runtime(*, gpu_visible: bool, probe: dict[str, Any]) -> tuple[bool, str | None]:
    if gpu_visible and not probe["cuda_available"]:
        return False, "GPU is visible to nvidia-smi but PyTorch CUDA is unavailable"

    if probe["cuda_available"] and probe["capability_tag"] and probe["supported_arches"]:
        if probe["capability_tag"] not in probe["supported_arches"]:
            supported = " ".join(probe["supported_arches"])
            return (
                False,
                f"GPU capability {probe['capability_tag']} is not supported by this PyTorch build "
                f"(supported: {supported})",
            )

    if probe["smoke_error"]:
        return False, f"CUDA smoke test failed: {probe['smoke_error']}"

    return True, None


def print_probe(*, gpu_visible: bool, probe: dict[str, Any]) -> None:
    print(f"  torch: {probe['torch_version']}")
    print(f"  torch.version.cuda: {probe['cuda_version']}")
    print(f"  CUDA available: {probe['cuda_available']}")
    print(f"  device_count: {probe['device_count']}")
    print(f"  nvidia-smi GPU visible: {gpu_visible}")
    if probe["device_name"]:
        print(
            "  GPU: "
            f"{probe['device_name']}, VRAM: {probe['total_memory_gb']} GB, capability: {probe['capability_tag']}"
        )
    if probe["supported_arches"]:
        print(f"  PyTorch CUDA arch list: {' '.join(probe['supported_arches'])}")
    if probe["smoke_error"]:
        print(f"  CUDA smoke error: {probe['smoke_error']}")


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)
    gpu_visible = detect_gpu_visibility()
    probe = probe_torch()
    print_probe(gpu_visible=gpu_visible, probe=probe)
    ok, message = evaluate_runtime(gpu_visible=gpu_visible, probe=probe)
    if not ok:
        print(f"  ERROR: {message}")
        print(
            f"  {args.context} GPU runtime is not usable with the current PyTorch/CUDA stack. "
            "Use a supported NVIDIA architecture or a newer compatible template."
        )
        return 1
    return 0


if __name__ == "__main__":
    raise SystemExit(main())