Spaces:
Sleeping
Sleeping
fix: patch torchmetrics._apply using CPU probe tensor instead of self.device
Browse filesThe root cause of 'Torch not compiled with CUDA enabled' is that
torchmetrics.Metric._apply does fn(torch.zeros(1, device=self.device))
where self.device='cuda:0' from the checkpoint — even after map_location
moves all tensors to CPU, because _device is an attribute not a tensor.
Fix: replace the CUDA probe with fn(torch.zeros(1, device='cpu')) so
the destination device is inferred safely without touching CUDA hardware.
Patch is applied before pytorch_forecasting is imported so it is in place
when Lightning restores metric state during load_from_checkpoint.
- app/models/tft_predictor.py +32 -3
app/models/tft_predictor.py
CHANGED
|
@@ -57,6 +57,31 @@ def _maybe_download(filename: str, local_path: str) -> bool:
|
|
| 57 |
return False
|
| 58 |
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
def load_model(model_path: str):
|
| 61 |
"""Load and cache the pytorch-forecasting TFT from a Lightning checkpoint."""
|
| 62 |
global _model, _model_path_cached
|
|
@@ -67,11 +92,15 @@ def load_model(model_path: str):
|
|
| 67 |
if not os.path.exists(model_path):
|
| 68 |
raise FileNotFoundError(f"Model checkpoint not found: {model_path}")
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
from pytorch_forecasting import TemporalFusionTransformer
|
| 71 |
|
| 72 |
-
#
|
| 73 |
-
#
|
| 74 |
-
#
|
| 75 |
model = TemporalFusionTransformer.load_from_checkpoint(
|
| 76 |
model_path,
|
| 77 |
map_location=lambda storage, loc: storage.cpu(),
|
|
|
|
| 57 |
return False
|
| 58 |
|
| 59 |
|
| 60 |
+
def _patch_torchmetrics_cpu():
|
| 61 |
+
"""Patch torchmetrics.Metric._apply to avoid CUDA errors on CPU-only servers.
|
| 62 |
+
|
| 63 |
+
When a GPU-trained checkpoint is loaded on CPU-only hardware, the torchmetrics
|
| 64 |
+
Metric._apply method does `fn(torch.zeros(1, device=self.device))` where
|
| 65 |
+
self.device may still be "cuda:0" from the checkpoint. We replace that with
|
| 66 |
+
a safe CPU probe so the destination device is inferred without touching CUDA.
|
| 67 |
+
"""
|
| 68 |
+
try:
|
| 69 |
+
import torchmetrics
|
| 70 |
+
import torch.nn as nn
|
| 71 |
+
|
| 72 |
+
_orig = torchmetrics.Metric._apply
|
| 73 |
+
|
| 74 |
+
def _safe_apply(self, fn):
|
| 75 |
+
# Probe destination device via a CPU tensor — never touches CUDA.
|
| 76 |
+
self._device = fn(torch.zeros(1, device="cpu")).device
|
| 77 |
+
return nn.Module._apply(self, fn)
|
| 78 |
+
|
| 79 |
+
torchmetrics.Metric._apply = _safe_apply
|
| 80 |
+
print("[tft] torchmetrics._apply patched for CPU-only inference")
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f"[tft] torchmetrics patch skipped: {e}")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
def load_model(model_path: str):
|
| 86 |
"""Load and cache the pytorch-forecasting TFT from a Lightning checkpoint."""
|
| 87 |
global _model, _model_path_cached
|
|
|
|
| 92 |
if not os.path.exists(model_path):
|
| 93 |
raise FileNotFoundError(f"Model checkpoint not found: {model_path}")
|
| 94 |
|
| 95 |
+
# Patch torchmetrics BEFORE importing pytorch_forecasting so the patched
|
| 96 |
+
# _apply is in place when Lightning restores metric state from the checkpoint.
|
| 97 |
+
_patch_torchmetrics_cpu()
|
| 98 |
+
|
| 99 |
from pytorch_forecasting import TemporalFusionTransformer
|
| 100 |
|
| 101 |
+
# Callable map_location: moves all tensors to CPU AND skips Lightning's
|
| 102 |
+
# isinstance(map_location, (str, torch.device)) branch that would call
|
| 103 |
+
# model.to(map_location) — which would re-trigger the CUDA error.
|
| 104 |
model = TemporalFusionTransformer.load_from_checkpoint(
|
| 105 |
model_path,
|
| 106 |
map_location=lambda storage, loc: storage.cpu(),
|