File size: 11,919 Bytes
1834e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
import os
import threading
import warnings
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from trackio.run import Run

pynvml: Any = None
PYNVML_AVAILABLE = False
_nvml_initialized = False
_nvml_lock = threading.Lock()
_energy_baseline: dict[int, float] = {}


def _ensure_pynvml():
    global PYNVML_AVAILABLE, pynvml
    if PYNVML_AVAILABLE:
        return pynvml
    try:
        import pynvml as _pynvml

        pynvml = _pynvml
        PYNVML_AVAILABLE = True
        return pynvml
    except ImportError:
        raise ImportError(
            "nvidia-ml-py is required for GPU monitoring. "
            "Install it with: pip install nvidia-ml-py"
        )


def _init_nvml() -> bool:
    global _nvml_initialized
    with _nvml_lock:
        if _nvml_initialized:
            return True
        try:
            nvml = _ensure_pynvml()
            nvml.nvmlInit()
            _nvml_initialized = True
            return True
        except Exception:
            return False


def _shutdown_nvml():
    global _nvml_initialized
    with _nvml_lock:
        if _nvml_initialized and pynvml is not None:
            try:
                pynvml.nvmlShutdown()
            except Exception:
                pass
            _nvml_initialized = False


def get_gpu_count() -> tuple[int, list[int]]:
    """
    Get the number of GPUs visible to this process and their physical indices.
    Respects CUDA_VISIBLE_DEVICES environment variable.

    Returns:
        Tuple of (count, physical_indices) where:
        - count: Number of visible GPUs
        - physical_indices: List mapping logical index to physical GPU index.
          e.g., if CUDA_VISIBLE_DEVICES=2,3 returns (2, [2, 3])
          meaning logical GPU 0 = physical GPU 2, logical GPU 1 = physical GPU 3
    """
    if not _init_nvml():
        return 0, []

    cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES")
    if cuda_visible is not None and cuda_visible.strip():
        try:
            indices = [int(x.strip()) for x in cuda_visible.split(",") if x.strip()]
            return len(indices), indices
        except ValueError:
            pass

    try:
        total = pynvml.nvmlDeviceGetCount()
        return total, list(range(total))
    except Exception:
        return 0, []


def gpu_available() -> bool:
    """
    Check if GPU monitoring is available.

    Returns True if nvidia-ml-py is installed and at least one NVIDIA GPU is detected.
    This is used for auto-detection of GPU logging.
    """
    try:
        _ensure_pynvml()
        count, _ = get_gpu_count()
        return count > 0
    except ImportError:
        return False
    except Exception:
        return False


def reset_energy_baseline():
    """Reset the energy baseline for all GPUs. Called when a new run starts."""
    global _energy_baseline
    _energy_baseline = {}


