Spaces:
Running
Running
patch KronosPredictor for torch 2.6 meta tensor safety
Browse files- Dockerfile +1 -1
- model/kronos.py +18 -2
Dockerfile
CHANGED
|
@@ -19,7 +19,7 @@ RUN pip install --user --no-cache-dir --index-url https://download.pytorch.org/w
|
|
| 19 |
# App deps — all pinned to known-compatible versions
|
| 20 |
RUN pip install --user --no-cache-dir \
|
| 21 |
"gradio[mcp]==5.30.0" \
|
| 22 |
-
"huggingface_hub>=0.
|
| 23 |
"numpy>=1.26,<2.3" \
|
| 24 |
"pandas>=2.1" \
|
| 25 |
"yfinance>=0.2.50" \
|
|
|
|
| 19 |
# App deps — all pinned to known-compatible versions
|
| 20 |
RUN pip install --user --no-cache-dir \
|
| 21 |
"gradio[mcp]==5.30.0" \
|
| 22 |
+
"huggingface_hub>=0.27.0,<1.0" \
|
| 23 |
"numpy>=1.26,<2.3" \
|
| 24 |
"pandas>=2.1" \
|
| 25 |
"yfinance>=0.2.50" \
|
model/kronos.py
CHANGED
|
@@ -502,8 +502,24 @@ class KronosPredictor:
|
|
| 502 |
|
| 503 |
self.device = device
|
| 504 |
|
| 505 |
-
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
|
| 508 |
def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose):
|
| 509 |
|
|
|
|
| 502 |
|
| 503 |
self.device = device
|
| 504 |
|
| 505 |
+
# torch 2.6 + huggingface_hub 0.27+ compat: skip .to() when target is CPU
|
| 506 |
+
# (models already load to CPU by default) and handle meta tensors gracefully.
|
| 507 |
+
def _safe_to(mod, dev):
|
| 508 |
+
try:
|
| 509 |
+
has_meta = any(p.is_meta for p in mod.parameters()) or \
|
| 510 |
+
any(b.is_meta for b in mod.buffers())
|
| 511 |
+
except Exception:
|
| 512 |
+
has_meta = False
|
| 513 |
+
if has_meta:
|
| 514 |
+
# Should be rare for fresh from_pretrained; skip .to() — torch raises on meta.to()
|
| 515 |
+
return mod
|
| 516 |
+
# Skip no-op .to('cpu') to avoid touching the moved-from-meta code path
|
| 517 |
+
if str(dev) in ("cpu", "torch.device('cpu')") and next(mod.parameters()).device.type == "cpu":
|
| 518 |
+
return mod
|
| 519 |
+
return mod.to(dev)
|
| 520 |
+
|
| 521 |
+
self.tokenizer = _safe_to(self.tokenizer, self.device)
|
| 522 |
+
self.model = _safe_to(self.model, self.device)
|
| 523 |
|
| 524 |
def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose):
|
| 525 |
|