"""TimesFM 2.5 forecasting wrapper (Transformers path). Shares the baseline contract: forecast(obs_inc, horizon, ...) -> {q10, q50, q90}. IMPORTANT * Do NOT externally normalise inputs - TimesFM 2.5 applies internal instance normalisation (RevIN). Feeding pre-scaled data corrupts the forecast. * Verified against transformers 5.10.2 / google/timesfm-2.5-200m-transformers: forward(past_values=); out.full_predictions is (B, 128, 10) laid out as [mean, q0.1 .. q0.9]; horizon is fixed at 128 (slice to what you need). Docs: https://huggingface.co/docs/transformers/model_doc/timesfm2_5 """ from __future__ import annotations import numpy as np DEFAULT_REPO = "google/timesfm-2.5-200m-transformers" _Z10 = 1.2816 _model = None _model_key = None def load_timesfm(repo: str = DEFAULT_REPO, device: str = "cpu", adapter: str | None = None): """Load (and cache) the model, optionally with a LoRA adapter for the fine-tune.""" global _model, _model_key key = (repo, adapter) if _model is None or _model_key != key: from transformers import TimesFm2_5ModelForPrediction model = TimesFm2_5ModelForPrediction.from_pretrained(repo) if adapter: from peft import PeftModel model = PeftModel.from_pretrained(model, adapter) _model = model.to(device).eval() _model_key = key return _model def timesfm_forecast( obs_inc, horizon, *, repo: str = DEFAULT_REPO, device: str = "cpu", adapter: str | None = None, forecast_context_len: int | None = None, **kw, ) -> dict: """Zero-shot (adapter=None) or fine-tuned (adapter=) forecast of the incremental series. Returns P10/P50/P90 increments.""" import torch model = load_timesfm(repo=repo, device=device, adapter=adapter) x = [torch.tensor(np.asarray(obs_inc, dtype="float32"), device=device)] with torch.no_grad(): out = (model(past_values=x, forecast_context_len=forecast_context_len) if forecast_context_len else model(past_values=x)) # full_predictions: (B, 128, 1 + n_quantiles) laid out as [mean, q0.1 .. q0.9]. full = np.asarray(out.full_predictions[0].detach().cpu()) h = min(horizon, full.shape[0]) qs = list(getattr(model.config, "quantiles", (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9))) def _col(level): # column 0 is the mean; quantiles start at col 1 return 1 + min(range(len(qs)), key=lambda i: abs(qs[i] - level)) q10, q50, q90 = full[:h, _col(0.1)], full[:h, _col(0.5)], full[:h, _col(0.9)] if horizon > h: # pad if asked beyond the fixed 128 horizon pad = horizon - h q10 = np.concatenate([q10, np.repeat(q10[-1], pad)]) q50 = np.concatenate([q50, np.repeat(q50[-1], pad)]) q90 = np.concatenate([q90, np.repeat(q90[-1], pad)]) return { "q10": np.clip(q10, 0.0, None), "q50": np.clip(q50, 0.0, None), "q90": np.clip(q90, 0.0, None), }