def collect_gpu_metrics(device: int | None = None) -> dict:
    """
    Collect GPU metrics for visible GPUs.

    Args:
        device: CUDA device index to collect metrics from. If None, collects
                from all GPUs visible to this process (respects CUDA_VISIBLE_DEVICES).
                The device index is the logical CUDA index (0, 1, 2...), not the
                physical GPU index.

    Returns:
        Dictionary of GPU metrics. Keys use logical device indices (gpu/0/, gpu/1/, etc.)
        which correspond to CUDA device indices, not physical GPU indices.
    """
    if not _init_nvml():
        return {}

    gpu_count, visible_gpus = get_gpu_count()
    if gpu_count == 0:
        return {}

    if device is not None:
        if device < 0 or device >= gpu_count:
            return {}
        gpu_indices = [(device, visible_gpus[device])]
    else:
        gpu_indices = list(enumerate(visible_gpus))

    metrics = {}
    total_util = 0.0
    total_mem_used_gib = 0.0
    total_power = 0.0
    max_temp = 0.0
    valid_util_count = 0

    for logical_idx, physical_idx in gpu_indices:
        prefix = f"gpu/{logical_idx}"
        try:
            handle = pynvml.nvmlDeviceGetHandleByIndex(physical_idx)

            try:
                util = pynvml.nvmlDeviceGetUtilizationRates(handle)
                metrics[f"{prefix}/utilization"] = util.gpu
                metrics[f"{prefix}/memory_utilization"] = util.memory
                total_util += util.gpu
                valid_util_count += 1
            except Exception:
                pass

            try:
                mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
                mem_used_gib = mem.used / (1024**3)
                mem_total_gib = mem.total / (1024**3)
                metrics[f"{prefix}/allocated_memory"] = mem_used_gib
                metrics[f"{prefix}/total_memory"] = mem_total_gib
                if mem.total > 0:
                    metrics[f"{prefix}/memory_usage"] = mem.used / mem.total
                total_mem_used_gib += mem_used_gib
            except Exception:
                pass

            try:
                power_mw = pynvml.nvmlDeviceGetPowerUsage(handle)
                power_w = power_mw / 1000.0
                metrics[f"{prefix}/power"] = power_w
                total_power += power_w
            except Exception:
                pass

            try:
                power_limit_mw = pynvml.nvmlDeviceGetPowerManagementLimit(handle)
                power_limit_w = power_limit_mw / 1000.0
                metrics[f"{prefix}/power_limit"] = power_limit_w
                if power_limit_w > 0 and f"{prefix}/power" in metrics:
                    metrics[f"{prefix}/power_percent"] = (
                        metrics[f"{prefix}/power"] / power_limit_w
                    ) * 100
            except Exception:
                pass

            try:
                temp = pynvml.nvmlDeviceGetTemperature(
                    handle, pynvml.NVML_TEMPERATURE_GPU
                )
                metrics[f"{prefix}/temp"] = temp
                max_temp = max(max_temp, temp)
            except Exception:
                pass

            try:
                sm_clock = pynvml.nvmlDeviceGetClockInfo(handle, pynvml.NVML_CLOCK_SM)
                metrics[f"{prefix}/sm_clock"] = sm_clock
            except Exception:
                pass

            try:
                mem_clock = pynvml.nvmlDeviceGetClockInfo(handle, pynvml.NVML_CLOCK_MEM)
                metrics[f"{prefix}/memory_clock"] = mem_clock
            except Exception:
                pass

            try:
                fan_speed = pynvml.nvmlDeviceGetFanSpeed(handle)
                metrics[f"{prefix}/fan_speed"] = fan_speed
            except Exception:
                pass

            try:
                pstate = pynvml.nvmlDeviceGetPerformanceState(handle)
                metrics[f"{prefix}/performance_state"] = pstate
            except Exception:
                pass

            try:
                energy_mj = pynvml.nvmlDeviceGetTotalEnergyConsumption(handle)
                if logical_idx not in _energy_baseline:
                    _energy_baseline[logical_idx] = energy_mj
                energy_consumed_mj = energy_mj - _energy_baseline[logical_idx]
                metrics[f"{prefix}/energy_consumed"] = energy_consumed_mj / 1000.0
            except Exception:
                pass

            try:
                pcie_tx = pynvml.nvmlDeviceGetPcieThroughput(
                    handle, pynvml.NVML_PCIE_UTIL_TX_BYTES
                )
                pcie_rx = pynvml.nvmlDeviceGetPcieThroughput(
                    handle, pynvml.NVML_PCIE_UTIL_RX_BYTES
                )
                metrics[f"{prefix}/pcie_tx"] = pcie_tx / 1024.0
                metrics[f"{prefix}/pcie_rx"] = pcie_rx / 1024.0
            except Exception:
                pass

            try:
                throttle = pynvml.nvmlDeviceGetCurrentClocksThrottleReasons(handle)
                metrics[f"{prefix}/throttle_thermal"] = int(
                    bool(throttle & pynvml.nvmlClocksThrottleReasonSwThermalSlowdown)
                )
                metrics[f"{prefix}/throttle_power"] = int(
                    bool(throttle & pynvml.nvmlClocksThrottleReasonSwPowerCap)
                )
                metrics[f"{prefix}/throttle_hw_slowdown"] = int(
                    bool(throttle & pynvml.nvmlClocksThrottleReasonHwSlowdown)
                )
                metrics[f"{prefix}/throttle_apps"] = int(
                    bool(
                        throttle
                        & pynvml.nvmlClocksThrottleReasonApplicationsClocksSetting
                    )
                )
            except Exception:
                pass

            try:
                ecc_corrected = pynvml.nvmlDeviceGetTotalEccErrors(
                    handle,
                    pynvml.NVML_MEMORY_ERROR_TYPE_CORRECTED,
                    pynvml.NVML_VOLATILE_ECC,
                )
                metrics[f"{prefix}/corrected_memory_errors"] = ecc_corrected
            except Exception:
                pass

            try:
                ecc_uncorrected = pynvml.nvmlDeviceGetTotalEccErrors(
                    handle,
                    pynvml.NVML_MEMORY_ERROR_TYPE_UNCORRECTED,
                    pynvml.NVML_VOLATILE_ECC,
                )
                metrics[f"{prefix}/uncorrected_memory_errors"] = ecc_uncorrected
            except Exception:
                pass

        except Exception:
            continue

    if valid_util_count > 0:
        metrics["gpu/mean_utilization"] = total_util / valid_util_count
    if total_mem_used_gib > 0:
        metrics["gpu/total_allocated_memory"] = total_mem_used_gib
    if total_power > 0:
        metrics["gpu/total_power"] = total_power
    if max_temp > 0:
        metrics["gpu/max_temp"] = max_temp

    return metrics


class GpuMonitor:
    def __init__(self, run: "Run", interval: float = 10.0):
        self._run = run
        self._interval = interval
        self._stop_flag = threading.Event()
        self._thread: "threading.Thread | None" = None

    def start(self):
        count, _ = get_gpu_count()
        if count == 0:
            warnings.warn(
                "auto_log_gpu=True but no NVIDIA GPUs detected. GPU logging disabled."
            )
            return

        reset_energy_baseline()
        self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
        self._thread.start()

    def stop(self):
        self._stop_flag.set()
        if self._thread is not None:
            self._thread.join(timeout=2.0)

    def _monitor_loop(self):
        while not self._stop_flag.is_set():
            try:
                metrics = collect_gpu_metrics()
                if metrics:
                    self._run.log_system(metrics)
            except Exception:
                pass

            self._stop_flag.wait(timeout=self._interval)


def log_gpu(run: "Run | None" = None, device: int | None = None) -> dict:
    """
    Log GPU metrics to the current or specified run as system metrics.

    Args:
        run: Optional Run instance. If None, uses current run from context.
        device: CUDA device index to collect metrics from. If None, collects
                from all GPUs visible to this process (respects CUDA_VISIBLE_DEVICES).

    Returns:
        dict: The GPU metrics that were logged.

    Example:
        ```python
        import trackio

        run = trackio.init(project="my-project")
        trackio.log({"loss": 0.5})
        trackio.log_gpu()  # logs all visible GPUs
        trackio.log_gpu(device=0)  # logs only CUDA device 0
        ```
    """
    from trackio import context_vars

    if run is None:
        run = context_vars.current_run.get()
        if run is None:
            raise RuntimeError("Call trackio.init() before trackio.log_gpu().")

    metrics = collect_gpu_metrics(device=device)
    if metrics:
        run.log_system(metrics)
    return metrics