Spaces:
Running
Running
File size: 3,054 Bytes
c658ad5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | """TimesFM 2.5 forecasting wrapper (Transformers path).
Shares the baseline contract: forecast(obs_inc, horizon, ...) -> {q10, q50, q90}.
IMPORTANT
* Do NOT externally normalise inputs - TimesFM 2.5 applies internal instance
normalisation (RevIN). Feeding pre-scaled data corrupts the forecast.
* Verified against transformers 5.10.2 / google/timesfm-2.5-200m-transformers:
forward(past_values=<list of 1D tensors>); out.full_predictions is (B, 128, 10)
laid out as [mean, q0.1 .. q0.9]; horizon is fixed at 128 (slice to what you need).
Docs: https://huggingface.co/docs/transformers/model_doc/timesfm2_5
"""
from __future__ import annotations
import numpy as np
DEFAULT_REPO = "google/timesfm-2.5-200m-transformers"
_Z10 = 1.2816
_model = None
_model_key = None
def load_timesfm(repo: str = DEFAULT_REPO, device: str = "cpu", adapter: str | None = None):
"""Load (and cache) the model, optionally with a LoRA adapter for the fine-tune."""
global _model, _model_key
key = (repo, adapter)
if _model is None or _model_key != key:
from transformers import TimesFm2_5ModelForPrediction
model = TimesFm2_5ModelForPrediction.from_pretrained(repo)
if adapter:
from peft import PeftModel
model = PeftModel.from_pretrained(model, adapter)
_model = model.to(device).eval()
_model_key = key
return _model
def timesfm_forecast(
obs_inc,
horizon,
*,
repo: str = DEFAULT_REPO,
device: str = "cpu",
adapter: str | None = None,
forecast_context_len: int | None = None,
**kw,
) -> dict:
"""Zero-shot (adapter=None) or fine-tuned (adapter=<repo>) forecast of the
incremental series. Returns P10/P50/P90 increments."""
import torch
model = load_timesfm(repo=repo, device=device, adapter=adapter)
x = [torch.tensor(np.asarray(obs_inc, dtype="float32"), device=device)]
with torch.no_grad():
out = (model(past_values=x, forecast_context_len=forecast_context_len)
if forecast_context_len else model(past_values=x))
# full_predictions: (B, 128, 1 + n_quantiles) laid out as [mean, q0.1 .. q0.9].
full = np.asarray(out.full_predictions[0].detach().cpu())
h = min(horizon, full.shape[0])
qs = list(getattr(model.config, "quantiles", (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)))
def _col(level): # column 0 is the mean; quantiles start at col 1
return 1 + min(range(len(qs)), key=lambda i: abs(qs[i] - level))
q10, q50, q90 = full[:h, _col(0.1)], full[:h, _col(0.5)], full[:h, _col(0.9)]
if horizon > h: # pad if asked beyond the fixed 128 horizon
pad = horizon - h
q10 = np.concatenate([q10, np.repeat(q10[-1], pad)])
q50 = np.concatenate([q50, np.repeat(q50[-1], pad)])
q90 = np.concatenate([q90, np.repeat(q90[-1], pad)])
return {
"q10": np.clip(q10, 0.0, None),
"q50": np.clip(q50, 0.0, None),
"q90": np.clip(q90, 0.0, None),
}
|