will702 commited on
Commit
ec4688a
·
1 Parent(s): 49e908a

fix: patch torchmetrics._apply using CPU probe tensor instead of self.device

Browse files

The root cause of 'Torch not compiled with CUDA enabled' is that
torchmetrics.Metric._apply does fn(torch.zeros(1, device=self.device))
where self.device='cuda:0' from the checkpoint — even after map_location
moves all tensors to CPU, because _device is an attribute not a tensor.

Fix: replace the CUDA probe with fn(torch.zeros(1, device='cpu')) so
the destination device is inferred safely without touching CUDA hardware.
Patch is applied before pytorch_forecasting is imported so it is in place
when Lightning restores metric state during load_from_checkpoint.

Files changed (1) hide show
  1. app/models/tft_predictor.py +32 -3
app/models/tft_predictor.py CHANGED
@@ -57,6 +57,31 @@ def _maybe_download(filename: str, local_path: str) -> bool:
57
  return False
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def load_model(model_path: str):
61
  """Load and cache the pytorch-forecasting TFT from a Lightning checkpoint."""
62
  global _model, _model_path_cached
@@ -67,11 +92,15 @@ def load_model(model_path: str):
67
  if not os.path.exists(model_path):
68
  raise FileNotFoundError(f"Model checkpoint not found: {model_path}")
69
 
 
 
 
 
70
  from pytorch_forecasting import TemporalFusionTransformer
71
 
72
- # Checkpoint was saved on GPU (Colab). Using a callable map_location moves
73
- # all tensors to CPU AND skips Lightning's model.to(device) call, which is
74
- # what triggers the torchmetrics CUDA assertion on CPU-only servers.
75
  model = TemporalFusionTransformer.load_from_checkpoint(
76
  model_path,
77
  map_location=lambda storage, loc: storage.cpu(),
 
57
  return False
58
 
59
 
60
+ def _patch_torchmetrics_cpu():
61
+ """Patch torchmetrics.Metric._apply to avoid CUDA errors on CPU-only servers.
62
+
63
+ When a GPU-trained checkpoint is loaded on CPU-only hardware, the torchmetrics
64
+ Metric._apply method does `fn(torch.zeros(1, device=self.device))` where
65
+ self.device may still be "cuda:0" from the checkpoint. We replace that with
66
+ a safe CPU probe so the destination device is inferred without touching CUDA.
67
+ """
68
+ try:
69
+ import torchmetrics
70
+ import torch.nn as nn
71
+
72
+ _orig = torchmetrics.Metric._apply
73
+
74
+ def _safe_apply(self, fn):
75
+ # Probe destination device via a CPU tensor — never touches CUDA.
76
+ self._device = fn(torch.zeros(1, device="cpu")).device
77
+ return nn.Module._apply(self, fn)
78
+
79
+ torchmetrics.Metric._apply = _safe_apply
80
+ print("[tft] torchmetrics._apply patched for CPU-only inference")
81
+ except Exception as e:
82
+ print(f"[tft] torchmetrics patch skipped: {e}")
83
+
84
+
85
  def load_model(model_path: str):
86
  """Load and cache the pytorch-forecasting TFT from a Lightning checkpoint."""
87
  global _model, _model_path_cached
 
92
  if not os.path.exists(model_path):
93
  raise FileNotFoundError(f"Model checkpoint not found: {model_path}")
94
 
95
+ # Patch torchmetrics BEFORE importing pytorch_forecasting so the patched
96
+ # _apply is in place when Lightning restores metric state from the checkpoint.
97
+ _patch_torchmetrics_cpu()
98
+
99
  from pytorch_forecasting import TemporalFusionTransformer
100
 
101
+ # Callable map_location: moves all tensors to CPU AND skips Lightning's
102
+ # isinstance(map_location, (str, torch.device)) branch that would call
103
+ # model.to(map_location) which would re-trigger the CUDA error.
104
  model = TemporalFusionTransformer.load_from_checkpoint(
105
  model_path,
106
  map_location=lambda storage, loc: storage.cpu(),