shinfxh commited on
Commit ·
a0eaa2d
0
Parent(s):
initial commit
Browse files- .gitattributes +36 -0
- .gitignore +33 -0
- LICENSE +21 -0
- README.md +167 -0
- checkpoints/reverso/args.json +15 -0
- checkpoints/reverso_nano/args.json +15 -0
- checkpoints/reverso_small/args.json +15 -0
- checkpoints/reverso_small/checkpoint.pth +3 -0
- config/dataset_properties.json +1 -0
- config/downsample_factors.json +4 -0
- example/eval_gift.py +432 -0
- example/forecast_demo.py +137 -0
- example/requirements.txt +3 -0
- figures/gift_eval_pareto_overall.png +3 -0
- figures/new_arch.png +3 -0
- pyproject.toml +30 -0
- requirements.txt +4 -0
- reverso/__init__.py +6 -0
- reverso/forecast.py +93 -0
- reverso/model.py +191 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
figures/*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
*.egg-info/
|
| 7 |
+
*.egg
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
| 10 |
+
.eggs/
|
| 11 |
+
|
| 12 |
+
# Environments
|
| 13 |
+
.env
|
| 14 |
+
.venv
|
| 15 |
+
venv/
|
| 16 |
+
|
| 17 |
+
# Data and results (large files)
|
| 18 |
+
data/
|
| 19 |
+
results/
|
| 20 |
+
gift-eval/
|
| 21 |
+
|
| 22 |
+
# Outputs
|
| 23 |
+
*.png
|
| 24 |
+
!figures/*.png
|
| 25 |
+
|
| 26 |
+
# Scripts
|
| 27 |
+
push.sh
|
| 28 |
+
|
| 29 |
+
# Tools
|
| 30 |
+
.mypy_cache/
|
| 31 |
+
.ruff_cache/
|
| 32 |
+
.pytest_cache/
|
| 33 |
+
.ipynb_checkpoints/
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Reverso Authors
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
library_name: pytorch
|
| 4 |
+
pipeline_tag: time-series-forecasting
|
| 5 |
+
tags:
|
| 6 |
+
- time-series
|
| 7 |
+
- forecasting
|
| 8 |
+
- zero-shot
|
| 9 |
+
- convolution
|
| 10 |
+
- deltanet
|
| 11 |
+
- flash-fft-conv
|
| 12 |
+
- flash-linear-attention
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
<h1 align="center">Reverso</h1>
|
| 16 |
+
|
| 17 |
+
<h3 align="center">
|
| 18 |
+
Efficient time-series foundation models for zero-shot forecasting.
|
| 19 |
+
</h3>
|
| 20 |
+
|
| 21 |
+
<p align="center">
|
| 22 |
+
<a href="https://arxiv.org/abs/2602.17634">Paper</a> •
|
| 23 |
+
<a href="https://github.com/shinfxh/reverso">GitHub</a> •
|
| 24 |
+
<a href="https://huggingface.co/shinfxh/reverso">Hugging Face</a>
|
| 25 |
+
</p>
|
| 26 |
+
|
| 27 |
+
<p align="center">
|
| 28 |
+
By combining long convolutions with linear RNN layers, Reverso matches the performance of transformer-based models that are over <b>100x larger</b>.
|
| 29 |
+
</p>
|
| 30 |
+
|
| 31 |
+
## Key Results
|
| 32 |
+
|
| 33 |
+
<p align="center">
|
| 34 |
+
<img src="figures/gift_eval_pareto_overall.png" width="800">
|
| 35 |
+
</p>
|
| 36 |
+
|
| 37 |
+
Evaluated on [Gift-Eval](https://github.com/SalesforceAIResearch/gift-eval), a comprehensive time-series forecasting benchmark spanning 97 tasks within 23 datasets across 7 domains.
|
| 38 |
+
|
| 39 |
+
| Model | Params | Gift-Eval MASE |
|
| 40 |
+
|---|---|---|
|
| 41 |
+
| **Reverso** | 2.6M | **0.711** |
|
| 42 |
+
| Reverso-Small | 550K | 0.726 |
|
| 43 |
+
| Reverso-Nano | 200K | 0.760 |
|
| 44 |
+
|
| 45 |
+
For reference, Xihe-Max (1.5B params) achieves 0.711 and TimesFM-2.5 (200M params) achieves 0.705 on the same benchmark.
|
| 46 |
+
|
| 47 |
+
## Installation
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
pip install -r requirements.txt
|
| 51 |
+
pip install --no-build-isolation git+https://github.com/HazyResearch/flash-fft-conv.git#subdirectory=csrc/flashfftconv
|
| 52 |
+
pip install --no-build-isolation git+https://github.com/HazyResearch/flash-fft-conv.git
|
| 53 |
+
pip install -e .
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
### Requirements
|
| 57 |
+
|
| 58 |
+
- Python >= 3.11
|
| 59 |
+
- PyTorch 2.6.0
|
| 60 |
+
- CUDA-compatible GPU
|
| 61 |
+
- [FlashFFTConv](https://github.com/HazyResearch/flash-fft-conv)
|
| 62 |
+
- [flash-linear-attention](https://github.com/sustcsonglin/flash-linear-attention)
|
| 63 |
+
|
| 64 |
+
## Model Architecture
|
| 65 |
+
|
| 66 |
+
<p align="center">
|
| 67 |
+
<img src="figures/new_arch.png" width="800">
|
| 68 |
+
</p>
|
| 69 |
+
|
| 70 |
+
Reverso uses a hybrid architecture that interleaves:
|
| 71 |
+
1. **Long convolution layers** ([FlashFFTConv](https://github.com/HazyResearch/flash-fft-conv)) with gated short convolutions
|
| 72 |
+
2. **DeltaNet layers** for modeling sequential dependencies
|
| 73 |
+
3. **MLP layers** for channel mixing
|
| 74 |
+
4. **Attention-based decoder head** for producing the final forecast
|
| 75 |
+
|
| 76 |
+
Input sequences are normalized to [0, 1] and processed point-wise (no patching). The model predicts 48 time steps at a time and rolls out autoregressively for longer horizons.
|
| 77 |
+
|
| 78 |
+
| Config | Params | Layers | d_model |
|
| 79 |
+
|---|---|---|---|
|
| 80 |
+
| Reverso | 2.6M | 8 | 128 |
|
| 81 |
+
| Reverso-Small | 550K | 4 | 64 |
|
| 82 |
+
| Reverso-Nano | 200K | 2 | 32 |
|
| 83 |
+
|
| 84 |
+
The modeling code is in [`reverso/`](reverso/).
|
| 85 |
+
|
| 86 |
+
## Quick Start
|
| 87 |
+
|
| 88 |
+
```python
|
| 89 |
+
import torch
|
| 90 |
+
from reverso import load_model, forecast
|
| 91 |
+
|
| 92 |
+
model, cfg = load_model(
|
| 93 |
+
"checkpoints/reverso_small/checkpoint.pth",
|
| 94 |
+
"checkpoints/reverso_small/args.json",
|
| 95 |
+
device="cuda",
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
context = torch.full((1, 2048, 1), 5.0, device="cuda") # (batch, seq_len, 1)
|
| 99 |
+
predictions = forecast(
|
| 100 |
+
model, context,
|
| 101 |
+
prediction_length=96,
|
| 102 |
+
seq_len=cfg.seq_len,
|
| 103 |
+
output_token_len=cfg.output_token_len,
|
| 104 |
+
)
|
| 105 |
+
print(predictions.shape) # (1, 96, 1)
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## Examples
|
| 109 |
+
|
| 110 |
+
Install the example dependencies first:
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
pip install -r example/requirements.txt
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
### Forecast Demo
|
| 117 |
+
|
| 118 |
+
Run Reverso on synthetic signals (constant, linear, sine, sawtooth, square):
|
| 119 |
+
|
| 120 |
+
```bash
|
| 121 |
+
python example/forecast_demo.py --signal all
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
Use `--signal sine` to run a single signal, or `--list` to see all options.
|
| 125 |
+
|
| 126 |
+
### Gift-Eval Benchmark
|
| 127 |
+
|
| 128 |
+
To reproduce the benchmark results, first follow the [Gift-Eval setup instructions](https://github.com/SalesforceAIResearch/gift-eval) to install the package and download the data. Then run:
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
python example/eval_gift.py \
|
| 132 |
+
--checkpoint checkpoints/reverso_small/checkpoint.pth \
|
| 133 |
+
--output-dir results/ \
|
| 134 |
+
--force-flip-invariance
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
> **Note:** Dependencies within Gift-Eval may conflict with those in Reverso. If you encounter issues, try upgrading `huggingface_hub`:
|
| 138 |
+
> ```bash
|
| 139 |
+
> pip install --upgrade huggingface_hub
|
| 140 |
+
> ```
|
| 141 |
+
> **Note:** While running this benchmark, it is recommended to use flip invariance, but this requires two forward passes of the model. The inference speed is also not fully optimized and could be further sped up.
|
| 142 |
+
|
| 143 |
+
## Available Checkpoints
|
| 144 |
+
|
| 145 |
+
| Model | Status | Path |
|
| 146 |
+
|---|---|---|
|
| 147 |
+
| Reverso-Small (550K) | Available | `checkpoints/reverso_small/` |
|
| 148 |
+
| Reverso (2.6M) | Coming soon | — |
|
| 149 |
+
| Reverso-Nano (200K) | Coming soon | — |
|
| 150 |
+
|
| 151 |
+
## Citation
|
| 152 |
+
|
| 153 |
+
```bibtex
|
| 154 |
+
@misc{fu2026reversoefficienttimeseries,
|
| 155 |
+
title={Reverso: Efficient Time Series Foundation Models for Zero-shot Forecasting},
|
| 156 |
+
author={Xinghong Fu and Yanhong Li and Georgios Papaioannou and Yoon Kim},
|
| 157 |
+
year={2026},
|
| 158 |
+
eprint={2602.17634},
|
| 159 |
+
archivePrefix={arXiv},
|
| 160 |
+
primaryClass={cs.LG},
|
| 161 |
+
url={https://arxiv.org/abs/2602.17634},
|
| 162 |
+
}
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
## License
|
| 166 |
+
|
| 167 |
+
This project is licensed under the [MIT License](LICENSE).
|
checkpoints/reverso/args.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"seq_len": 2048,
|
| 3 |
+
"input_token_len": 2048,
|
| 4 |
+
"output_token_len": 48,
|
| 5 |
+
"e_layers": 8,
|
| 6 |
+
"d_model": 128,
|
| 7 |
+
"d_intermediate": 256,
|
| 8 |
+
"output_bottleneck_dim": 48,
|
| 9 |
+
"expand_v": 1.0,
|
| 10 |
+
"state_weaving": 1,
|
| 11 |
+
"gating_kernel_size": 3,
|
| 12 |
+
"main_module": "conv,attn,conv,attn",
|
| 13 |
+
"use_norm": true,
|
| 14 |
+
"learn_bias": 1
|
| 15 |
+
}
|
checkpoints/reverso_nano/args.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"seq_len": 2048,
|
| 3 |
+
"input_token_len": 2048,
|
| 4 |
+
"output_token_len": 48,
|
| 5 |
+
"e_layers": 2,
|
| 6 |
+
"d_model": 32,
|
| 7 |
+
"d_intermediate": 256,
|
| 8 |
+
"output_bottleneck_dim": 48,
|
| 9 |
+
"expand_v": 1.0,
|
| 10 |
+
"state_weaving": 1,
|
| 11 |
+
"gating_kernel_size": 3,
|
| 12 |
+
"main_module": "conv,attn,conv,attn",
|
| 13 |
+
"use_norm": true,
|
| 14 |
+
"learn_bias": 1
|
| 15 |
+
}
|
checkpoints/reverso_small/args.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"seq_len": 2048,
|
| 3 |
+
"input_token_len": 2048,
|
| 4 |
+
"output_token_len": 48,
|
| 5 |
+
"e_layers": 4,
|
| 6 |
+
"d_model": 64,
|
| 7 |
+
"d_intermediate": 256,
|
| 8 |
+
"output_bottleneck_dim": 48,
|
| 9 |
+
"expand_v": 1.0,
|
| 10 |
+
"state_weaving": 1,
|
| 11 |
+
"gating_kernel_size": 3,
|
| 12 |
+
"main_module": "conv,attn,conv,attn",
|
| 13 |
+
"use_norm": true,
|
| 14 |
+
"learn_bias": 1
|
| 15 |
+
}
|
checkpoints/reverso_small/checkpoint.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a475a728a7d6a625bead27b283abd4e7746d3a7099086e5a8cf6a23bb647502b
|
| 3 |
+
size 2252946
|
config/dataset_properties.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"m4_yearly": {"domain": "Econ/Fin", "frequency": "A", "num_variates": 1}, "m4_quarterly": {"domain": "Econ/Fin", "frequency": "Q", "num_variates": 1}, "m4_monthly": {"domain": "Econ/Fin", "frequency": "M", "num_variates": 1}, "m4_weekly": {"domain": "Econ/Fin", "frequency": "W", "num_variates": 1}, "m4_daily": {"domain": "Econ/Fin", "frequency": "D", "num_variates": 1}, "m4_hourly": {"domain": "Econ/Fin", "frequency": "H", "num_variates": 1}, "electricity": {"domain": "Energy", "frequency": "W", "num_variates": 1}, "ett1": {"domain": "Energy", "frequency": "W", "num_variates": 7}, "ett2": {"domain": "Energy", "frequency": "W", "num_variates": 7}, "solar": {"domain": "Energy", "frequency": "W", "num_variates": 1}, "hospital": {"domain": "Healthcare", "frequency": "M", "num_variates": 1}, "covid_deaths": {"domain": "Healthcare", "frequency": "D", "num_variates": 1}, "us_births": {"domain": "Healthcare", "frequency": "M", "num_variates": 1}, "saugeen": {"domain": "Nature", "frequency": "M", "num_variates": 1}, "temperature_rain": {"domain": "Nature", "frequency": "D", "num_variates": 1}, "kdd_cup_2018": {"domain": "Nature", "frequency": "D", "num_variates": 1}, "jena_weather": {"domain": "Nature", "frequency": "D", "num_variates": 21}, "car_parts": {"domain": "Sales", "frequency": "M", "num_variates": 1}, "restaurant": {"domain": "Sales", "frequency": "D", "num_variates": 1}, "hierarchical_sales": {"domain": "Sales", "frequency": "W-WED", "num_variates": 1}, "loop_seattle": {"domain": "Transport", "frequency": "D", "num_variates": 1}, "sz_taxi": {"domain": "Transport", "frequency": "H", "num_variates": 1}, "m_dense": {"domain": "Transport", "frequency": "D", "num_variates": 1}, "bitbrains_fast_storage": {"domain": "Web/CloudOps", "frequency": "H", "num_variates": 2}, "bitbrains_rnd": {"domain": "Web/CloudOps", "frequency": "H", "num_variates": 2}, "bizitobs_application": {"domain": "Web/CloudOps", "frequency": "10S", "num_variates": 2}, "bizitobs_service": {"domain": "Web/CloudOps", "frequency": "10S", "num_variates": 2}, "bizitobs_l2c": {"domain": "Web/CloudOps", "frequency": "H", "num_variates": 7}}
|
config/downsample_factors.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bizitobs_l2c/5t/medium": 7,
|
| 3 |
+
"bizitobs_l2c/5t/long": 7
|
| 4 |
+
}
|
example/eval_gift.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GiftEval evaluation script for Reverso.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import math
|
| 7 |
+
import argparse
|
| 8 |
+
import csv
|
| 9 |
+
from types import SimpleNamespace
|
| 10 |
+
from typing import List, Optional, Tuple
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import pandas as pd
|
| 16 |
+
|
| 17 |
+
from reverso.forecast import load_checkpoint
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from torch.cuda.amp import autocast as autocast_fp
|
| 21 |
+
except Exception:
|
| 22 |
+
autocast_fp = None
|
| 23 |
+
|
| 24 |
+
def numpy_fill(arr: np.ndarray) -> np.ndarray:
|
| 25 |
+
mask = np.isnan(arr)
|
| 26 |
+
idx = np.where(~mask, np.arange(mask.shape[1]), 0)
|
| 27 |
+
np.maximum.accumulate(idx, axis=1, out=idx)
|
| 28 |
+
out = arr[np.arange(idx.shape[0])[:, None], idx]
|
| 29 |
+
return out
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ReversoPredictor:
|
| 33 |
+
"""GiftEval predictor for reverso.Model."""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
prediction_length: int,
|
| 38 |
+
checkpoint_path: Optional[str] = None,
|
| 39 |
+
device: str = "cuda",
|
| 40 |
+
seq_len: int = 2048,
|
| 41 |
+
input_token_len: int = 2048,
|
| 42 |
+
output_token_len: int = 48,
|
| 43 |
+
e_layers: int = 8,
|
| 44 |
+
d_model: int = 128,
|
| 45 |
+
d_intermediate: int = 512,
|
| 46 |
+
output_bottleneck_dim: int = 48,
|
| 47 |
+
expand_v: float = 1.0,
|
| 48 |
+
state_weaving: int = 1,
|
| 49 |
+
gating_kernel_size: int = 3,
|
| 50 |
+
main_module: str = "conv,attn,conv,attn,conv,attn,conv,attn",
|
| 51 |
+
num_samples: int = 100,
|
| 52 |
+
batch_size: int = 256,
|
| 53 |
+
use_amp: int = 1,
|
| 54 |
+
downsample_factor: int = 1,
|
| 55 |
+
force_flip_invariance: bool = False,
|
| 56 |
+
):
|
| 57 |
+
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
| 58 |
+
self.prediction_length = int(prediction_length)
|
| 59 |
+
self.num_samples = int(num_samples)
|
| 60 |
+
self.batch_size = int(batch_size)
|
| 61 |
+
self.seq_len = int(seq_len)
|
| 62 |
+
self.input_token_len = int(input_token_len)
|
| 63 |
+
self.output_token_len = int(output_token_len)
|
| 64 |
+
self.use_amp = int(use_amp)
|
| 65 |
+
self.downsample_factor = int(downsample_factor)
|
| 66 |
+
self.force_flip_invariance = bool(force_flip_invariance)
|
| 67 |
+
|
| 68 |
+
args = SimpleNamespace(
|
| 69 |
+
input_token_len=self.input_token_len,
|
| 70 |
+
output_token_len=self.output_token_len,
|
| 71 |
+
seq_len=self.seq_len,
|
| 72 |
+
d_model=int(d_model),
|
| 73 |
+
d_intermediate=int(d_intermediate),
|
| 74 |
+
use_norm=True,
|
| 75 |
+
learn_bias=1,
|
| 76 |
+
output_bottleneck_dim=int(output_bottleneck_dim),
|
| 77 |
+
expand_v=float(expand_v),
|
| 78 |
+
state_weaving=int(state_weaving),
|
| 79 |
+
gating_kernel_size=int(gating_kernel_size),
|
| 80 |
+
main_module=str(main_module),
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
from reverso import model as model_impl
|
| 84 |
+
try:
|
| 85 |
+
self.model = model_impl.Model(args).to(self.device)
|
| 86 |
+
except RuntimeError as e:
|
| 87 |
+
if "CUDA" in str(e):
|
| 88 |
+
print(f"CUDA not usable ({e}); falling back to CPU.")
|
| 89 |
+
self.device = torch.device("cpu")
|
| 90 |
+
self.use_amp = 0
|
| 91 |
+
self.model = model_impl.Model(args).to(self.device)
|
| 92 |
+
else:
|
| 93 |
+
raise
|
| 94 |
+
self.model.eval()
|
| 95 |
+
|
| 96 |
+
if checkpoint_path is not None and os.path.isfile(checkpoint_path):
|
| 97 |
+
self._load_checkpoint(checkpoint_path)
|
| 98 |
+
else:
|
| 99 |
+
print("Warning: checkpoint_path not provided or file not found. Using randomly initialized weights.")
|
| 100 |
+
|
| 101 |
+
def _load_checkpoint(self, ckpt_path: str):
|
| 102 |
+
load_checkpoint(self.model, ckpt_path, device=str(self.device))
|
| 103 |
+
|
| 104 |
+
def _downsample_if_needed(self, series: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
| 105 |
+
cur = series
|
| 106 |
+
if self.downsample_factor > 1:
|
| 107 |
+
cur = cur[::self.downsample_factor]
|
| 108 |
+
return cur, self.downsample_factor
|
| 109 |
+
|
| 110 |
+
def _left_pad_to_len(self, arr: np.ndarray, target_len: int) -> Tuple[np.ndarray, int]:
|
| 111 |
+
if arr.shape[0] >= target_len:
|
| 112 |
+
return arr[-target_len:], 0
|
| 113 |
+
pad_len = target_len - arr.shape[0]
|
| 114 |
+
fill_value = arr[0] if arr.shape[0] > 0 else 0.0
|
| 115 |
+
padding = np.full((pad_len,), fill_value, dtype=arr.dtype)
|
| 116 |
+
return np.concatenate([padding, arr], axis=0), pad_len
|
| 117 |
+
|
| 118 |
+
def _prepare_context_matrix(self, context: List[torch.Tensor]) -> Tuple[torch.Tensor, List[int]]:
|
| 119 |
+
xs = []
|
| 120 |
+
downsample_factors = []
|
| 121 |
+
|
| 122 |
+
for c in context:
|
| 123 |
+
cur, downsample_factor = self._downsample_if_needed(c)
|
| 124 |
+
downsample_factors.append(downsample_factor)
|
| 125 |
+
|
| 126 |
+
cur_np = cur.detach().cpu().float().numpy()
|
| 127 |
+
cur_np, _ = self._left_pad_to_len(cur_np, self.seq_len)
|
| 128 |
+
|
| 129 |
+
x2d = cur_np[None, :]
|
| 130 |
+
x_interp = np.copy(x2d)
|
| 131 |
+
series = x2d[0]
|
| 132 |
+
if np.any(np.isnan(series)):
|
| 133 |
+
valid_mask = ~np.isnan(series)
|
| 134 |
+
if np.sum(valid_mask) >= 2:
|
| 135 |
+
valid_indices = np.where(valid_mask)[0]
|
| 136 |
+
valid_values = series[valid_mask]
|
| 137 |
+
x_interp[0] = np.interp(np.arange(len(series)), valid_indices, valid_values)
|
| 138 |
+
else:
|
| 139 |
+
x_interp = numpy_fill(x2d)
|
| 140 |
+
ff = numpy_fill(x_interp)
|
| 141 |
+
bf = np.flip(numpy_fill(np.flip(x_interp, axis=1)), axis=1)
|
| 142 |
+
x_imp = np.where(np.isnan(ff), bf, ff)
|
| 143 |
+
x_imp = np.where(np.isnan(x_imp), 0.0, x_imp)
|
| 144 |
+
xs.append(x_imp[0])
|
| 145 |
+
|
| 146 |
+
x = torch.tensor(np.stack(xs), device=self.device, dtype=torch.float32).unsqueeze(-1)
|
| 147 |
+
return x, downsample_factors
|
| 148 |
+
|
| 149 |
+
def _decode_autoregressive(self, init_ctx: torch.Tensor, use_bf16: bool, downsample_factors: List[int]) -> torch.Tensor:
|
| 150 |
+
B, _, C = init_ctx.shape
|
| 151 |
+
roll_len = int(self.output_token_len)
|
| 152 |
+
|
| 153 |
+
target_pred_lens = [int(self.prediction_length) // int(max(1, df)) for df in downsample_factors]
|
| 154 |
+
max_target_pred_len = max(target_pred_lens)
|
| 155 |
+
steps = math.ceil(max_target_pred_len / roll_len)
|
| 156 |
+
preds: List[torch.Tensor] = []
|
| 157 |
+
batch_ctx = init_ctx
|
| 158 |
+
|
| 159 |
+
y_mark = torch.zeros(B, self.output_token_len, C, device=self.device, dtype=init_ctx.dtype)
|
| 160 |
+
|
| 161 |
+
for _ in range(steps):
|
| 162 |
+
x_in = batch_ctx[:, -self.seq_len:, :]
|
| 163 |
+
x_mark = torch.zeros_like(x_in)
|
| 164 |
+
|
| 165 |
+
if autocast_fp is not None and self.use_amp and use_bf16:
|
| 166 |
+
try:
|
| 167 |
+
with autocast_fp(dtype=torch.bfloat16):
|
| 168 |
+
outputs = self.model(x_in, x_mark, y_mark)
|
| 169 |
+
except Exception:
|
| 170 |
+
outputs = self.model(x_in, x_mark, y_mark)
|
| 171 |
+
else:
|
| 172 |
+
outputs = self.model(x_in, x_mark, y_mark)
|
| 173 |
+
|
| 174 |
+
out_chunk = outputs[:, -self.output_token_len:, :]
|
| 175 |
+
take_chunk = out_chunk[:, :roll_len, :]
|
| 176 |
+
preds.append(take_chunk)
|
| 177 |
+
batch_ctx = torch.cat([batch_ctx, take_chunk], dim=1)
|
| 178 |
+
|
| 179 |
+
return torch.cat(preds, dim=1)
|
| 180 |
+
|
| 181 |
+
@torch.no_grad()
|
| 182 |
+
def predict(self, test_data_input, use_bf16_if_available: bool = True):
|
| 183 |
+
from gluonts.itertools import batcher
|
| 184 |
+
from gluonts.model.forecast import SampleForecast
|
| 185 |
+
|
| 186 |
+
forecasts = []
|
| 187 |
+
use_bf16 = bool(
|
| 188 |
+
use_bf16_if_available
|
| 189 |
+
and self.device.type == "cuda"
|
| 190 |
+
and torch.cuda.is_available()
|
| 191 |
+
and torch.cuda.is_bf16_supported()
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
for batch in batcher(test_data_input, batch_size=self.batch_size):
|
| 195 |
+
targets = [torch.tensor(entry["target"], dtype=torch.float32) for entry in batch]
|
| 196 |
+
batch_ctx, downsample_factors = self._prepare_context_matrix(targets)
|
| 197 |
+
|
| 198 |
+
pred_pos = self._decode_autoregressive(batch_ctx, use_bf16, downsample_factors)
|
| 199 |
+
if self.force_flip_invariance:
|
| 200 |
+
pred_neg = self._decode_autoregressive(-batch_ctx, use_bf16, downsample_factors)
|
| 201 |
+
pred_full = 0.5 * (pred_pos - pred_neg)
|
| 202 |
+
else:
|
| 203 |
+
pred_full = pred_pos
|
| 204 |
+
|
| 205 |
+
if torch.isnan(pred_full).any():
|
| 206 |
+
pf_2d = pred_full.squeeze(-1).detach().cpu().numpy()
|
| 207 |
+
pf_2d = numpy_fill(pf_2d)
|
| 208 |
+
pred_full = torch.tensor(pf_2d, device=pred_full.device, dtype=pred_full.dtype).unsqueeze(-1)
|
| 209 |
+
|
| 210 |
+
pred_full_np = pred_full.float().squeeze(-1).detach().cpu().numpy()
|
| 211 |
+
pred_list = []
|
| 212 |
+
for i in range(len(downsample_factors)):
|
| 213 |
+
df = downsample_factors[i]
|
| 214 |
+
target_pred_len = int(self.prediction_length) // int(max(1, df))
|
| 215 |
+
seq = pred_full_np[i, :target_pred_len]
|
| 216 |
+
if df > 1:
|
| 217 |
+
old_len = len(seq)
|
| 218 |
+
new_len = int(self.prediction_length)
|
| 219 |
+
seq = np.interp(np.linspace(0, 1, new_len), np.linspace(0, 1, old_len), seq)
|
| 220 |
+
pred_list.append(seq)
|
| 221 |
+
pred_full_np = np.array(pred_list)
|
| 222 |
+
|
| 223 |
+
for i, ts in enumerate(batch):
|
| 224 |
+
start_date = ts["start"] + len(ts["target"])
|
| 225 |
+
samples = np.repeat(pred_full_np[i][None, :], self.num_samples, axis=0)
|
| 226 |
+
forecasts.append(SampleForecast(samples=samples, start_date=start_date))
|
| 227 |
+
|
| 228 |
+
return forecasts
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# ==========================
|
| 232 |
+
# GiftEval evaluation script
|
| 233 |
+
# ==========================
|
| 234 |
+
from gluonts.ev.metrics import (
|
| 235 |
+
MAE, MAPE, MASE, MSE, MSIS, ND, NRMSE, RMSE, SMAPE,
|
| 236 |
+
MeanWeightedSumQuantileLoss,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
METRICS = [
|
| 240 |
+
MSE(forecast_type="mean"),
|
| 241 |
+
MSE(forecast_type=0.5),
|
| 242 |
+
MAE(),
|
| 243 |
+
MASE(),
|
| 244 |
+
MAPE(),
|
| 245 |
+
SMAPE(),
|
| 246 |
+
MSIS(),
|
| 247 |
+
RMSE(),
|
| 248 |
+
NRMSE(),
|
| 249 |
+
ND(),
|
| 250 |
+
MeanWeightedSumQuantileLoss(quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
|
| 251 |
+
]
|
| 252 |
+
|
| 253 |
+
PRETTY_NAMES = {
|
| 254 |
+
"saugeenday": "saugeen",
|
| 255 |
+
"temperature_rain_with_missing": "temperature_rain",
|
| 256 |
+
"kdd_cup_2018_with_missing": "kdd_cup_2018",
|
| 257 |
+
"car_parts_with_missing": "car_parts",
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
SHORT_DATASETS = "m4_yearly m4_quarterly m4_monthly m4_weekly m4_daily m4_hourly electricity/15T electricity/H electricity/D electricity/W solar/10T solar/H solar/D solar/W hospital covid_deaths us_births/D us_births/M us_births/W saugeenday/D saugeenday/M saugeenday/W temperature_rain_with_missing kdd_cup_2018_with_missing/H kdd_cup_2018_with_missing/D car_parts_with_missing restaurant hierarchical_sales/D hierarchical_sales/W LOOP_SEATTLE/5T LOOP_SEATTLE/H LOOP_SEATTLE/D SZ_TAXI/15T SZ_TAXI/H M_DENSE/H M_DENSE/D ett1/15T ett1/H ett1/D ett1/W ett2/W ett2/D jena_weather/10T jena_weather/H jena_weather/D bitbrains_fast_storage/5T bitbrains_fast_storage/H bitbrains_rnd/5T bitbrains_rnd/H bizitobs_application bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"
|
| 261 |
+
|
| 262 |
+
MED_LONG_DATASETS = "electricity/15T electricity/H solar/10T solar/H kdd_cup_2018_with_missing/H LOOP_SEATTLE/5T LOOP_SEATTLE/H SZ_TAXI/15T M_DENSE/H ett1/15T ett1/H ett2/15T ett2/H jena_weather/10T jena_weather/H bitbrains_fast_storage/5T bitbrains_rnd/5T bizitobs_application bizitobs_service bizitobs_l2c/5T bizitobs_l2c/H"
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def main():
|
| 266 |
+
parser = argparse.ArgumentParser(description="Run Reverso GiftEval across datasets")
|
| 267 |
+
parser.add_argument("--checkpoint", default='checkpoints/reverso_small/checkpoint.pth', help="Path to model checkpoint")
|
| 268 |
+
parser.add_argument("--json_path", default='checkpoints/reverso_small/args.json', help="Path to JSON file with model config overrides")
|
| 269 |
+
parser.add_argument("--output-dir", dest="output_dir", default='results/reverso_small', help="Output directory for results")
|
| 270 |
+
parser.add_argument("--dataset", default=None, help="Filter to specific dataset (substring match)")
|
| 271 |
+
parser.add_argument("--term", default=None, choices=["short", "medium", "long"], help="Filter to specific term")
|
| 272 |
+
parser.add_argument("--force-flip-invariance", dest="force_flip_invariance", action="store_true",
|
| 273 |
+
help="Average f(x) with -f(-x) for flip invariance")
|
| 274 |
+
parser.add_argument("--downsample-json", dest="downsample_json",
|
| 275 |
+
default="config/downsample_factors.json",
|
| 276 |
+
help="Path to JSON with downsample factors per dataset/term")
|
| 277 |
+
args = parser.parse_args()
|
| 278 |
+
|
| 279 |
+
# Load model config from JSON if provided
|
| 280 |
+
json_cfg = {}
|
| 281 |
+
if args.json_path and os.path.isfile(args.json_path):
|
| 282 |
+
with open(args.json_path, "r") as f:
|
| 283 |
+
json_cfg = json.load(f)
|
| 284 |
+
|
| 285 |
+
# Model hyperparameters
|
| 286 |
+
SEQ_LEN = int(json_cfg.get("seq_len", 2048))
|
| 287 |
+
INPUT_TOKEN_LEN = int(json_cfg.get("input_token_len", 2048))
|
| 288 |
+
OUTPUT_TOKEN_LEN = int(json_cfg.get("output_token_len", 48))
|
| 289 |
+
E_LAYERS = int(json_cfg.get("e_layers", 8))
|
| 290 |
+
D_MODEL = int(json_cfg.get("d_model", 128))
|
| 291 |
+
D_INTERMEDIATE = int(json_cfg.get("d_intermediate", 512))
|
| 292 |
+
OUTPUT_BOTTLENECK_DIM = int(json_cfg.get("output_bottleneck_dim", 48))
|
| 293 |
+
EXPAND_V = float(json_cfg.get("expand_v", 1.0))
|
| 294 |
+
STATE_WEAVING = int(json_cfg.get("state_weaving", 1))
|
| 295 |
+
GATING_KERNEL_SIZE = int(json_cfg.get("gating_kernel_size", 3))
|
| 296 |
+
MAIN_MODULE = str(json_cfg.get("main_module", "conv,attn,conv,attn,conv,attn,conv,attn"))
|
| 297 |
+
|
| 298 |
+
DEVICE = "cuda"
|
| 299 |
+
NUM_SAMPLES = 100
|
| 300 |
+
BATCH_SIZE = 256
|
| 301 |
+
USE_AMP = 1
|
| 302 |
+
|
| 303 |
+
downsample_map = {}
|
| 304 |
+
if os.path.isfile(args.downsample_json):
|
| 305 |
+
with open(args.downsample_json, "r") as f:
|
| 306 |
+
downsample_map = json.load(f)
|
| 307 |
+
|
| 308 |
+
# Setup datasets
|
| 309 |
+
all_datasets = sorted(set(SHORT_DATASETS.split() + MED_LONG_DATASETS.split()))
|
| 310 |
+
med_long_set = set(MED_LONG_DATASETS.split())
|
| 311 |
+
all_terms = ["short", "medium", "long"]
|
| 312 |
+
|
| 313 |
+
with open("config/dataset_properties.json", "r") as f:
|
| 314 |
+
dataset_properties = json.load(f)
|
| 315 |
+
|
| 316 |
+
os.environ.setdefault("GIFT_EVAL", os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data"))
|
| 317 |
+
|
| 318 |
+
if args.dataset:
|
| 319 |
+
all_datasets = [ds for ds in all_datasets if args.dataset in ds]
|
| 320 |
+
if not all_datasets:
|
| 321 |
+
print(f"No datasets found matching '{args.dataset}'")
|
| 322 |
+
return
|
| 323 |
+
|
| 324 |
+
if args.term:
|
| 325 |
+
all_terms = [args.term]
|
| 326 |
+
|
| 327 |
+
# Setup output
|
| 328 |
+
output_dir = args.output_dir or os.path.join(os.path.dirname(os.path.abspath(__file__)), "results")
|
| 329 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 330 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 331 |
+
csv_path = os.path.join(output_dir, f"all_results_{timestamp}.csv")
|
| 332 |
+
|
| 333 |
+
with open(csv_path, "w", newline="") as f:
|
| 334 |
+
writer = csv.writer(f)
|
| 335 |
+
writer.writerow([
|
| 336 |
+
"dataset", "model",
|
| 337 |
+
"eval_metrics/MSE[mean]", "eval_metrics/MSE[0.5]",
|
| 338 |
+
"eval_metrics/MAE[0.5]", "eval_metrics/MASE[0.5]",
|
| 339 |
+
"eval_metrics/MAPE[0.5]", "eval_metrics/sMAPE[0.5]",
|
| 340 |
+
"eval_metrics/MSIS", "eval_metrics/RMSE[mean]",
|
| 341 |
+
"eval_metrics/NRMSE[mean]", "eval_metrics/ND[0.5]",
|
| 342 |
+
"eval_metrics/mean_weighted_sum_quantile_loss",
|
| 343 |
+
"domain", "num_variates",
|
| 344 |
+
])
|
| 345 |
+
|
| 346 |
+
from gluonts.model import evaluate_model
|
| 347 |
+
from gluonts.time_feature import get_seasonality
|
| 348 |
+
from gift_eval.data import Dataset
|
| 349 |
+
|
| 350 |
+
print(f"Evaluating {len(all_datasets)} datasets, terms: {all_terms}")
|
| 351 |
+
print(f"Flip invariance: {args.force_flip_invariance}")
|
| 352 |
+
|
| 353 |
+
for ds_num, ds_name in enumerate(all_datasets):
|
| 354 |
+
if "/" in ds_name:
|
| 355 |
+
ds_key = PRETTY_NAMES.get(ds_name.split("/")[0].lower(), ds_name.split("/")[0].lower())
|
| 356 |
+
ds_freq = ds_name.split("/")[1]
|
| 357 |
+
else:
|
| 358 |
+
ds_key = PRETTY_NAMES.get(ds_name.lower(), ds_name.lower())
|
| 359 |
+
ds_freq = dataset_properties[ds_key]["frequency"]
|
| 360 |
+
|
| 361 |
+
print(f"[{ds_num + 1}/{len(all_datasets)}] {ds_name}")
|
| 362 |
+
|
| 363 |
+
for term in all_terms:
|
| 364 |
+
if term in ("medium", "long") and ds_name not in med_long_set:
|
| 365 |
+
continue
|
| 366 |
+
|
| 367 |
+
ds_config = f"{ds_key}/{ds_freq}/{term}"
|
| 368 |
+
probe = Dataset(name=ds_name, term=term, to_univariate=False)
|
| 369 |
+
to_univariate = probe.target_dim != 1
|
| 370 |
+
dataset = Dataset(name=ds_name, term=term, to_univariate=to_univariate)
|
| 371 |
+
season_length = get_seasonality(dataset.freq)
|
| 372 |
+
|
| 373 |
+
downsample_key = f"{ds_key}/{ds_freq}/{term}".lower()
|
| 374 |
+
downsample_factor = downsample_map.get(downsample_key, 1)
|
| 375 |
+
|
| 376 |
+
info = f" {term}: {len(dataset.test_data)} instances"
|
| 377 |
+
if downsample_factor > 1:
|
| 378 |
+
info += f", downsample={downsample_factor}"
|
| 379 |
+
print(info)
|
| 380 |
+
|
| 381 |
+
predictor = ReversoPredictor(
|
| 382 |
+
prediction_length=dataset.prediction_length,
|
| 383 |
+
checkpoint_path=args.checkpoint,
|
| 384 |
+
device=DEVICE,
|
| 385 |
+
seq_len=SEQ_LEN,
|
| 386 |
+
input_token_len=INPUT_TOKEN_LEN,
|
| 387 |
+
output_token_len=OUTPUT_TOKEN_LEN,
|
| 388 |
+
e_layers=E_LAYERS,
|
| 389 |
+
d_model=D_MODEL,
|
| 390 |
+
d_intermediate=D_INTERMEDIATE,
|
| 391 |
+
output_bottleneck_dim=OUTPUT_BOTTLENECK_DIM,
|
| 392 |
+
expand_v=EXPAND_V,
|
| 393 |
+
state_weaving=STATE_WEAVING,
|
| 394 |
+
gating_kernel_size=GATING_KERNEL_SIZE,
|
| 395 |
+
main_module=MAIN_MODULE,
|
| 396 |
+
num_samples=NUM_SAMPLES,
|
| 397 |
+
batch_size=BATCH_SIZE,
|
| 398 |
+
use_amp=USE_AMP,
|
| 399 |
+
downsample_factor=downsample_factor,
|
| 400 |
+
force_flip_invariance=args.force_flip_invariance,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
res = evaluate_model(
|
| 404 |
+
predictor,
|
| 405 |
+
test_data=dataset.test_data,
|
| 406 |
+
metrics=METRICS,
|
| 407 |
+
batch_size=BATCH_SIZE,
|
| 408 |
+
axis=None,
|
| 409 |
+
mask_invalid_label=True,
|
| 410 |
+
allow_nan_forecast=False,
|
| 411 |
+
seasonality=season_length,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
with open(csv_path, "a", newline="") as f:
|
| 415 |
+
writer = csv.writer(f)
|
| 416 |
+
writer.writerow([
|
| 417 |
+
ds_config, "reverso",
|
| 418 |
+
res["MSE[mean]"][0], res["MSE[0.5]"][0],
|
| 419 |
+
res["MAE[0.5]"][0], res["MASE[0.5]"][0],
|
| 420 |
+
res["MAPE[0.5]"][0], res["sMAPE[0.5]"][0],
|
| 421 |
+
res["MSIS"][0], res["RMSE[mean]"][0],
|
| 422 |
+
res["NRMSE[mean]"][0], res["ND[0.5]"][0],
|
| 423 |
+
res["mean_weighted_sum_quantile_loss"][0],
|
| 424 |
+
dataset_properties[ds_key]["domain"],
|
| 425 |
+
dataset_properties[ds_key]["num_variates"],
|
| 426 |
+
])
|
| 427 |
+
|
| 428 |
+
print(f"\nResults saved to: {csv_path}")
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
if __name__ == "__main__":
|
| 432 |
+
main()
|
example/forecast_demo.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Demo: autoregressive forecasting on simple synthetic signals.
|
| 3 |
+
|
| 4 |
+
Run all signals: python example/forecast_demo.py --signal all
|
| 5 |
+
Run one signal: python example/forecast_demo.py --signal sine
|
| 6 |
+
List available: python example/forecast_demo.py --list
|
| 7 |
+
"""
|
| 8 |
+
import argparse
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
|
| 14 |
+
from reverso.forecast import load_model, forecast
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
# Signal generators — each returns float32 array of length n
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
|
| 21 |
+
def signal_constant(n: int) -> np.ndarray:
|
| 22 |
+
return np.full(n, 5.0, dtype=np.float32)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def signal_linear(n: int) -> np.ndarray:
|
| 26 |
+
return np.linspace(0, 40, n).astype(np.float32)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def signal_sine(n: int) -> np.ndarray:
|
| 31 |
+
t = np.arange(n, dtype=np.float64)
|
| 32 |
+
return (5.0 * np.sin(2 * np.pi * t / 200)).astype(np.float32)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def signal_sawtooth(n: int) -> np.ndarray:
|
| 36 |
+
t = np.arange(n, dtype=np.float64)
|
| 37 |
+
period = 200
|
| 38 |
+
return (10.0 * (t % period) / period).astype(np.float32)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def signal_square(n: int) -> np.ndarray:
|
| 42 |
+
t = np.arange(n, dtype=np.float64)
|
| 43 |
+
return (5.0 * np.sign(np.sin(2 * np.pi * t / 200))).astype(np.float32)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
SIGNALS = {
|
| 48 |
+
"constant": ("Constant", signal_constant),
|
| 49 |
+
"linear": ("Linear trend", signal_linear),
|
| 50 |
+
"sine": ("Sine wave", signal_sine),
|
| 51 |
+
"sawtooth": ("Sawtooth wave", signal_sawtooth),
|
| 52 |
+
"square": ("Square wave", signal_square),
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def run_one(name, label, gen_fn, model, cfg, device, context_length, prediction_length,
|
| 57 |
+
output_dir, flip_invariance=False):
|
| 58 |
+
total_len = context_length + prediction_length
|
| 59 |
+
signal = gen_fn(total_len)
|
| 60 |
+
context_np = signal[:context_length]
|
| 61 |
+
ground_truth = signal[context_length:]
|
| 62 |
+
|
| 63 |
+
context_tensor = torch.tensor(context_np, device=device).unsqueeze(0).unsqueeze(-1)
|
| 64 |
+
pred_pos = forecast(
|
| 65 |
+
model, context_tensor, prediction_length,
|
| 66 |
+
seq_len=cfg.seq_len, output_token_len=cfg.output_token_len,
|
| 67 |
+
)
|
| 68 |
+
if flip_invariance:
|
| 69 |
+
pred_neg = forecast(
|
| 70 |
+
model, -context_tensor, prediction_length,
|
| 71 |
+
seq_len=cfg.seq_len, output_token_len=cfg.output_token_len,
|
| 72 |
+
)
|
| 73 |
+
preds_tensor = 0.5 * (pred_pos - pred_neg)
|
| 74 |
+
else:
|
| 75 |
+
preds_tensor = pred_pos
|
| 76 |
+
preds = preds_tensor[0, :, 0].float().cpu().numpy()
|
| 77 |
+
|
| 78 |
+
ctx_t = np.arange(context_length)
|
| 79 |
+
pred_t = np.arange(context_length, total_len)
|
| 80 |
+
|
| 81 |
+
fig, ax = plt.subplots(figsize=(14, 5))
|
| 82 |
+
ax.plot(ctx_t, context_np, color="steelblue", label="Context")
|
| 83 |
+
ax.plot(pred_t, ground_truth, color="gray", linestyle="--", label="Ground truth")
|
| 84 |
+
ax.plot(pred_t, preds, color="tomato", label="Forecast")
|
| 85 |
+
ax.axvline(context_length, color="black", linestyle=":", alpha=0.5)
|
| 86 |
+
ax.set_xlabel("Time step")
|
| 87 |
+
ax.set_ylabel("Value")
|
| 88 |
+
ax.set_title(f"Reverso: {label}")
|
| 89 |
+
ax.legend()
|
| 90 |
+
fig.tight_layout()
|
| 91 |
+
out_path = f"{output_dir}/{name}_forecast.png"
|
| 92 |
+
fig.savefig(out_path, dpi=150)
|
| 93 |
+
plt.close(fig)
|
| 94 |
+
print(f" {label:25s} -> {out_path}")
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def main():
|
| 98 |
+
parser = argparse.ArgumentParser(description="Reverso forecast demo on synthetic signals")
|
| 99 |
+
parser.add_argument("--signal", type=str, default="all",
|
| 100 |
+
help="Signal name, or 'all' to run every signal")
|
| 101 |
+
parser.add_argument("--list", action="store_true", help="List available signals and exit")
|
| 102 |
+
parser.add_argument("--checkpoint", type=str,
|
| 103 |
+
default="checkpoints/reverso_small/checkpoint.pth")
|
| 104 |
+
parser.add_argument("--args-json", type=str,
|
| 105 |
+
default="checkpoints/reverso_small/args.json")
|
| 106 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 107 |
+
parser.add_argument("--context-length", type=int, default=2048)
|
| 108 |
+
parser.add_argument("--prediction-length", type=int, default=480)
|
| 109 |
+
parser.add_argument("--output-dir", type=str, default="example")
|
| 110 |
+
parser.add_argument("--flip-invariance", action="store_true",
|
| 111 |
+
help="Average f(x) with -f(-x) for flip invariance")
|
| 112 |
+
args = parser.parse_args()
|
| 113 |
+
|
| 114 |
+
if args.list:
|
| 115 |
+
for name, (label, _) in SIGNALS.items():
|
| 116 |
+
print(f" {name:15s} {label}")
|
| 117 |
+
return
|
| 118 |
+
|
| 119 |
+
model, cfg = load_model(args.checkpoint, args.args_json, args.device)
|
| 120 |
+
print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} params")
|
| 121 |
+
|
| 122 |
+
if args.signal == "all":
|
| 123 |
+
to_run = list(SIGNALS.items())
|
| 124 |
+
else:
|
| 125 |
+
if args.signal not in SIGNALS:
|
| 126 |
+
print(f"Unknown signal '{args.signal}'. Use --list to see options.")
|
| 127 |
+
return
|
| 128 |
+
to_run = [(args.signal, SIGNALS[args.signal])]
|
| 129 |
+
|
| 130 |
+
for name, (label, gen_fn) in to_run:
|
| 131 |
+
run_one(name, label, gen_fn, model, cfg, args.device,
|
| 132 |
+
args.context_length, args.prediction_length, args.output_dir,
|
| 133 |
+
args.flip_invariance)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
+
main()
|
example/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
matplotlib
|
| 2 |
+
gluonts~=0.15.1
|
| 3 |
+
python-dotenv==1.0.0
|
figures/gift_eval_pareto_overall.png
ADDED
|
Git LFS Details
|
figures/new_arch.png
ADDED
|
Git LFS Details
|
pyproject.toml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=64"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "reverso"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Efficient time-series foundation models for zero-shot forecasting"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
license = "MIT"
|
| 11 |
+
requires-python = ">=3.11"
|
| 12 |
+
dependencies = [
|
| 13 |
+
"torch>=2.6.0",
|
| 14 |
+
"numpy",
|
| 15 |
+
"pandas",
|
| 16 |
+
"flash-linear-attention",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
[project.optional-dependencies]
|
| 20 |
+
examples = [
|
| 21 |
+
"matplotlib",
|
| 22 |
+
"gluonts~=0.15.1",
|
| 23 |
+
"python-dotenv>=1.0.0",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
[project.urls]
|
| 27 |
+
Repository = "https://github.com/shinfxh/reverso"
|
| 28 |
+
|
| 29 |
+
[tool.setuptools.packages.find]
|
| 30 |
+
include = ["reverso*"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.6.0
|
| 2 |
+
numpy
|
| 3 |
+
pandas
|
| 4 |
+
flash-linear-attention
|
reverso/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reverso: Efficient time-series foundation models for zero-shot forecasting."""
|
| 2 |
+
|
| 3 |
+
from reverso.model import Model
|
| 4 |
+
from reverso.forecast import forecast, load_checkpoint, load_model
|
| 5 |
+
|
| 6 |
+
__all__ = ["Model", "forecast", "load_checkpoint", "load_model"]
|
reverso/forecast.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Autoregressive forecasting utilities for Reverso."""
|
| 2 |
+
import json
|
| 3 |
+
import math
|
| 4 |
+
from types import SimpleNamespace
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from reverso.model import Model
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_checkpoint(model: Model, checkpoint_path: str, device: str = "cuda"):
|
| 12 |
+
"""Load a checkpoint into an existing Reverso model.
|
| 13 |
+
|
| 14 |
+
Handles common checkpoint formats (raw state_dict, or dicts keyed by
|
| 15 |
+
"model_state_dict", "state_dict", "model", "ema", "ema_state_dict")
|
| 16 |
+
and strips the "module." prefix left by DDP.
|
| 17 |
+
"""
|
| 18 |
+
raw = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 19 |
+
state_dict = raw
|
| 20 |
+
if isinstance(raw, dict):
|
| 21 |
+
for k in ("model_state_dict", "state_dict", "model", "ema", "ema_state_dict"):
|
| 22 |
+
if k in raw and isinstance(raw[k], dict):
|
| 23 |
+
state_dict = raw[k]
|
| 24 |
+
break
|
| 25 |
+
state_dict = {k.removeprefix("module."): v for k, v in state_dict.items()}
|
| 26 |
+
model.load_state_dict(state_dict, strict=True)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_model(checkpoint_path: str, args_json: str, device: str = "cuda"):
|
| 30 |
+
"""Load a Reverso model from a checkpoint and config JSON.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
(model, cfg) tuple.
|
| 34 |
+
"""
|
| 35 |
+
with open(args_json) as f:
|
| 36 |
+
cfg = SimpleNamespace(**json.load(f))
|
| 37 |
+
|
| 38 |
+
model = Model(cfg).to(device)
|
| 39 |
+
load_checkpoint(model, checkpoint_path, device)
|
| 40 |
+
model.eval()
|
| 41 |
+
return model, cfg
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@torch.no_grad()
|
| 45 |
+
def forecast(
|
| 46 |
+
model: Model,
|
| 47 |
+
context: torch.Tensor,
|
| 48 |
+
prediction_length: int,
|
| 49 |
+
seq_len: int,
|
| 50 |
+
output_token_len: int,
|
| 51 |
+
use_amp: bool = True,
|
| 52 |
+
) -> torch.Tensor:
|
| 53 |
+
"""Autoregressive multi-step forecast.
|
| 54 |
+
|
| 55 |
+
Follows the rollout pattern from eval_gift.py's _decode_autoregressive.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
model: Reverso Model (already on the target device, in eval mode).
|
| 59 |
+
context: Input context tensor of shape (B, L, 1).
|
| 60 |
+
prediction_length: Number of future steps to predict.
|
| 61 |
+
seq_len: Model's context window length (cfg.seq_len).
|
| 62 |
+
output_token_len: Steps produced per model call (cfg.output_token_len).
|
| 63 |
+
use_amp: Whether to use bfloat16 autocast (requires CUDA).
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Predictions tensor of shape (B, prediction_length, 1).
|
| 67 |
+
"""
|
| 68 |
+
device = context.device
|
| 69 |
+
B, _, C = context.shape
|
| 70 |
+
roll_len = output_token_len
|
| 71 |
+
steps = math.ceil(prediction_length / roll_len)
|
| 72 |
+
|
| 73 |
+
batch_ctx = context
|
| 74 |
+
preds = []
|
| 75 |
+
|
| 76 |
+
y_mark = torch.zeros(B, output_token_len, C, device=device, dtype=context.dtype)
|
| 77 |
+
|
| 78 |
+
for _ in range(steps):
|
| 79 |
+
x_in = batch_ctx[:, -seq_len:, :]
|
| 80 |
+
x_mark = torch.zeros_like(x_in)
|
| 81 |
+
|
| 82 |
+
if use_amp and device.type == "cuda":
|
| 83 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 84 |
+
outputs = model(x_in, x_mark, y_mark)
|
| 85 |
+
else:
|
| 86 |
+
outputs = model(x_in, x_mark, y_mark)
|
| 87 |
+
|
| 88 |
+
out_chunk = outputs[:, -output_token_len:, :]
|
| 89 |
+
take_chunk = out_chunk[:, :roll_len, :]
|
| 90 |
+
preds.append(take_chunk)
|
| 91 |
+
batch_ctx = torch.cat([batch_ctx, take_chunk], dim=1)
|
| 92 |
+
|
| 93 |
+
return torch.cat(preds, dim=1)[:, :prediction_length, :]
|
reverso/model.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Reverso: conv-attention hybrid for time series forecasting.
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from flashfftconv import FlashFFTConv
|
| 8 |
+
from fla.layers import DeltaNet
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Gating(nn.Module):
|
| 13 |
+
def __init__(self, channels, temporal_kernel=3):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.net = nn.Sequential(
|
| 16 |
+
nn.Conv1d(channels, channels, kernel_size=temporal_kernel,
|
| 17 |
+
padding=temporal_kernel // 2, groups=channels),
|
| 18 |
+
nn.SiLU(),
|
| 19 |
+
nn.Conv1d(channels, channels, kernel_size=1),
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
return torch.sigmoid(self.net(x))
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class MLPBlock(nn.Module):
|
| 27 |
+
def __init__(self, d_in, d_out, d_intermediate=0):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.norm = nn.LayerNorm(d_out)
|
| 30 |
+
if d_intermediate and d_intermediate > 0:
|
| 31 |
+
self.linear = nn.Linear(d_in, d_intermediate)
|
| 32 |
+
self.linear_final = nn.Linear(d_intermediate, d_out)
|
| 33 |
+
else:
|
| 34 |
+
self.linear = nn.Linear(d_in, d_out)
|
| 35 |
+
self.linear_final = nn.Identity()
|
| 36 |
+
self.activation = nn.ReLU()
|
| 37 |
+
self.skip_linear = nn.Linear(d_in, d_out) if d_in != d_out else nn.Identity()
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
if x.ndim == 3:
|
| 41 |
+
x = x.permute(0, 2, 1)
|
| 42 |
+
residual = self.skip_linear(x)
|
| 43 |
+
y = self.linear(x)
|
| 44 |
+
y = self.activation(y)
|
| 45 |
+
y = self.linear_final(y)
|
| 46 |
+
y = self.norm(y)
|
| 47 |
+
y = residual + y
|
| 48 |
+
if y.ndim == 3:
|
| 49 |
+
y = y.permute(0, 2, 1)
|
| 50 |
+
return y
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class CNNBlock(nn.Module):
|
| 54 |
+
def __init__(self, channels, seq_len, flashfftconv, gating_kernel_size=3):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.flashfftconv = flashfftconv
|
| 57 |
+
self.k = nn.Parameter(torch.randn(channels, seq_len, dtype=torch.float32))
|
| 58 |
+
self.pregate = Gating(channels, gating_kernel_size)
|
| 59 |
+
self.activation = nn.ReLU()
|
| 60 |
+
self.norm = nn.LayerNorm(channels)
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
residual = x
|
| 64 |
+
x_conv = x.contiguous().to(torch.bfloat16)
|
| 65 |
+
pregate = self.pregate(x_conv.float()).to(x_conv.dtype)
|
| 66 |
+
postgate = torch.ones_like(x_conv)
|
| 67 |
+
out = self.flashfftconv(x_conv, self.k, pregate=pregate, postgate=postgate)
|
| 68 |
+
out = self.activation(out)
|
| 69 |
+
out = out.transpose(1, 2)
|
| 70 |
+
out = self.norm(out)
|
| 71 |
+
out = out.transpose(1, 2)
|
| 72 |
+
out = out + residual
|
| 73 |
+
return out
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class AttentionBlock(nn.Module):
|
| 77 |
+
def __init__(self, d_model, expand_v, state_weaving=False, is_intermediate=False):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.state_weaving = state_weaving
|
| 80 |
+
self.is_intermediate = is_intermediate
|
| 81 |
+
self.attention = DeltaNet(
|
| 82 |
+
mode='chunk',
|
| 83 |
+
d_model=d_model,
|
| 84 |
+
expand_k=1.0,
|
| 85 |
+
expand_v=expand_v,
|
| 86 |
+
num_heads=4,
|
| 87 |
+
use_beta=True,
|
| 88 |
+
use_gate=False,
|
| 89 |
+
use_short_conv=True,
|
| 90 |
+
conv_size=4,
|
| 91 |
+
allow_neg_eigval=False,
|
| 92 |
+
qk_activation='silu',
|
| 93 |
+
qk_norm='l2',
|
| 94 |
+
layer_idx=0,
|
| 95 |
+
)
|
| 96 |
+
self.norm = nn.LayerNorm(d_model)
|
| 97 |
+
|
| 98 |
+
def forward(self, x):
|
| 99 |
+
x_t = x.transpose(1, 2)
|
| 100 |
+
residual = x_t
|
| 101 |
+
if self.state_weaving and self.is_intermediate:
|
| 102 |
+
x_t = x_t.clone()
|
| 103 |
+
x_t[:, 0:1, :] = x_t[:, 0:1, :] + x_t[:, -1:, :]
|
| 104 |
+
attn_out = self.attention(hidden_states=x_t, attention_mask=None)
|
| 105 |
+
if isinstance(attn_out, tuple):
|
| 106 |
+
out = attn_out[0]
|
| 107 |
+
else:
|
| 108 |
+
out = attn_out
|
| 109 |
+
out = self.norm(out)
|
| 110 |
+
out = out + residual
|
| 111 |
+
out = out.transpose(1, 2)
|
| 112 |
+
return out
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class Model(nn.Module):
|
| 116 |
+
"""
|
| 117 |
+
Reverso: conv-deltanet hybrid for time series forecasting.
|
| 118 |
+
"""
|
| 119 |
+
def __init__(self, configs):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.seq_len = configs.seq_len
|
| 122 |
+
self.input_token_len = configs.input_token_len
|
| 123 |
+
self.output_token_len = configs.output_token_len
|
| 124 |
+
self.d_model = configs.d_model
|
| 125 |
+
self.use_norm = configs.use_norm
|
| 126 |
+
|
| 127 |
+
self.embedding = nn.Linear(1, self.d_model, bias=False)
|
| 128 |
+
self.shared_flashfftconv = FlashFFTConv(self.seq_len, dtype=torch.bfloat16)
|
| 129 |
+
|
| 130 |
+
d_intermediate = configs.d_intermediate
|
| 131 |
+
expand_v = getattr(configs, 'expand_v', 1.0)
|
| 132 |
+
state_weaving = getattr(configs, 'state_weaving', False)
|
| 133 |
+
gating_kernel_size = getattr(configs, 'gating_kernel_size', 3)
|
| 134 |
+
module_list = [m.strip() for m in configs.main_module.split(',')]
|
| 135 |
+
e_layers = len(module_list)
|
| 136 |
+
|
| 137 |
+
layers = []
|
| 138 |
+
for i, layer_type in enumerate(module_list):
|
| 139 |
+
if layer_type == 'conv':
|
| 140 |
+
layers.append(CNNBlock(
|
| 141 |
+
self.d_model, self.seq_len, self.shared_flashfftconv, gating_kernel_size,
|
| 142 |
+
))
|
| 143 |
+
elif layer_type == 'attn':
|
| 144 |
+
is_intermediate = (i > 0) and (i < e_layers - 1)
|
| 145 |
+
layers.append(AttentionBlock(
|
| 146 |
+
self.d_model, expand_v, state_weaving, is_intermediate,
|
| 147 |
+
))
|
| 148 |
+
else:
|
| 149 |
+
raise ValueError(f'Invalid layer type: {layer_type}')
|
| 150 |
+
layers.append(MLPBlock(self.d_model, self.d_model, d_intermediate))
|
| 151 |
+
self.layers = nn.Sequential(*layers)
|
| 152 |
+
|
| 153 |
+
output_bottleneck_dim = getattr(configs, 'output_bottleneck_dim', self.output_token_len)
|
| 154 |
+
self.head = nn.Linear(self.input_token_len, output_bottleneck_dim, bias=configs.learn_bias)
|
| 155 |
+
self.simple_q_proj = nn.Linear(self.d_model, self.d_model)
|
| 156 |
+
self.key_proj = nn.Linear(self.d_model, self.d_model)
|
| 157 |
+
self.value_proj = nn.Linear(self.d_model, self.d_model)
|
| 158 |
+
self.out_proj = nn.Linear(self.d_model, 1)
|
| 159 |
+
|
| 160 |
+
def forward(self, x, x_mark=None, y_mark=None, **kwargs: Any):
|
| 161 |
+
B, L, C = x.shape
|
| 162 |
+
|
| 163 |
+
if self.use_norm:
|
| 164 |
+
x_min = x.min(1, keepdim=True)[0].detach()
|
| 165 |
+
x_max = x.max(1, keepdim=True)[0].detach()
|
| 166 |
+
x_range = torch.clamp(x_max - x_min, min=1e-5).detach()
|
| 167 |
+
x = (x - x_min) / x_range
|
| 168 |
+
means = x_min
|
| 169 |
+
stdev = x_range
|
| 170 |
+
|
| 171 |
+
x = self.embedding(x).transpose(1, 2)
|
| 172 |
+
|
| 173 |
+
dec_out = self.layers(x)
|
| 174 |
+
|
| 175 |
+
temp_out = self.head(dec_out).permute(0, 2, 1)
|
| 176 |
+
q = self.simple_q_proj(temp_out)
|
| 177 |
+
|
| 178 |
+
dec_out_perm = dec_out.permute(0, 2, 1)
|
| 179 |
+
k = self.key_proj(dec_out_perm)
|
| 180 |
+
v = self.value_proj(dec_out_perm)
|
| 181 |
+
|
| 182 |
+
attn = F.scaled_dot_product_attention(q, k, v)
|
| 183 |
+
dec_out = self.out_proj(attn)
|
| 184 |
+
|
| 185 |
+
if self.use_norm:
|
| 186 |
+
dec_out = dec_out * stdev + means
|
| 187 |
+
|
| 188 |
+
return dec_out
|
| 189 |
+
|
| 190 |
+
def forecast(self, x, x_mark=None, y_mark=None, **kwargs):
|
| 191 |
+
return self.forward(x, x_mark, y_mark, **kwargs)
|