shinfxh commited on
Commit
a0eaa2d
·
0 Parent(s):

initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures/*.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ *.egg-info/
7
+ *.egg
8
+ dist/
9
+ build/
10
+ .eggs/
11
+
12
+ # Environments
13
+ .env
14
+ .venv
15
+ venv/
16
+
17
+ # Data and results (large files)
18
+ data/
19
+ results/
20
+ gift-eval/
21
+
22
+ # Outputs
23
+ *.png
24
+ !figures/*.png
25
+
26
+ # Scripts
27
+ push.sh
28
+
29
+ # Tools
30
+ .mypy_cache/
31
+ .ruff_cache/
32
+ .pytest_cache/
33
+ .ipynb_checkpoints/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Reverso Authors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: pytorch
4
+ pipeline_tag: time-series-forecasting
5
+ tags:
6
+ - time-series
7
+ - forecasting
8
+ - zero-shot
9
+ - convolution
10
+ - deltanet
11
+ - flash-fft-conv
12
+ - flash-linear-attention
13
+ ---
14
+
15
+ <h1 align="center">Reverso</h1>
16
+
17
+ <h3 align="center">
18
+ Efficient time-series foundation models for zero-shot forecasting.
19
+ </h3>
20
+
21
+ <p align="center">
22
+ <a href="https://arxiv.org/abs/2602.17634">Paper</a> •
23
+ <a href="https://github.com/shinfxh/reverso">GitHub</a> •
24
+ <a href="https://huggingface.co/shinfxh/reverso">Hugging Face</a>
25
+ </p>
26
+
27
+ <p align="center">
28
+ By combining long convolutions with linear RNN layers, Reverso matches the performance of transformer-based models that are over <b>100x larger</b>.
29
+ </p>
30
+
31
+ ## Key Results
32
+
33
+ <p align="center">
34
+ <img src="figures/gift_eval_pareto_overall.png" width="800">
35
+ </p>
36
+
37
+ Evaluated on [Gift-Eval](https://github.com/SalesforceAIResearch/gift-eval), a comprehensive time-series forecasting benchmark spanning 97 tasks within 23 datasets across 7 domains.
38
+
39
+ | Model | Params | Gift-Eval MASE |
40
+ |---|---|---|
41
+ | **Reverso** | 2.6M | **0.711** |
42
+ | Reverso-Small | 550K | 0.726 |
43
+ | Reverso-Nano | 200K | 0.760 |
44
+
45
+ For reference, Xihe-Max (1.5B params) achieves 0.711 and TimesFM-2.5 (200M params) achieves 0.705 on the same benchmark.
46
+
47
+ ## Installation
48
+
49
+ ```bash
50
+ pip install -r requirements.txt
51
+ pip install --no-build-isolation git+https://github.com/HazyResearch/flash-fft-conv.git#subdirectory=csrc/flashfftconv
52
+ pip install --no-build-isolation git+https://github.com/HazyResearch/flash-fft-conv.git
53
+ pip install -e .
54
+ ```
55
+
56
+ ### Requirements
57
+
58
+ - Python >= 3.11
59
+ - PyTorch 2.6.0
60
+ - CUDA-compatible GPU
61
+ - [FlashFFTConv](https://github.com/HazyResearch/flash-fft-conv)
62
+ - [flash-linear-attention](https://github.com/sustcsonglin/flash-linear-attention)
63
+
64
+ ## Model Architecture
65
+
66
+ <p align="center">
67
+ <img src="figures/new_arch.png" width="800">
68
+ </p>
69
+
70
+ Reverso uses a hybrid architecture that interleaves:
71
+ 1. **Long convolution layers** ([FlashFFTConv](https://github.com/HazyResearch/flash-fft-conv)) with gated short convolutions
72
+ 2. **DeltaNet layers** for modeling sequential dependencies
73
+ 3. **MLP layers** for channel mixing
74
+ 4. **Attention-based decoder head** for producing the final forecast
75
+
76
+ Input sequences are normalized to [0, 1] and processed point-wise (no patching). The model predicts 48 time steps at a time and rolls out autoregressively for longer horizons.
77
+
78
+ | Config | Params | Layers | d_model |
79
+ |---|---|---|---|
80
+ | Reverso | 2.6M | 8 | 128 |
81
+ | Reverso-Small | 550K | 4 | 64 |
82
+ | Reverso-Nano | 200K | 2 | 32 |
83
+
84
+ The modeling code is in [`reverso/`](reverso/).
85
+
86
+ ## Quick Start
87
+
88
+ ```python
89
+ import torch
90
+ from reverso import load_model, forecast
91
+
92
+ model, cfg = load_model(
93
+ "checkpoints/reverso_small/checkpoint.pth",
94
+ "checkpoints/reverso_small/args.json",
95
+ device="cuda",
96
+ )
97
+
98
+ context = torch.full((1, 2048, 1), 5.0, device="cuda") # (batch, seq_len, 1)
99
+ predictions = forecast(
100
+ model, context,
101
+ prediction_length=96,
102
+ seq_len=cfg.seq_len,
103
+ output_token_len=cfg.output_token_len,
104
+ )
105
+ print(predictions.shape) # (1, 96, 1)
106
+ ```
107
+
108
+ ## Examples
109
+
110
+ Install the example dependencies first:
111
+
112
+ ```bash
113
+ pip install -r example/requirements.txt
114
+ ```
115
+
116
+ ### Forecast Demo
117
+
118
+ Run Reverso on synthetic signals (constant, linear, sine, sawtooth, square):
119
+
120
+ ```bash
121
+ python example/forecast_demo.py --signal all
122
+ ```
123
+
124
+ Use `--signal sine` to run a single signal, or `--list` to see all options.
125
+
126
+ ### Gift-Eval Benchmark
127
+
128
+ To reproduce the benchmark results, first follow the [Gift-Eval setup instructions](https://github.com/SalesforceAIResearch/gift-eval) to install the package and download the data. Then run:
129
+
130
+ ```bash
131
+ python example/eval_gift.py \
132
+ --checkpoint checkpoints/reverso_small/checkpoint.pth \
133
+ --output-dir results/ \
134
+ --force-flip-invariance
135
+ ```
136
+
137
+ > **Note:** Dependencies within Gift-Eval may conflict with those in Reverso. If you encounter issues, try upgrading `huggingface_hub`:
138
+ > ```bash
139
+ > pip install --upgrade huggingface_hub
140
+ > ```
141
+ > **Note:** While running this benchmark, it is recommended to use flip invariance, but this requires two forward passes of the model. The inference speed is also not fully optimized and could be further sped up.
142
+
143
+ ## Available Checkpoints
144
+
145
+ | Model | Status | Path |
146
+ |---|---|---|
147
+ | Reverso-Small (550K) | Available | `checkpoints/reverso_small/` |
148
+ | Reverso (2.6M) | Coming soon | — |
149
+ | Reverso-Nano (200K) | Coming soon | — |
150
+
151
+ ## Citation
152
+
153
+ ```bibtex
154
+ @misc{fu2026reversoefficienttimeseries,
155
+ title={Reverso: Efficient Time Series Foundation Models for Zero-shot Forecasting},
156
+ author={Xinghong Fu and Yanhong Li and Georgios Papaioannou and Yoon Kim},
157
+ year={2026},
158
+ eprint={2602.17634},
159
+ archivePrefix={arXiv},
160
+ primaryClass={cs.LG},
161
+ url={https://arxiv.org/abs/2602.17634},
162
+ }
163
+ ```
164
+
165
+ ## License
166
+
167
+ This project is licensed under the [MIT License](LICENSE).
checkpoints/reverso/args.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "seq_len": 2048,
3
+ "input_token_len": 2048,
4
+ "output_token_len": 48,
5
+ "e_layers": 8,
6
+ "d_model": 128,
7
+ "d_intermediate": 256,
8
+ "output_bottleneck_dim": 48,
9
+ "expand_v": 1.0,
10
+ "state_weaving": 1,
11
+ "gating_kernel_size": 3,
12
+ "main_module": "conv,attn,conv,attn",
13
+ "use_norm": true,
14
+ "learn_bias": 1
15
+ }
checkpoints/reverso_nano/args.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "seq_len": 2048,
3
+ "input_token_len": 2048,
4
+ "output_token_len": 48,
5
+ "e_layers": 2,
6
+ "d_model": 32,
7
+ "d_intermediate": 256,
8
+ "output_bottleneck_dim": 48,
9
+ "expand_v": 1.0,
10
+ "state_weaving": 1,
11
+ "gating_kernel_size": 3,
12
+ "main_module": "conv,attn,conv,attn",
13
+ "use_norm": true,
14
+ "learn_bias": 1
15
+ }
checkpoints/reverso_small/args.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "seq_len": 2048,
3
+ "input_token_len": 2048,
4
+ "output_token_len": 48,
5
+ "e_layers": 4,
6
+ "d_model": 64,
7
+ "d_intermediate": 256,
8
+ "output_bottleneck_dim": 48,
9
+ "expand_v": 1.0,
10
+ "state_weaving": 1,
11
+ "gating_kernel_size": 3,
12
+ "main_module": "conv,attn,conv,attn",
13
+ "use_norm": true,
14
+ "learn_bias": 1
15
+ }
checkpoints/reverso_small/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a475a728a7d6a625bead27b283abd4e7746d3a7099086e5a8cf6a23bb647502b
3
+ size 2252946
config/dataset_properties.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"m4_yearly": {"domain": "Econ/Fin", "frequency": "A", "num_variates": 1}, "m4_quarterly": {"domain": "Econ/Fin", "frequency": "Q", "num_variates": 1}, "m4_monthly": {"domain": "Econ/Fin", "frequency": "M", "num_variates": 1}, "m4_weekly": {"domain": "Econ/Fin", "frequency": "W", "num_variates": 1}, "m4_daily": {"domain": "Econ/Fin", "frequency": "D", "num_variates": 1}, "m4_hourly": {"domain": "Econ/Fin", "frequency": "H", "num_variates": 1}, "electricity": {"domain": "Energy", "frequency": "W", "num_variates": 1}, "ett1": {"domain": "Energy", "frequency": "W", "num_variates": 7}, "ett2": {"domain": "Energy", "frequency": "W", "num_variates": 7}, "solar": {"domain": "Energy", "frequency": "W", "num_variates": 1}, "hospital": {"domain": "Healthcare", "frequency": "M", "num_variates": 1}, "covid_deaths": {"domain": "Healthcare", "frequency": "D", "num_variates": 1}, "us_births": {"domain": "Healthcare", "frequency": "M", "num_variates": 1}, "saugeen": {"domain": "Nature", "frequency": "M", "num_variates": 1}, "temperature_rain": {"domain": "Nature", "frequency": "D", "num_variates": 1}, "kdd_cup_2018": {"domain": "Nature", "frequency": "D", "num_variates": 1}, "jena_weather": {"domain": "Nature", "frequency": "D", "num_variates": 21}, "car_parts": {"domain": "Sales", "frequency": "M", "num_variates": 1}, "restaurant": {"domain": "Sales", "frequency": "D", "num_variates": 1}, "hierarchical_sales": {"domain": "Sales", "frequency": "W-WED", "num_variates": 1}, "loop_seattle": {"domain": "Transport", "frequency": "D", "num_variates": 1}, "sz_taxi": {"domain": "Transport", "frequency": "H", "num_variates": 1}, "m_dense": {"domain": "Transport", "frequency": "D", "num_variates": 1}, "bitbrains_fast_storage": {"domain": "Web/CloudOps", "frequency": "H", "num_variates": 2}, "bitbrains_rnd": {"domain": "Web/CloudOps", "frequency": "H", "num_variates": 2}, "bizitobs_application": {"domain": "Web/CloudOps", "frequency": "10S", "num_variates": 2}, "bizitobs_service": {"domain": "Web/CloudOps", "frequency": "10S", "num_variates": 2}, "bizitobs_l2c": {"domain": "Web/CloudOps", "frequency": "H", "num_variates": 7}}
config/downsample_factors.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "bizitobs_l2c/5t/medium": 7,
3
+ "bizitobs_l2c/5t/long": 7
4
+ }
example/eval_gift.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GiftEval evaluation script for Reverso.
3
+ """
4
+ import os
5
+ import json
6
+ import math
7
+ import argparse
8
+ import csv
9
+ from types import SimpleNamespace
10
+ from typing import List, Optional, Tuple
11
+ from datetime import datetime
12
+
13
+ import numpy as np
14
+ import torch
15
+ import pandas as pd
16
+
17
+ from reverso.forecast import load_checkpoint
18
+
19
+ try:
20
+ from torch.cuda.amp import autocast as autocast_fp
21
+ except Exception:
22
+ autocast_fp = None
23
+
24
+ def numpy_fill(arr: np.ndarray) -> np.ndarray:
25
+ mask = np.isnan(arr)
26
+ idx = np.where(~mask, np.arange(mask.shape[1]), 0)
27
+ np.maximum.accumulate(idx, axis=1, out=idx)
28
+ out = arr[np.arange(idx.shape[0])[:, None], idx]
29
+ return out
30
+
31
+
32
+ class ReversoPredictor:
33
+ """GiftEval predictor for reverso.Model."""
34
+
35
+ def __init__(
36
+ self,
37
+ prediction_length: int,
38
+ checkpoint_path: Optional[str] = None,
39
+ device: str = "cuda",
40
+ seq_len: int = 2048,
41
+ input_token_len: int = 2048,
42
+ output_token_len: int = 48,
43
+ e_layers: int = 8,
44
+ d_model: int = 128,
45
+ d_intermediate: int = 512,
46
+ output_bottleneck_dim: int = 48,
47
+ expand_v: float = 1.0,
48
+ state_weaving: int = 1,
49
+ gating_kernel_size: int = 3,
50
+ main_module: str = "conv,attn,conv,attn,conv,attn,conv,attn",
51
+ num_samples: int = 100,
52
+ batch_size: int = 256,
53
+ use_amp: int = 1,
54
+ downsample_factor: int = 1,
55
+ force_flip_invariance: bool = False,
56
+ ):
57
+ self.device = torch.device(device if torch.cuda.is_available() else "cpu")
58
+ self.prediction_length = int(prediction_length)
59
+ self.num_samples = int(num_samples)
60
+ self.batch_size = int(batch_size)
61
+ self.seq_len = int(seq_len)
62
+ self.input_token_len = int(input_token_len)
63
+ self.output_token_len = int(output_token_len)
64
+ self.use_amp = int(use_amp)
65
+ self.downsample_factor = int(downsample_factor)
66
+ self.force_flip_invariance = bool(force_flip_invariance)
67
+
68
+ args = SimpleNamespace(
69
+ input_token_len=self.input_token_len,
70
+ output_token_len=self.output_token_len,
71
+ seq_len=self.seq_len,
72
+ d_model=int(d_model),
73
+ d_intermediate=int(d_intermediate),
74
+ use_norm=True,
75
+ learn_bias=1,
76
+ output_bottleneck_dim=int(output_bottleneck_dim),
77
+ expand_v=float(expand_v),
78
+ state_weaving=int(state_weaving),
79
+ gating_kernel_size=int(gating_kernel_size),
80
+ main_module=str(main_module),
81
+ )
82
+
83
+ from reverso import model as model_impl
84
+ try:
85
+ self.model = model_impl.Model(args).to(self.device)
86
+ except RuntimeError as e:
87
+ if "CUDA" in str(e):
88
+ print(f"CUDA not usable ({e}); falling back to CPU.")
89
+ self.device = torch.device("cpu")
90
+ self.use_amp = 0
91
+ self.model = model_impl.Model(args).to(self.device)
92
+ else:
93
+ raise
94
+ self.model.eval()
95
+
96
+ if checkpoint_path is not None and os.path.isfile(checkpoint_path):
97
+ self._load_checkpoint(checkpoint_path)
98
+ else:
99
+ print("Warning: checkpoint_path not provided or file not found. Using randomly initialized weights.")
100
+
101
+ def _load_checkpoint(self, ckpt_path: str):
102
+ load_checkpoint(self.model, ckpt_path, device=str(self.device))
103
+
104
+ def _downsample_if_needed(self, series: torch.Tensor) -> Tuple[torch.Tensor, int]:
105
+ cur = series
106
+ if self.downsample_factor > 1:
107
+ cur = cur[::self.downsample_factor]
108
+ return cur, self.downsample_factor
109
+
110
+ def _left_pad_to_len(self, arr: np.ndarray, target_len: int) -> Tuple[np.ndarray, int]:
111
+ if arr.shape[0] >= target_len:
112
+ return arr[-target_len:], 0
113
+ pad_len = target_len - arr.shape[0]
114
+ fill_value = arr[0] if arr.shape[0] > 0 else 0.0
115
+ padding = np.full((pad_len,), fill_value, dtype=arr.dtype)
116
+ return np.concatenate([padding, arr], axis=0), pad_len
117
+
118
+ def _prepare_context_matrix(self, context: List[torch.Tensor]) -> Tuple[torch.Tensor, List[int]]:
119
+ xs = []
120
+ downsample_factors = []
121
+
122
+ for c in context:
123
+ cur, downsample_factor = self._downsample_if_needed(c)
124
+ downsample_factors.append(downsample_factor)
125
+
126
+ cur_np = cur.detach().cpu().float().numpy()
127
+ cur_np, _ = self._left_pad_to_len(cur_np, self.seq_len)
128
+
129
+ x2d = cur_np[None, :]
130
+ x_interp = np.copy(x2d)
131
+ series = x2d[0]
132
+ if np.any(np.isnan(series)):
133
+ valid_mask = ~np.isnan(series)
134
+ if np.sum(valid_mask) >= 2:
135
+ valid_indices = np.where(valid_mask)[0]
136
+ valid_values = series[valid_mask]
137
+ x_interp[0] = np.interp(np.arange(len(series)), valid_indices, valid_values)
138
+ else:
139
+ x_interp = numpy_fill(x2d)
140
+ ff = numpy_fill(x_interp)
141
+ bf = np.flip(numpy_fill(np.flip(x_interp, axis=1)), axis=1)
142
+ x_imp = np.where(np.isnan(ff), bf, ff)
143
+ x_imp = np.where(np.isnan(x_imp), 0.0, x_imp)
144
+ xs.append(x_imp[0])
145
+
146
+ x = torch.tensor(np.stack(xs), device=self.device, dtype=torch.float32).unsqueeze(-1)
147
+ return x, downsample_factors
148
+
149
+ def _decode_autoregressive(self, init_ctx: torch.Tensor, use_bf16: bool, downsample_factors: List[int]) -> torch.Tensor:
150
+ B, _, C = init_ctx.shape
151
+ roll_len = int(self.output_token_len)
152
+
153
+ target_pred_lens = [int(self.prediction_length) // int(max(1, df)) for df in downsample_factors]
154
+ max_target_pred_len = max(target_pred_lens)
155
+ steps = math.ceil(max_target_pred_len / roll_len)
156
+ preds: List[torch.Tensor] = []
157
+ batch_ctx = init_ctx
158
+
159
+ y_mark = torch.zeros(B, self.output_token_len, C, device=self.device, dtype=init_ctx.dtype)
160
+
161
+ for _ in range(steps):
162
+ x_in = batch_ctx[:, -self.seq_len:, :]
163
+ x_mark = torch.zeros_like(x_in)
164
+
165
+ if autocast_fp is not None and self.use_amp and use_bf16:
166
+ try:
167
+ with autocast_fp(dtype=torch.bfloat16):
168
+ outputs = self.model(x_in, x_mark, y_mark)
169
+ except Exception:
170
+ outputs = self.model(x_in, x_mark, y_mark)
171
+ else:
172
+ outputs = self.model(x_in, x_mark, y_mark)
173
+
174
+ out_chunk = outputs[:, -self.output_token_len:, :]
175
+ take_chunk = out_chunk[:, :roll_len, :]
176
+ preds.append(take_chunk)
177
+ batch_ctx = torch.cat([batch_ctx, take_chunk], dim=1)
178
+
179
+ return torch.cat(preds, dim=1)
180
+
181
+ @torch.no_grad()
182
+ def predict(self, test_data_input, use_bf16_if_available: bool = True):
183
+ from gluonts.itertools import batcher
184
+ from gluonts.model.forecast import SampleForecast
185
+
186
+ forecasts = []
187
+ use_bf16 = bool(
188
+ use_bf16_if_available
189
+ and self.device.type == "cuda"
190
+ and torch.cuda.is_available()
191
+ and torch.cuda.is_bf16_supported()
192
+ )
193
+
194
+ for batch in batcher(test_data_input, batch_size=self.batch_size):
195
+ targets = [torch.tensor(entry["target"], dtype=torch.float32) for entry in batch]
196
+ batch_ctx, downsample_factors = self._prepare_context_matrix(targets)
197
+
198
+ pred_pos = self._decode_autoregressive(batch_ctx, use_bf16, downsample_factors)
199
+ if self.force_flip_invariance:
200
+ pred_neg = self._decode_autoregressive(-batch_ctx, use_bf16, downsample_factors)
201
+ pred_full = 0.5 * (pred_pos - pred_neg)
202
+ else:
203
+ pred_full = pred_pos
204
+
205
+ if torch.isnan(pred_full).any():
206
+ pf_2d = pred_full.squeeze(-1).detach().cpu().numpy()
207
+ pf_2d = numpy_fill(pf_2d)
208
+ pred_full = torch.tensor(pf_2d, device=pred_full.device, dtype=pred_full.dtype).unsqueeze(-1)
209
+
210
+ pred_full_np = pred_full.float().squeeze(-1).detach().cpu().numpy()
211
+ pred_list = []
212
+ for i in range(len(downsample_factors)):
213
+ df = downsample_factors[i]
214
+ target_pred_len = int(self.prediction_length) // int(max(1, df))
215
+ seq = pred_full_np[i, :target_pred_len]
216
+ if df > 1:
217
+ old_len = len(seq)
218
+ new_len = int(self.prediction_length)
219
+ seq = np.interp(np.linspace(0, 1, new_len), np.linspace(0, 1, old_len), seq)
220
+ pred_list.append(seq)
221
+ pred_full_np = np.array(pred_list)
222
+
223
+ for i, ts in enumerate(batch):
224
+ start_date = ts["start"] + len(ts["target"])
225
+ samples = np.repeat(pred_full_np[i][None, :], self.num_samples, axis=0)
226
+ forecasts.append(SampleForecast(samples=samples, start_date=start_date))
227
+
228
+ return forecasts
229
+
230
+
231
+ # ==========================
232
+ # GiftEval evaluation script
233
+ # ==========================
234
+ from gluonts.ev.metrics import (
235
+ MAE, MAPE, MASE, MSE, MSIS, ND, NRMSE, RMSE, SMAPE,
236
+ MeanWeightedSumQuantileLoss,
237
+ )
238
+
239
+ METRICS = [
240
+ MSE(forecast_type="mean"),
241
+ MSE(forecast_type=0.5),
242
+ MAE(),
243
+ MASE(),
244
+ MAPE(),
245
+ SMAPE(),
246
+ MSIS(),
247
+ RMSE(),
248
+ NRMSE(),
249
+ ND(),
250
+ MeanWeightedSumQuantileLoss(quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
251
+ ]
252
+
253
+ PRETTY_NAMES = {
254
+ "saugeenday": "saugeen",
255
+ "temperature_rain_with_missing": "temperature_rain",
256
+ "kdd_cup_2018_with_missing": "kdd_cup_2018",
257
+ "car_parts_with_missing": "car_parts",
258
+ }
259
+
260
+ SHORT_DATASETS = "m4_yearly m4_quarterly m4_monthly m4_weekly m4_daily m4_hourly electricity/15T electricity/H electricity/D electricity/W solar/10T solar/H solar/D solar/W hospital covid_deaths us_births/D us_births/M us_births/W saugeenday/D saugeenday/M saugeenday/W temperature_rain_with_missing kdd_cup_2018_with_missing/H kdd_cup_2018_with_missing/D car_parts_with_missing restaurant hierarchical_sales/D hierarchical_sales/W LOOP_SEATTLE/5T LOOP_SEATTLE/H LOOP_SEATTLE/D SZ_TAXI/15T SZ_TAXI/H M_DENSE/H M_DENSE/D ett1/15T ett1/H ett1/D ett1/W ett2/W ett2/D jena_weather/10T jena_weather/H jena_weather/D bitbrains_fast_storage/5T bitbrains_fast_storage/H bitbrains_rnd/5T bitbrains_rnd/H bizitobs_application bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"
261
+
262
+ MED_LONG_DATASETS = "electricity/15T electricity/H solar/10T solar/H kdd_cup_2018_with_missing/H LOOP_SEATTLE/5T LOOP_SEATTLE/H SZ_TAXI/15T M_DENSE/H ett1/15T ett1/H ett2/15T ett2/H jena_weather/10T jena_weather/H bitbrains_fast_storage/5T bitbrains_rnd/5T bizitobs_application bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"
263
+
264
+
265
+ def main():
266
+ parser = argparse.ArgumentParser(description="Run Reverso GiftEval across datasets")
267
+ parser.add_argument("--checkpoint", default='checkpoints/reverso_small/checkpoint.pth', help="Path to model checkpoint")
268
+ parser.add_argument("--json_path", default='checkpoints/reverso_small/args.json', help="Path to JSON file with model config overrides")
269
+ parser.add_argument("--output-dir", dest="output_dir", default='results/reverso_small', help="Output directory for results")
270
+ parser.add_argument("--dataset", default=None, help="Filter to specific dataset (substring match)")
271
+ parser.add_argument("--term", default=None, choices=["short", "medium", "long"], help="Filter to specific term")
272
+ parser.add_argument("--force-flip-invariance", dest="force_flip_invariance", action="store_true",
273
+ help="Average f(x) with -f(-x) for flip invariance")
274
+ parser.add_argument("--downsample-json", dest="downsample_json",
275
+ default="config/downsample_factors.json",
276
+ help="Path to JSON with downsample factors per dataset/term")
277
+ args = parser.parse_args()
278
+
279
+ # Load model config from JSON if provided
280
+ json_cfg = {}
281
+ if args.json_path and os.path.isfile(args.json_path):
282
+ with open(args.json_path, "r") as f:
283
+ json_cfg = json.load(f)
284
+
285
+ # Model hyperparameters
286
+ SEQ_LEN = int(json_cfg.get("seq_len", 2048))
287
+ INPUT_TOKEN_LEN = int(json_cfg.get("input_token_len", 2048))
288
+ OUTPUT_TOKEN_LEN = int(json_cfg.get("output_token_len", 48))
289
+ E_LAYERS = int(json_cfg.get("e_layers", 8))
290
+ D_MODEL = int(json_cfg.get("d_model", 128))
291
+ D_INTERMEDIATE = int(json_cfg.get("d_intermediate", 512))
292
+ OUTPUT_BOTTLENECK_DIM = int(json_cfg.get("output_bottleneck_dim", 48))
293
+ EXPAND_V = float(json_cfg.get("expand_v", 1.0))
294
+ STATE_WEAVING = int(json_cfg.get("state_weaving", 1))
295
+ GATING_KERNEL_SIZE = int(json_cfg.get("gating_kernel_size", 3))
296
+ MAIN_MODULE = str(json_cfg.get("main_module", "conv,attn,conv,attn,conv,attn,conv,attn"))
297
+
298
+ DEVICE = "cuda"
299
+ NUM_SAMPLES = 100
300
+ BATCH_SIZE = 256
301
+ USE_AMP = 1
302
+
303
+ downsample_map = {}
304
+ if os.path.isfile(args.downsample_json):
305
+ with open(args.downsample_json, "r") as f:
306
+ downsample_map = json.load(f)
307
+
308
+ # Setup datasets
309
+ all_datasets = sorted(set(SHORT_DATASETS.split() + MED_LONG_DATASETS.split()))
310
+ med_long_set = set(MED_LONG_DATASETS.split())
311
+ all_terms = ["short", "medium", "long"]
312
+
313
+ with open("config/dataset_properties.json", "r") as f:
314
+ dataset_properties = json.load(f)
315
+
316
+ os.environ.setdefault("GIFT_EVAL", os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data"))
317
+
318
+ if args.dataset:
319
+ all_datasets = [ds for ds in all_datasets if args.dataset in ds]
320
+ if not all_datasets:
321
+ print(f"No datasets found matching '{args.dataset}'")
322
+ return
323
+
324
+ if args.term:
325
+ all_terms = [args.term]
326
+
327
+ # Setup output
328
+ output_dir = args.output_dir or os.path.join(os.path.dirname(os.path.abspath(__file__)), "results")
329
+ os.makedirs(output_dir, exist_ok=True)
330
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
331
+ csv_path = os.path.join(output_dir, f"all_results_{timestamp}.csv")
332
+
333
+ with open(csv_path, "w", newline="") as f:
334
+ writer = csv.writer(f)
335
+ writer.writerow([
336
+ "dataset", "model",
337
+ "eval_metrics/MSE[mean]", "eval_metrics/MSE[0.5]",
338
+ "eval_metrics/MAE[0.5]", "eval_metrics/MASE[0.5]",
339
+ "eval_metrics/MAPE[0.5]", "eval_metrics/sMAPE[0.5]",
340
+ "eval_metrics/MSIS", "eval_metrics/RMSE[mean]",
341
+ "eval_metrics/NRMSE[mean]", "eval_metrics/ND[0.5]",
342
+ "eval_metrics/mean_weighted_sum_quantile_loss",
343
+ "domain", "num_variates",
344
+ ])
345
+
346
+ from gluonts.model import evaluate_model
347
+ from gluonts.time_feature import get_seasonality
348
+ from gift_eval.data import Dataset
349
+
350
+ print(f"Evaluating {len(all_datasets)} datasets, terms: {all_terms}")
351
+ print(f"Flip invariance: {args.force_flip_invariance}")
352
+
353
+ for ds_num, ds_name in enumerate(all_datasets):
354
+ if "/" in ds_name:
355
+ ds_key = PRETTY_NAMES.get(ds_name.split("/")[0].lower(), ds_name.split("/")[0].lower())
356
+ ds_freq = ds_name.split("/")[1]
357
+ else:
358
+ ds_key = PRETTY_NAMES.get(ds_name.lower(), ds_name.lower())
359
+ ds_freq = dataset_properties[ds_key]["frequency"]
360
+
361
+ print(f"[{ds_num + 1}/{len(all_datasets)}] {ds_name}")
362
+
363
+ for term in all_terms:
364
+ if term in ("medium", "long") and ds_name not in med_long_set:
365
+ continue
366
+
367
+ ds_config = f"{ds_key}/{ds_freq}/{term}"
368
+ probe = Dataset(name=ds_name, term=term, to_univariate=False)
369
+ to_univariate = probe.target_dim != 1
370
+ dataset = Dataset(name=ds_name, term=term, to_univariate=to_univariate)
371
+ season_length = get_seasonality(dataset.freq)
372
+
373
+ downsample_key = f"{ds_key}/{ds_freq}/{term}".lower()
374
+ downsample_factor = downsample_map.get(downsample_key, 1)
375
+
376
+ info = f" {term}: {len(dataset.test_data)} instances"
377
+ if downsample_factor > 1:
378
+ info += f", downsample={downsample_factor}"
379
+ print(info)
380
+
381
+ predictor = ReversoPredictor(
382
+ prediction_length=dataset.prediction_length,
383
+ checkpoint_path=args.checkpoint,
384
+ device=DEVICE,
385
+ seq_len=SEQ_LEN,
386
+ input_token_len=INPUT_TOKEN_LEN,
387
+ output_token_len=OUTPUT_TOKEN_LEN,
388
+ e_layers=E_LAYERS,
389
+ d_model=D_MODEL,
390
+ d_intermediate=D_INTERMEDIATE,
391
+ output_bottleneck_dim=OUTPUT_BOTTLENECK_DIM,
392
+ expand_v=EXPAND_V,
393
+ state_weaving=STATE_WEAVING,
394
+ gating_kernel_size=GATING_KERNEL_SIZE,
395
+ main_module=MAIN_MODULE,
396
+ num_samples=NUM_SAMPLES,
397
+ batch_size=BATCH_SIZE,
398
+ use_amp=USE_AMP,
399
+ downsample_factor=downsample_factor,
400
+ force_flip_invariance=args.force_flip_invariance,
401
+ )
402
+
403
+ res = evaluate_model(
404
+ predictor,
405
+ test_data=dataset.test_data,
406
+ metrics=METRICS,
407
+ batch_size=BATCH_SIZE,
408
+ axis=None,
409
+ mask_invalid_label=True,
410
+ allow_nan_forecast=False,
411
+ seasonality=season_length,
412
+ )
413
+
414
+ with open(csv_path, "a", newline="") as f:
415
+ writer = csv.writer(f)
416
+ writer.writerow([
417
+ ds_config, "reverso",
418
+ res["MSE[mean]"][0], res["MSE[0.5]"][0],
419
+ res["MAE[0.5]"][0], res["MASE[0.5]"][0],
420
+ res["MAPE[0.5]"][0], res["sMAPE[0.5]"][0],
421
+ res["MSIS"][0], res["RMSE[mean]"][0],
422
+ res["NRMSE[mean]"][0], res["ND[0.5]"][0],
423
+ res["mean_weighted_sum_quantile_loss"][0],
424
+ dataset_properties[ds_key]["domain"],
425
+ dataset_properties[ds_key]["num_variates"],
426
+ ])
427
+
428
+ print(f"\nResults saved to: {csv_path}")
429
+
430
+
431
+ if __name__ == "__main__":
432
+ main()
example/forecast_demo.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Demo: autoregressive forecasting on simple synthetic signals.
3
+
4
+ Run all signals: python example/forecast_demo.py --signal all
5
+ Run one signal: python example/forecast_demo.py --signal sine
6
+ List available: python example/forecast_demo.py --list
7
+ """
8
+ import argparse
9
+
10
+ import numpy as np
11
+ import torch
12
+ import matplotlib.pyplot as plt
13
+
14
+ from reverso.forecast import load_model, forecast
15
+
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Signal generators — each returns float32 array of length n
19
+ # ---------------------------------------------------------------------------
20
+
21
+ def signal_constant(n: int) -> np.ndarray:
22
+ return np.full(n, 5.0, dtype=np.float32)
23
+
24
+
25
+ def signal_linear(n: int) -> np.ndarray:
26
+ return np.linspace(0, 40, n).astype(np.float32)
27
+
28
+
29
+
30
+ def signal_sine(n: int) -> np.ndarray:
31
+ t = np.arange(n, dtype=np.float64)
32
+ return (5.0 * np.sin(2 * np.pi * t / 200)).astype(np.float32)
33
+
34
+
35
+ def signal_sawtooth(n: int) -> np.ndarray:
36
+ t = np.arange(n, dtype=np.float64)
37
+ period = 200
38
+ return (10.0 * (t % period) / period).astype(np.float32)
39
+
40
+
41
+ def signal_square(n: int) -> np.ndarray:
42
+ t = np.arange(n, dtype=np.float64)
43
+ return (5.0 * np.sign(np.sin(2 * np.pi * t / 200))).astype(np.float32)
44
+
45
+
46
+
47
+ SIGNALS = {
48
+ "constant": ("Constant", signal_constant),
49
+ "linear": ("Linear trend", signal_linear),
50
+ "sine": ("Sine wave", signal_sine),
51
+ "sawtooth": ("Sawtooth wave", signal_sawtooth),
52
+ "square": ("Square wave", signal_square),
53
+ }
54
+
55
+
56
+ def run_one(name, label, gen_fn, model, cfg, device, context_length, prediction_length,
57
+ output_dir, flip_invariance=False):
58
+ total_len = context_length + prediction_length
59
+ signal = gen_fn(total_len)
60
+ context_np = signal[:context_length]
61
+ ground_truth = signal[context_length:]
62
+
63
+ context_tensor = torch.tensor(context_np, device=device).unsqueeze(0).unsqueeze(-1)
64
+ pred_pos = forecast(
65
+ model, context_tensor, prediction_length,
66
+ seq_len=cfg.seq_len, output_token_len=cfg.output_token_len,
67
+ )
68
+ if flip_invariance:
69
+ pred_neg = forecast(
70
+ model, -context_tensor, prediction_length,
71
+ seq_len=cfg.seq_len, output_token_len=cfg.output_token_len,
72
+ )
73
+ preds_tensor = 0.5 * (pred_pos - pred_neg)
74
+ else:
75
+ preds_tensor = pred_pos
76
+ preds = preds_tensor[0, :, 0].float().cpu().numpy()
77
+
78
+ ctx_t = np.arange(context_length)
79
+ pred_t = np.arange(context_length, total_len)
80
+
81
+ fig, ax = plt.subplots(figsize=(14, 5))
82
+ ax.plot(ctx_t, context_np, color="steelblue", label="Context")
83
+ ax.plot(pred_t, ground_truth, color="gray", linestyle="--", label="Ground truth")
84
+ ax.plot(pred_t, preds, color="tomato", label="Forecast")
85
+ ax.axvline(context_length, color="black", linestyle=":", alpha=0.5)
86
+ ax.set_xlabel("Time step")
87
+ ax.set_ylabel("Value")
88
+ ax.set_title(f"Reverso: {label}")
89
+ ax.legend()
90
+ fig.tight_layout()
91
+ out_path = f"{output_dir}/{name}_forecast.png"
92
+ fig.savefig(out_path, dpi=150)
93
+ plt.close(fig)
94
+ print(f" {label:25s} -> {out_path}")
95
+
96
+
97
+ def main():
98
+ parser = argparse.ArgumentParser(description="Reverso forecast demo on synthetic signals")
99
+ parser.add_argument("--signal", type=str, default="all",
100
+ help="Signal name, or 'all' to run every signal")
101
+ parser.add_argument("--list", action="store_true", help="List available signals and exit")
102
+ parser.add_argument("--checkpoint", type=str,
103
+ default="checkpoints/reverso_small/checkpoint.pth")
104
+ parser.add_argument("--args-json", type=str,
105
+ default="checkpoints/reverso_small/args.json")
106
+ parser.add_argument("--device", type=str, default="cuda")
107
+ parser.add_argument("--context-length", type=int, default=2048)
108
+ parser.add_argument("--prediction-length", type=int, default=480)
109
+ parser.add_argument("--output-dir", type=str, default="example")
110
+ parser.add_argument("--flip-invariance", action="store_true",
111
+ help="Average f(x) with -f(-x) for flip invariance")
112
+ args = parser.parse_args()
113
+
114
+ if args.list:
115
+ for name, (label, _) in SIGNALS.items():
116
+ print(f" {name:15s} {label}")
117
+ return
118
+
119
+ model, cfg = load_model(args.checkpoint, args.args_json, args.device)
120
+ print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} params")
121
+
122
+ if args.signal == "all":
123
+ to_run = list(SIGNALS.items())
124
+ else:
125
+ if args.signal not in SIGNALS:
126
+ print(f"Unknown signal '{args.signal}'. Use --list to see options.")
127
+ return
128
+ to_run = [(args.signal, SIGNALS[args.signal])]
129
+
130
+ for name, (label, gen_fn) in to_run:
131
+ run_one(name, label, gen_fn, model, cfg, args.device,
132
+ args.context_length, args.prediction_length, args.output_dir,
133
+ args.flip_invariance)
134
+
135
+
136
+ if __name__ == "__main__":
137
+ main()
example/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ matplotlib
2
+ gluonts~=0.15.1
3
+ python-dotenv==1.0.0
figures/gift_eval_pareto_overall.png ADDED

Git LFS Details

  • SHA256: dd1e9d6b26355dd1e773948df13806d844c599b0f9a454c15b02f091b45ad8a6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.08 MB
figures/new_arch.png ADDED

Git LFS Details

  • SHA256: 71d34bc959984b3d89e1c2a83e6842a5a05606401db0d52867e21c70f1f683eb
  • Pointer size: 131 Bytes
  • Size of remote file: 169 kB
pyproject.toml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=64"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "reverso"
7
+ version = "0.1.0"
8
+ description = "Efficient time-series foundation models for zero-shot forecasting"
9
+ readme = "README.md"
10
+ license = "MIT"
11
+ requires-python = ">=3.11"
12
+ dependencies = [
13
+ "torch>=2.6.0",
14
+ "numpy",
15
+ "pandas",
16
+ "flash-linear-attention",
17
+ ]
18
+
19
+ [project.optional-dependencies]
20
+ examples = [
21
+ "matplotlib",
22
+ "gluonts~=0.15.1",
23
+ "python-dotenv>=1.0.0",
24
+ ]
25
+
26
+ [project.urls]
27
+ Repository = "https://github.com/shinfxh/reverso"
28
+
29
+ [tool.setuptools.packages.find]
30
+ include = ["reverso*"]
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==2.6.0
2
+ numpy
3
+ pandas
4
+ flash-linear-attention
reverso/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Reverso: Efficient time-series foundation models for zero-shot forecasting."""
2
+
3
+ from reverso.model import Model
4
+ from reverso.forecast import forecast, load_checkpoint, load_model
5
+
6
+ __all__ = ["Model", "forecast", "load_checkpoint", "load_model"]
reverso/forecast.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Autoregressive forecasting utilities for Reverso."""
2
+ import json
3
+ import math
4
+ from types import SimpleNamespace
5
+
6
+ import torch
7
+
8
+ from reverso.model import Model
9
+
10
+
11
+ def load_checkpoint(model: Model, checkpoint_path: str, device: str = "cuda"):
12
+ """Load a checkpoint into an existing Reverso model.
13
+
14
+ Handles common checkpoint formats (raw state_dict, or dicts keyed by
15
+ "model_state_dict", "state_dict", "model", "ema", "ema_state_dict")
16
+ and strips the "module." prefix left by DDP.
17
+ """
18
+ raw = torch.load(checkpoint_path, map_location=device, weights_only=False)
19
+ state_dict = raw
20
+ if isinstance(raw, dict):
21
+ for k in ("model_state_dict", "state_dict", "model", "ema", "ema_state_dict"):
22
+ if k in raw and isinstance(raw[k], dict):
23
+ state_dict = raw[k]
24
+ break
25
+ state_dict = {k.removeprefix("module."): v for k, v in state_dict.items()}
26
+ model.load_state_dict(state_dict, strict=True)
27
+
28
+
29
+ def load_model(checkpoint_path: str, args_json: str, device: str = "cuda"):
30
+ """Load a Reverso model from a checkpoint and config JSON.
31
+
32
+ Returns:
33
+ (model, cfg) tuple.
34
+ """
35
+ with open(args_json) as f:
36
+ cfg = SimpleNamespace(**json.load(f))
37
+
38
+ model = Model(cfg).to(device)
39
+ load_checkpoint(model, checkpoint_path, device)
40
+ model.eval()
41
+ return model, cfg
42
+
43
+
44
+ @torch.no_grad()
45
+ def forecast(
46
+ model: Model,
47
+ context: torch.Tensor,
48
+ prediction_length: int,
49
+ seq_len: int,
50
+ output_token_len: int,
51
+ use_amp: bool = True,
52
+ ) -> torch.Tensor:
53
+ """Autoregressive multi-step forecast.
54
+
55
+ Follows the rollout pattern from eval_gift.py's _decode_autoregressive.
56
+
57
+ Args:
58
+ model: Reverso Model (already on the target device, in eval mode).
59
+ context: Input context tensor of shape (B, L, 1).
60
+ prediction_length: Number of future steps to predict.
61
+ seq_len: Model's context window length (cfg.seq_len).
62
+ output_token_len: Steps produced per model call (cfg.output_token_len).
63
+ use_amp: Whether to use bfloat16 autocast (requires CUDA).
64
+
65
+ Returns:
66
+ Predictions tensor of shape (B, prediction_length, 1).
67
+ """
68
+ device = context.device
69
+ B, _, C = context.shape
70
+ roll_len = output_token_len
71
+ steps = math.ceil(prediction_length / roll_len)
72
+
73
+ batch_ctx = context
74
+ preds = []
75
+
76
+ y_mark = torch.zeros(B, output_token_len, C, device=device, dtype=context.dtype)
77
+
78
+ for _ in range(steps):
79
+ x_in = batch_ctx[:, -seq_len:, :]
80
+ x_mark = torch.zeros_like(x_in)
81
+
82
+ if use_amp and device.type == "cuda":
83
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
84
+ outputs = model(x_in, x_mark, y_mark)
85
+ else:
86
+ outputs = model(x_in, x_mark, y_mark)
87
+
88
+ out_chunk = outputs[:, -output_token_len:, :]
89
+ take_chunk = out_chunk[:, :roll_len, :]
90
+ preds.append(take_chunk)
91
+ batch_ctx = torch.cat([batch_ctx, take_chunk], dim=1)
92
+
93
+ return torch.cat(preds, dim=1)[:, :prediction_length, :]
reverso/model.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reverso: conv-attention hybrid for time series forecasting.
3
+ """
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ from flashfftconv import FlashFFTConv
8
+ from fla.layers import DeltaNet
9
+ from typing import Any
10
+
11
+
12
+ class Gating(nn.Module):
13
+ def __init__(self, channels, temporal_kernel=3):
14
+ super().__init__()
15
+ self.net = nn.Sequential(
16
+ nn.Conv1d(channels, channels, kernel_size=temporal_kernel,
17
+ padding=temporal_kernel // 2, groups=channels),
18
+ nn.SiLU(),
19
+ nn.Conv1d(channels, channels, kernel_size=1),
20
+ )
21
+
22
+ def forward(self, x):
23
+ return torch.sigmoid(self.net(x))
24
+
25
+
26
+ class MLPBlock(nn.Module):
27
+ def __init__(self, d_in, d_out, d_intermediate=0):
28
+ super().__init__()
29
+ self.norm = nn.LayerNorm(d_out)
30
+ if d_intermediate and d_intermediate > 0:
31
+ self.linear = nn.Linear(d_in, d_intermediate)
32
+ self.linear_final = nn.Linear(d_intermediate, d_out)
33
+ else:
34
+ self.linear = nn.Linear(d_in, d_out)
35
+ self.linear_final = nn.Identity()
36
+ self.activation = nn.ReLU()
37
+ self.skip_linear = nn.Linear(d_in, d_out) if d_in != d_out else nn.Identity()
38
+
39
+ def forward(self, x):
40
+ if x.ndim == 3:
41
+ x = x.permute(0, 2, 1)
42
+ residual = self.skip_linear(x)
43
+ y = self.linear(x)
44
+ y = self.activation(y)
45
+ y = self.linear_final(y)
46
+ y = self.norm(y)
47
+ y = residual + y
48
+ if y.ndim == 3:
49
+ y = y.permute(0, 2, 1)
50
+ return y
51
+
52
+
53
+ class CNNBlock(nn.Module):
54
+ def __init__(self, channels, seq_len, flashfftconv, gating_kernel_size=3):
55
+ super().__init__()
56
+ self.flashfftconv = flashfftconv
57
+ self.k = nn.Parameter(torch.randn(channels, seq_len, dtype=torch.float32))
58
+ self.pregate = Gating(channels, gating_kernel_size)
59
+ self.activation = nn.ReLU()
60
+ self.norm = nn.LayerNorm(channels)
61
+
62
+ def forward(self, x):
63
+ residual = x
64
+ x_conv = x.contiguous().to(torch.bfloat16)
65
+ pregate = self.pregate(x_conv.float()).to(x_conv.dtype)
66
+ postgate = torch.ones_like(x_conv)
67
+ out = self.flashfftconv(x_conv, self.k, pregate=pregate, postgate=postgate)
68
+ out = self.activation(out)
69
+ out = out.transpose(1, 2)
70
+ out = self.norm(out)
71
+ out = out.transpose(1, 2)
72
+ out = out + residual
73
+ return out
74
+
75
+
76
+ class AttentionBlock(nn.Module):
77
+ def __init__(self, d_model, expand_v, state_weaving=False, is_intermediate=False):
78
+ super().__init__()
79
+ self.state_weaving = state_weaving
80
+ self.is_intermediate = is_intermediate
81
+ self.attention = DeltaNet(
82
+ mode='chunk',
83
+ d_model=d_model,
84
+ expand_k=1.0,
85
+ expand_v=expand_v,
86
+ num_heads=4,
87
+ use_beta=True,
88
+ use_gate=False,
89
+ use_short_conv=True,
90
+ conv_size=4,
91
+ allow_neg_eigval=False,
92
+ qk_activation='silu',
93
+ qk_norm='l2',
94
+ layer_idx=0,
95
+ )
96
+ self.norm = nn.LayerNorm(d_model)
97
+
98
+ def forward(self, x):
99
+ x_t = x.transpose(1, 2)
100
+ residual = x_t
101
+ if self.state_weaving and self.is_intermediate:
102
+ x_t = x_t.clone()
103
+ x_t[:, 0:1, :] = x_t[:, 0:1, :] + x_t[:, -1:, :]
104
+ attn_out = self.attention(hidden_states=x_t, attention_mask=None)
105
+ if isinstance(attn_out, tuple):
106
+ out = attn_out[0]
107
+ else:
108
+ out = attn_out
109
+ out = self.norm(out)
110
+ out = out + residual
111
+ out = out.transpose(1, 2)
112
+ return out
113
+
114
+
115
+ class Model(nn.Module):
116
+ """
117
+ Reverso: conv-deltanet hybrid for time series forecasting.
118
+ """
119
+ def __init__(self, configs):
120
+ super().__init__()
121
+ self.seq_len = configs.seq_len
122
+ self.input_token_len = configs.input_token_len
123
+ self.output_token_len = configs.output_token_len
124
+ self.d_model = configs.d_model
125
+ self.use_norm = configs.use_norm
126
+
127
+ self.embedding = nn.Linear(1, self.d_model, bias=False)
128
+ self.shared_flashfftconv = FlashFFTConv(self.seq_len, dtype=torch.bfloat16)
129
+
130
+ d_intermediate = configs.d_intermediate
131
+ expand_v = getattr(configs, 'expand_v', 1.0)
132
+ state_weaving = getattr(configs, 'state_weaving', False)
133
+ gating_kernel_size = getattr(configs, 'gating_kernel_size', 3)
134
+ module_list = [m.strip() for m in configs.main_module.split(',')]
135
+ e_layers = len(module_list)
136
+
137
+ layers = []
138
+ for i, layer_type in enumerate(module_list):
139
+ if layer_type == 'conv':
140
+ layers.append(CNNBlock(
141
+ self.d_model, self.seq_len, self.shared_flashfftconv, gating_kernel_size,
142
+ ))
143
+ elif layer_type == 'attn':
144
+ is_intermediate = (i > 0) and (i < e_layers - 1)
145
+ layers.append(AttentionBlock(
146
+ self.d_model, expand_v, state_weaving, is_intermediate,
147
+ ))
148
+ else:
149
+ raise ValueError(f'Invalid layer type: {layer_type}')
150
+ layers.append(MLPBlock(self.d_model, self.d_model, d_intermediate))
151
+ self.layers = nn.Sequential(*layers)
152
+
153
+ output_bottleneck_dim = getattr(configs, 'output_bottleneck_dim', self.output_token_len)
154
+ self.head = nn.Linear(self.input_token_len, output_bottleneck_dim, bias=configs.learn_bias)
155
+ self.simple_q_proj = nn.Linear(self.d_model, self.d_model)
156
+ self.key_proj = nn.Linear(self.d_model, self.d_model)
157
+ self.value_proj = nn.Linear(self.d_model, self.d_model)
158
+ self.out_proj = nn.Linear(self.d_model, 1)
159
+
160
+ def forward(self, x, x_mark=None, y_mark=None, **kwargs: Any):
161
+ B, L, C = x.shape
162
+
163
+ if self.use_norm:
164
+ x_min = x.min(1, keepdim=True)[0].detach()
165
+ x_max = x.max(1, keepdim=True)[0].detach()
166
+ x_range = torch.clamp(x_max - x_min, min=1e-5).detach()
167
+ x = (x - x_min) / x_range
168
+ means = x_min
169
+ stdev = x_range
170
+
171
+ x = self.embedding(x).transpose(1, 2)
172
+
173
+ dec_out = self.layers(x)
174
+
175
+ temp_out = self.head(dec_out).permute(0, 2, 1)
176
+ q = self.simple_q_proj(temp_out)
177
+
178
+ dec_out_perm = dec_out.permute(0, 2, 1)
179
+ k = self.key_proj(dec_out_perm)
180
+ v = self.value_proj(dec_out_perm)
181
+
182
+ attn = F.scaled_dot_product_attention(q, k, v)
183
+ dec_out = self.out_proj(attn)
184
+
185
+ if self.use_norm:
186
+ dec_out = dec_out * stdev + means
187
+
188
+ return dec_out
189
+
190
+ def forecast(self, x, x_mark=None, y_mark=None, **kwargs):
191
+ return self.forward(x, x_mark, y_mark, **kwargs)