Nikita commited on
Commit
ab631da
·
1 Parent(s): 9ff4612

torch.compile, removed tirex folder, loading from pip

Browse files
app.py CHANGED
@@ -15,7 +15,7 @@ from tirex import load_model, ForecastModel
15
  # ----------------------------
16
 
17
  torch.manual_seed(42)
18
- model: ForecastModel = load_model("NX-AI/TiRex",device='cuda')
19
 
20
  def model_forecast(input_data, forecast_length=256, file_name=None):
21
  if os.path.basename(file_name) == "loop.csv" and forecast_length==256:
 
15
  # ----------------------------
16
 
17
  torch.manual_seed(42)
18
+ model: ForecastModel = load_model("NX-AI/TiRex", backend="torch", device="cuda", compile=True)
19
 
20
  def model_forecast(input_data, forecast_length=256, file_name=None):
21
  if os.path.basename(file_name) == "loop.csv" and forecast_length==256:
environment.yaml CHANGED
@@ -15,6 +15,7 @@ dependencies:
15
  - cuda-toolkit=12.6
16
  - cuda-cccl=12.6
17
  - pip:
 
18
  - --index-url https://download.pytorch.org/whl/cu126
19
  - --extra-index-url https://pypi.org/simple
20
  - pyarrow
 
15
  - cuda-toolkit=12.6
16
  - cuda-cccl=12.6
17
  - pip:
18
+ - tirex-ts
19
  - --index-url https://download.pytorch.org/whl/cu126
20
  - --extra-index-url https://pypi.org/simple
21
  - pyarrow
tirex/__init__.py DELETED
@@ -1,8 +0,0 @@
1
- # Copyright (c) NXAI GmbH.
2
- # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
-
4
- from .api_adapter.forecast import ForecastModel
5
- from .base import load_model
6
- from .models.tirex import TiRexZero
7
-
8
- __all__ = ["load_model", "ForecastModel"]
 
 
 
 
 
 
 
 
 
tirex/api_adapter/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- # Copyright (c) NXAI GmbH.
2
- # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
 
 
 
