File size: 4,542 Bytes
db06ffa
 
 
 
 
 
 
 
 
fa2127b
db06ffa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa2127b
db06ffa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa2127b
db06ffa
 
 
 
 
 
 
 
 
 
 
fa2127b
 
 
 
db06ffa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa2127b
db06ffa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""GPU runtime and Hugging Face Spaces status helpers."""

from __future__ import annotations

from dataclasses import dataclass, field
import os
from typing import Any

from zsgdp.gpu.model_server import GPUModelConfig
from zsgdp.gpu.zero_gpu import is_zero_gpu_available
from zsgdp.utils import to_plain_data


@dataclass(slots=True)
class GPURuntimeStatus:
    provider: str
    backend: str
    space_name: str
    gpu_models_target: str
    running_on_huggingface_space: bool
    space_id: str | None
    hardware: str | None
    device: str
    torch_available: bool
    torch_version: str | None = None
    cuda_available: bool = False
    cuda_device_count: int = 0
    cuda_devices: list[str] = field(default_factory=list)
    mps_available: bool = False
    batch_pages: bool = True
    max_batch_size: int = 4
    max_gpu_seconds_per_doc: float = 120.0
    max_vlm_calls_per_doc: int = 30
    configured_models: dict[str, Any] = field(default_factory=dict)
    zero_gpu_available: bool = False
    notes: list[str] = field(default_factory=list)

    def to_dict(self) -> dict[str, Any]:
        return to_plain_data(self)


def collect_gpu_runtime_status(config: dict[str, Any]) -> GPURuntimeStatus:
    gpu = config.get("gpu", {})
    deployment = config.get("deployment", {})
    model_config = GPUModelConfig.from_config(config)
    torch_status = _torch_status()
    running_on_space = bool(os.environ.get("SPACE_ID") or os.environ.get("SPACE_HOST"))
    hardware = os.environ.get("SPACE_HARDWARE") or os.environ.get("HF_SPACE_HARDWARE")
    device = _preferred_device(torch_status)

    zero_gpu = is_zero_gpu_available()
    notes: list[str] = []
    if not running_on_space:
        notes.append("Hugging Face Spaces environment variables were not detected; this looks like a local run.")
    if device == "cpu":
        notes.append("No CUDA or MPS accelerator was detected by PyTorch.")
    elif device == "cuda":
        notes.append("CUDA accelerator detected.")
    elif device == "mps":
        notes.append("Apple MPS accelerator detected.")
    if model_config.provider == "huggingface_spaces" and not hardware:
        notes.append("No Space hardware label was found; set hardware in the Space settings for GPU deployment.")
    if zero_gpu:
        notes.append("ZeroGPU SDK detected — H200 slots will be allocated per @spaces.GPU call.")
    elif running_on_space and (hardware or "").lower().startswith("zero"):
        notes.append("Hardware reports ZeroGPU but the `spaces` SDK was not importable; install via the Space's requirements.txt.")

    return GPURuntimeStatus(
        provider=model_config.provider,
        backend=model_config.backend,
        space_name=model_config.space_name,
        gpu_models_target=str(deployment.get("gpu_models_target", model_config.space_name)),
        running_on_huggingface_space=running_on_space,
        space_id=os.environ.get("SPACE_ID"),
        hardware=hardware,
        device=device,
        batch_pages=bool(gpu.get("batch_pages", True)),
        max_batch_size=model_config.max_batch_size,
        max_gpu_seconds_per_doc=float(gpu.get("max_gpu_seconds_per_doc", 120)),
        max_vlm_calls_per_doc=int(gpu.get("max_vlm_calls_per_doc", 30)),
        configured_models=dict(gpu.get("models", {})),
        zero_gpu_available=zero_gpu,
        notes=notes,
        **torch_status,
    )


def _torch_status() -> dict[str, Any]:
    try:
        import torch  # type: ignore
    except Exception:
        return {
            "torch_available": False,
            "torch_version": None,
            "cuda_available": False,
            "cuda_device_count": 0,
            "cuda_devices": [],
            "mps_available": False,
        }

    cuda_available = bool(torch.cuda.is_available())
    cuda_device_count = int(torch.cuda.device_count()) if cuda_available else 0
    cuda_devices = [torch.cuda.get_device_name(index) for index in range(cuda_device_count)]
    mps_available = bool(getattr(torch.backends, "mps", None) and torch.backends.mps.is_available())
    return {
        "torch_available": True,
        "torch_version": getattr(torch, "__version__", None),
        "cuda_available": cuda_available,
        "cuda_device_count": cuda_device_count,
        "cuda_devices": cuda_devices,
        "mps_available": mps_available,
    }


def _preferred_device(torch_status: dict[str, Any]) -> str:
    if torch_status.get("cuda_available"):
        return "cuda"
    if torch_status.get("mps_available"):
        return "mps"
    return "cpu"