vmjn commited on
Commit
f3634b6
·
verified ·
1 Parent(s): 33d7b77

patch KronosPredictor for torch 2.6 meta tensor safety

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -1
  2. model/kronos.py +18 -2
Dockerfile CHANGED
@@ -19,7 +19,7 @@ RUN pip install --user --no-cache-dir --index-url https://download.pytorch.org/w
19
  # App deps — all pinned to known-compatible versions
20
  RUN pip install --user --no-cache-dir \
21
  "gradio[mcp]==5.30.0" \
22
- "huggingface_hub>=0.25.0,<0.27.0" \
23
  "numpy>=1.26,<2.3" \
24
  "pandas>=2.1" \
25
  "yfinance>=0.2.50" \
 
19
  # App deps — all pinned to known-compatible versions
20
  RUN pip install --user --no-cache-dir \
21
  "gradio[mcp]==5.30.0" \
22
+ "huggingface_hub>=0.27.0,<1.0" \
23
  "numpy>=1.26,<2.3" \
24
  "pandas>=2.1" \
25
  "yfinance>=0.2.50" \
model/kronos.py CHANGED
@@ -502,8 +502,24 @@ class KronosPredictor:
502
 
503
  self.device = device
504
 
505
- self.tokenizer = self.tokenizer.to(self.device)
506
- self.model = self.model.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
 
508
  def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose):
509
 
 
502
 
503
  self.device = device
504
 
505
+ # torch 2.6 + huggingface_hub 0.27+ compat: skip .to() when target is CPU
506
+ # (models already load to CPU by default) and handle meta tensors gracefully.
507
+ def _safe_to(mod, dev):
508
+ try:
509
+ has_meta = any(p.is_meta for p in mod.parameters()) or \
510
+ any(b.is_meta for b in mod.buffers())
511
+ except Exception:
512
+ has_meta = False
513
+ if has_meta:
514
+ # Should be rare for fresh from_pretrained; skip .to() — torch raises on meta.to()
515
+ return mod
516
+ # Skip no-op .to('cpu') to avoid touching the moved-from-meta code path
517
+ if str(dev) in ("cpu", "torch.device('cpu')") and next(mod.parameters()).device.type == "cpu":
518
+ return mod
519
+ return mod.to(dev)
520
+
521
+ self.tokenizer = _safe_to(self.tokenizer, self.device)
522
+ self.model = _safe_to(self.model, self.device)
523
 
524
  def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose):
525