tirex/api_adapter/forecast.py DELETED
@@ -1,209 +0,0 @@
1
- # Copyright (c) NXAI GmbH.
2
- # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
-
4
- from abc import ABC, abstractmethod
5
- from typing import Literal
6
-
7
- import torch
8
-
9
- from .standard_adapter import ContextType, get_batches
10
-
11
- try:
12
- from .gluon import format_gluonts_output, get_gluon_batches
13
-
14
- _GLUONTS_AVAILABLE = True
15
- except ImportError:
16
- _GLUONTS_AVAILABLE = False
17
-
18
- try:
19
- from .hf_data import get_hfdata_batches
20
-
21
- _HF_DATASETS_AVAILABLE = True
22
- except ImportError:
23
- _HF_DATASETS_AVAILABLE = False
24
-
25
-
26
- DEF_TARGET_COLUMN = "target"
27
- DEF_META_COLUMNS = ("start", "item_id")
28
-
29
-
30
- def _format_output(
31
- quantiles: torch.Tensor,
32
- means: torch.Tensor,
33
- sample_meta: list[dict],
34
- quantile_levels: list[float],
35
- output_type: Literal["torch", "numpy", "gluonts"],
36
- ):
37
- if output_type == "torch":
38
- return quantiles.cpu(), means.cpu()
39
- elif output_type == "numpy":
40
- return quantiles.cpu().numpy(), means.cpu().numpy()
41
- elif output_type == "gluonts":
42
- if not _GLUONTS_AVAILABLE:
43
- raise ValueError("output_type glutonts needs GluonTs but GluonTS is not available (not installed)!")
44
- return format_gluonts_output(quantiles, means, sample_meta, quantile_levels)
45
- else:
46
- raise ValueError(f"Invalid output type: {output_type}")
47
-
48
-
49
- def _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs):
50
- for batch_ctx, batch_meta in batches:
51
- quantiles, mean = fc_func(batch_ctx, **predict_kwargs)
52
- yield _format_output(
53
- quantiles=quantiles,
54
- means=mean,
55
- sample_meta=batch_meta,
56
- quantile_levels=quantile_levels,
57
- output_type=output_type,
58
- )
59
-
60
-
61
- def _gen_forecast(fc_func, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs):
62
- if yield_per_batch:
63
- return _as_generator(batches, fc_func, quantile_levels, output_type, **predict_kwargs)
64
-
65
- prediction_q = []
66
- prediction_m = []
67
- sample_meta = []
68
- for batch_ctx, batch_meta in batches:
69
- quantiles, mean = fc_func(batch_ctx, **predict_kwargs)
70
- prediction_q.append(quantiles)
71
- prediction_m.append(mean)
72
- sample_meta.extend(batch_meta)
73
-
74
- prediction_q = torch.cat(prediction_q, dim=0)
75
- prediction_m = torch.cat(prediction_m, dim=0)
76
-
77
- return _format_output(
78
- quantiles=prediction_q,
79
- means=prediction_m,
80
- sample_meta=sample_meta,
81
- quantile_levels=quantile_levels,
82
- output_type=output_type,
83
- )
84
-
85
-
86
- def _common_forecast_doc():
87
- common_doc = f"""
88
- This method takes historical context data as input and outputs probabilistic forecasts.
89
-
90
- Args:
91
- output_type (Literal["torch", "numpy", "gluonts"], optional):
92
- Specifies the desired format of the returned forecasts:
93
- - "torch": Returns forecasts as `torch.Tensor` objects [batch_dim, forecast_len, |quantile_levels|]
94
- - "numpy": Returns forecasts as `numpy.ndarray` objects [batch_dim, forecast_len, |quantile_levels|]
95
- - "gluonts": Returns forecasts as a list of GluonTS `Forecast` objects.
96
- Defaults to "torch".
97
-
98
- batch_size (int, optional): The number of time series instances to process concurrently by the model.
99
- Defaults to 512. Must be $>= 1$.
100
-
101
- quantile_levels (List[float], optional): Quantile levels for which predictions should be generated.
102
- Defaults to (0.1, 0.2, ..., 0.9).
103
-
104
- yield_per_batch (bool, optional): If `True`, the method will act as a generator, yielding
105
- forecasts batch by batch as they are computed.
106
- Defaults to `False`.
107
-
108
- **predict_kwargs: Additional keyword arguments that are passed directly to the underlying
109
- prediction mechanism of the pre-trained model. Refer to the model's
110
- internal prediction method documentation for available options.
111
-
112
- Returns:
113
- The return type depends on `output_type` and `yield_per_batch`:
114
- - If `yield_per_batch` is `True`: An iterator that yields forecasts. Each yielded item
115
- will correspond to a batch of forecasts in the format specified by `output_type`.
116
- - If `yield_per_batch` is `False`: A single object containing all forecasts.
117
- - If `output_type="torch"`: `Tuple[torch.Tensor, torch.Tensor]` (quantiles, mean).
118
- - If `output_type="numpy"`: `Tuple[numpy.ndarray, numpy.ndarray]` (quantiles, mean).
119
- - If `output_type="gluonts"`: A `List[gluonts.model.forecast.Forecast]` of all forecasts.
120
- """
121
- return common_doc
122
-
123
-
124
- class ForecastModel(ABC):
125
- @abstractmethod
126
- def _forecast_quantiles(self, batch, **predict_kwargs):
127
- pass
128
-
129
- def forecast(
130
- self,
131
- context: ContextType,
132
- output_type: Literal["torch", "numpy", "gluonts"] = "torch",
133
- batch_size: int = 512,
134
- quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
135
- yield_per_batch: bool = False,
136
- **predict_kwargs,
137
- ):
138
- f"""
139
- {_common_forecast_doc}
140
- Args:
141
- context (ContextType): The historical "context" data of the time series:
142
- - `torch.Tensor`: 1D `[context_length]` or 2D `[batch_dim, context_length]` tensor
143
- - `np.ndarray`: 1D `[context_length]` or 2D `[batch_dim, context_length]` array
144
- - `List[torch.Tensor]`: List of 1D tensors (samples with different lengths get padded per batch)
145
- - `List[np.ndarray]`: List of 1D arrays (samples with different lengths get padded per batch)
146
- """
147
- assert batch_size >= 1, "Batch size must be >= 1"
148
- batches = get_batches(context, batch_size)
149
- return _gen_forecast(
150
- self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
151
- )
152
-
153
- def forecast_gluon(
154
- self,
155
- gluonDataset,
156
- output_type: Literal["torch", "numpy", "gluonts"] = "torch",
157
- batch_size: int = 512,
158
- quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
159
- yield_per_batch: bool = False,
160
- data_kwargs: dict = {},
161
- **predict_kwargs,
162
- ):
163
- f"""
164
- {_common_forecast_doc()}
165
-
166
- Args:
167
- gluonDataset (gluon_ts.dataset.common.Dataset): A GluonTS dataset object containing the
168
- historical time series data.
169
-
170
- data_kwargs (dict, optional): Additional keyword arguments passed to the
171
- autogluon data processing function.
172
- """
173
- assert batch_size >= 1, "Batch size must be >= 1"
174
- if not _GLUONTS_AVAILABLE:
175
- raise ValueError("forecast_gluon glutonts needs GluonTs but GluonTS is not available (not installed)!")
176
- batches = get_gluon_batches(gluonDataset, batch_size, **data_kwargs)
177
- return _gen_forecast(
178
- self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
179
- )
180
-
181
- def forecast_hfdata(
182
- self,
183
- hf_dataset,
184
- output_type: Literal["torch", "numpy", "gluonts"] = "torch",
185
- batch_size: int = 512,
186
- quantile_levels: list[float] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
187
- yield_per_batch: bool = False,
188
- data_kwargs: dict = {},
189
- **predict_kwargs,
190
- ):
191
- f"""
192
- {_common_forecast_doc()}
193
-
194
- Args:
195
- hf_dataset (datasets.Dataset): A Hugging Face `Dataset` object containing the
196
- historical time series data.
197
-
198
- data_kwargs (dict, optional): Additional keyword arguments passed to the
199
- datasets data processing function.
200
- """
201
- assert batch_size >= 1, "Batch size must be >= 1"
202
- if not _HF_DATASETS_AVAILABLE:
203
- raise ValueError(
204
- "forecast_hfdata glutonts needs HuggingFace datasets but datasets is not available (not installed)!"
205
- )
206
- batches = get_hfdata_batches(hf_dataset, batch_size, **data_kwargs)
207
- return _gen_forecast(
208
- self._forecast_quantiles, batches, output_type, quantile_levels, yield_per_batch, **predict_kwargs
209
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tirex/api_adapter/gluon.py DELETED
@@ -1,48 +0,0 @@
1
- # Copyright (c) NXAI GmbH.
2
- # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
-
4
- import pandas as pd
5
- import torch
6
- from gluonts.dataset.common import Dataset
7
- from gluonts.dataset.field_names import FieldName
8
- from gluonts.model.forecast import QuantileForecast
9
-
10
- from .standard_adapter import _batch_pad_iterable
11
-
12
- DEF_TARGET_COLUMN = FieldName.TARGET # target
13
- DEF_META_COLUMNS = (FieldName.START, FieldName.ITEM_ID)
14
-
15
-
16
- def _get_gluon_ts_map(**gluon_kwargs):
17
- target_col = gluon_kwargs.get("target_column", DEF_TARGET_COLUMN)
18
- meta_columns = gluon_kwargs.get("meta_columns", DEF_META_COLUMNS)
19
-
20
- def extract_gluon(series):
21
- ctx = torch.Tensor(series[target_col])
22
- meta = {k: series[k] for k in meta_columns if k in series}
23
- meta["length"] = len(ctx)
24
- return ctx, meta
25
-
26
- return extract_gluon
27
-
28
-
29
- def get_gluon_batches(gluonDataset: Dataset, batch_size: int, **gluon_kwargs):
30
- return _batch_pad_iterable(map(_get_gluon_ts_map(**gluon_kwargs), gluonDataset), batch_size)
31
-
32
-
33
- def format_gluonts_output(quantile_forecasts: torch.Tensor, mean_forecasts, meta: list[dict], quantile_levels):
34
- forecasts = []
35
- for i in range(quantile_forecasts.shape[0]):
36
- start_date = meta[i].get(FieldName.START, pd.Period("01-01-2000", freq=meta[i].get("freq", "h")))
37
- start_date += meta[i].get("length", 0)
38
- forecasts.append(
39
- QuantileForecast(
40
- forecast_arrays=torch.cat((quantile_forecasts[i], mean_forecasts[i].unsqueeze(1)), dim=1)
41
- .T.cpu()
42
- .numpy(),
43
- start_date=start_date,
44
- item_id=meta[i].get(FieldName.ITEM_ID, None),
45
- forecast_keys=list(map(str, quantile_levels)) + ["mean"],
46
- )
47
- )
48
- return forecasts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tirex/api_adapter/hf_data.py DELETED
@@ -1,38 +0,0 @@
1
- # Copyright (c) NXAI GmbH.
2
- # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
-
4
- import datasets
5
- import torch
6
-
7
- from .standard_adapter import _batch_pad_iterable
8
-
9
- DEF_TARGET_COLUMN = "target"
10
-
11
-
12
- def _get_hf_map(dataset: datasets.Dataset, **hf_kwargs):
13
- target_col = hf_kwargs.get("target_column", DEF_TARGET_COLUMN)
14
- meta_columns = hf_kwargs.get("meta_columns", ())
15
-
16
- columns_to_pass = [target_col] + list(meta_columns)
17
- remove_cols = [col for col in dataset.column_names if col not in columns_to_pass]
18
- dataset = (
19
- dataset.with_format("torch")
20
- .remove_columns(remove_cols)
21
- .cast_column(target_col, datasets.Sequence(datasets.Value("float32")))
22
- )
23
-
24
- def yield_batch_tuples(sample: dict) -> tuple[torch.Tensor, dict]:
25
- context_data = sample[target_col]
26
- if context_data.ndim > 1:
27
- context_data = context_data.squeeze()
28
- assert context_data.ndim == 1
29
- meta = {k: sample[k] for k in meta_columns if k in sample}
30
- meta["length"] = len(context_data)
31
- return context_data, meta
32
-
33
- return dataset, yield_batch_tuples
34
-
35
-
36
- def get_hfdata_batches(hf_dataset: datasets.Dataset, batch_size: int, **hf_kwargs):
37
- dataset, map_func = _get_hf_map(hf_dataset, **hf_kwargs)
38
- return _batch_pad_iterable(map(map_func, dataset), batch_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tirex/api_adapter/standard_adapter.py DELETED
@@ -1,67 +0,0 @@
1
- # Copyright (c) NXAI GmbH.
2
- # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
-
4
- import itertools
5
- from collections.abc import Iterable, Iterator, Sequence
6
- from typing import Union
7
-
8
- import numpy as np
9
- import torch
10
-
11
- ContextType = Union[
12
- torch.Tensor,
13
- np.ndarray,
14
- list[torch.Tensor],
15
- list[np.ndarray],
16
- ]
17
-
18
-
19
- def _batched_slice(full_batch, full_meta: list[dict] | None, batch_size: int) -> Iterator[tuple[Sequence, list[dict]]]:
20
- if len(full_batch) <= batch_size:
21
- yield full_batch, full_meta if full_meta is not None else [{} for _ in range(len(full_batch))]
22
- else:
23
- for i in range(0, len(full_batch), batch_size):
24
- batch = full_batch[i : i + batch_size]
25
- yield batch, (full_meta[i : i + batch_size] if full_meta is not None else [{} for _ in range(len(batch))])
26
-
27
-
28
- def _batched(iterable: Iterable, n: int):
29
- it = iter(iterable)
30
- while batch := tuple(itertools.islice(it, n)):
31
- yield batch
32
-
33
-
34
- def _batch_pad_iterable(iterable: Iterable[tuple[torch.Tensor, dict]], batch_size: int):
35
- for batch in _batched(iterable, batch_size):
36
- # ctx_it_len, ctx_it_data, it_meta = itertools.tee(batch, 3)
37
- max_len = max(len(el[0]) for el in batch)
38
- padded_batch = []
39
- meta = []
40
- for el in batch:
41
- sample = el[0]
42
- assert isinstance(sample, torch.Tensor)
43
- assert sample.ndim == 1
44
- assert len(sample) > 0, "Each sample needs to have a length > 0"
45
- padding = torch.full(size=(max_len - len(sample),), fill_value=torch.nan, device=sample.device)
46
- padded_batch.append(torch.cat((padding, sample)))
47
- meta.append(el[1])
48
- yield torch.stack(padded_batch), meta
49
-
50
-
51
- def get_batches(context: ContextType, batch_size: int):
52
- batches = None
53
- if isinstance(context, torch.Tensor):
54
- if context.ndim == 1:
55
- context = context.unsqueeze(0)
56
- assert context.ndim == 2
57
- batches = _batched_slice(context, None, batch_size)
58
- elif isinstance(context, np.ndarray):
59
- if context.ndim == 1:
60
- context = np.expand_dims(context, axis=0)
61
- assert context.ndim == 2
62
- batches = map(lambda x: (torch.Tensor(x[0]), x[1]), _batched_slice(context, None, batch_size))
63
- elif isinstance(context, (list, Iterable)):
64
- batches = _batch_pad_iterable(map(lambda x: (torch.Tensor(x), None), context), batch_size)
65
- if batches is None:
66
- raise ValueError(f"Context type {type(context)} not supported! Supported Types: {ContextType}")
67
- return batches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tirex/base.py DELETED
@@ -1,73 +0,0 @@
1
- # Copyright (c) NXAI GmbH.
2
- # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
-
4
- import os
5
- from abc import ABC, abstractmethod
6
- from typing import TypeVar
7
-
8
- from huggingface_hub import hf_hub_download
9
-
10
- T = TypeVar("T", bound="PretrainedModel")
11
-
12
-
13
- def parse_hf_repo_id(path):
14
- parts = path.split("/")
15
- return "/".join(parts[0:2])
16
-
17
-
18
- class PretrainedModel(ABC):
19
- REGISTRY: dict[str, "PretrainedModel"] = {}
20
-
21
- def __init_subclass__(cls, **kwargs):
22
- super().__init_subclass__(**kwargs)
23
- cls.REGISTRY[cls.register_name()] = cls
24
-
25
- @classmethod
26
- def from_pretrained(cls: type[T], path, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=None) -> T:
27
- if hf_kwargs is None:
28
- hf_kwargs = {}
29
- if ckp_kwargs is None:
30
- ckp_kwargs = {}
31
- if os.path.exists(path):
32
- print("Loading weights from local directory")
33
- checkpoint_path = path
34
- else:
35
- repo_id = parse_hf_repo_id(path)
36
- checkpoint_path = hf_hub_download(repo_id=repo_id, filename="model.ckpt", **hf_kwargs)
37
- model = cls.load_from_checkpoint(checkpoint_path, map_location=device, **ckp_kwargs)
38
- model.after_load_from_checkpoint()
39
- return model
40
-
41
- @classmethod
42
- @abstractmethod
43
- def register_name(cls) -> str:
44
- pass
45
-
46
- def after_load_from_checkpoint(self):
47
- pass
48
-
49
-
50
- def load_model(path: str, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=None) -> PretrainedModel:
51
- """Loads a TiRex model. This function attempts to load the specified model.
52
-
53
- Args:
54
- path (str): Hugging Face path to the model (e.g. NX-AI/TiRex)
55
- device (str, optional): The device on which to load the model (e.g., "cuda:0", "cpu").
56
- If you want to use "cpu" you need to deactivate the sLSTM CUDA kernels (check repository FAQ!).
57
- hf_kwargs (dict, optional): Keyword arguments to pass to the Hugging Face Hub download method.
58
- ckp_kwargs (dict, optional): Keyword arguments to pass when loading the checkpoint.
59
-
60
- Returns:
61
- PretrainedModel: The loaded model.
62
-
63
- Examples:
64
- model: ForecastModel = load_model("NX-AI/TiRex")
65
- """
66
- try:
67
- _, model_id = parse_hf_repo_id(path).split("/")
68
- except:
69
- raise ValueError(f"Invalid model path {path}")
70
- model_cls = PretrainedModel.REGISTRY.get(model_id, None)
71
- if model_cls is None:
72
- raise ValueError(f"Invalid model id {model_id}")
73
- return model_cls.from_pretrained(path, device=device, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tirex/models/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- # Copyright (c) NXAI GmbH.
2
- # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
 
 
 
tirex/models/components.py DELETED
@@ -1,147 +0,0 @@
1
- # Copyright (c) NXAI GmbH.
2
- # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
-
4
-
5
- from dataclasses import dataclass, field
6
- from typing import Any
7
-
8
- import torch
9
-
10
- SCALER_STATE = "scaler_state"
11
-
12
-
13
- class ResidualBlock(torch.nn.Module):
14
- def __init__(
15
- self,
16
- in_dim: int,
17
- h_dim: int,
18
- out_dim: int,
19
- dropout: float = 0,
20
- ) -> None:
21
- super().__init__()
22
- self.dropout = torch.nn.Dropout(dropout)
23
- self.hidden_layer = torch.nn.Linear(in_dim, h_dim)
24
- self.output_layer = torch.nn.Linear(h_dim, out_dim)
25
- self.residual_layer = torch.nn.Linear(in_dim, out_dim)
26
- self.act = torch.nn.ReLU()
27
-
28
- def forward(self, x: torch.Tensor):
29
- hid = self.act(self.hidden_layer(x))
30
- out = self.output_layer(hid)
31
- res = self.residual_layer(x)
32
- out = out + res
33
- return out
34
-
35
-
36
- @dataclass
37
- class StandardScaler:
38
- eps: float = 1e-5
39
- nan_loc: float = 0.0
40
-
41
- def scale(
42
- self,
43
- x: torch.Tensor,
44
- loc_scale: tuple[torch.Tensor, torch.Tensor] | None = None,
45
- ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
46
- if loc_scale is None:
47
- loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=self.nan_loc)
48
- scale = torch.nan_to_num(torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0)
49
- scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
50
- else:
51
- loc, scale = loc_scale
52
-
53
- return ((x - loc) / scale), (loc, scale)
54
-
55
- def re_scale(self, x: torch.Tensor, loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
56
- loc, scale = loc_scale
57
- return x * scale + loc
58
-
59
-
60
- @dataclass
61
- class _Patcher:
62
- patch_size: int
63
- patch_stride: int
64
- left_pad: bool
65
-
66
- def __post_init__(self):
67
- assert self.patch_size % self.patch_stride == 0
68
-
69
- def __call__(self, x: torch.Tensor) -> torch.Tensor:
70
- assert x.ndim == 2
71
- length = x.shape[-1]
72
-
73
- if length < self.patch_size or (length % self.patch_stride != 0):
74
- if length < self.patch_size:
75
- padding_size = (
76
- *x.shape[:-1],
77
- self.patch_size - (length % self.patch_size),
78
- )
79
- else:
80
- padding_size = (
81
- *x.shape[:-1],
82
- self.patch_stride - (length % self.patch_stride),
83
- )
84
- padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device)
85
- if self.left_pad:
86
- x = torch.concat((padding, x), dim=-1)
87
- else:
88
- x = torch.concat((x, padding), dim=-1)
89
-
90
- x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
91
- return x
92
-
93
-
94
- @dataclass
95
- class PatchedUniTokenizer:
96
- patch_size: int
97
- scaler: Any = field(default_factory=StandardScaler)
98
- patch_stride: int | None = None
99
-
100
- def __post_init__(self):
101
- if self.patch_stride is None:
102
- self.patch_stride = self.patch_size
103
- self.patcher = _Patcher(self.patch_size, self.patch_stride, left_pad=True)
104
-
105
- def context_input_transform(self, data: torch.Tensor):
106
- assert data.ndim == 2
107
- data, scale_state = self.scaler.scale(data)
108
- return self.patcher(data), {SCALER_STATE: scale_state}
109
-
110
- def output_transform(self, data: torch.Tensor, tokenizer_state: dict):
111
- data_shape = data.shape
112
- data = self.scaler.re_scale(data.reshape(data_shape[0], -1), tokenizer_state[SCALER_STATE]).view(*data_shape)
113
- return data
114
-
115
-
116
- class StreamToLogger:
117
- """Fake file-like stream object that redirects writes to a logger
118
- instance."""
119
-
120
- def __init__(self, logger, log_level):
121
- self.logger = logger
122
- self.log_level = log_level
123
- self.linebuf = "" # Buffer for partial lines
124
-
125
- def write(self, message):
126
- # Filter out empty messages (often from just a newline)
127
- if message.strip():
128
- self.linebuf += message
129
- # If the message contains a newline, process the full line
130
- if "\n" in self.linebuf:
131
- lines = self.linebuf.splitlines(keepends=True)
132
- for line in lines:
133
- if line.endswith("\n"):
134
- # Log full lines without the trailing newline (logger adds its own)
135
- self.logger.log(self.log_level, line.rstrip("\n"))
136
- else:
137
- # Keep partial lines in buffer
138
- self.linebuf = line
139
- return
140
- self.linebuf = "" # All lines processed
141
- # If no newline, keep buffering
142
-
143
- def flush(self):
144
- # Log any remaining buffered content when flush is called
145
- if self.linebuf.strip():
146
- self.logger.log(self.log_level, self.linebuf.rstrip("\n"))
147
- self.linebuf = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tirex/models/mixed_stack.py DELETED
@@ -1,143 +0,0 @@
1
- # Copyright (c) NXAI GmbH.
2
- # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
-
4
-
5
- import os
6
- from dataclasses import dataclass, field
7
-
8
- import torch
9
- from torch import nn
10
- from xlstm.blocks.slstm.layer import sLSTMLayer, sLSTMLayerConfig
11
- from xlstm.xlstm_large import xLSTMLargeConfig
12
- from xlstm.xlstm_large.components import RMSNorm
13
- from xlstm.xlstm_large.model import FeedForward, mLSTMBlock, mLSTMStateType
14
-
15
-
16
- def skip_cuda():
17
- return os.getenv("TIREX_NO_CUDA", "False").lower() in ("true", "1", "t")
18
-
19
-
20
- def init_cell(config: xLSTMLargeConfig, block_idx, num_blocks):
21
- return sLSTMLayer(
22
- sLSTMLayerConfig(
23
- embedding_dim=config.embedding_dim,
24
- num_heads=config.num_heads,
25
- conv1d_kernel_size=0, # 0 means no convolution included
26
- group_norm_weight=True,
27
- dropout=0,
28
- # CellConfig
29
- backend="vanilla" if skip_cuda() else "cuda",
30
- bias_init="powerlaw_blockdependent",
31
- recurrent_weight_init="zeros",
32
- num_gates=4,
33
- gradient_recurrent_cut=False,
34
- gradient_recurrent_clipval=None,
35
- forward_clipval=None,
36
- batch_size=8, # needed?
37
- _block_idx=block_idx,
38
- _num_blocks=num_blocks,
39
- )
40
- )
41
-
42
-
43
- sLSTMLayerStateType = tuple[torch.Tensor, torch.Tensor]
44
- sLSTMStateType = dict[int, sLSTMLayerStateType]
45
-
46
-
47
- class sLSTMBlock(nn.Module):
48
- def __init__(self, config: xLSTMLargeConfig, block_idx: int, num_blocks: int):
49
- super().__init__()
50
- self.config = config
51
- self.norm_slstm = RMSNorm(
52
- num_features=config.embedding_dim,
53
- eps=config.norm_eps,
54
- use_weight=True,
55
- use_bias=config.use_bias,
56
- force_float32_reductions=config.norm_reduction_force_float32,
57
- )
58
- self.slstm_layer = init_cell(config, block_idx, num_blocks)
59
-
60
- self.norm_ffn = RMSNorm(
61
- num_features=config.embedding_dim,
62
- eps=config.norm_eps,
63
- use_weight=True,
64
- use_bias=config.use_bias,
65
- force_float32_reductions=config.norm_reduction_force_float32,
66
- )
67
- self.ffn = FeedForward(config)
68
-
69
- def forward(
70
- self, x: torch.Tensor, state: sLSTMLayerStateType | None = None
71
- ) -> tuple[torch.Tensor, sLSTMLayerStateType]:
72
- x_slstm = self.norm_slstm(x)
73
- if state is None:
74
- conv_state, slstm_state = None, None
75
- else:
76
- conv_state, slstm_state = state
77
- x_slstm, state = self.slstm_layer(x_slstm, conv_state, slstm_state, return_last_state=True)
78
- x = x + x_slstm
79
-
80
- x_ffn = self.norm_ffn(x)
81
- x_ffn = self.ffn(x_ffn)
82
- x = x + x_ffn
83
-
84
- return x, (state["conv_state"], state["slstm_state"])
85
-
86
-
87
- @dataclass
88
- class xLSTMMixedLargeConfig(xLSTMLargeConfig):
89
- slstm_at: list[int] = field(default_factory=list)
90
- all_slstm: bool = True
91
-
92
- @property
93
- def block_types(self):
94
- return ["s" if i in self.slstm_at or self.all_slstm else "m" for i in range(self.num_blocks)]
95
-
96
-
97
- class xLSTMMixedLargeBlockStack(nn.Module):
98
- config_class = xLSTMMixedLargeConfig
99
-
100
- def __init__(self, config: xLSTMMixedLargeConfig):
101
- super().__init__()
102
- self.config = config
103
-
104
- self.blocks = nn.ModuleList(
105
- [
106
- sLSTMBlock(config, block_idx=i, num_blocks=config.num_blocks) if t == "s" else mLSTMBlock(config)
107
- for i, t in enumerate(config.block_types)
108
- ]
109
- )
110
-
111
- if self.config.add_out_norm:
112
- self.out_norm = RMSNorm(
113
- num_features=config.embedding_dim,
114
- eps=config.norm_eps,
115
- use_weight=True,
116
- use_bias=config.use_bias,
117
- force_float32_reductions=config.norm_reduction_force_float32,
118
- )
119
- else:
120
- self.out_norm = nn.Identity()
121
-
122
- def forward(
123
- self, x: torch.Tensor, state: mLSTMStateType | sLSTMStateType | None = None
124
- ) -> tuple[torch.Tensor, mLSTMStateType]:
125
- if state is None:
126
- state = {i: None for i in range(len(self.blocks))}
127
-
128
- for i, block in enumerate(self.blocks):
129
- block_state = state[i]
130
- x, block_state_new = block(x, block_state)
131
-
132
- if block_state is None:
133
- state[i] = block_state_new
134
- else:
135
- pass
136
- ## layer state is a tuple of three tensors: c, n, m
137
- ## we update the state in place in order to avoid creating new tensors
138
- # for state_idx in range(len(block_state)):
139
- # state[i][state_idx].copy_(block_state_new[state_idx])
140
-
141
- x = self.out_norm(x)
142
-
143
- return x, state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tirex/models/predict_utils.py DELETED
@@ -1,72 +0,0 @@
1
- # Copyright (c) NXAI GmbH.
2
- # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
-
4
-
5
- import logging
6
- from abc import abstractmethod
7
-
8
- import torch
9
-
10
- from ..api_adapter.forecast import ForecastModel
11
-
12
- LOGGER = logging.getLogger()
13
-
14
-
15
- class TensorQuantileUniPredictMixin(ForecastModel):
16
- @abstractmethod
17
- def _forecast_tensor(
18
- self,
19
- context: torch.Tensor,
20
- prediction_length: int | None = None,
21
- **predict_kwargs,
22
- ) -> torch.Tensor:
23
- pass
24
-
25
- @property
26
- @abstractmethod
27
- def quantiles(self):
28
- pass
29
-
30
- def _forecast_quantiles(
31
- self,
32
- context: torch.Tensor,
33
- prediction_length: int | None = None,
34
- quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
35
- output_device: str = "cpu",
36
- auto_cast: bool = False,
37
- **predict_kwargs,
38
- ) -> tuple[torch.Tensor, torch.Tensor]:
39
- with torch.autocast(device_type=self.device.type, enabled=auto_cast):
40
- predictions = self._forecast_tensor(
41
- context=context, prediction_length=prediction_length, **predict_kwargs
42
- ).detach()
43
- predictions = predictions.to(torch.device(output_device)).swapaxes(1, 2)
44
-
45
- training_quantile_levels = list(self.quantiles)
46
-
47
- if set(quantile_levels).issubset(set(training_quantile_levels)):
48
- quantiles = predictions[..., [training_quantile_levels.index(q) for q in quantile_levels]]
49
- else:
50
- if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max(
51
- training_quantile_levels
52
- ):
53
- logging.warning(
54
- f"Requested quantile levels ({quantile_levels}) fall outside the range of "
55
- f"quantiles the model was trained on ({training_quantile_levels}). "
56
- "Predictions for out-of-range quantiles will be clamped to the nearest "
57
- "boundary of the trained quantiles (i.e., minimum or maximum trained level). "
58
- "This can significantly impact prediction accuracy, especially for extreme quantiles. "
59
- )
60
- # Interpolate quantiles
61
- augmented_predictions = torch.cat(
62
- [predictions[..., [0]], predictions, predictions[..., [-1]]],
63
- dim=-1,
64
- )
65
- quantiles = torch.quantile(
66
- augmented_predictions,
67
- q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype),
68
- dim=-1,
69
- ).permute(1, 2, 0)
70
- # median as mean
71
- mean = predictions[:, :, training_quantile_levels.index(0.5)]
72
- return quantiles, mean
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tirex/models/tirex.py DELETED
@@ -1,231 +0,0 @@
1
- # Copyright (c) NXAI GmbH.
2
- # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
-
4
- import logging
5
- import warnings
6
- from contextlib import redirect_stdout
7
- from dataclasses import dataclass
8
-
9
- import lightning as L
10
- import torch
11
- from dacite import Config, from_dict
12
-
13
- from ..base import PretrainedModel
14
- from .components import PatchedUniTokenizer, ResidualBlock, StreamToLogger
15
- from .mixed_stack import skip_cuda, xLSTMMixedLargeBlockStack, xLSTMMixedLargeConfig
16
- from .predict_utils import TensorQuantileUniPredictMixin
17
-
18
- LOGGER = logging.getLogger()
19
-
20
-
21
- @dataclass
22
- class TiRexZeroConfig:
23
- input_patch_size: int
24
- output_patch_size: int
25
- quantiles: list[float]
26
- block_kwargs: dict
27
- input_ff_dim: int
28
-
29
-
30
- class TiRexZero(L.LightningModule, PretrainedModel, TensorQuantileUniPredictMixin):
31
- def __init__(self, model_config: dict, train_ctx_len=None):
32
- super().__init__()
33
- self.model_config: TiRexZeroConfig = from_dict(TiRexZeroConfig, model_config, config=Config(strict=True))
34
- assert self.model_config.input_patch_size == self.model_config.output_patch_size
35
- self.train_ctx_len = train_ctx_len
36
-
37
- # Block Stack
38
- self.nan_mask_value = 0
39
- self.block_stack, resolved_config = self.init_block(self.model_config.block_kwargs)
40
- self.model_config.block_kwargs = resolved_config
41
-
42
- # Input Layer
43
- self.input_patch_embedding = ResidualBlock(
44
- in_dim=self.model_config.input_patch_size * 2,
45
- h_dim=self.model_config.input_ff_dim,
46
- out_dim=self.model_config.block_kwargs.embedding_dim,
47
- )
48
- self.tokenizer = PatchedUniTokenizer(
49
- patch_size=self.model_config.input_patch_size,
50
- )
51
-
52
- # Output Layer
53
- self.num_quantiles = len(self.model_config.quantiles)
54
- quantiles = torch.tensor(self.model_config.quantiles)
55
- self.register_buffer("quantiles", quantiles, persistent=False)
56
-
57
- self.output_patch_embedding = ResidualBlock(
58
- in_dim=self.model_config.block_kwargs.embedding_dim,
59
- h_dim=self.model_config.input_ff_dim,
60
- out_dim=self.num_quantiles * self.model_config.output_patch_size,
61
- )
62
-
63
- self.save_hyperparameters()
64
-
65
- @classmethod
66
- def register_name(cls):
67
- return "TiRex"
68
-
69
- def init_block(self, block_kwargs):
70
- config = from_dict(xLSTMMixedLargeConfig, block_kwargs)
71
- log_redirect = StreamToLogger(LOGGER, logging.INFO)
72
- with redirect_stdout(log_redirect): # avoid excessive print statements of sLSTM compile
73
- model = xLSTMMixedLargeBlockStack(config)
74
- return model, config
75
-
76
- @property
77
- def quantiles(self):
78
- return self.model.quantiles
79
-
80
- def _forward_model_tokenized(
81
- self,
82
- input_token,
83
- input_mask=None,
84
- rollouts=1,
85
- ):
86
- input_mask = (
87
- input_mask.to(input_token.dtype)
88
- if input_mask is not None
89
- else torch.isnan(input_token).logical_not().to(input_token.dtype)
90
- )
91
- assert rollouts >= 1
92
- bs, numb_ctx_token, token_dim = input_token.shape
93
- if rollouts > 1:
94
- input_token = torch.cat(
95
- (
96
- input_token,
97
- torch.full(
98
- (bs, rollouts - 1, token_dim),
99
- fill_value=torch.nan,
100
- device=input_token.device,
101
- dtype=input_token.dtype,
102
- ),
103
- ),
104
- dim=1,
105
- )
106
- input_mask = torch.cat(
107
- (
108
- input_mask,
109
- torch.full(
110
- (bs, rollouts - 1, token_dim),
111
- fill_value=False,
112
- device=input_mask.device,
113
- dtype=input_mask.dtype,
114
- ),
115
- ),
116
- dim=1,
117
- )
118
- input_token = torch.nan_to_num(input_token, nan=self.nan_mask_value)
119
- input_embeds = self.input_patch_embedding(torch.cat((input_token, input_mask), dim=2))
120
-
121
- # hidden_states = []
122
- # for rollout in range(rollout):
123
- x = self.block_stack(input_embeds)
124
- if isinstance(x, tuple):
125
- hidden_states = x[0]
126
- else:
127
- hidden_states = x
128
-
129
- quantile_preds = self.output_patch_embedding(hidden_states)
130
- quantile_preds = torch.unflatten(quantile_preds, -1, (self.num_quantiles, self.model_config.output_patch_size))
131
- quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
132
- # quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
133
-
134
- return quantile_preds, hidden_states
135
-
136
- @torch.inference_mode()
137
- def _forecast_tensor(
138
- self,
139
- context: torch.Tensor,
140
- prediction_length: int | None = None,
141
- max_context: int | None = None,
142
- max_accelerated_rollout_steps: int = 1,
143
- ) -> torch.Tensor:
144
- predictions = []
145
- if prediction_length is None:
146
- prediction_length = self.tokenizer.patch_size
147
- remaining = -(prediction_length // -self.tokenizer.patch_size)
148
- if max_context is None:
149
- max_context = self.train_ctx_len
150
- min_context = max(self.train_ctx_len, max_context)
151
-
152
- context = context.to(
153
- device=self.device,
154
- dtype=torch.float32,
155
- )
156
- while remaining > 0:
157
- if context.shape[-1] > max_context:
158
- context = context[..., -max_context:]
159
- if context.shape[-1] < min_context:
160
- pad = torch.full(
161
- (context.shape[0], min_context - context.shape[-1]),
162
- fill_value=torch.nan,
163
- device=context.device,
164
- dtype=context.dtype,
165
- )
166
- context = torch.concat((pad, context), dim=1)
167
- tokenized_tensor, tokenizer_state = self.tokenizer.context_input_transform(context)
168
- fut_rollouts = min(remaining, max_accelerated_rollout_steps)
169
- with torch.no_grad():
170
- prediction, _ = self._forward_model_tokenized(input_token=tokenized_tensor, rollouts=fut_rollouts)
171
- prediction = prediction[:, :, -fut_rollouts:, :].to(tokenized_tensor) # predicted token
172
- # [bs, num_quantiles, num_predicted_token, output_patch_size]
173
- prediction = self.tokenizer.output_transform(prediction, tokenizer_state)
174
- prediction = prediction.flatten(start_dim=2)
175
-
176
- predictions.append(prediction)
177
- remaining -= fut_rollouts
178
-
179
- if remaining <= 0:
180
- break
181
-
182
- context = torch.cat([context, torch.full_like(prediction[:, 0, :], fill_value=torch.nan)], dim=-1)
183
-
184
- return torch.cat(predictions, dim=-1)[..., :prediction_length].to(
185
- dtype=torch.float32,
186
- )
187
-
188
- def on_load_checkpoint(self, checkpoint: dict) -> None:
189
- state_dict = checkpoint["state_dict"]
190
- load_vanilla_kernel = skip_cuda()
191
- if load_vanilla_kernel:
192
- warnings.warn(
193
- "You use TiRex without sLSTM CUDA kernels! This might slow down the model considerably and might degrade forecasting results!"
194
- "Set the environment variable TIREX_NO_CUDA to 0 to avoid this!"
195
- )
196
- block_kwargs = self.model_config.block_kwargs
197
- head_dim = block_kwargs.embedding_dim // block_kwargs.num_heads
198
- num_gates = 4
199
- new_state_dict = {}
200
- for k, v in state_dict.items():
201
- if "slstm_layer.slstm_cell._recurrent_kernel_" in k:
202
- new_state_dict[k] = (
203
- v.reshape(
204
- block_kwargs.num_heads,
205
- head_dim,
206
- num_gates,
207
- head_dim,
208
- )
209
- .permute(0, 2, 3, 1)
210
- .reshape(
211
- block_kwargs.num_heads,
212
- num_gates * head_dim,
213
- head_dim,
214
- )
215
- )
216
- # new_state_dict[k] = v.permute(0, 2, 1)
217
- elif "slstm_layer.slstm_cell._bias_" in k:
218
- new_state_dict[k] = (
219
- v.reshape(block_kwargs.num_heads, num_gates, head_dim).permute(1, 0, 2).reshape(-1)
220
- )
221
- else:
222
- new_state_dict[k] = v
223
- checkpoint["state_dict"] = new_state_dict
224
-
225
- def after_load_from_checkpoint(self):
226
- if not skip_cuda() and self.device.type != "cuda":
227
- warnings.warn(
228
- f"You use TiRex with sLSTM CUDA kernels BUT DO NOT LOAD THE DEVICE ON A CUDA DEVICE (device type is {self.device.type})!"
229
- "This is not supported and calls to the model will likely lead to an error if you dont move your model to a CUDA device!"
230
- "If you want to run TiRex on CPU you need to disable sLSTM CUDA kernels but be aware of the downsides (see FAQ)"
231
- )