will702 commited on
Commit
49e908a
·
1 Parent(s): 5c3e365

fix: use callable map_location to skip Lightning model.to() — avoids torchmetrics CUDA assertion

Browse files
Files changed (1) hide show
  1. app/models/tft_predictor.py +8 -26
app/models/tft_predictor.py CHANGED
@@ -57,21 +57,6 @@ def _maybe_download(filename: str, local_path: str) -> bool:
57
  return False
58
 
59
 
60
- def _patch_checkpoint_for_cpu(obj):
61
- """Recursively replace CUDA device references with CPU in a checkpoint dict."""
62
- if isinstance(obj, dict):
63
- return {k: _patch_checkpoint_for_cpu(v) for k, v in obj.items()}
64
- if isinstance(obj, list):
65
- return [_patch_checkpoint_for_cpu(v) for v in obj]
66
- if isinstance(obj, torch.Tensor):
67
- return obj.cpu()
68
- if isinstance(obj, torch.device) and obj.type == "cuda":
69
- return torch.device("cpu")
70
- if isinstance(obj, str) and obj.lower().startswith("cuda"):
71
- return "cpu"
72
- return obj
73
-
74
-
75
  def load_model(model_path: str):
76
  """Load and cache the pytorch-forecasting TFT from a Lightning checkpoint."""
77
  global _model, _model_path_cached
@@ -84,20 +69,17 @@ def load_model(model_path: str):
84
 
85
  from pytorch_forecasting import TemporalFusionTransformer
86
 
87
- # Checkpoint was saved on GPU (Colab). Patch the raw checkpoint to replace
88
- # all CUDA device references before Lightning tries to move the model.
89
- cpu_path = model_path + ".cpu"
90
- if not os.path.exists(cpu_path):
91
- ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
92
- ckpt = _patch_checkpoint_for_cpu(ckpt)
93
- torch.save(ckpt, cpu_path)
94
-
95
- model = TemporalFusionTransformer.load_from_checkpoint(cpu_path, map_location="cpu")
96
- model = model.cpu()
97
  model.eval()
98
  _model = model
99
  _model_path_cached = model_path
100
- print(f"[tft] Loaded pytorch-forecasting TFT from {cpu_path}")
101
  return model
102
 
103
 
 
57
  return False
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def load_model(model_path: str):
61
  """Load and cache the pytorch-forecasting TFT from a Lightning checkpoint."""
62
  global _model, _model_path_cached
 
69
 
70
  from pytorch_forecasting import TemporalFusionTransformer
71
 
72
+ # Checkpoint was saved on GPU (Colab). Using a callable map_location moves
73
+ # all tensors to CPU AND skips Lightning's model.to(device) call, which is
74
+ # what triggers the torchmetrics CUDA assertion on CPU-only servers.
75
+ model = TemporalFusionTransformer.load_from_checkpoint(
76
+ model_path,
77
+ map_location=lambda storage, loc: storage.cpu(),
78
+ )
 
 
 
79
  model.eval()
80
  _model = model
81
  _model_path_cached = model_path
82
+ print(f"[tft] Loaded pytorch-forecasting TFT from {model_path}")
83
  return model
84
 
85