File size: 3,054 Bytes
c658ad5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""TimesFM 2.5 forecasting wrapper (Transformers path).

Shares the baseline contract: forecast(obs_inc, horizon, ...) -> {q10, q50, q90}.

IMPORTANT
  * Do NOT externally normalise inputs - TimesFM 2.5 applies internal instance
    normalisation (RevIN). Feeding pre-scaled data corrupts the forecast.
  * Verified against transformers 5.10.2 / google/timesfm-2.5-200m-transformers:
    forward(past_values=<list of 1D tensors>); out.full_predictions is (B, 128, 10)
    laid out as [mean, q0.1 .. q0.9]; horizon is fixed at 128 (slice to what you need).
    Docs: https://huggingface.co/docs/transformers/model_doc/timesfm2_5
"""
from __future__ import annotations

import numpy as np

DEFAULT_REPO = "google/timesfm-2.5-200m-transformers"
_Z10 = 1.2816

_model = None
_model_key = None


def load_timesfm(repo: str = DEFAULT_REPO, device: str = "cpu", adapter: str | None = None):
    """Load (and cache) the model, optionally with a LoRA adapter for the fine-tune."""
    global _model, _model_key
    key = (repo, adapter)
    if _model is None or _model_key != key:
        from transformers import TimesFm2_5ModelForPrediction

        model = TimesFm2_5ModelForPrediction.from_pretrained(repo)
        if adapter:
            from peft import PeftModel

            model = PeftModel.from_pretrained(model, adapter)
        _model = model.to(device).eval()
        _model_key = key
    return _model


def timesfm_forecast(
    obs_inc,
    horizon,
    *,
    repo: str = DEFAULT_REPO,
    device: str = "cpu",
    adapter: str | None = None,
    forecast_context_len: int | None = None,
    **kw,
) -> dict:
    """Zero-shot (adapter=None) or fine-tuned (adapter=<repo>) forecast of the
    incremental series. Returns P10/P50/P90 increments."""
    import torch

    model = load_timesfm(repo=repo, device=device, adapter=adapter)
    x = [torch.tensor(np.asarray(obs_inc, dtype="float32"), device=device)]
    with torch.no_grad():
        out = (model(past_values=x, forecast_context_len=forecast_context_len)
               if forecast_context_len else model(past_values=x))

    # full_predictions: (B, 128, 1 + n_quantiles) laid out as [mean, q0.1 .. q0.9].
    full = np.asarray(out.full_predictions[0].detach().cpu())
    h = min(horizon, full.shape[0])
    qs = list(getattr(model.config, "quantiles", (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)))

    def _col(level):                       # column 0 is the mean; quantiles start at col 1
        return 1 + min(range(len(qs)), key=lambda i: abs(qs[i] - level))

    q10, q50, q90 = full[:h, _col(0.1)], full[:h, _col(0.5)], full[:h, _col(0.9)]
    if horizon > h:                        # pad if asked beyond the fixed 128 horizon
        pad = horizon - h
        q10 = np.concatenate([q10, np.repeat(q10[-1], pad)])
        q50 = np.concatenate([q50, np.repeat(q50[-1], pad)])
        q90 = np.concatenate([q90, np.repeat(q90[-1], pad)])

    return {
        "q10": np.clip(q10, 0.0, None),
        "q50": np.clip(q50, 0.0, None),
        "q90": np.clip(q90, 0.0, None),
    }