Upload Time-Series-Library
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +16 -0
- Autoformer.csv +2 -0
- CONTRIBUTING.md +20 -0
- DLinear.csv +2 -0
- Informer.csv +2 -0
- LICENSE +21 -0
- README.md +173 -0
- Reformer.csv +2 -0
- data_provider/__init__.py +1 -0
- data_provider/calculate_window_len.py +127 -0
- data_provider/data_factory.py +88 -0
- data_provider/data_loader.py +1029 -0
- data_provider/load.py +31 -0
- data_provider/m4.py +141 -0
- data_provider/uea.py +125 -0
- dataset/m4/Daily-test.csv +0 -0
- dataset/m4/Daily-train.csv +3 -0
- dataset/m4/Hourly-test.csv +0 -0
- dataset/m4/Hourly-train.csv +0 -0
- dataset/m4/M4-info.csv +0 -0
- dataset/m4/Monthly-test.csv +0 -0
- dataset/m4/Monthly-train.csv +3 -0
- dataset/m4/Quarterly-test.csv +0 -0
- dataset/m4/Quarterly-train.csv +3 -0
- dataset/m4/Weekly-test.csv +360 -0
- dataset/m4/Weekly-train.csv +0 -0
- dataset/m4/Yearly-test.csv +0 -0
- dataset/m4/Yearly-train.csv +3 -0
- dataset/m4/submission-Naive2.csv +3 -0
- dataset/m4/test.npz +3 -0
- dataset/m4/training.npz +3 -0
- dataset/poly/polymarket_data_processed_Crypto_test.jsonl +3 -0
- dataset/poly/polymarket_data_processed_Election_test.jsonl +3 -0
- dataset/poly/polymarket_data_processed_Other_test.jsonl +3 -0
- dataset/poly/polymarket_data_processed_Politics_test.jsonl +3 -0
- dataset/poly/polymarket_data_processed_Sports_test.jsonl +3 -0
- dataset/poly/polymarket_data_processed_dev.jsonl +3 -0
- dataset/poly/polymarket_data_processed_test.jsonl +3 -0
- dataset/poly/polymarket_data_processed_train.jsonl +3 -0
- exp/__init__.py +0 -0
- exp/exp_anomaly_detection.py +207 -0
- exp/exp_basic.py +79 -0
- exp/exp_classification.py +191 -0
- exp/exp_imputation.py +228 -0
- exp/exp_long_term_forecasting.py +268 -0
- exp/exp_short_term_forecasting.py +302 -0
- kalshi_results/Autoformer/Autoformer_results.csv +2 -0
- kalshi_results/DLinear/DLinear_results.csv +2 -0
- layers/AutoCorrelation.py +163 -0
- layers/Autoformer_EncDec.py +203 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,19 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
dataset/m4/Daily-train.csv filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
dataset/m4/Monthly-train.csv filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
dataset/m4/Quarterly-train.csv filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
dataset/m4/Yearly-train.csv filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
dataset/m4/submission-Naive2.csv filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
dataset/poly/polymarket_data_processed_Crypto_test.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
dataset/poly/polymarket_data_processed_Election_test.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
dataset/poly/polymarket_data_processed_Other_test.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
dataset/poly/polymarket_data_processed_Politics_test.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
dataset/poly/polymarket_data_processed_Sports_test.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
dataset/poly/polymarket_data_processed_dev.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
dataset/poly/polymarket_data_processed_test.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
dataset/poly/polymarket_data_processed_train.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
pic/dataset.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
tutorial/conv.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
tutorial/fft.png filter=lfs diff=lfs merge=lfs -text
|
Autoformer.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
,Politics_testrmse,Politics_testmae,Sports_testrmse,Sports_testmae,Crypto_testrmse,Crypto_testmae,Election_testrmse,Election_testmae,Other_testrmse,Other_testmae,testrmse,testmae
|
| 2 |
+
0,0.39830628510682603,0.32073805520677084,0.3818132712824491,0.31051893144830756,0.3828605477887506,0.308403942223605,0.4182624238680359,0.3396595166127513,0.39639385459888354,0.3268426575370005,0.37177661055284744,0.2993196869106722
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Instructions for Contributing to TSlib
|
| 2 |
+
|
| 3 |
+
Sincerely thanks to all the researchers who want to use or contribute to TSlib.
|
| 4 |
+
|
| 5 |
+
Since our team may not have enough time to fix all the bugs and catch up with the latest model, your contribution is essential to this project.
|
| 6 |
+
|
| 7 |
+
### (1) Fix Bug
|
| 8 |
+
|
| 9 |
+
You can directly propose a pull request and add detailed descriptions to the comment, such as [this pull request](https://github.com/thuml/Time-Series-Library/pull/498).
|
| 10 |
+
|
| 11 |
+
### (2) Add a new time series model
|
| 12 |
+
|
| 13 |
+
Thanks to creative researchers, extensive great TS models are presented, which advance this community significantly. If you want to add your model to TSlib, here are some instructions:
|
| 14 |
+
|
| 15 |
+
- Propose an issue to describe your model and give a link to your paper and official code. We will discuss whether your model is suitable for this library, such as [this issue](https://github.com/thuml/Time-Series-Library/issues/346).
|
| 16 |
+
- Propose a pull request in a similar style as TSlib, which means adding an additional file to ./models and providing corresponding scripts for reproduction, such as [this pull request](https://github.com/thuml/Time-Series-Library/pull/446).
|
| 17 |
+
|
| 18 |
+
Note: Given that there are a lot of TS models that have been proposed, we may not have enough time to judge which model can be a remarkable supplement to the current library. Thus, we decide ONLY to add the officially published paper to our library. Peer review can be a reliable criterion.
|
| 19 |
+
|
| 20 |
+
Thanks again for your valuable contributions.
|
DLinear.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
,Politics_testrmse,Politics_testmae,Sports_testrmse,Sports_testmae,Crypto_testrmse,Crypto_testmae,Election_testrmse,Election_testmae,Other_testrmse,Other_testmae,testrmse,testmae
|
| 2 |
+
0,0.3961972827838396,0.31730998926485693,0.37965161394318636,0.3065756351649381,0.3828125705658356,0.3088032054398474,0.4141024259329349,0.33436751696330935,0.3924656681263207,0.3223706541409388,0.37120005835210995,0.29826033160423604
|
Informer.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
,Politics_testrmse,Politics_testmae,Sports_testrmse,Sports_testmae,Crypto_testrmse,Crypto_testmae,Election_testrmse,Election_testmae,Other_testrmse,Other_testmae,testrmse,testmae
|
| 2 |
+
0,0.3974158067085375,0.31855139807181687,0.381309328795267,0.30970417269120676,0.38429771883487124,0.3100599430647862,0.4150583272414749,0.33584398433823054,0.3975700532168396,0.3287726753077512,0.3721965853006681,0.2996654640008122
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2021 THUML @ Tsinghua University
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Time Series Library (TSLib)
|
| 2 |
+
TSLib is an open-source library for deep learning researchers, especially for deep time series analysis.
|
| 3 |
+
|
| 4 |
+
We provide a neat code base to evaluate advanced deep time series models or develop your model, which covers five mainstream tasks: **long- and short-term forecasting, imputation, anomaly detection, and classification.**
|
| 5 |
+
|
| 6 |
+
:triangular_flag_on_post:**News** (2024.10) We have included [[TimeXer]](https://arxiv.org/abs/2402.19072), which defined a practical forecasting paradigm: Forecasting with Exogenous Variables. Considering both practicability and computation efficiency, we believe the new forecasting paradigm defined in TimeXer can be the "right" task for future research.
|
| 7 |
+
|
| 8 |
+
:triangular_flag_on_post:**News** (2024.10) Our lab has open-sourced [[OpenLTM]](https://github.com/thuml/OpenLTM), which provides a distinct pretrain-finetuning paradigm compared to TSLib. If you are interested in Large Time Series Models, you may find this repository helpful.
|
| 9 |
+
|
| 10 |
+
:triangular_flag_on_post:**News** (2024.07) We wrote a comprehensive survey of [[Deep Time Series Models]](https://arxiv.org/abs/2407.13278) with a rigorous benchmark based on TSLib. In this paper, we summarized the design principles of current time series models supported by insightful experiments, hoping to be helpful to future research.
|
| 11 |
+
|
| 12 |
+
:triangular_flag_on_post:**News** (2024.04) Many thanks for the great work from [frecklebars](https://github.com/thuml/Time-Series-Library/pull/378). The famous sequential model [Mamba](https://arxiv.org/abs/2312.00752) has been included in our library. See [this file](https://github.com/thuml/Time-Series-Library/blob/main/models/Mamba.py), where you need to install `mamba_ssm` with pip at first.
|
| 13 |
+
|
| 14 |
+
:triangular_flag_on_post:**News** (2024.03) Given the inconsistent look-back length of various papers, we split the long-term forecasting in the leaderboard into two categories: Look-Back-96 and Look-Back-Searching. We recommend researchers read [TimeMixer](https://openreview.net/pdf?id=7oLshfEIC2), which includes both look-back length settings in experiments for scientific rigor.
|
| 15 |
+
|
| 16 |
+
:triangular_flag_on_post:**News** (2023.10) We add an implementation to [iTransformer](https://arxiv.org/abs/2310.06625), which is the state-of-the-art model for long-term forecasting. The official code and complete scripts of iTransformer can be found [here](https://github.com/thuml/iTransformer).
|
| 17 |
+
|
| 18 |
+
:triangular_flag_on_post:**News** (2023.09) We added a detailed [tutorial](https://github.com/thuml/Time-Series-Library/blob/main/tutorial/TimesNet_tutorial.ipynb) for [TimesNet](https://openreview.net/pdf?id=ju_Uqw384Oq) and this library, which is quite friendly to beginners of deep time series analysis.
|
| 19 |
+
|
| 20 |
+
:triangular_flag_on_post:**News** (2023.02) We release the TSlib as a comprehensive benchmark and code base for time series models, which is extended from our previous GitHub repository [Autoformer](https://github.com/thuml/Autoformer).
|
| 21 |
+
|
| 22 |
+
## Leaderboard for Time Series Analysis
|
| 23 |
+
|
| 24 |
+
Till March 2024, the top three models for five different tasks are:
|
| 25 |
+
|
| 26 |
+
| Model<br>Ranking | Long-term<br>Forecasting<br>Look-Back-96 | Long-term<br/>Forecasting<br/>Look-Back-Searching | Short-term<br>Forecasting | Imputation | Classification | Anomaly<br>Detection |
|
| 27 |
+
| ---------------- | ----------------------------------------------------- | ----------------------------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | -------------------------------------------------- |
|
| 28 |
+
| 🥇 1st | [TimeXer](https://arxiv.org/abs/2402.19072) | [TimeMixer](https://openreview.net/pdf?id=7oLshfEIC2) | [TimesNet](https://arxiv.org/abs/2210.02186) | [TimesNet](https://arxiv.org/abs/2210.02186) | [TimesNet](https://arxiv.org/abs/2210.02186) | [TimesNet](https://arxiv.org/abs/2210.02186) |
|
| 29 |
+
| 🥈 2nd | [iTransformer](https://arxiv.org/abs/2310.06625) | [PatchTST](https://github.com/yuqinie98/PatchTST) | [Non-stationary<br/>Transformer](https://github.com/thuml/Nonstationary_Transformers) | [Non-stationary<br/>Transformer](https://github.com/thuml/Nonstationary_Transformers) | [Non-stationary<br/>Transformer](https://github.com/thuml/Nonstationary_Transformers) | [FEDformer](https://github.com/MAZiqing/FEDformer) |
|
| 30 |
+
| 🥉 3rd | [TimeMixer](https://openreview.net/pdf?id=7oLshfEIC2) | [DLinear](https://arxiv.org/pdf/2205.13504.pdf) | [FEDformer](https://github.com/MAZiqing/FEDformer) | [Autoformer](https://github.com/thuml/Autoformer) | [Informer](https://github.com/zhouhaoyi/Informer2020) | [Autoformer](https://github.com/thuml/Autoformer) |
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
**Note: We will keep updating this leaderboard.** If you have proposed advanced and awesome models, you can send us your paper/code link or raise a pull request. We will add them to this repo and update the leaderboard as soon as possible.
|
| 34 |
+
|
| 35 |
+
**Compared models of this leaderboard.** ☑ means that their codes have already been included in this repo.
|
| 36 |
+
- [x] **TimeXer** - TimeXer: Empowering Transformers for Time Series Forecasting with Exogenous Variables [[NeurIPS 2024]](https://arxiv.org/abs/2402.19072) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/TimeXer.py)
|
| 37 |
+
- [x] **TimeMixer** - TimeMixer: Decomposable Multiscale Mixing for Time Series Forecasting [[ICLR 2024]](https://openreview.net/pdf?id=7oLshfEIC2) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/TimeMixer.py).
|
| 38 |
+
- [x] **TSMixer** - TSMixer: An All-MLP Architecture for Time Series Forecasting [[arXiv 2023]](https://arxiv.org/pdf/2303.06053.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/TSMixer.py)
|
| 39 |
+
- [x] **iTransformer** - iTransformer: Inverted Transformers Are Effective for Time Series Forecasting [[ICLR 2024]](https://arxiv.org/abs/2310.06625) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/iTransformer.py).
|
| 40 |
+
- [x] **PatchTST** - A Time Series is Worth 64 Words: Long-term Forecasting with Transformers [[ICLR 2023]](https://openreview.net/pdf?id=Jbdc0vTOcol) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/PatchTST.py).
|
| 41 |
+
- [x] **TimesNet** - TimesNet: Temporal 2D-Variation Modeling for General Time Series Analysis [[ICLR 2023]](https://openreview.net/pdf?id=ju_Uqw384Oq) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/TimesNet.py).
|
| 42 |
+
- [x] **DLinear** - Are Transformers Effective for Time Series Forecasting? [[AAAI 2023]](https://arxiv.org/pdf/2205.13504.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/DLinear.py).
|
| 43 |
+
- [x] **LightTS** - Less Is More: Fast Multivariate Time Series Forecasting with Light Sampling-oriented MLP Structures [[arXiv 2022]](https://arxiv.org/abs/2207.01186) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/LightTS.py).
|
| 44 |
+
- [x] **ETSformer** - ETSformer: Exponential Smoothing Transformers for Time-series Forecasting [[arXiv 2022]](https://arxiv.org/abs/2202.01381) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/ETSformer.py).
|
| 45 |
+
- [x] **Non-stationary Transformer** - Non-stationary Transformers: Exploring the Stationarity in Time Series Forecasting [[NeurIPS 2022]](https://openreview.net/pdf?id=ucNDIDRNjjv) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Nonstationary_Transformer.py).
|
| 46 |
+
- [x] **FEDformer** - FEDformer: Frequency Enhanced Decomposed Transformer for Long-term Series Forecasting [[ICML 2022]](https://proceedings.mlr.press/v162/zhou22g.html) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/FEDformer.py).
|
| 47 |
+
- [x] **Pyraformer** - Pyraformer: Low-complexity Pyramidal Attention for Long-range Time Series Modeling and Forecasting [[ICLR 2022]](https://openreview.net/pdf?id=0EXmFzUn5I) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Pyraformer.py).
|
| 48 |
+
- [x] **Autoformer** - Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting [[NeurIPS 2021]](https://openreview.net/pdf?id=I55UqU-M11y) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Autoformer.py).
|
| 49 |
+
- [x] **Informer** - Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting [[AAAI 2021]](https://ojs.aaai.org/index.php/AAAI/article/view/17325/17132) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Informer.py).
|
| 50 |
+
- [x] **Reformer** - Reformer: The Efficient Transformer [[ICLR 2020]](https://openreview.net/forum?id=rkgNKkHtvB) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Reformer.py).
|
| 51 |
+
- [x] **Transformer** - Attention is All You Need [[NeurIPS 2017]](https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Transformer.py).
|
| 52 |
+
|
| 53 |
+
See our latest paper [[TimesNet]](https://arxiv.org/abs/2210.02186) for the comprehensive benchmark. We will release a real-time updated online version soon.
|
| 54 |
+
|
| 55 |
+
**Newly added baselines.** We will add them to the leaderboard after a comprehensive evaluation.
|
| 56 |
+
- [x] **MultiPatchFormer** - A multiscale model for multivariate time series forecasting [[Scientific Reports 2025]](https://www.nature.com/articles/s41598-024-82417-4) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/MultiPatchFormer.py)
|
| 57 |
+
- [x] **WPMixer** - WPMixer: Efficient Multi-Resolution Mixing for Long-Term Time Series Forecasting [[AAAI 2025]](https://arxiv.org/abs/2412.17176) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/WPMixer.py)
|
| 58 |
+
- [x] **PAttn** - Are Language Models Actually Useful for Time Series Forecasting? [[NeurIPS 2024]](https://arxiv.org/pdf/2406.16964) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/PAttn.py)
|
| 59 |
+
- [x] **Mamba** - Mamba: Linear-Time Sequence Modeling with Selective State Spaces [[arXiv 2023]](https://arxiv.org/abs/2312.00752) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Mamba.py)
|
| 60 |
+
- [x] **SegRNN** - SegRNN: Segment Recurrent Neural Network for Long-Term Time Series Forecasting [[arXiv 2023]](https://arxiv.org/abs/2308.11200.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/SegRNN.py).
|
| 61 |
+
- [x] **Koopa** - Koopa: Learning Non-stationary Time Series Dynamics with Koopman Predictors [[NeurIPS 2023]](https://arxiv.org/pdf/2305.18803.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Koopa.py).
|
| 62 |
+
- [x] **FreTS** - Frequency-domain MLPs are More Effective Learners in Time Series Forecasting [[NeurIPS 2023]](https://arxiv.org/pdf/2311.06184.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/FreTS.py).
|
| 63 |
+
- [x] **MICN** - MICN: Multi-scale Local and Global Context Modeling for Long-term Series Forecasting [[ICLR 2023]](https://openreview.net/pdf?id=zt53IDUR1U)[[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/MICN.py).
|
| 64 |
+
- [x] **Crossformer** - Crossformer: Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting [[ICLR 2023]](https://openreview.net/pdf?id=vSVLM2j9eie)[[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Crossformer.py).
|
| 65 |
+
- [x] **TiDE** - Long-term Forecasting with TiDE: Time-series Dense Encoder [[arXiv 2023]](https://arxiv.org/pdf/2304.08424.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/TiDE.py).
|
| 66 |
+
- [x] **SCINet** - SCINet: Time Series Modeling and Forecasting with Sample Convolution and Interaction [[NeurIPS 2022]](https://openreview.net/pdf?id=AyajSjTAzmg)[[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/SCINet.py).
|
| 67 |
+
- [x] **FiLM** - FiLM: Frequency improved Legendre Memory Model for Long-term Time Series Forecasting [[NeurIPS 2022]](https://openreview.net/forum?id=zTQdHSQUQWc)[[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/FiLM.py).
|
| 68 |
+
- [x] **TFT** - Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting [[arXiv 2019]](https://arxiv.org/abs/1912.09363)[[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/TemporalFusionTransformer.py).
|
| 69 |
+
|
| 70 |
+
## Usage
|
| 71 |
+
|
| 72 |
+
1. Install Python 3.8. For convenience, execute the following command.
|
| 73 |
+
|
| 74 |
+
```
|
| 75 |
+
pip install -r requirements.txt
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
2. Prepare Data. You can obtain the well pre-processed datasets from [[Google Drive]](https://drive.google.com/drive/folders/13Cg1KYOlzM5C7K8gK8NfC-F3EYxkM3D2?usp=sharing) or [[Baidu Drive]](https://pan.baidu.com/s/1r3KhGd0Q9PJIUZdfEYoymg?pwd=i9iy), Then place the downloaded data in the folder`./dataset`. Here is a summary of supported datasets.
|
| 79 |
+
|
| 80 |
+
<p align="center">
|
| 81 |
+
<img src=".\pic\dataset.png" height = "200" alt="" align=center />
|
| 82 |
+
</p>
|
| 83 |
+
|
| 84 |
+
3. Train and evaluate model. We provide the experiment scripts for all benchmarks under the folder `./scripts/`. You can reproduce the experiment results as the following examples:
|
| 85 |
+
|
| 86 |
+
```
|
| 87 |
+
# long-term forecast
|
| 88 |
+
bash ./scripts/long_term_forecast/ETT_script/TimesNet_ETTh1.sh
|
| 89 |
+
# short-term forecast
|
| 90 |
+
bash ./scripts/short_term_forecast/TimesNet_M4.sh
|
| 91 |
+
# imputation
|
| 92 |
+
bash ./scripts/imputation/ETT_script/TimesNet_ETTh1.sh
|
| 93 |
+
# anomaly detection
|
| 94 |
+
bash ./scripts/anomaly_detection/PSM/TimesNet.sh
|
| 95 |
+
# classification
|
| 96 |
+
bash ./scripts/classification/TimesNet.sh
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
4. Develop your own model.
|
| 100 |
+
|
| 101 |
+
- Add the model file to the folder `./models`. You can follow the `./models/Transformer.py`.
|
| 102 |
+
- Include the newly added model in the `Exp_Basic.model_dict` of `./exp/exp_basic.py`.
|
| 103 |
+
- Create the corresponding scripts under the folder `./scripts`.
|
| 104 |
+
|
| 105 |
+
Note:
|
| 106 |
+
|
| 107 |
+
(1) About classification: Since we include all five tasks in a unified code base, the accuracy of each subtask may fluctuate but the average performance can be reproduced (even a bit better). We have provided the reproduced checkpoints [here](https://github.com/thuml/Time-Series-Library/issues/494).
|
| 108 |
+
|
| 109 |
+
(2) About anomaly detection: Some discussion about the adjustment strategy in anomaly detection can be found [here](https://github.com/thuml/Anomaly-Transformer/issues/14). The key point is that the adjustment strategy corresponds to an event-level metric.
|
| 110 |
+
|
| 111 |
+
## Citation
|
| 112 |
+
|
| 113 |
+
If you find this repo useful, please cite our paper.
|
| 114 |
+
|
| 115 |
+
```
|
| 116 |
+
@inproceedings{wu2023timesnet,
|
| 117 |
+
title={TimesNet: Temporal 2D-Variation Modeling for General Time Series Analysis},
|
| 118 |
+
author={Haixu Wu and Tengge Hu and Yong Liu and Hang Zhou and Jianmin Wang and Mingsheng Long},
|
| 119 |
+
booktitle={International Conference on Learning Representations},
|
| 120 |
+
year={2023},
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
@article{wang2024tssurvey,
|
| 124 |
+
title={Deep Time Series Models: A Comprehensive Survey and Benchmark},
|
| 125 |
+
author={Yuxuan Wang and Haixu Wu and Jiaxiang Dong and Yong Liu and Mingsheng Long and Jianmin Wang},
|
| 126 |
+
booktitle={arXiv preprint arXiv:2407.13278},
|
| 127 |
+
year={2024},
|
| 128 |
+
}
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
## Contact
|
| 132 |
+
If you have any questions or suggestions, feel free to contact our maintenance team:
|
| 133 |
+
|
| 134 |
+
Current:
|
| 135 |
+
- Haixu Wu (Ph.D. student, wuhx23@mails.tsinghua.edu.cn)
|
| 136 |
+
- Yong Liu (Ph.D. student, liuyong21@mails.tsinghua.edu.cn)
|
| 137 |
+
- Yuxuan Wang (Ph.D. student, wangyuxu22@mails.tsinghua.edu.cn)
|
| 138 |
+
- Huikun Weng (Undergraduate, wenghk22@mails.tsinghua.edu.cn)
|
| 139 |
+
|
| 140 |
+
Previous:
|
| 141 |
+
- Tengge Hu (Master student, htg21@mails.tsinghua.edu.cn)
|
| 142 |
+
- Haoran Zhang (Master student, z-hr20@mails.tsinghua.edu.cn)
|
| 143 |
+
- Jiawei Guo (Undergraduate, guo-jw21@mails.tsinghua.edu.cn)
|
| 144 |
+
|
| 145 |
+
Or describe it in Issues.
|
| 146 |
+
|
| 147 |
+
## Acknowledgement
|
| 148 |
+
|
| 149 |
+
This project is supported by the National Key R&D Program of China (2021YFB1715200).
|
| 150 |
+
|
| 151 |
+
This library is constructed based on the following repos:
|
| 152 |
+
|
| 153 |
+
- Forecasting: https://github.com/thuml/Autoformer.
|
| 154 |
+
|
| 155 |
+
- Anomaly Detection: https://github.com/thuml/Anomaly-Transformer.
|
| 156 |
+
|
| 157 |
+
- Classification: https://github.com/thuml/Flowformer.
|
| 158 |
+
|
| 159 |
+
All the experiment datasets are public, and we obtain them from the following links:
|
| 160 |
+
|
| 161 |
+
- Long-term Forecasting and Imputation: https://github.com/thuml/Autoformer.
|
| 162 |
+
|
| 163 |
+
- Short-term Forecasting: https://github.com/ServiceNow/N-BEATS.
|
| 164 |
+
|
| 165 |
+
- Anomaly Detection: https://github.com/thuml/Anomaly-Transformer.
|
| 166 |
+
|
| 167 |
+
- Classification: https://www.timeseriesclassification.com/.
|
| 168 |
+
|
| 169 |
+
## All Thanks To Our Contributors
|
| 170 |
+
|
| 171 |
+
<a href="https://github.com/thuml/Time-Series-Library/graphs/contributors">
|
| 172 |
+
<img src="https://contrib.rocks/image?repo=thuml/Time-Series-Library" />
|
| 173 |
+
</a>
|
Reformer.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
,Politics_testrmse,Politics_testmae,Sports_testrmse,Sports_testmae,Crypto_testrmse,Crypto_testmae,Election_testrmse,Election_testmae,Other_testrmse,Other_testmae,testrmse,testmae
|
| 2 |
+
0,0.39759935719372996,0.3176409474228964,0.3804696526717909,0.3091429407280413,0.38329537668949676,0.30926928011712945,0.4144651556154994,0.3346732748303968,0.39416373838022656,0.3253543121596952,0.3719121830741643,0.2993834586683424
|
data_provider/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
data_provider/calculate_window_len.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from collections import Counter
|
| 3 |
+
|
| 4 |
+
def analyze_window_history_lengths(filepath):
|
| 5 |
+
"""
|
| 6 |
+
统计JSONL文件中所有window_history的长度分布
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
filepath: JSONL文件的路径
|
| 10 |
+
|
| 11 |
+
Returns:
|
| 12 |
+
Counter对象,包含长度分布统计
|
| 13 |
+
"""
|
| 14 |
+
length_counter = Counter()
|
| 15 |
+
total_records = 0
|
| 16 |
+
total_breakpoints = 0
|
| 17 |
+
total_window_histories = 0
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 21 |
+
for line_num, line in enumerate(f, 1):
|
| 22 |
+
line = line.strip()
|
| 23 |
+
if not line:
|
| 24 |
+
continue
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
data = json.loads(line)
|
| 28 |
+
total_records += 1
|
| 29 |
+
|
| 30 |
+
# 检查是否有daily_breakpoints字段
|
| 31 |
+
if 'daily_breakpoints' not in data:
|
| 32 |
+
print(f"警告: 第{line_num}行没有'daily_breakpoints'字段")
|
| 33 |
+
continue
|
| 34 |
+
|
| 35 |
+
daily_breakpoints = data['daily_breakpoints']
|
| 36 |
+
|
| 37 |
+
# 遍历每个breakpoint
|
| 38 |
+
for i, breakpoint in enumerate(daily_breakpoints):
|
| 39 |
+
total_breakpoints += 1
|
| 40 |
+
|
| 41 |
+
# 检查是否有window_history字段
|
| 42 |
+
if 'window_history' not in breakpoint:
|
| 43 |
+
print(f"警告: 第{line_num}行的breakpoint[{i}]没有'window_history'字段")
|
| 44 |
+
continue
|
| 45 |
+
|
| 46 |
+
window_history = breakpoint['window_history']
|
| 47 |
+
|
| 48 |
+
# 统计长度
|
| 49 |
+
if isinstance(window_history, list):
|
| 50 |
+
length = len(window_history)
|
| 51 |
+
length_counter[length] += 1
|
| 52 |
+
total_window_histories += 1
|
| 53 |
+
else:
|
| 54 |
+
print(f"警告: 第{line_num}行的breakpoint[{i}]的'window_history'不是列表")
|
| 55 |
+
|
| 56 |
+
except json.JSONDecodeError as e:
|
| 57 |
+
print(f"JSON解析错误在第{line_num}行: {e}")
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
return length_counter, total_records, total_breakpoints, total_window_histories
|
| 61 |
+
|
| 62 |
+
except FileNotFoundError:
|
| 63 |
+
print(f"文件未找到: {filepath}")
|
| 64 |
+
return None, 0, 0, 0
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"发生错误: {e}")
|
| 67 |
+
return None, 0, 0, 0
|
| 68 |
+
|
| 69 |
+
def print_statistics(length_counter, total_records, total_breakpoints, total_window_histories):
|
| 70 |
+
"""
|
| 71 |
+
打印统计结果
|
| 72 |
+
"""
|
| 73 |
+
print("\n" + "="*60)
|
| 74 |
+
print("统计摘要:")
|
| 75 |
+
print("="*60)
|
| 76 |
+
print(f"总JSON记录数: {total_records}")
|
| 77 |
+
print(f"总breakpoints数: {total_breakpoints}")
|
| 78 |
+
print(f"总window_history数: {total_window_histories}")
|
| 79 |
+
|
| 80 |
+
if not length_counter:
|
| 81 |
+
print("\n没有找到任何window_history数据")
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
print("\n" + "="*60)
|
| 85 |
+
print("Window History 长度分布:")
|
| 86 |
+
print("="*60)
|
| 87 |
+
print(f"{'长度':<10} {'数量':<10} {'百分比':<10} {'分布图'}")
|
| 88 |
+
print("-"*60)
|
| 89 |
+
|
| 90 |
+
# 按长度排序
|
| 91 |
+
for length in sorted(length_counter.keys()):
|
| 92 |
+
count = length_counter[length]
|
| 93 |
+
percentage = (count / total_window_histories) * 100
|
| 94 |
+
bar = '█' * int(percentage / 2) # 每个█代表2%
|
| 95 |
+
print(f"{length:<10} {count:<10} {percentage:>6.2f}% {bar}")
|
| 96 |
+
|
| 97 |
+
print("\n" + "="*60)
|
| 98 |
+
print("统计信息:")
|
| 99 |
+
print("="*60)
|
| 100 |
+
print(f"最小长度: {min(length_counter.keys())}")
|
| 101 |
+
print(f"最大长度: {max(length_counter.keys())}")
|
| 102 |
+
|
| 103 |
+
# 计算平均长度
|
| 104 |
+
total_length = sum(length * count for length, count in length_counter.items())
|
| 105 |
+
avg_length = total_length / total_window_histories
|
| 106 |
+
print(f"平均长度: {avg_length:.2f}")
|
| 107 |
+
|
| 108 |
+
# 计算中位数
|
| 109 |
+
sorted_lengths = []
|
| 110 |
+
for length, count in sorted(length_counter.items()):
|
| 111 |
+
sorted_lengths.extend([length] * count)
|
| 112 |
+
median_length = sorted_lengths[len(sorted_lengths) // 2]
|
| 113 |
+
print(f"中位数长度: {median_length}")
|
| 114 |
+
|
| 115 |
+
print("="*60)
|
| 116 |
+
|
| 117 |
+
# 使用示例
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
# 替换成你的JSONL文件路径
|
| 120 |
+
filepath = "/data/haofeiy2/social-world-model/data/splitted_polymarket/polymarket_data_processed_with_news_train_2024-11-01.jsonl"
|
| 121 |
+
|
| 122 |
+
print(f"正在分析文件: {filepath}")
|
| 123 |
+
length_counter, total_records, total_breakpoints, total_window_histories = \
|
| 124 |
+
analyze_window_history_lengths(filepath)
|
| 125 |
+
|
| 126 |
+
if length_counter is not None:
|
| 127 |
+
print_statistics(length_counter, total_records, total_breakpoints, total_window_histories)
|
data_provider/data_factory.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_M4, PSMSegLoader, \
|
| 2 |
+
MSLSegLoader, SMAPSegLoader, SMDSegLoader, SWATSegLoader, UEAloader, Dataset_Poly, Dataset_Kalshi
|
| 3 |
+
from data_provider.uea import collate_fn
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
|
| 6 |
+
data_dict = {
|
| 7 |
+
'ETTh1': Dataset_ETT_hour,
|
| 8 |
+
'ETTh2': Dataset_ETT_hour,
|
| 9 |
+
'ETTm1': Dataset_ETT_minute,
|
| 10 |
+
'ETTm2': Dataset_ETT_minute,
|
| 11 |
+
'custom': Dataset_Custom,
|
| 12 |
+
'm4': Dataset_M4,
|
| 13 |
+
'poly': Dataset_Poly,
|
| 14 |
+
'kalshi': Dataset_Kalshi,
|
| 15 |
+
'PSM': PSMSegLoader,
|
| 16 |
+
'MSL': MSLSegLoader,
|
| 17 |
+
'SMAP': SMAPSegLoader,
|
| 18 |
+
'SMD': SMDSegLoader,
|
| 19 |
+
'SWAT': SWATSegLoader,
|
| 20 |
+
'UEA': UEAloader
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def data_provider(args, flag):
|
| 25 |
+
Data = data_dict[args.data]
|
| 26 |
+
timeenc = 0 if args.embed != 'timeF' else 1
|
| 27 |
+
|
| 28 |
+
shuffle_flag = False if (flag == 'test' or flag == 'TEST') else True
|
| 29 |
+
drop_last = False
|
| 30 |
+
batch_size = args.batch_size
|
| 31 |
+
freq = args.freq
|
| 32 |
+
|
| 33 |
+
if args.task_name == 'anomaly_detection':
|
| 34 |
+
drop_last = False
|
| 35 |
+
data_set = Data(
|
| 36 |
+
args = args,
|
| 37 |
+
root_path=args.root_path,
|
| 38 |
+
win_size=args.seq_len,
|
| 39 |
+
flag=flag,
|
| 40 |
+
)
|
| 41 |
+
print(flag, len(data_set))
|
| 42 |
+
data_loader = DataLoader(
|
| 43 |
+
data_set,
|
| 44 |
+
batch_size=batch_size,
|
| 45 |
+
shuffle=shuffle_flag,
|
| 46 |
+
num_workers=args.num_workers,
|
| 47 |
+
drop_last=drop_last)
|
| 48 |
+
return data_set, data_loader
|
| 49 |
+
elif args.task_name == 'classification':
|
| 50 |
+
drop_last = False
|
| 51 |
+
data_set = Data(
|
| 52 |
+
args = args,
|
| 53 |
+
root_path=args.root_path,
|
| 54 |
+
flag=flag,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
data_loader = DataLoader(
|
| 58 |
+
data_set,
|
| 59 |
+
batch_size=batch_size,
|
| 60 |
+
shuffle=shuffle_flag,
|
| 61 |
+
num_workers=args.num_workers,
|
| 62 |
+
drop_last=drop_last,
|
| 63 |
+
collate_fn=lambda x: collate_fn(x, max_len=args.seq_len)
|
| 64 |
+
)
|
| 65 |
+
return data_set, data_loader
|
| 66 |
+
else:
|
| 67 |
+
if args.data == 'm4':
|
| 68 |
+
drop_last = False
|
| 69 |
+
data_set = Data(
|
| 70 |
+
args = args,
|
| 71 |
+
root_path=args.root_path,
|
| 72 |
+
data_path=args.data_path,
|
| 73 |
+
flag=flag,
|
| 74 |
+
size=[args.seq_len, args.label_len, args.pred_len],
|
| 75 |
+
features=args.features,
|
| 76 |
+
target=args.target,
|
| 77 |
+
timeenc=timeenc,
|
| 78 |
+
freq=freq,
|
| 79 |
+
seasonal_patterns=args.seasonal_patterns
|
| 80 |
+
)
|
| 81 |
+
print(flag, len(data_set))
|
| 82 |
+
data_loader = DataLoader(
|
| 83 |
+
data_set,
|
| 84 |
+
batch_size=batch_size,
|
| 85 |
+
shuffle=shuffle_flag,
|
| 86 |
+
num_workers=args.num_workers,
|
| 87 |
+
drop_last=drop_last)
|
| 88 |
+
return data_set, data_loader
|
data_provider/data_loader.py
ADDED
|
@@ -0,0 +1,1029 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import glob
|
| 5 |
+
import re
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
from sklearn.preprocessing import StandardScaler
|
| 9 |
+
from utils.timefeatures import time_features
|
| 10 |
+
from data_provider.m4 import M4Dataset, M4Meta
|
| 11 |
+
from data_provider.uea import subsample, interpolate_missing, Normalizer
|
| 12 |
+
from sktime.datasets import load_from_tsfile_to_dataframe
|
| 13 |
+
import warnings
|
| 14 |
+
from utils.augmentation import run_augmentation_single
|
| 15 |
+
import json
|
| 16 |
+
|
| 17 |
+
warnings.filterwarnings('ignore')
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Dataset_ETT_hour(Dataset):
|
| 21 |
+
def __init__(self, args, root_path, flag='train', size=None,
|
| 22 |
+
features='S', data_path='ETTh1.csv',
|
| 23 |
+
target='OT', scale=True, timeenc=0, freq='h', seasonal_patterns=None):
|
| 24 |
+
# size [seq_len, label_len, pred_len]
|
| 25 |
+
self.args = args
|
| 26 |
+
# info
|
| 27 |
+
if size == None:
|
| 28 |
+
self.seq_len = 24 * 4 * 4
|
| 29 |
+
self.label_len = 24 * 4
|
| 30 |
+
self.pred_len = 24 * 4
|
| 31 |
+
else:
|
| 32 |
+
self.seq_len = size[0]
|
| 33 |
+
self.label_len = size[1]
|
| 34 |
+
self.pred_len = size[2]
|
| 35 |
+
# init
|
| 36 |
+
assert flag in ['train', 'test', 'val']
|
| 37 |
+
type_map = {'train': 0, 'val': 1, 'test': 2}
|
| 38 |
+
self.set_type = type_map[flag]
|
| 39 |
+
|
| 40 |
+
self.features = features
|
| 41 |
+
self.target = target
|
| 42 |
+
self.scale = scale
|
| 43 |
+
self.timeenc = timeenc
|
| 44 |
+
self.freq = freq
|
| 45 |
+
|
| 46 |
+
self.root_path = root_path
|
| 47 |
+
self.data_path = data_path
|
| 48 |
+
self.__read_data__()
|
| 49 |
+
|
| 50 |
+
def __read_data__(self):
|
| 51 |
+
self.scaler = StandardScaler()
|
| 52 |
+
df_raw = pd.read_csv(os.path.join(self.root_path,
|
| 53 |
+
self.data_path))
|
| 54 |
+
|
| 55 |
+
border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len]
|
| 56 |
+
border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]
|
| 57 |
+
border1 = border1s[self.set_type]
|
| 58 |
+
border2 = border2s[self.set_type]
|
| 59 |
+
|
| 60 |
+
if self.features == 'M' or self.features == 'MS':
|
| 61 |
+
cols_data = df_raw.columns[1:]
|
| 62 |
+
df_data = df_raw[cols_data]
|
| 63 |
+
elif self.features == 'S':
|
| 64 |
+
df_data = df_raw[[self.target]]
|
| 65 |
+
|
| 66 |
+
if self.scale:
|
| 67 |
+
train_data = df_data[border1s[0]:border2s[0]]
|
| 68 |
+
self.scaler.fit(train_data.values)
|
| 69 |
+
data = self.scaler.transform(df_data.values)
|
| 70 |
+
else:
|
| 71 |
+
data = df_data.values
|
| 72 |
+
|
| 73 |
+
df_stamp = df_raw[['date']][border1:border2]
|
| 74 |
+
df_stamp['date'] = pd.to_datetime(df_stamp.date)
|
| 75 |
+
if self.timeenc == 0:
|
| 76 |
+
df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
|
| 77 |
+
df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
|
| 78 |
+
df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
|
| 79 |
+
df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
|
| 80 |
+
data_stamp = df_stamp.drop(['date'], 1).values
|
| 81 |
+
elif self.timeenc == 1:
|
| 82 |
+
data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
|
| 83 |
+
data_stamp = data_stamp.transpose(1, 0)
|
| 84 |
+
|
| 85 |
+
self.data_x = data[border1:border2]
|
| 86 |
+
self.data_y = data[border1:border2]
|
| 87 |
+
|
| 88 |
+
if self.set_type == 0 and self.args.augmentation_ratio > 0:
|
| 89 |
+
self.data_x, self.data_y, augmentation_tags = run_augmentation_single(self.data_x, self.data_y, self.args)
|
| 90 |
+
|
| 91 |
+
self.data_stamp = data_stamp
|
| 92 |
+
|
| 93 |
+
def __getitem__(self, index):
|
| 94 |
+
s_begin = index
|
| 95 |
+
s_end = s_begin + self.seq_len
|
| 96 |
+
r_begin = s_end - self.label_len
|
| 97 |
+
r_end = r_begin + self.label_len + self.pred_len
|
| 98 |
+
|
| 99 |
+
seq_x = self.data_x[s_begin:s_end]
|
| 100 |
+
seq_y = self.data_y[r_begin:r_end]
|
| 101 |
+
seq_x_mark = self.data_stamp[s_begin:s_end]
|
| 102 |
+
seq_y_mark = self.data_stamp[r_begin:r_end]
|
| 103 |
+
|
| 104 |
+
return seq_x, seq_y, seq_x_mark, seq_y_mark
|
| 105 |
+
|
| 106 |
+
def __len__(self):
|
| 107 |
+
return len(self.data_x) - self.seq_len - self.pred_len + 1
|
| 108 |
+
|
| 109 |
+
def inverse_transform(self, data):
|
| 110 |
+
return self.scaler.inverse_transform(data)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class Dataset_ETT_minute(Dataset):
|
| 114 |
+
def __init__(self, args, root_path, flag='train', size=None,
|
| 115 |
+
features='S', data_path='ETTm1.csv',
|
| 116 |
+
target='OT', scale=True, timeenc=0, freq='t', seasonal_patterns=None):
|
| 117 |
+
# size [seq_len, label_len, pred_len]
|
| 118 |
+
self.args = args
|
| 119 |
+
# info
|
| 120 |
+
if size == None:
|
| 121 |
+
self.seq_len = 24 * 4 * 4
|
| 122 |
+
self.label_len = 24 * 4
|
| 123 |
+
self.pred_len = 24 * 4
|
| 124 |
+
else:
|
| 125 |
+
self.seq_len = size[0]
|
| 126 |
+
self.label_len = size[1]
|
| 127 |
+
self.pred_len = size[2]
|
| 128 |
+
# init
|
| 129 |
+
assert flag in ['train', 'test', 'val']
|
| 130 |
+
type_map = {'train': 0, 'val': 1, 'test': 2}
|
| 131 |
+
self.set_type = type_map[flag]
|
| 132 |
+
|
| 133 |
+
self.features = features
|
| 134 |
+
self.target = target
|
| 135 |
+
self.scale = scale
|
| 136 |
+
self.timeenc = timeenc
|
| 137 |
+
self.freq = freq
|
| 138 |
+
|
| 139 |
+
self.root_path = root_path
|
| 140 |
+
self.data_path = data_path
|
| 141 |
+
self.__read_data__()
|
| 142 |
+
|
| 143 |
+
def __read_data__(self):
|
| 144 |
+
self.scaler = StandardScaler()
|
| 145 |
+
df_raw = pd.read_csv(os.path.join(self.root_path,
|
| 146 |
+
self.data_path))
|
| 147 |
+
|
| 148 |
+
border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len]
|
| 149 |
+
border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4]
|
| 150 |
+
border1 = border1s[self.set_type]
|
| 151 |
+
border2 = border2s[self.set_type]
|
| 152 |
+
|
| 153 |
+
if self.features == 'M' or self.features == 'MS':
|
| 154 |
+
cols_data = df_raw.columns[1:]
|
| 155 |
+
df_data = df_raw[cols_data]
|
| 156 |
+
elif self.features == 'S':
|
| 157 |
+
df_data = df_raw[[self.target]]
|
| 158 |
+
|
| 159 |
+
if self.scale:
|
| 160 |
+
train_data = df_data[border1s[0]:border2s[0]]
|
| 161 |
+
self.scaler.fit(train_data.values)
|
| 162 |
+
data = self.scaler.transform(df_data.values)
|
| 163 |
+
else:
|
| 164 |
+
data = df_data.values
|
| 165 |
+
|
| 166 |
+
df_stamp = df_raw[['date']][border1:border2]
|
| 167 |
+
df_stamp['date'] = pd.to_datetime(df_stamp.date)
|
| 168 |
+
if self.timeenc == 0:
|
| 169 |
+
df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
|
| 170 |
+
df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
|
| 171 |
+
df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
|
| 172 |
+
df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
|
| 173 |
+
df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1)
|
| 174 |
+
df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15)
|
| 175 |
+
data_stamp = df_stamp.drop(['date'], 1).values
|
| 176 |
+
elif self.timeenc == 1:
|
| 177 |
+
data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
|
| 178 |
+
data_stamp = data_stamp.transpose(1, 0)
|
| 179 |
+
|
| 180 |
+
self.data_x = data[border1:border2]
|
| 181 |
+
self.data_y = data[border1:border2]
|
| 182 |
+
|
| 183 |
+
if self.set_type == 0 and self.args.augmentation_ratio > 0:
|
| 184 |
+
self.data_x, self.data_y, augmentation_tags = run_augmentation_single(self.data_x, self.data_y, self.args)
|
| 185 |
+
|
| 186 |
+
self.data_stamp = data_stamp
|
| 187 |
+
|
| 188 |
+
def __getitem__(self, index):
|
| 189 |
+
s_begin = index
|
| 190 |
+
s_end = s_begin + self.seq_len
|
| 191 |
+
r_begin = s_end - self.label_len
|
| 192 |
+
r_end = r_begin + self.label_len + self.pred_len
|
| 193 |
+
|
| 194 |
+
seq_x = self.data_x[s_begin:s_end]
|
| 195 |
+
seq_y = self.data_y[r_begin:r_end]
|
| 196 |
+
seq_x_mark = self.data_stamp[s_begin:s_end]
|
| 197 |
+
seq_y_mark = self.data_stamp[r_begin:r_end]
|
| 198 |
+
|
| 199 |
+
return seq_x, seq_y, seq_x_mark, seq_y_mark
|
| 200 |
+
|
| 201 |
+
def __len__(self):
|
| 202 |
+
return len(self.data_x) - self.seq_len - self.pred_len + 1
|
| 203 |
+
|
| 204 |
+
def inverse_transform(self, data):
|
| 205 |
+
return self.scaler.inverse_transform(data)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class Dataset_Custom(Dataset):
|
| 209 |
+
def __init__(self, args, root_path, flag='train', size=None,
|
| 210 |
+
features='S', data_path='ETTh1.csv',
|
| 211 |
+
target='OT', scale=True, timeenc=0, freq='h', seasonal_patterns=None):
|
| 212 |
+
# size [seq_len, label_len, pred_len]
|
| 213 |
+
self.args = args
|
| 214 |
+
# info
|
| 215 |
+
if size == None:
|
| 216 |
+
self.seq_len = 24 * 4 * 4
|
| 217 |
+
self.label_len = 24 * 4
|
| 218 |
+
self.pred_len = 24 * 4
|
| 219 |
+
else:
|
| 220 |
+
self.seq_len = size[0]
|
| 221 |
+
self.label_len = size[1]
|
| 222 |
+
self.pred_len = size[2]
|
| 223 |
+
# init
|
| 224 |
+
assert flag in ['train', 'test', 'val']
|
| 225 |
+
type_map = {'train': 0, 'val': 1, 'test': 2}
|
| 226 |
+
self.set_type = type_map[flag]
|
| 227 |
+
|
| 228 |
+
self.features = features
|
| 229 |
+
self.target = target
|
| 230 |
+
self.scale = scale
|
| 231 |
+
self.timeenc = timeenc
|
| 232 |
+
self.freq = freq
|
| 233 |
+
|
| 234 |
+
self.root_path = root_path
|
| 235 |
+
self.data_path = data_path
|
| 236 |
+
self.__read_data__()
|
| 237 |
+
|
| 238 |
+
def __read_data__(self):
|
| 239 |
+
self.scaler = StandardScaler()
|
| 240 |
+
df_raw = pd.read_csv(os.path.join(self.root_path,
|
| 241 |
+
self.data_path))
|
| 242 |
+
|
| 243 |
+
'''
|
| 244 |
+
df_raw.columns: ['date', ...(other features), target feature]
|
| 245 |
+
'''
|
| 246 |
+
cols = list(df_raw.columns)
|
| 247 |
+
cols.remove(self.target)
|
| 248 |
+
cols.remove('date')
|
| 249 |
+
df_raw = df_raw[['date'] + cols + [self.target]]
|
| 250 |
+
num_train = int(len(df_raw) * 0.7)
|
| 251 |
+
num_test = int(len(df_raw) * 0.2)
|
| 252 |
+
num_vali = len(df_raw) - num_train - num_test
|
| 253 |
+
border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
|
| 254 |
+
border2s = [num_train, num_train + num_vali, len(df_raw)]
|
| 255 |
+
border1 = border1s[self.set_type]
|
| 256 |
+
border2 = border2s[self.set_type]
|
| 257 |
+
|
| 258 |
+
if self.features == 'M' or self.features == 'MS':
|
| 259 |
+
cols_data = df_raw.columns[1:]
|
| 260 |
+
df_data = df_raw[cols_data]
|
| 261 |
+
elif self.features == 'S':
|
| 262 |
+
df_data = df_raw[[self.target]]
|
| 263 |
+
|
| 264 |
+
if self.scale:
|
| 265 |
+
train_data = df_data[border1s[0]:border2s[0]]
|
| 266 |
+
self.scaler.fit(train_data.values)
|
| 267 |
+
data = self.scaler.transform(df_data.values)
|
| 268 |
+
else:
|
| 269 |
+
data = df_data.values
|
| 270 |
+
|
| 271 |
+
df_stamp = df_raw[['date']][border1:border2]
|
| 272 |
+
df_stamp['date'] = pd.to_datetime(df_stamp.date)
|
| 273 |
+
if self.timeenc == 0:
|
| 274 |
+
df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
|
| 275 |
+
df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
|
| 276 |
+
df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
|
| 277 |
+
df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
|
| 278 |
+
data_stamp = df_stamp.drop(['date'], 1).values
|
| 279 |
+
elif self.timeenc == 1:
|
| 280 |
+
data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
|
| 281 |
+
data_stamp = data_stamp.transpose(1, 0)
|
| 282 |
+
|
| 283 |
+
self.data_x = data[border1:border2]
|
| 284 |
+
self.data_y = data[border1:border2]
|
| 285 |
+
|
| 286 |
+
if self.set_type == 0 and self.args.augmentation_ratio > 0:
|
| 287 |
+
self.data_x, self.data_y, augmentation_tags = run_augmentation_single(self.data_x, self.data_y, self.args)
|
| 288 |
+
|
| 289 |
+
self.data_stamp = data_stamp
|
| 290 |
+
|
| 291 |
+
def __getitem__(self, index):
|
| 292 |
+
s_begin = index
|
| 293 |
+
s_end = s_begin + self.seq_len
|
| 294 |
+
r_begin = s_end - self.label_len
|
| 295 |
+
r_end = r_begin + self.label_len + self.pred_len
|
| 296 |
+
|
| 297 |
+
seq_x = self.data_x[s_begin:s_end]
|
| 298 |
+
seq_y = self.data_y[r_begin:r_end]
|
| 299 |
+
seq_x_mark = self.data_stamp[s_begin:s_end]
|
| 300 |
+
seq_y_mark = self.data_stamp[r_begin:r_end]
|
| 301 |
+
|
| 302 |
+
return seq_x, seq_y, seq_x_mark, seq_y_mark
|
| 303 |
+
|
| 304 |
+
def __len__(self):
|
| 305 |
+
return len(self.data_x) - self.seq_len - self.pred_len + 1
|
| 306 |
+
|
| 307 |
+
def inverse_transform(self, data):
|
| 308 |
+
return self.scaler.inverse_transform(data)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class Dataset_M4(Dataset):
|
| 312 |
+
def __init__(self, args, root_path, flag='pred', size=None,
|
| 313 |
+
features='S', data_path='ETTh1.csv',
|
| 314 |
+
target='OT', scale=False, inverse=False, timeenc=0, freq='15min',
|
| 315 |
+
seasonal_patterns='Yearly'):
|
| 316 |
+
# size [seq_len, label_len, pred_len]
|
| 317 |
+
# init
|
| 318 |
+
self.features = features
|
| 319 |
+
self.target = target
|
| 320 |
+
self.scale = scale
|
| 321 |
+
self.inverse = inverse
|
| 322 |
+
self.timeenc = timeenc
|
| 323 |
+
self.root_path = root_path
|
| 324 |
+
|
| 325 |
+
self.seq_len = size[0]
|
| 326 |
+
self.label_len = size[1]
|
| 327 |
+
self.pred_len = size[2]
|
| 328 |
+
|
| 329 |
+
self.seasonal_patterns = seasonal_patterns
|
| 330 |
+
self.history_size = M4Meta.history_size[seasonal_patterns]
|
| 331 |
+
self.window_sampling_limit = int(self.history_size * self.pred_len)
|
| 332 |
+
self.flag = flag
|
| 333 |
+
|
| 334 |
+
self.__read_data__()
|
| 335 |
+
|
| 336 |
+
def __read_data__(self):
|
| 337 |
+
# M4Dataset.initialize()
|
| 338 |
+
if self.flag == 'train':
|
| 339 |
+
dataset = M4Dataset.load(training=True, dataset_file=self.root_path)
|
| 340 |
+
else:
|
| 341 |
+
dataset = M4Dataset.load(training=False, dataset_file=self.root_path)
|
| 342 |
+
training_values = np.array(
|
| 343 |
+
[v[~np.isnan(v)] for v in
|
| 344 |
+
dataset.values[dataset.groups == self.seasonal_patterns]]) # split different frequencies
|
| 345 |
+
self.ids = np.array([i for i in dataset.ids[dataset.groups == self.seasonal_patterns]])
|
| 346 |
+
self.timeseries = [ts for ts in training_values]
|
| 347 |
+
import pdb
|
| 348 |
+
pdb.set_trace()
|
| 349 |
+
|
| 350 |
+
def __getitem__(self, index):
|
| 351 |
+
insample = np.zeros((self.seq_len, 1))
|
| 352 |
+
insample_mask = np.zeros((self.seq_len, 1))
|
| 353 |
+
outsample = np.zeros((self.pred_len + self.label_len, 1))
|
| 354 |
+
outsample_mask = np.zeros((self.pred_len + self.label_len, 1)) # m4 dataset
|
| 355 |
+
|
| 356 |
+
sampled_timeseries = self.timeseries[index]
|
| 357 |
+
cut_point = np.random.randint(low=max(1, len(sampled_timeseries) - self.window_sampling_limit),
|
| 358 |
+
high=len(sampled_timeseries),
|
| 359 |
+
size=1)[0]
|
| 360 |
+
|
| 361 |
+
insample_window = sampled_timeseries[max(0, cut_point - self.seq_len):cut_point]
|
| 362 |
+
insample[-len(insample_window):, 0] = insample_window
|
| 363 |
+
insample_mask[-len(insample_window):, 0] = 1.0
|
| 364 |
+
outsample_window = sampled_timeseries[
|
| 365 |
+
max(0, cut_point - self.label_len):min(len(sampled_timeseries), cut_point + self.pred_len)]
|
| 366 |
+
outsample[:len(outsample_window), 0] = outsample_window
|
| 367 |
+
outsample_mask[:len(outsample_window), 0] = 1.0
|
| 368 |
+
return insample, outsample, insample_mask, outsample_mask
|
| 369 |
+
|
| 370 |
+
def __len__(self):
|
| 371 |
+
return len(self.timeseries)
|
| 372 |
+
|
| 373 |
+
def inverse_transform(self, data):
|
| 374 |
+
return self.scaler.inverse_transform(data)
|
| 375 |
+
|
| 376 |
+
def last_insample_window(self):
|
| 377 |
+
"""
|
| 378 |
+
The last window of insample size of all timeseries.
|
| 379 |
+
This function does not support batching and does not reshuffle timeseries.
|
| 380 |
+
|
| 381 |
+
:return: Last insample window of all timeseries. Shape "timeseries, insample size"
|
| 382 |
+
"""
|
| 383 |
+
insample = np.zeros((len(self.timeseries), self.seq_len))
|
| 384 |
+
insample_mask = np.zeros((len(self.timeseries), self.seq_len))
|
| 385 |
+
for i, ts in enumerate(self.timeseries):
|
| 386 |
+
ts_last_window = ts[-self.seq_len:]
|
| 387 |
+
insample[i, -len(ts):] = ts_last_window
|
| 388 |
+
insample_mask[i, -len(ts):] = 1.0
|
| 389 |
+
return insample, insample_mask
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class Dataset_Poly(Dataset):
|
| 395 |
+
def __init__(self, args, root_path, flag='pred', size=None,
|
| 396 |
+
features='S', data_path='ETTh1.csv',
|
| 397 |
+
target='OT', scale=False, inverse=False, timeenc=0, freq='15min',
|
| 398 |
+
seasonal_patterns='Yearly'):
|
| 399 |
+
self.args = args
|
| 400 |
+
self.features = features
|
| 401 |
+
self.target = target
|
| 402 |
+
self.scale = scale
|
| 403 |
+
self.inverse = inverse
|
| 404 |
+
self.timeenc = timeenc
|
| 405 |
+
self.root_path = root_path
|
| 406 |
+
self.flag = flag
|
| 407 |
+
|
| 408 |
+
# 从 size 或 args 获取参数
|
| 409 |
+
# size = [seq_len, label_len, pred_len]
|
| 410 |
+
if size is not None:
|
| 411 |
+
self.seq_len = size[0] # 建议设为 16
|
| 412 |
+
self.label_len = size[1] # 建议设为 1
|
| 413 |
+
self.pred_len = size[2] # 建议设为 1
|
| 414 |
+
else:
|
| 415 |
+
self.seq_len = getattr(args, 'seq_len', 16)
|
| 416 |
+
self.label_len = getattr(args, 'label_len', 1)
|
| 417 |
+
self.pred_len = getattr(args, 'pred_len', 1)
|
| 418 |
+
|
| 419 |
+
self.__read_data__()
|
| 420 |
+
|
| 421 |
+
def __read_data__(self):
|
| 422 |
+
# 基础路径
|
| 423 |
+
base_path = self.root_path
|
| 424 |
+
category_path = self.root_path
|
| 425 |
+
|
| 426 |
+
# 文件路径映射 (flag -> (目录, 文件名))
|
| 427 |
+
file_map = {
|
| 428 |
+
'train': (base_path, 'polymarket_data_processed_with_news_train_2025-11-01.jsonl'),
|
| 429 |
+
'val': (base_path, 'polymarket_data_processed_with_news_train_2025-11-01.jsonl'),
|
| 430 |
+
'test': (base_path, 'polymarket_data_processed_with_news_test_2025-11-01.jsonl'),
|
| 431 |
+
# 分类别测试集(不同目录)
|
| 432 |
+
'test_Crypto': (category_path, 'polymarket_data_processed_with_news_test_2025-11-01_crypto.jsonl'),
|
| 433 |
+
'test_Politics': (category_path, 'polymarket_data_processed_with_news_test_2025-11-01_politics.jsonl'),
|
| 434 |
+
'test_Election': (category_path, 'polymarket_data_processed_with_news_test_2025-11-01_election.jsonl'),
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
self.timeseries = []
|
| 438 |
+
|
| 439 |
+
# 获取文件路径
|
| 440 |
+
if self.flag in file_map:
|
| 441 |
+
dir_path, file_name = file_map[self.flag]
|
| 442 |
+
file_path = os.path.join(dir_path, file_name)
|
| 443 |
+
else:
|
| 444 |
+
file_path = os.path.join(base_path, f'polymarket_data_processed_with_news_{self.flag}_2025-11-01.jsonl')
|
| 445 |
+
|
| 446 |
+
if not os.path.exists(file_path):
|
| 447 |
+
raise FileNotFoundError(f"Data file not found: {file_path}")
|
| 448 |
+
|
| 449 |
+
all_samples = []
|
| 450 |
+
skipped = 0
|
| 451 |
+
total = 0
|
| 452 |
+
|
| 453 |
+
with open(file_path, 'r') as fcc_file:
|
| 454 |
+
for line in fcc_file:
|
| 455 |
+
obj = json.loads(line)
|
| 456 |
+
|
| 457 |
+
if 'daily_breakpoints' not in obj:
|
| 458 |
+
continue
|
| 459 |
+
|
| 460 |
+
for bp in obj['daily_breakpoints']:
|
| 461 |
+
total += 1
|
| 462 |
+
window_history = bp.get('window_history', [])
|
| 463 |
+
|
| 464 |
+
if len(window_history) < self.seq_len + 1:
|
| 465 |
+
skipped += 1
|
| 466 |
+
continue
|
| 467 |
+
|
| 468 |
+
prices = [e['p'] for e in window_history]
|
| 469 |
+
all_samples.append(prices)
|
| 470 |
+
|
| 471 |
+
# 对 train/val 做 80/20 切分
|
| 472 |
+
if self.flag == 'train':
|
| 473 |
+
split_idx = int(len(all_samples) * 0.8)
|
| 474 |
+
self.timeseries = all_samples[:split_idx]
|
| 475 |
+
elif self.flag == 'val':
|
| 476 |
+
split_idx = int(len(all_samples) * 0.8)
|
| 477 |
+
self.timeseries = all_samples[split_idx:]
|
| 478 |
+
else:
|
| 479 |
+
self.timeseries = all_samples
|
| 480 |
+
|
| 481 |
+
print(f"[{self.flag}] Loaded {len(self.timeseries)} samples from {os.path.basename(file_path)}")
|
| 482 |
+
print(f"[{self.flag}] Skipped {skipped}/{total} (seq_len requirement: {self.seq_len + 1})")
|
| 483 |
+
|
| 484 |
+
def __getitem__(self, index):
|
| 485 |
+
sampled_timeseries = self.timeseries[index]
|
| 486 |
+
|
| 487 |
+
# ========== insample: (seq_len, 1) ==========
|
| 488 |
+
# 取最后 seq_len+1 个点,前 seq_len 个作为输入
|
| 489 |
+
# 例如:window_history 长度 17,取 [-17:-1] 共 16 个点作为输入
|
| 490 |
+
insample = np.zeros((self.seq_len, 1), dtype=np.float32)
|
| 491 |
+
insample_mask = np.zeros((self.seq_len, 1), dtype=np.float32)
|
| 492 |
+
|
| 493 |
+
# 输入:倒数第 seq_len+1 到倒数第 2 个点(不含 after)
|
| 494 |
+
input_prices = sampled_timeseries[-(self.seq_len + 1):-1]
|
| 495 |
+
insample[:, 0] = input_prices
|
| 496 |
+
insample_mask[:, 0] = 1.0
|
| 497 |
+
|
| 498 |
+
# ========== outsample: (label_len + pred_len, 1) ==========
|
| 499 |
+
# label_len=1 对应 before 点,pred_len=1 对应 after 点
|
| 500 |
+
outsample = np.zeros((self.label_len + self.pred_len, 1), dtype=np.float32)
|
| 501 |
+
outsample_mask = np.zeros((self.label_len + self.pred_len, 1), dtype=np.float32)
|
| 502 |
+
|
| 503 |
+
# outsample = [最后 label_len+pred_len 个点]
|
| 504 |
+
# 即 [before, after] 当 label_len=1, pred_len=1
|
| 505 |
+
outsample[:, 0] = sampled_timeseries[-(self.label_len + self.pred_len):]
|
| 506 |
+
outsample_mask[:, 0] = 1.0
|
| 507 |
+
|
| 508 |
+
return insample, outsample, insample_mask, outsample_mask
|
| 509 |
+
|
| 510 |
+
def __len__(self):
|
| 511 |
+
return len(self.timeseries)
|
| 512 |
+
|
| 513 |
+
def inverse_transform(self, data):
|
| 514 |
+
if hasattr(self, 'scaler') and self.scaler is not None:
|
| 515 |
+
return self.scaler.inverse_transform(data)
|
| 516 |
+
return data
|
| 517 |
+
|
| 518 |
+
def last_insample_window(self):
|
| 519 |
+
"""用于推理时获取所有时间序列的最后输入窗口"""
|
| 520 |
+
insample = np.zeros((len(self.timeseries), self.seq_len), dtype=np.float32)
|
| 521 |
+
insample_mask = np.zeros((len(self.timeseries), self.seq_len), dtype=np.float32)
|
| 522 |
+
|
| 523 |
+
for i, ts in enumerate(self.timeseries):
|
| 524 |
+
input_prices = ts[-(self.seq_len + 1):-1]
|
| 525 |
+
insample[i, :] = input_prices
|
| 526 |
+
insample_mask[i, :] = 1.0
|
| 527 |
+
|
| 528 |
+
return insample, insample_mask
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
class PSMSegLoader(Dataset):
|
| 532 |
+
def __init__(self, args, root_path, win_size, step=1, flag="train"):
|
| 533 |
+
self.flag = flag
|
| 534 |
+
self.step = step
|
| 535 |
+
self.win_size = win_size
|
| 536 |
+
self.scaler = StandardScaler()
|
| 537 |
+
data = pd.read_csv(os.path.join(root_path, 'train.csv'))
|
| 538 |
+
data = data.values[:, 1:]
|
| 539 |
+
data = np.nan_to_num(data)
|
| 540 |
+
self.scaler.fit(data)
|
| 541 |
+
data = self.scaler.transform(data)
|
| 542 |
+
test_data = pd.read_csv(os.path.join(root_path, 'test.csv'))
|
| 543 |
+
test_data = test_data.values[:, 1:]
|
| 544 |
+
test_data = np.nan_to_num(test_data)
|
| 545 |
+
self.test = self.scaler.transform(test_data)
|
| 546 |
+
self.train = data
|
| 547 |
+
data_len = len(self.train)
|
| 548 |
+
self.val = self.train[(int)(data_len * 0.8):]
|
| 549 |
+
self.test_labels = pd.read_csv(os.path.join(root_path, 'test_label.csv')).values[:, 1:]
|
| 550 |
+
print("test:", self.test.shape)
|
| 551 |
+
print("train:", self.train.shape)
|
| 552 |
+
|
| 553 |
+
def __len__(self):
|
| 554 |
+
if self.flag == "train":
|
| 555 |
+
return (self.train.shape[0] - self.win_size) // self.step + 1
|
| 556 |
+
elif (self.flag == 'val'):
|
| 557 |
+
return (self.val.shape[0] - self.win_size) // self.step + 1
|
| 558 |
+
elif (self.flag == 'test'):
|
| 559 |
+
return (self.test.shape[0] - self.win_size) // self.step + 1
|
| 560 |
+
else:
|
| 561 |
+
return (self.test.shape[0] - self.win_size) // self.win_size + 1
|
| 562 |
+
|
| 563 |
+
def __getitem__(self, index):
|
| 564 |
+
index = index * self.step
|
| 565 |
+
if self.flag == "train":
|
| 566 |
+
return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
|
| 567 |
+
elif (self.flag == 'val'):
|
| 568 |
+
return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
|
| 569 |
+
elif (self.flag == 'test'):
|
| 570 |
+
return np.float32(self.test[index:index + self.win_size]), np.float32(
|
| 571 |
+
self.test_labels[index:index + self.win_size])
|
| 572 |
+
else:
|
| 573 |
+
return np.float32(self.test[
|
| 574 |
+
index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
|
| 575 |
+
self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
class MSLSegLoader(Dataset):
|
| 579 |
+
def __init__(self, args, root_path, win_size, step=1, flag="train"):
|
| 580 |
+
self.flag = flag
|
| 581 |
+
self.step = step
|
| 582 |
+
self.win_size = win_size
|
| 583 |
+
self.scaler = StandardScaler()
|
| 584 |
+
data = np.load(os.path.join(root_path, "MSL_train.npy"))
|
| 585 |
+
self.scaler.fit(data)
|
| 586 |
+
data = self.scaler.transform(data)
|
| 587 |
+
test_data = np.load(os.path.join(root_path, "MSL_test.npy"))
|
| 588 |
+
self.test = self.scaler.transform(test_data)
|
| 589 |
+
self.train = data
|
| 590 |
+
data_len = len(self.train)
|
| 591 |
+
self.val = self.train[(int)(data_len * 0.8):]
|
| 592 |
+
self.test_labels = np.load(os.path.join(root_path, "MSL_test_label.npy"))
|
| 593 |
+
print("test:", self.test.shape)
|
| 594 |
+
print("train:", self.train.shape)
|
| 595 |
+
|
| 596 |
+
def __len__(self):
|
| 597 |
+
if self.flag == "train":
|
| 598 |
+
return (self.train.shape[0] - self.win_size) // self.step + 1
|
| 599 |
+
elif (self.flag == 'val'):
|
| 600 |
+
return (self.val.shape[0] - self.win_size) // self.step + 1
|
| 601 |
+
elif (self.flag == 'test'):
|
| 602 |
+
return (self.test.shape[0] - self.win_size) // self.step + 1
|
| 603 |
+
else:
|
| 604 |
+
return (self.test.shape[0] - self.win_size) // self.win_size + 1
|
| 605 |
+
|
| 606 |
+
def __getitem__(self, index):
|
| 607 |
+
index = index * self.step
|
| 608 |
+
if self.flag == "train":
|
| 609 |
+
return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
|
| 610 |
+
elif (self.flag == 'val'):
|
| 611 |
+
return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
|
| 612 |
+
elif (self.flag == 'test'):
|
| 613 |
+
return np.float32(self.test[index:index + self.win_size]), np.float32(
|
| 614 |
+
self.test_labels[index:index + self.win_size])
|
| 615 |
+
else:
|
| 616 |
+
return np.float32(self.test[
|
| 617 |
+
index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
|
| 618 |
+
self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class SMAPSegLoader(Dataset):
|
| 622 |
+
def __init__(self, args, root_path, win_size, step=1, flag="train"):
|
| 623 |
+
self.flag = flag
|
| 624 |
+
self.step = step
|
| 625 |
+
self.win_size = win_size
|
| 626 |
+
self.scaler = StandardScaler()
|
| 627 |
+
data = np.load(os.path.join(root_path, "SMAP_train.npy"))
|
| 628 |
+
self.scaler.fit(data)
|
| 629 |
+
data = self.scaler.transform(data)
|
| 630 |
+
test_data = np.load(os.path.join(root_path, "SMAP_test.npy"))
|
| 631 |
+
self.test = self.scaler.transform(test_data)
|
| 632 |
+
self.train = data
|
| 633 |
+
data_len = len(self.train)
|
| 634 |
+
self.val = self.train[(int)(data_len * 0.8):]
|
| 635 |
+
self.test_labels = np.load(os.path.join(root_path, "SMAP_test_label.npy"))
|
| 636 |
+
print("test:", self.test.shape)
|
| 637 |
+
print("train:", self.train.shape)
|
| 638 |
+
|
| 639 |
+
def __len__(self):
|
| 640 |
+
|
| 641 |
+
if self.flag == "train":
|
| 642 |
+
return (self.train.shape[0] - self.win_size) // self.step + 1
|
| 643 |
+
elif (self.flag == 'val'):
|
| 644 |
+
return (self.val.shape[0] - self.win_size) // self.step + 1
|
| 645 |
+
elif (self.flag == 'test'):
|
| 646 |
+
return (self.test.shape[0] - self.win_size) // self.step + 1
|
| 647 |
+
else:
|
| 648 |
+
return (self.test.shape[0] - self.win_size) // self.win_size + 1
|
| 649 |
+
|
| 650 |
+
def __getitem__(self, index):
|
| 651 |
+
index = index * self.step
|
| 652 |
+
if self.flag == "train":
|
| 653 |
+
return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
|
| 654 |
+
elif (self.flag == 'val'):
|
| 655 |
+
return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
|
| 656 |
+
elif (self.flag == 'test'):
|
| 657 |
+
return np.float32(self.test[index:index + self.win_size]), np.float32(
|
| 658 |
+
self.test_labels[index:index + self.win_size])
|
| 659 |
+
else:
|
| 660 |
+
return np.float32(self.test[
|
| 661 |
+
index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
|
| 662 |
+
self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
class SMDSegLoader(Dataset):
|
| 666 |
+
def __init__(self, args, root_path, win_size, step=100, flag="train"):
|
| 667 |
+
self.flag = flag
|
| 668 |
+
self.step = step
|
| 669 |
+
self.win_size = win_size
|
| 670 |
+
self.scaler = StandardScaler()
|
| 671 |
+
data = np.load(os.path.join(root_path, "SMD_train.npy"))
|
| 672 |
+
self.scaler.fit(data)
|
| 673 |
+
data = self.scaler.transform(data)
|
| 674 |
+
test_data = np.load(os.path.join(root_path, "SMD_test.npy"))
|
| 675 |
+
self.test = self.scaler.transform(test_data)
|
| 676 |
+
self.train = data
|
| 677 |
+
data_len = len(self.train)
|
| 678 |
+
self.val = self.train[(int)(data_len * 0.8):]
|
| 679 |
+
self.test_labels = np.load(os.path.join(root_path, "SMD_test_label.npy"))
|
| 680 |
+
|
| 681 |
+
def __len__(self):
|
| 682 |
+
if self.flag == "train":
|
| 683 |
+
return (self.train.shape[0] - self.win_size) // self.step + 1
|
| 684 |
+
elif (self.flag == 'val'):
|
| 685 |
+
return (self.val.shape[0] - self.win_size) // self.step + 1
|
| 686 |
+
elif (self.flag == 'test'):
|
| 687 |
+
return (self.test.shape[0] - self.win_size) // self.step + 1
|
| 688 |
+
else:
|
| 689 |
+
return (self.test.shape[0] - self.win_size) // self.win_size + 1
|
| 690 |
+
|
| 691 |
+
def __getitem__(self, index):
|
| 692 |
+
index = index * self.step
|
| 693 |
+
if self.flag == "train":
|
| 694 |
+
return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
|
| 695 |
+
elif (self.flag == 'val'):
|
| 696 |
+
return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
|
| 697 |
+
elif (self.flag == 'test'):
|
| 698 |
+
return np.float32(self.test[index:index + self.win_size]), np.float32(
|
| 699 |
+
self.test_labels[index:index + self.win_size])
|
| 700 |
+
else:
|
| 701 |
+
return np.float32(self.test[
|
| 702 |
+
index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
|
| 703 |
+
self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
class SWATSegLoader(Dataset):
|
| 707 |
+
def __init__(self, args, root_path, win_size, step=1, flag="train"):
|
| 708 |
+
self.flag = flag
|
| 709 |
+
self.step = step
|
| 710 |
+
self.win_size = win_size
|
| 711 |
+
self.scaler = StandardScaler()
|
| 712 |
+
|
| 713 |
+
train_data = pd.read_csv(os.path.join(root_path, 'swat_train2.csv'))
|
| 714 |
+
test_data = pd.read_csv(os.path.join(root_path, 'swat2.csv'))
|
| 715 |
+
labels = test_data.values[:, -1:]
|
| 716 |
+
train_data = train_data.values[:, :-1]
|
| 717 |
+
test_data = test_data.values[:, :-1]
|
| 718 |
+
|
| 719 |
+
self.scaler.fit(train_data)
|
| 720 |
+
train_data = self.scaler.transform(train_data)
|
| 721 |
+
test_data = self.scaler.transform(test_data)
|
| 722 |
+
self.train = train_data
|
| 723 |
+
self.test = test_data
|
| 724 |
+
data_len = len(self.train)
|
| 725 |
+
self.val = self.train[(int)(data_len * 0.8):]
|
| 726 |
+
self.test_labels = labels
|
| 727 |
+
print("test:", self.test.shape)
|
| 728 |
+
print("train:", self.train.shape)
|
| 729 |
+
|
| 730 |
+
def __len__(self):
|
| 731 |
+
"""
|
| 732 |
+
Number of images in the object dataset.
|
| 733 |
+
"""
|
| 734 |
+
if self.flag == "train":
|
| 735 |
+
return (self.train.shape[0] - self.win_size) // self.step + 1
|
| 736 |
+
elif (self.flag == 'val'):
|
| 737 |
+
return (self.val.shape[0] - self.win_size) // self.step + 1
|
| 738 |
+
elif (self.flag == 'test'):
|
| 739 |
+
return (self.test.shape[0] - self.win_size) // self.step + 1
|
| 740 |
+
else:
|
| 741 |
+
return (self.test.shape[0] - self.win_size) // self.win_size + 1
|
| 742 |
+
|
| 743 |
+
def __getitem__(self, index):
|
| 744 |
+
index = index * self.step
|
| 745 |
+
if self.flag == "train":
|
| 746 |
+
return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
|
| 747 |
+
elif (self.flag == 'val'):
|
| 748 |
+
return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
|
| 749 |
+
elif (self.flag == 'test'):
|
| 750 |
+
return np.float32(self.test[index:index + self.win_size]), np.float32(
|
| 751 |
+
self.test_labels[index:index + self.win_size])
|
| 752 |
+
else:
|
| 753 |
+
return np.float32(self.test[
|
| 754 |
+
index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
|
| 755 |
+
self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
class UEAloader(Dataset):
|
| 759 |
+
"""
|
| 760 |
+
Dataset class for datasets included in:
|
| 761 |
+
Time Series Classification Archive (www.timeseriesclassification.com)
|
| 762 |
+
Argument:
|
| 763 |
+
limit_size: float in (0, 1) for debug
|
| 764 |
+
Attributes:
|
| 765 |
+
all_df: (num_samples * seq_len, num_columns) dataframe indexed by integer indices, with multiple rows corresponding to the same index (sample).
|
| 766 |
+
Each row is a time step; Each column contains either metadata (e.g. timestamp) or a feature.
|
| 767 |
+
feature_df: (num_samples * seq_len, feat_dim) dataframe; contains the subset of columns of `all_df` which correspond to selected features
|
| 768 |
+
feature_names: names of columns contained in `feature_df` (same as feature_df.columns)
|
| 769 |
+
all_IDs: (num_samples,) series of IDs contained in `all_df`/`feature_df` (same as all_df.index.unique() )
|
| 770 |
+
labels_df: (num_samples, num_labels) pd.DataFrame of label(s) for each sample
|
| 771 |
+
max_seq_len: maximum sequence (time series) length. If None, script argument `max_seq_len` will be used.
|
| 772 |
+
(Moreover, script argument overrides this attribute)
|
| 773 |
+
"""
|
| 774 |
+
|
| 775 |
+
def __init__(self, args, root_path, file_list=None, limit_size=None, flag=None):
|
| 776 |
+
self.args = args
|
| 777 |
+
self.root_path = root_path
|
| 778 |
+
self.flag = flag
|
| 779 |
+
self.all_df, self.labels_df = self.load_all(root_path, file_list=file_list, flag=flag)
|
| 780 |
+
self.all_IDs = self.all_df.index.unique() # all sample IDs (integer indices 0 ... num_samples-1)
|
| 781 |
+
|
| 782 |
+
if limit_size is not None:
|
| 783 |
+
if limit_size > 1:
|
| 784 |
+
limit_size = int(limit_size)
|
| 785 |
+
else: # interpret as proportion if in (0, 1]
|
| 786 |
+
limit_size = int(limit_size * len(self.all_IDs))
|
| 787 |
+
self.all_IDs = self.all_IDs[:limit_size]
|
| 788 |
+
self.all_df = self.all_df.loc[self.all_IDs]
|
| 789 |
+
|
| 790 |
+
# use all features
|
| 791 |
+
self.feature_names = self.all_df.columns
|
| 792 |
+
self.feature_df = self.all_df
|
| 793 |
+
|
| 794 |
+
# pre_process
|
| 795 |
+
normalizer = Normalizer()
|
| 796 |
+
self.feature_df = normalizer.normalize(self.feature_df)
|
| 797 |
+
print(len(self.all_IDs))
|
| 798 |
+
|
| 799 |
+
def load_all(self, root_path, file_list=None, flag=None):
|
| 800 |
+
"""
|
| 801 |
+
Loads datasets from ts files contained in `root_path` into a dataframe, optionally choosing from `pattern`
|
| 802 |
+
Args:
|
| 803 |
+
root_path: directory containing all individual .ts files
|
| 804 |
+
file_list: optionally, provide a list of file paths within `root_path` to consider.
|
| 805 |
+
Otherwise, entire `root_path` contents will be used.
|
| 806 |
+
Returns:
|
| 807 |
+
all_df: a single (possibly concatenated) dataframe with all data corresponding to specified files
|
| 808 |
+
labels_df: dataframe containing label(s) for each sample
|
| 809 |
+
"""
|
| 810 |
+
# Select paths for training and evaluation
|
| 811 |
+
if file_list is None:
|
| 812 |
+
data_paths = glob.glob(os.path.join(root_path, '*')) # list of all paths
|
| 813 |
+
else:
|
| 814 |
+
data_paths = [os.path.join(root_path, p) for p in file_list]
|
| 815 |
+
if len(data_paths) == 0:
|
| 816 |
+
raise Exception('No files found using: {}'.format(os.path.join(root_path, '*')))
|
| 817 |
+
if flag is not None:
|
| 818 |
+
data_paths = list(filter(lambda x: re.search(flag, x), data_paths))
|
| 819 |
+
input_paths = [p for p in data_paths if os.path.isfile(p) and p.endswith('.ts')]
|
| 820 |
+
if len(input_paths) == 0:
|
| 821 |
+
pattern='*.ts'
|
| 822 |
+
raise Exception("No .ts files found using pattern: '{}'".format(pattern))
|
| 823 |
+
|
| 824 |
+
all_df, labels_df = self.load_single(input_paths[0]) # a single file contains dataset
|
| 825 |
+
|
| 826 |
+
return all_df, labels_df
|
| 827 |
+
|
| 828 |
+
def load_single(self, filepath):
|
| 829 |
+
df, labels = load_from_tsfile_to_dataframe(filepath, return_separate_X_and_y=True,
|
| 830 |
+
replace_missing_vals_with='NaN')
|
| 831 |
+
labels = pd.Series(labels, dtype="category")
|
| 832 |
+
self.class_names = labels.cat.categories
|
| 833 |
+
labels_df = pd.DataFrame(labels.cat.codes,
|
| 834 |
+
dtype=np.int8) # int8-32 gives an error when using nn.CrossEntropyLoss
|
| 835 |
+
|
| 836 |
+
lengths = df.applymap(
|
| 837 |
+
lambda x: len(x)).values # (num_samples, num_dimensions) array containing the length of each series
|
| 838 |
+
|
| 839 |
+
horiz_diffs = np.abs(lengths - np.expand_dims(lengths[:, 0], -1))
|
| 840 |
+
|
| 841 |
+
if np.sum(horiz_diffs) > 0: # if any row (sample) has varying length across dimensions
|
| 842 |
+
df = df.applymap(subsample)
|
| 843 |
+
|
| 844 |
+
lengths = df.applymap(lambda x: len(x)).values
|
| 845 |
+
vert_diffs = np.abs(lengths - np.expand_dims(lengths[0, :], 0))
|
| 846 |
+
if np.sum(vert_diffs) > 0: # if any column (dimension) has varying length across samples
|
| 847 |
+
self.max_seq_len = int(np.max(lengths[:, 0]))
|
| 848 |
+
else:
|
| 849 |
+
self.max_seq_len = lengths[0, 0]
|
| 850 |
+
|
| 851 |
+
# First create a (seq_len, feat_dim) dataframe for each sample, indexed by a single integer ("ID" of the sample)
|
| 852 |
+
# Then concatenate into a (num_samples * seq_len, feat_dim) dataframe, with multiple rows corresponding to the
|
| 853 |
+
# sample index (i.e. the same scheme as all datasets in this project)
|
| 854 |
+
|
| 855 |
+
df = pd.concat((pd.DataFrame({col: df.loc[row, col] for col in df.columns}).reset_index(drop=True).set_index(
|
| 856 |
+
pd.Series(lengths[row, 0] * [row])) for row in range(df.shape[0])), axis=0)
|
| 857 |
+
|
| 858 |
+
# Replace NaN values
|
| 859 |
+
grp = df.groupby(by=df.index)
|
| 860 |
+
df = grp.transform(interpolate_missing)
|
| 861 |
+
|
| 862 |
+
return df, labels_df
|
| 863 |
+
|
| 864 |
+
def instance_norm(self, case):
|
| 865 |
+
if self.root_path.count('EthanolConcentration') > 0: # special process for numerical stability
|
| 866 |
+
mean = case.mean(0, keepdim=True)
|
| 867 |
+
case = case - mean
|
| 868 |
+
stdev = torch.sqrt(torch.var(case, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
| 869 |
+
case /= stdev
|
| 870 |
+
return case
|
| 871 |
+
else:
|
| 872 |
+
return case
|
| 873 |
+
|
| 874 |
+
def __getitem__(self, ind):
|
| 875 |
+
batch_x = self.feature_df.loc[self.all_IDs[ind]].values
|
| 876 |
+
labels = self.labels_df.loc[self.all_IDs[ind]].values
|
| 877 |
+
if self.flag == "TRAIN" and self.args.augmentation_ratio > 0:
|
| 878 |
+
num_samples = len(self.all_IDs)
|
| 879 |
+
num_columns = self.feature_df.shape[1]
|
| 880 |
+
seq_len = int(self.feature_df.shape[0] / num_samples)
|
| 881 |
+
batch_x = batch_x.reshape((1, seq_len, num_columns))
|
| 882 |
+
batch_x, labels, augmentation_tags = run_augmentation_single(batch_x, labels, self.args)
|
| 883 |
+
|
| 884 |
+
batch_x = batch_x.reshape((1 * seq_len, num_columns))
|
| 885 |
+
|
| 886 |
+
return self.instance_norm(torch.from_numpy(batch_x)), \
|
| 887 |
+
torch.from_numpy(labels)
|
| 888 |
+
|
| 889 |
+
def __len__(self):
|
| 890 |
+
return len(self.all_IDs)
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
class Dataset_Kalshi(Dataset):
|
| 894 |
+
def __init__(self, args, root_path, flag='pred', size=None,
|
| 895 |
+
features='S', data_path='ETTh1.csv',
|
| 896 |
+
target='OT', scale=False, inverse=False, timeenc=0, freq='15min',
|
| 897 |
+
seasonal_patterns='Yearly'):
|
| 898 |
+
self.args = args
|
| 899 |
+
self.features = features
|
| 900 |
+
self.target = target
|
| 901 |
+
self.scale = scale
|
| 902 |
+
self.inverse = inverse
|
| 903 |
+
self.timeenc = timeenc
|
| 904 |
+
self.root_path = root_path
|
| 905 |
+
self.flag = flag
|
| 906 |
+
|
| 907 |
+
# 从 size 或 args 获取参数
|
| 908 |
+
# size = [seq_len, label_len, pred_len]
|
| 909 |
+
if size is not None:
|
| 910 |
+
self.seq_len = size[0] # 建议设为 16
|
| 911 |
+
self.label_len = size[1] # 建议设为 1
|
| 912 |
+
self.pred_len = size[2] # 建议设为 1
|
| 913 |
+
else:
|
| 914 |
+
self.seq_len = getattr(args, 'seq_len', 16)
|
| 915 |
+
self.label_len = getattr(args, 'label_len', 1)
|
| 916 |
+
self.pred_len = getattr(args, 'pred_len', 1)
|
| 917 |
+
|
| 918 |
+
self.__read_data__()
|
| 919 |
+
|
| 920 |
+
def __read_data__(self):
|
| 921 |
+
# 基础路径
|
| 922 |
+
base_path = self.root_path
|
| 923 |
+
category_path = base_path
|
| 924 |
+
|
| 925 |
+
# 文件路径映射 (flag -> (目录, 文件名))
|
| 926 |
+
file_map = {
|
| 927 |
+
'train': (base_path, 'kalshi_data_processed_with_news_train_2025-11-01.jsonl'),
|
| 928 |
+
'val': (base_path, 'kalshi_data_processed_with_news_train_2025-11-01.jsonl'),
|
| 929 |
+
'test': (base_path, 'kalshi_data_processed_with_news_test_2025-11-01.jsonl'),
|
| 930 |
+
# 分类别测试集(不同目录)
|
| 931 |
+
'test_Companies': (category_path, 'kalshi_data_processed_with_news_test_2025-11-01_companies.jsonl'),
|
| 932 |
+
'test_Economics': (category_path, 'kalshi_data_processed_with_news_test_2025-11-01_economics.jsonl'),
|
| 933 |
+
'test_Entertainment': (category_path, 'kalshi_data_processed_with_news_test_2025-11-01_entertainment.jsonl'),
|
| 934 |
+
'test_Mentions': (category_path, 'kalshi_data_processed_with_news_test_2025-11-01_mentions.jsonl'),
|
| 935 |
+
'test_Politics': (category_path, 'kalshi_data_processed_with_news_test_2025-11-01_politics.jsonl'),
|
| 936 |
+
}
|
| 937 |
+
|
| 938 |
+
self.timeseries = []
|
| 939 |
+
|
| 940 |
+
# 获取文件路径
|
| 941 |
+
if self.flag in file_map:
|
| 942 |
+
dir_path, file_name = file_map[self.flag]
|
| 943 |
+
file_path = os.path.join(dir_path, file_name)
|
| 944 |
+
else:
|
| 945 |
+
file_path = os.path.join(base_path, f'kalshi_data_processed_with_news_{self.flag}_2024-11-01.jsonl')
|
| 946 |
+
|
| 947 |
+
if not os.path.exists(file_path):
|
| 948 |
+
raise FileNotFoundError(f"Data file not found: {file_path}")
|
| 949 |
+
|
| 950 |
+
all_samples = []
|
| 951 |
+
skipped = 0
|
| 952 |
+
total = 0
|
| 953 |
+
|
| 954 |
+
with open(file_path, 'r') as fcc_file:
|
| 955 |
+
for line in fcc_file:
|
| 956 |
+
obj = json.loads(line)
|
| 957 |
+
|
| 958 |
+
if 'daily_breakpoints' not in obj:
|
| 959 |
+
continue
|
| 960 |
+
|
| 961 |
+
for bp in obj['daily_breakpoints']:
|
| 962 |
+
total += 1
|
| 963 |
+
window_history = bp.get('window_history', [])
|
| 964 |
+
|
| 965 |
+
if len(window_history) < self.seq_len + 1:
|
| 966 |
+
skipped += 1
|
| 967 |
+
continue
|
| 968 |
+
|
| 969 |
+
prices = [e['p'] for e in window_history]
|
| 970 |
+
all_samples.append(prices)
|
| 971 |
+
|
| 972 |
+
# 对 train/val 做 80/20 切分
|
| 973 |
+
if self.flag == 'train':
|
| 974 |
+
split_idx = int(len(all_samples) * 0.8)
|
| 975 |
+
self.timeseries = all_samples[:split_idx]
|
| 976 |
+
elif self.flag == 'val':
|
| 977 |
+
split_idx = int(len(all_samples) * 0.8)
|
| 978 |
+
self.timeseries = all_samples[split_idx:]
|
| 979 |
+
else:
|
| 980 |
+
self.timeseries = all_samples
|
| 981 |
+
|
| 982 |
+
print(f"[{self.flag}] Loaded {len(self.timeseries)} samples from {os.path.basename(file_path)}")
|
| 983 |
+
print(f"[{self.flag}] Skipped {skipped}/{total} (seq_len requirement: {self.seq_len + 1})")
|
| 984 |
+
|
| 985 |
+
def __getitem__(self, index):
|
| 986 |
+
sampled_timeseries = self.timeseries[index]
|
| 987 |
+
|
| 988 |
+
# ========== insample: (seq_len, 1) ==========
|
| 989 |
+
# 取最后 seq_len+1 个点,前 seq_len 个作为输入
|
| 990 |
+
# 例如:window_history 长度 17,取 [-17:-1] 共 16 个点作为输入
|
| 991 |
+
insample = np.zeros((self.seq_len, 1), dtype=np.float32)
|
| 992 |
+
insample_mask = np.zeros((self.seq_len, 1), dtype=np.float32)
|
| 993 |
+
|
| 994 |
+
# 输入:倒数第 seq_len+1 到倒数第 2 个点(不含 after)
|
| 995 |
+
input_prices = sampled_timeseries[-(self.seq_len + 1):-1]
|
| 996 |
+
insample[:, 0] = input_prices
|
| 997 |
+
insample_mask[:, 0] = 1.0
|
| 998 |
+
|
| 999 |
+
# ========== outsample: (label_len + pred_len, 1) ==========
|
| 1000 |
+
# label_len=1 对应 before 点,pred_len=1 对应 after 点
|
| 1001 |
+
outsample = np.zeros((self.label_len + self.pred_len, 1), dtype=np.float32)
|
| 1002 |
+
outsample_mask = np.zeros((self.label_len + self.pred_len, 1), dtype=np.float32)
|
| 1003 |
+
|
| 1004 |
+
# outsample = [最后 label_len+pred_len 个点]
|
| 1005 |
+
# 即 [before, after] 当 label_len=1, pred_len=1
|
| 1006 |
+
outsample[:, 0] = sampled_timeseries[-(self.label_len + self.pred_len):]
|
| 1007 |
+
outsample_mask[:, 0] = 1.0
|
| 1008 |
+
|
| 1009 |
+
return insample, outsample, insample_mask, outsample_mask
|
| 1010 |
+
|
| 1011 |
+
def __len__(self):
|
| 1012 |
+
return len(self.timeseries)
|
| 1013 |
+
|
| 1014 |
+
def inverse_transform(self, data):
|
| 1015 |
+
if hasattr(self, 'scaler') and self.scaler is not None:
|
| 1016 |
+
return self.scaler.inverse_transform(data)
|
| 1017 |
+
return data
|
| 1018 |
+
|
| 1019 |
+
def last_insample_window(self):
|
| 1020 |
+
"""用于推理时获取所有时间序列的最后输入窗口"""
|
| 1021 |
+
insample = np.zeros((len(self.timeseries), self.seq_len), dtype=np.float32)
|
| 1022 |
+
insample_mask = np.zeros((len(self.timeseries), self.seq_len), dtype=np.float32)
|
| 1023 |
+
|
| 1024 |
+
for i, ts in enumerate(self.timeseries):
|
| 1025 |
+
input_prices = ts[-(self.seq_len + 1):-1]
|
| 1026 |
+
insample[i, :] = input_prices
|
| 1027 |
+
insample_mask[i, :] = 1.0
|
| 1028 |
+
|
| 1029 |
+
return insample, insample_mask
|
data_provider/load.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
def load_first_jsonl_record(filepath):
|
| 4 |
+
try:
|
| 5 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 6 |
+
first_line = f.readline().strip()
|
| 7 |
+
if first_line:
|
| 8 |
+
return json.loads(first_line)
|
| 9 |
+
else:
|
| 10 |
+
print("文件为空")
|
| 11 |
+
return None
|
| 12 |
+
except FileNotFoundError:
|
| 13 |
+
print(f"文件未找到: {filepath}")
|
| 14 |
+
return None
|
| 15 |
+
except json.JSONDecodeError as e:
|
| 16 |
+
print(f"JSON解析错误: {e}")
|
| 17 |
+
return None
|
| 18 |
+
except Exception as e:
|
| 19 |
+
print(f"发生错误: {e}")
|
| 20 |
+
return None
|
| 21 |
+
|
| 22 |
+
# 使用示例
|
| 23 |
+
if __name__ == "__main__":
|
| 24 |
+
filepath = "/data/haofeiy2/social-world-model/data/splitted_polymarket/polymarket_data_processed_with_news_train_2024-11-01.jsonl"
|
| 25 |
+
|
| 26 |
+
first_record = load_first_jsonl_record(filepath)
|
| 27 |
+
breakpoint()
|
| 28 |
+
|
| 29 |
+
if first_record:
|
| 30 |
+
print("第一条记录:")
|
| 31 |
+
print(json.dumps(first_record, ensure_ascii=False, indent=2))
|
data_provider/m4.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This source code is provided for the purposes of scientific reproducibility
|
| 2 |
+
# under the following limited license from Element AI Inc. The code is an
|
| 3 |
+
# implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis
|
| 4 |
+
# expansion analysis for interpretable time series forecasting,
|
| 5 |
+
# https://arxiv.org/abs/1905.10437). The copyright to the source code is
|
| 6 |
+
# licensed under the Creative Commons - Attribution-NonCommercial 4.0
|
| 7 |
+
# International license (CC BY-NC 4.0):
|
| 8 |
+
# https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether
|
| 9 |
+
# for the benefit of third parties or internally in production) requires an
|
| 10 |
+
# explicit license. The subject-matter of the N-BEATS model and associated
|
| 11 |
+
# materials are the property of Element AI Inc. and may be subject to patent
|
| 12 |
+
# protection. No license to patents is granted hereunder (whether express or
|
| 13 |
+
# implied). Copyright © 2020 Element AI Inc. All rights reserved.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
M4 Dataset
|
| 17 |
+
"""
|
| 18 |
+
import logging
|
| 19 |
+
import os
|
| 20 |
+
from collections import OrderedDict
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from glob import glob
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import pandas as pd
|
| 26 |
+
import patoolib
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
import logging
|
| 29 |
+
import os
|
| 30 |
+
import pathlib
|
| 31 |
+
import sys
|
| 32 |
+
from urllib import request
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def url_file_name(url: str) -> str:
|
| 36 |
+
"""
|
| 37 |
+
Extract file name from url.
|
| 38 |
+
|
| 39 |
+
:param url: URL to extract file name from.
|
| 40 |
+
:return: File name.
|
| 41 |
+
"""
|
| 42 |
+
return url.split('/')[-1] if len(url) > 0 else ''
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def download(url: str, file_path: str) -> None:
|
| 46 |
+
"""
|
| 47 |
+
Download a file to the given path.
|
| 48 |
+
|
| 49 |
+
:param url: URL to download
|
| 50 |
+
:param file_path: Where to download the content.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def progress(count, block_size, total_size):
|
| 54 |
+
progress_pct = float(count * block_size) / float(total_size) * 100.0
|
| 55 |
+
sys.stdout.write('\rDownloading {} to {} {:.1f}%'.format(url, file_path, progress_pct))
|
| 56 |
+
sys.stdout.flush()
|
| 57 |
+
|
| 58 |
+
if not os.path.isfile(file_path):
|
| 59 |
+
opener = request.build_opener()
|
| 60 |
+
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
|
| 61 |
+
request.install_opener(opener)
|
| 62 |
+
pathlib.Path(os.path.dirname(file_path)).mkdir(parents=True, exist_ok=True)
|
| 63 |
+
f, _ = request.urlretrieve(url, file_path, progress)
|
| 64 |
+
sys.stdout.write('\n')
|
| 65 |
+
sys.stdout.flush()
|
| 66 |
+
file_info = os.stat(f)
|
| 67 |
+
logging.info(f'Successfully downloaded {os.path.basename(file_path)} {file_info.st_size} bytes.')
|
| 68 |
+
else:
|
| 69 |
+
file_info = os.stat(file_path)
|
| 70 |
+
logging.info(f'File already exists: {file_path} {file_info.st_size} bytes.')
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@dataclass()
|
| 74 |
+
class M4Dataset:
|
| 75 |
+
ids: np.ndarray
|
| 76 |
+
groups: np.ndarray
|
| 77 |
+
frequencies: np.ndarray
|
| 78 |
+
horizons: np.ndarray
|
| 79 |
+
values: np.ndarray
|
| 80 |
+
|
| 81 |
+
@staticmethod
|
| 82 |
+
def load(training: bool = True, dataset_file: str = '../dataset/m4') -> 'M4Dataset':
|
| 83 |
+
"""
|
| 84 |
+
Load cached dataset.
|
| 85 |
+
|
| 86 |
+
:param training: Load training part if training is True, test part otherwise.
|
| 87 |
+
"""
|
| 88 |
+
info_file = os.path.join(dataset_file, 'M4-info.csv')
|
| 89 |
+
train_cache_file = os.path.join(dataset_file, 'training.npz')
|
| 90 |
+
test_cache_file = os.path.join(dataset_file, 'test.npz')
|
| 91 |
+
m4_info = pd.read_csv(info_file)
|
| 92 |
+
m4dataset = M4Dataset(ids=m4_info.M4id.values,
|
| 93 |
+
groups=m4_info.SP.values,
|
| 94 |
+
frequencies=m4_info.Frequency.values,
|
| 95 |
+
horizons=m4_info.Horizon.values,
|
| 96 |
+
values=np.load(
|
| 97 |
+
train_cache_file if training else test_cache_file,
|
| 98 |
+
allow_pickle=True))
|
| 99 |
+
# import pdb
|
| 100 |
+
# pdb.set_trace()
|
| 101 |
+
return m4dataset
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@dataclass()
|
| 105 |
+
class M4Meta:
|
| 106 |
+
seasonal_patterns = ['Yearly', 'Quarterly', 'Monthly', 'Weekly', 'Daily', 'Hourly']
|
| 107 |
+
horizons = [6, 8, 18, 13, 14, 48]
|
| 108 |
+
frequencies = [1, 4, 12, 1, 1, 24]
|
| 109 |
+
horizons_map = {
|
| 110 |
+
'Yearly': 6,
|
| 111 |
+
'Quarterly': 8,
|
| 112 |
+
'Monthly': 18,
|
| 113 |
+
'Weekly': 13,
|
| 114 |
+
'Daily': 14,
|
| 115 |
+
'Hourly': 48
|
| 116 |
+
} # different predict length
|
| 117 |
+
frequency_map = {
|
| 118 |
+
'Yearly': 1,
|
| 119 |
+
'Quarterly': 4,
|
| 120 |
+
'Monthly': 12,
|
| 121 |
+
'Weekly': 1,
|
| 122 |
+
'Daily': 1,
|
| 123 |
+
'Hourly': 24
|
| 124 |
+
}
|
| 125 |
+
history_size = {
|
| 126 |
+
'Yearly': 1.5,
|
| 127 |
+
'Quarterly': 1.5,
|
| 128 |
+
'Monthly': 1.5,
|
| 129 |
+
'Weekly': 10,
|
| 130 |
+
'Daily': 10,
|
| 131 |
+
'Hourly': 10
|
| 132 |
+
} # from interpretable.gin
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def load_m4_info() -> pd.DataFrame:
|
| 136 |
+
"""
|
| 137 |
+
Load M4Info file.
|
| 138 |
+
|
| 139 |
+
:return: Pandas DataFrame of M4Info.
|
| 140 |
+
"""
|
| 141 |
+
return pd.read_csv(INFO_FILE_PATH)
|
data_provider/uea.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def collate_fn(data, max_len=None):
|
| 8 |
+
"""Build mini-batch tensors from a list of (X, mask) tuples. Mask input. Create
|
| 9 |
+
Args:
|
| 10 |
+
data: len(batch_size) list of tuples (X, y).
|
| 11 |
+
- X: torch tensor of shape (seq_length, feat_dim); variable seq_length.
|
| 12 |
+
- y: torch tensor of shape (num_labels,) : class indices or numerical targets
|
| 13 |
+
(for classification or regression, respectively). num_labels > 1 for multi-task models
|
| 14 |
+
max_len: global fixed sequence length. Used for architectures requiring fixed length input,
|
| 15 |
+
where the batch length cannot vary dynamically. Longer sequences are clipped, shorter are padded with 0s
|
| 16 |
+
Returns:
|
| 17 |
+
X: (batch_size, padded_length, feat_dim) torch tensor of masked features (input)
|
| 18 |
+
targets: (batch_size, padded_length, feat_dim) torch tensor of unmasked features (output)
|
| 19 |
+
target_masks: (batch_size, padded_length, feat_dim) boolean torch tensor
|
| 20 |
+
0 indicates masked values to be predicted, 1 indicates unaffected/"active" feature values
|
| 21 |
+
padding_masks: (batch_size, padded_length) boolean tensor, 1 means keep vector at this position, 0 means padding
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
batch_size = len(data)
|
| 25 |
+
features, labels = zip(*data)
|
| 26 |
+
|
| 27 |
+
# Stack and pad features and masks (convert 2D to 3D tensors, i.e. add batch dimension)
|
| 28 |
+
lengths = [X.shape[0] for X in features] # original sequence length for each time series
|
| 29 |
+
if max_len is None:
|
| 30 |
+
max_len = max(lengths)
|
| 31 |
+
|
| 32 |
+
X = torch.zeros(batch_size, max_len, features[0].shape[-1]) # (batch_size, padded_length, feat_dim)
|
| 33 |
+
for i in range(batch_size):
|
| 34 |
+
end = min(lengths[i], max_len)
|
| 35 |
+
X[i, :end, :] = features[i][:end, :]
|
| 36 |
+
|
| 37 |
+
targets = torch.stack(labels, dim=0) # (batch_size, num_labels)
|
| 38 |
+
|
| 39 |
+
padding_masks = padding_mask(torch.tensor(lengths, dtype=torch.int16),
|
| 40 |
+
max_len=max_len) # (batch_size, padded_length) boolean tensor, "1" means keep
|
| 41 |
+
|
| 42 |
+
return X, targets, padding_masks
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def padding_mask(lengths, max_len=None):
|
| 46 |
+
"""
|
| 47 |
+
Used to mask padded positions: creates a (batch_size, max_len) boolean mask from a tensor of sequence lengths,
|
| 48 |
+
where 1 means keep element at this position (time step)
|
| 49 |
+
"""
|
| 50 |
+
batch_size = lengths.numel()
|
| 51 |
+
max_len = max_len or lengths.max_val() # trick works because of overloading of 'or' operator for non-boolean types
|
| 52 |
+
return (torch.arange(0, max_len, device=lengths.device)
|
| 53 |
+
.type_as(lengths)
|
| 54 |
+
.repeat(batch_size, 1)
|
| 55 |
+
.lt(lengths.unsqueeze(1)))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Normalizer(object):
|
| 59 |
+
"""
|
| 60 |
+
Normalizes dataframe across ALL contained rows (time steps). Different from per-sample normalization.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(self, norm_type='standardization', mean=None, std=None, min_val=None, max_val=None):
|
| 64 |
+
"""
|
| 65 |
+
Args:
|
| 66 |
+
norm_type: choose from:
|
| 67 |
+
"standardization", "minmax": normalizes dataframe across ALL contained rows (time steps)
|
| 68 |
+
"per_sample_std", "per_sample_minmax": normalizes each sample separately (i.e. across only its own rows)
|
| 69 |
+
mean, std, min_val, max_val: optional (num_feat,) Series of pre-computed values
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
self.norm_type = norm_type
|
| 73 |
+
self.mean = mean
|
| 74 |
+
self.std = std
|
| 75 |
+
self.min_val = min_val
|
| 76 |
+
self.max_val = max_val
|
| 77 |
+
|
| 78 |
+
def normalize(self, df):
|
| 79 |
+
"""
|
| 80 |
+
Args:
|
| 81 |
+
df: input dataframe
|
| 82 |
+
Returns:
|
| 83 |
+
df: normalized dataframe
|
| 84 |
+
"""
|
| 85 |
+
if self.norm_type == "standardization":
|
| 86 |
+
if self.mean is None:
|
| 87 |
+
self.mean = df.mean()
|
| 88 |
+
self.std = df.std()
|
| 89 |
+
return (df - self.mean) / (self.std + np.finfo(float).eps)
|
| 90 |
+
|
| 91 |
+
elif self.norm_type == "minmax":
|
| 92 |
+
if self.max_val is None:
|
| 93 |
+
self.max_val = df.max()
|
| 94 |
+
self.min_val = df.min()
|
| 95 |
+
return (df - self.min_val) / (self.max_val - self.min_val + np.finfo(float).eps)
|
| 96 |
+
|
| 97 |
+
elif self.norm_type == "per_sample_std":
|
| 98 |
+
grouped = df.groupby(by=df.index)
|
| 99 |
+
return (df - grouped.transform('mean')) / grouped.transform('std')
|
| 100 |
+
|
| 101 |
+
elif self.norm_type == "per_sample_minmax":
|
| 102 |
+
grouped = df.groupby(by=df.index)
|
| 103 |
+
min_vals = grouped.transform('min')
|
| 104 |
+
return (df - min_vals) / (grouped.transform('max') - min_vals + np.finfo(float).eps)
|
| 105 |
+
|
| 106 |
+
else:
|
| 107 |
+
raise (NameError(f'Normalize method "{self.norm_type}" not implemented'))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def interpolate_missing(y):
|
| 111 |
+
"""
|
| 112 |
+
Replaces NaN values in pd.Series `y` using linear interpolation
|
| 113 |
+
"""
|
| 114 |
+
if y.isna().any():
|
| 115 |
+
y = y.interpolate(method='linear', limit_direction='both')
|
| 116 |
+
return y
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def subsample(y, limit=256, factor=2):
|
| 120 |
+
"""
|
| 121 |
+
If a given Series is longer than `limit`, returns subsampled sequence by the specified integer factor
|
| 122 |
+
"""
|
| 123 |
+
if len(y) > limit:
|
| 124 |
+
return y[::factor].reset_index(drop=True)
|
| 125 |
+
return y
|
dataset/m4/Daily-test.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/m4/Daily-train.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:78e94591c60c06309f1e544fd7b2ccba28f3616b01913dd336a1d1d98483a1ec
|
| 3 |
+
size 95765153
|
dataset/m4/Hourly-test.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/m4/Hourly-train.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/m4/M4-info.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/m4/Monthly-test.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/m4/Monthly-train.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:63aaa56198b4a22279a2fec541f56921f3bed6d8e0f84e550206df9276f2e9b9
|
| 3 |
+
size 91655432
|
dataset/m4/Quarterly-test.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/m4/Quarterly-train.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4450678cd493aa965c14bde984a5b62a66d8c3b2f7af1e53779bc91ff008ccdc
|
| 3 |
+
size 38788547
|
dataset/m4/Weekly-test.csv
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"V1","V2","V3","V4","V5","V6","V7","V8","V9","V10","V11","V12","V13","V14"
|
| 2 |
+
"W1","35397.16","35808.59","35808.59","36246.14","36246.14","36403.7","36403.7","36150.2","36150.2","35790.55","35790.55","34066.95","34066.95"
|
| 3 |
+
"W2","3608.061","3624.368","3537.64","3587.344","3662.088","3658.634","3622.148","3591.48","3638.559","3635.781","3522.328","3458.631","3354.758"
|
| 4 |
+
"W3","9602.4","9895.9","9915.9","9726.6","9626.5","9966.1","10008.5","9911.7","9760.7","9968.2","10089.1","10005.9","9823.3"
|
| 5 |
+
"W4","2336.56","2373.14","2457.51","2586","2237.7","2877.94","2497.74","3525.34","3053.33","3151.07","3425.79","3505.61","4098.69"
|
| 6 |
+
"W5","1499","1734","1391","1833","1637","2122","1757","2079","2477","1957","1907","1752","1330"
|
| 7 |
+
"W6","502","568","528","632","606","785","779","834","906","1372","1080","938","752"
|
| 8 |
+
"W7","3117","2579","3199","1716","4558","3001","3263","3110","5899","7148","4048","2707","2579"
|
| 9 |
+
"W8","3430","2870","4620","4890","7470","6960","6430","3970","3650","5730","4500","2870","3270"
|
| 10 |
+
"W9","8100","7481","8953","8087","9724","8993","10011","9705","8397","7560","8669","7574","7106"
|
| 11 |
+
"W10","5122.14","5104.68","5104.68","5104.68","5565.7","5565.7","5565.7","5565.7","5565.7","5544.94","5544.94","5544.94","5544.94"
|
| 12 |
+
"W11","337.19","359.52","357.62","368.22","328.77","391.31","340.06","323.19","323.3","339.94","316.45","295.99","280.07"
|
| 13 |
+
"W12","1092","1085","1078","1071","1064","1057","1050","1043","1092","1085","1078","1071","1064"
|
| 14 |
+
"W13","2373.445","2373.105","2453.974","2390.353","2324.83","2337.239","2375.971","2398.802","2350.108","2363.517","2484.652","2558.814","2653.311"
|
| 15 |
+
"W14","2341.516","2407.573","2435.3","2361.92","2310.927","2322.117","2408.718","2476.214","2361.203","2393.436","2552.503","2554.351","2500.651"
|
| 16 |
+
"W15","4231","4311.68","3848.09","4149.38","3633.49","3553.49","3743.59","4626.91","4235.48","4288.81","5601.72","5737.57","5232.13"
|
| 17 |
+
"W16","4231","4311.68","3848.09","4149.38","3633.49","3553.49","3743.59","4626.91","4235.48","4288.81","5601.72","5737.57","5232.13"
|
| 18 |
+
"W17","12027.16","12027.2","12027.24","11985.62","11985.74","12432.86","12432.99","12418.26","12418.32","12418.37","12418.43","12418.49","12323.22"
|
| 19 |
+
"W18","4334.57","4334.61","4334.65","4351.65","4351.77","3846.99","3954.85","3996.48","3992.55","3992.61","3992.68","3992.75","4029.96"
|
| 20 |
+
"W19","2754","3035","2591","2808","2324","4033","3146","3402","3677","4790","7520","4445","2960"
|
| 21 |
+
"W20","3942","4184","3374","4291","2657","4258","4050","3730","3311","3154","4551","4677","3944"
|
| 22 |
+
"W21","2203","2234","1786","2582","2838","2949","2483","2184","3809","3395","3194","2744","2054"
|
| 23 |
+
"W22","3780","4430","3380","4100","2710","3980","3720","3300","3460","2760","4730","3930","3180"
|
| 24 |
+
"W23","753.8","767.6","722.9","769.5","536.6","756.4","813","732.2","617.9","617","901.1","872.9","739.2"
|
| 25 |
+
"W24","6510","6109","5215","6798","4354","10041","7637","6137","7002","8575","19189","12646","6956"
|
| 26 |
+
"W25","1431","1679","1348","1870","2110","2492","2701","2649","2745","2676","3017","2669","2017"
|
| 27 |
+
"W26","2672","3061","2947","4293","3409","4660","5136","6235","5580","4585","6168","4943","3945"
|
| 28 |
+
"W27","2144","2144","1752","2314","2645","2643","3120","3302","3802","2762","3713","4473","2643"
|
| 29 |
+
"W28","7793","3341","3003","3507","2555","3167","3300","3507","8089","8507","5611","4081","3025"
|
| 30 |
+
"W29","851","1017","907","1075","833","1716","1339","1312","1345","1322","2142","1732","1143"
|
| 31 |
+
"W30","3623","4016","3075","4373","3773","5121","4129","4785","6509","7561","6183","5485","4223"
|
| 32 |
+
"W31","708","829","796","1026","1483","1412","1665","1696","2150","1948","1778","1156","1053"
|
| 33 |
+
"W32","5029","3990","5803","3980","6833","5247","6023","7372","6529","7042","6642","6607","7307"
|
| 34 |
+
"W33","1635.1","1470","1831.2","1860.7","2952.6","2076","1926.5","2127.6","2526.5","3130.5","2901.6","1965.4","1930.4"
|
| 35 |
+
"W34","1528","1485","1676","1434","2556","2017","1752","1575","1478","2618","2045","1480","1464"
|
| 36 |
+
"W35","1240","938","1641","1225","2143","2298","1561","1253","1257","2191","2137","1662","1565"
|
| 37 |
+
"W36","18469.2","14697.3","14490.1","16890.6","13307.5","11488","11021.4","13786","17773.2","20125.2","13053.1","14418.3","18561.2"
|
| 38 |
+
"W37","2891.8","2870","2839.7","2925.6","2858.2","2873.3","2857.3","2832.8","2818.5","2843","2863.1","2838.6","2827.4"
|
| 39 |
+
"W38","3050.13","2991.26","2966.63","3079.93","3027.55","3046.13","3052.31","3028.4","3042.15","3060.21","3044.19","3005.94","2999.31"
|
| 40 |
+
"W39","5820.04","5811.86","5764.85","5774.63","5782.51","5729.81","5736.21","5784.8","5783.06","5742.76","5732.81","5730.19","5678.13"
|
| 41 |
+
"W40","5816.795","5807.545","5761.31","5766.93","5777.23","5727.54","5732.42","5780.82","5776.885","5739.61","5732.305","5721.025","5674.69"
|
| 42 |
+
"W41","6138.53","6046.39","6014.57","6020.12","6111.92","6065.35","6090.68","6170.24","6206.4","6161.53","6092.33","6028.87","6039.88"
|
| 43 |
+
"W42","6145.38","6055.38","6021.96","6036.22","6123.1","6070.16","6098.74","6178.74","6219.68","6168.3","6093.4","6048.22","6047.21"
|
| 44 |
+
"W43","6141.955","6050.885","6018.265","6028.17","6117.51","6067.755","6094.71","6174.49","6213.04","6164.915","6092.865","6038.545","6043.545"
|
| 45 |
+
"W44","5813.55","5803.23","5757.77","5759.23","5771.95","5725.27","5728.63","5776.84","5770.71","5736.46","5731.8","5711.86","5671.25"
|
| 46 |
+
"W45","2934.68","2967.95","3071.57","3101.51","3101.51","3136.93","3189.06","3165.62","3069.02","2990.9","2965.64","2997.27","3001.98"
|
| 47 |
+
"W46","3156.835","3088.77","3129.65","3163.045","3114.31","3110.356","3123.937","3136.711","3158.378","3095.771","3042.689","3026.794","3062.599"
|
| 48 |
+
"W47","2933.38","2967.36","3070.66","3098.73","3098.73","3134.37","3186.88","3164.56","3067.65","2988.52","2963.47","2994.78","2999.43"
|
| 49 |
+
"W48","2964.555","2996.025","3000.705","2941.08","2935.96","2938.24","2936.72","2936.66","2882.2","2862.63","2867.64","2875.68","2871.67"
|
| 50 |
+
"W49","3193.28","3248.07","3406.28","3378.86","3335.78","3331.52","3375.54","3365.51","3264.59","3155.58","3087.64","3128.35","3161.7"
|
| 51 |
+
"W50","3194.69","3248.72","3407.29","3381.89","3338.78","3334.24","3377.85","3366.64","3266.05","3158.09","3089.9","3130.95","3164.39"
|
| 52 |
+
"W51","6549.4","6749.6","6672.9","6664.1","6601.3","6601.8","6483.1","6488.7","6431.8","6468.2","6418.9","6383.5","6403.6"
|
| 53 |
+
"W52","6980.8","7049.9","6985.1","7060.9","7021.9","7015","6921.2","6940.5","6942","6894.3","6783.9","6801.8","6773.4"
|
| 54 |
+
"W53","1915.5","1845.1","1867.4","2142","2309.8","2248.6","2182.6","2050.6","1986.8","1865.6","1912.3","1982.7","1933.5"
|
| 55 |
+
"W54","2483","2491","2474","2491","2483","2460","2473","2448","2429","2411","2434","2445","2439"
|
| 56 |
+
"W55","5056","4936","4551","4461","4515","4725","4863","5097","5191","5195","5328","5342","5207"
|
| 57 |
+
"W56","3030","3099","3193","3261","3334","3440","3538","3633","3733","3814","3877","3929","3978"
|
| 58 |
+
"W57","2070","1970","1830","1730","1640","1550","1500","1460","1460","1390","1370","1370","1410"
|
| 59 |
+
"W58","3485","3651","3681","3704","3737","3907","3747","3820","3842","3649","3414","3116","2984"
|
| 60 |
+
"W59","2350","2040","2030","1600","1160","2140","2780","1910","2180","2170","2010","1690","1880"
|
| 61 |
+
"W60","9160.47","9180.81","9257.59","9244.22","9195.97","9383.28","9480.43","9506.87","9431.19","9450.03","9630.17","9660.81","9601.98"
|
| 62 |
+
"W61","10128.3","10802.8","10802.8","10971.1","10971.1","10436.8","10436.8","11025.5","11025.5","10135.8","10135.8","10934.7","10934.7"
|
| 63 |
+
"W62","7208.8","7061.6","7061.6","7018.9","7018.9","7157.6","7157.6","7144.4","7144.4","7536.5","7536.5","7620.1","7620.1"
|
| 64 |
+
"W63","6062.1","6110.6","6110.6","6119.9","6119.9","6115.9","6115.9","6194.7","6194.7","6386.4","6386.4","6649.3","6649.3"
|
| 65 |
+
"W64","14834","12734","12772","13825","15080","12824","12879","13769","14767","13645","12832","13629","15073"
|
| 66 |
+
"W65","20461","17999","18027","19270","20726","18118","18144","19242","20438","19136","18135","19163","20871"
|
| 67 |
+
"W66","7318.7","7580.2","7605.2","7418.2","7320.8","7633.4","7663.4","7564.1","7418.4","7616","7729","7645","7464.2"
|
| 68 |
+
"W67","13361","13641","13485","13635","13869","13787","13846","13810","13958","13995","13987","14185","14413"
|
| 69 |
+
"W68","6925.91","6949.42","6997.28","7002.45","6998.8","6970.53","7019.43","7039.14","7041.96","6974.79","6967.61","6968.64","6957.49"
|
| 70 |
+
"W69","6930.32","6983.87","7009.79","7004","7003.5","6971.28","7036.53","7049.97","7035.15","6965.37","6967.94","6962.49","6960.41"
|
| 71 |
+
"W70","2612.74","2607.29","2583.84","2601.24","2597.5","2591.6","2648.13","2649.71","2654.19","2655.63","2614.58","2613.64","2614.93"
|
| 72 |
+
"W71","1930","1927","1927","1907","1906","1903","1903","1902","1897","1897","1896","1893","1892"
|
| 73 |
+
"W72","1930","1927","1927","1907","1906","1903","1903","1902","1897","1897","1896","1893","1892"
|
| 74 |
+
"W73","1930","1927","1927","1907","1906","1903","1903","1902","1897","1897","1896","1893","1892"
|
| 75 |
+
"W74","4494.5","3936.22","4013.09","3951.31","3701.88","3380.45","3652.31","3893.49","4156.82","4353.71","4927.72","5405.49","6528.25"
|
| 76 |
+
"W75","23442.9","23935.8","24274.4","23870.6","23235.6","23930.5","25498.8","25416.7","23773.6","25931.8","25269.5","24984.5","25674.5"
|
| 77 |
+
"W76","440","440","440","440","440","440","440","440","240","240","240","240","240"
|
| 78 |
+
"W77","3977.66","4541.82","5291.64","4690.9","4198.24","4610.46","4725.98","4681.88","3926.7","3826.87","4457.07","4650.71","4440.09"
|
| 79 |
+
"W78","3923.9","4525.96","5268.49","4149.73","4106.37","4300.7","4930.68","4726.57","3929.57","4190.49","4428.39","4249.56","4221.58"
|
| 80 |
+
"W79","1560","1720","1770","1820","1860","1930","1790","1660","1620","1600","1590","1640","1670"
|
| 81 |
+
"W80","1720","1760","1810","1840","1950","1820","1670","1620","1600","1580","1640","1660","1760"
|
| 82 |
+
"W81","23057.9","24044.6","23942.2","23970.5","23519","23811.9","24756.7","24982.2","24388.4","24379.1","25518.2","24728","25659.7"
|
| 83 |
+
"W82","4417.421","4434.675","4430.035","4412.757","4414.728","4419.787","4422.151","4410.932","4409.458","4417.371","4423.566","4427.224","4414.39"
|
| 84 |
+
"W83","6001.5","5974.5","5931.6","6028.6","6053","6064","6017.2","5972.1","5997.6","5991.1","5939.8","5939.5","5940.6"
|
| 85 |
+
"W84","6016.7","5923.8","5945","6050.1","6069.5","6022.3","5944.7","5985.6","5996","5978","5937.4","5937.5","5944"
|
| 86 |
+
"W85","17289","17312","17335","17364","17429","17501","17567","17632","17683","17718","17753","17791","17808"
|
| 87 |
+
"W86","17297","17320","17344","17385","17454","17522","17601","17661","17696","17731","17766","17801","17791"
|
| 88 |
+
"W87","1094.56","1094.52","1092.83","1091.55","1098.78","1099.02","1096.83","1093.89","1089.57","1099.3","1098.45","1095.05","1088.79"
|
| 89 |
+
"W88","12154.01","12042.04","12027.43","11998.81","12089.79","12025.85","12079.76","12053.18","12179.34","12133.8","12093.14","12163.75","12117.78"
|
| 90 |
+
"W89","1185.62","1180.63","1201.97","1132.26","1087.99","1095.31","1151.72","1115.15","1107.33","1089.42","1084.26","1103.96","1071.63"
|
| 91 |
+
"W90","1828.128","1861.782","1872.236","1854.954","1852.089","1854.091","1868.669","1893.618","1859.457","1860.784","1871.368","1829.142","1818.838"
|
| 92 |
+
"W91","19466.9","19798.45","19924.33","19681.8","19608.88","19636.22","19838.41","20051.33","19701.9","19697.26","19797.94","19395.38","19260.01"
|
| 93 |
+
"W92","21837.03","22832.08","22198.61","22399.36","22658.55","22969.77","22908.66","22311.88","21853.35","22670.54","22686.31","21102.33","21032.34"
|
| 94 |
+
"W93","1748.563","1753.833","1740.646","1741.955","1753.649","1760.835","1747.62","1745.674","1751.342","1758.483","1747.428","1749.617","1753.035"
|
| 95 |
+
"W94","1802.54","1986.53","1931.14","1831.39","1948.19","1948.12","2000.27","2077.83","2068.5","2093.54","2174.32","2052.07","2108"
|
| 96 |
+
"W95","1621.22","1764.77","1679.71","1638.8","1728.67","1733.06","1799.69","1857.32","1899.37","1907.46","1936.32","1863.71","1912.86"
|
| 97 |
+
"W96","11355.272","11385.864","11303.634","11309.627","11397.101","11427.142","11459.31","11388.745","11451.752","11478.884","11545.812","11497.115","11568.007"
|
| 98 |
+
"W97","447.83","454.85","482.02","472.14","455.09","481.15","524.38","519.49","485.51","467.8","486.97","483.66","457.56"
|
| 99 |
+
"W98","584.3","590.59","610.91","597.94","585.15","610.26","651.93","646.63","617.37","598.04","617.17","615.46","586.87"
|
| 100 |
+
"W99","14837.74","14751.79","14883.29","15009.8","14963.45","14930.14","14847.66","14828.55","14864.72","14893.35","14811.6","14898.23","14855.39"
|
| 101 |
+
"W100","11474.49","11875.5","11517.02","11305.8","11423.14","11647.5","11632.38","11729.89","11679.04","11740.51","11934.32","11741.1","11634.37"
|
| 102 |
+
"W101","9871.498","9910.685","9815.305","9808.647","9900.756","9934.128","9974.544","9905.89","9965.28","9989.549","10064.652","10007.292","10082.468"
|
| 103 |
+
"W102","16004.79","15859.41","15729.3","15616.87","15749.81","15642.21","15716.89","15789.79","15961.13","15986.34","15871.65","15813.95","15743.94"
|
| 104 |
+
"W103","4153.94","4375.42","4063.47","3975.35","4202.22","4212.67","4171.08","4025.43","4260.1","4159.33","4272.43","3968.74","3905.78"
|
| 105 |
+
"W104","1026.84","1019.01","1009.31","1003.54","1007.61","1004.8","998.68","991.28","991.98","988.73","979.2","976.82","968.28"
|
| 106 |
+
"W105","9135.8","9155.8","9123.07","9168.48","9097.38","9109.48","9061.98","9058.14","9033.5","9064.66","9038.48","9025.31","9001.11"
|
| 107 |
+
"W106","2162.966","2167.806","2153.199","2154.003","2164.848","2171.359","2157.906","2154.995","2160.496","2165.439","2154.43","2155.905","2159.078"
|
| 108 |
+
"W107","24019.73","23983.72","24168.42","24240.09","24481.23","24419.29","24441.25","24333.51","24398.35","24222.27","24359.23","24297.26","24345.73"
|
| 109 |
+
"W108","15926.84","16078.017","15963.238","15970.428","16056.811","16086.342","16092.328","16042.099","16027.739","16096.719","16115.75","15955.488","15951.666"
|
| 110 |
+
"W109","14171.645","14318.713","14198.442","14206.103","14300.423","14334.406","14340.354","14293.978","14288.988","14354.158","14370.531","14207.81","14198.367"
|
| 111 |
+
"W110","7331.84","7307.69","7396.85","7438.91","7519.81","7592.29","7571.98","7464.36","7484.2","7484.94","7528.27","7483.09","7522.26"
|
| 112 |
+
"W111","1080.57","1080.53","1078.63","1077.26","1084.46","1084.89","1082.71","1079.78","1076.06","1085.53","1084.67","1081.1","1074.27"
|
| 113 |
+
"W112","9609.45","9520.42","9512.01","9471.4","9562.31","9488.15","9556.86","9534.86","9647.41","9594.13","9576.58","9662.18","9644.83"
|
| 114 |
+
"W113","754.7","766.1","766.27","714.66","682.99","700.3","763.74","747.68","742.31","717.41","712.44","734.81","726.45"
|
| 115 |
+
"W114","10463.22","10370.99","10322.48","10331.1","10275.58","10318.71","10216.97","10282.47","10238.62","10071.88","10200.06","10227.22","10233.33"
|
| 116 |
+
"W115","11217.92","11137.09","11088.75","11045.76","10958.57","11019.01","10980.71","11030.15","10980.93","10789.29","10912.5","10962.03","10959.78"
|
| 117 |
+
"W116","14539.1","14894.1","13984.77","13907.03","14293.59","14685.26","15017.67","14100.07","14408.13","14676.51","15205.38","13875.72","14094.91"
|
| 118 |
+
"W117","7037.56","7043.25","7026.12","7062.94","7075.16","7057.07","7064.06","7097.6","7249.93","7267.89","7328.08","7388.03","7480.86"
|
| 119 |
+
"W118","1747.852","1753.13","1739.954","1741.265","1752.962","1760.145","1746.934","1744.992","1750.66","1757.801","1746.75","1748.939","1752.358"
|
| 120 |
+
"W119","1029.75","1180.41","1140.58","1060.38","1154.44","1152.64","1175.8","1244.19","1225.38","1262.37","1332.13","1221.9","1270.41"
|
| 121 |
+
"W120","855.09","989.28","916.28","889.25","943.99","943.1","996.81","1044.71","1069.8","1090.56","1126.05","1064.67","1106.98"
|
| 122 |
+
"W121","10465.971","10496.016","10403.031","10399.53","10497.08","10528.68","10569.283","10488.464","10554.103","10585.971","10651.968","10589.396","10672.134"
|
| 123 |
+
"W122","329.99","321.94","332.45","333.74","338.72","353.4","396.27","383.57","366.9","359.23","363.53","378.46","369.19"
|
| 124 |
+
"W123","440.2","431.45","435.3","432.77","440.92","454.74","497.32","484.31","473.21","461.88","466.64","483.21","471.99"
|
| 125 |
+
"W124","8434.6","8404.06","8370.62","8440.35","8445.77","8485.14","8424.83","8351.47","8365.51","8411.94","8472.83","8434.49","8421.44"
|
| 126 |
+
"W125","10927.21","11218.64","10974.46","10771.82","10862.93","11054.22","11031.79","11089.91","11069.27","11163.08","11300.69","11152.85","11055.06"
|
| 127 |
+
"W126","9622.511","9655.61","9565.969","9555.495","9652.503","9680.166","9726.8","9653.317","9717.552","9744.777","9804.685","9745.947","9829.99"
|
| 128 |
+
"W127","11173.39","10918.46","10850.02","10756.78","10815.1","10764.33","10842.66","10886.6","10968.92","11026.69","11039.92","11213.69","11323.96"
|
| 129 |
+
"W128","3659.91","3836.54","3603.72","3518.12","3731.7","3682.25","3703.91","3552.16","3731.15","3699.87","3741.75","3541.9","3452.87"
|
| 130 |
+
"W129","1007.32","1000.05","990.76","985.07","988.6","986.32","981.23","974.06","974.39","971.47","961.91","960.25","951.77"
|
| 131 |
+
"W130","8142.17","8164.91","8155.02","8197.46","8121.84","8135.96","8115.75","8106.23","8078.34","8105.2","8077.74","8059.49","8032.97"
|
| 132 |
+
"W131","2161.865","2166.713","2152.117","2152.923","2163.772","2170.279","2156.83","2153.923","2159.431","2164.374","2153.369","2154.844","2158.018"
|
| 133 |
+
"W132","22735.69","22695.8","22873.78","22912.73","23157.6","23052.45","23109.41","23066.29","23133.89","22943.61","23070.07","23056.38","23135.75"
|
| 134 |
+
"W133","13883.352","13943.084","13817.27","13799.658","13890.442","13929.041","13978.179","13895.258","13956.526","13983.295","14057.238","13958.295","14011.606"
|
| 135 |
+
"W134","12132.819","12188.302","12057.06","12039.798","12138.491","12181.638","12230.64","12151.594","12222.102","12245.022","12316.353","12215.006","12262.678"
|
| 136 |
+
"W135","6190.89","6166.87","6248.29","6256.87","6346.31","6377.59","6382.05","6333.99","6350.53","6326.36","6362.01","6363.97","6432.69"
|
| 137 |
+
"W136","6432.1","6492.9","6523.6","6596.4","6759.9","6701.7","6745.7","6721.2","6644","6667.8","6661.7","6458.5","6508.5"
|
| 138 |
+
"W137","7126.057","7071.44","7078.15","7037.05","7071.02","7029.9","7096.12","7087.92","7166.37","7130.38","7131.41","7203.06","7162.08"
|
| 139 |
+
"W138","387.13","396.14","404.7","363.68","322.02","328.5","392.13","382.78","364.21","339.84","334.59","363.64","337.18"
|
| 140 |
+
"W139","7833.402","7763.08","7697.07","7691.19","7674.28","7754.88","7648.39","7689.17","7639.02","7478.78","7613.66","7563.24","7531.5"
|
| 141 |
+
"W140","8220.532","8159.22","8101.77","8054.87","7996.3","8083.38","8040.52","8071.95","8003.23","7818.62","7948.25","7926.88","7868.68"
|
| 142 |
+
"W141","10971.829","11241.17","10370","10322.61","10592.83","10873.84","11160.36","10316.67","10680.77","10809.59","11455.11","10315.26","10644.41"
|
| 143 |
+
"W142","5388.424","5388.93","5368.95","5390.96","5401.97","5389.98","5392.11","5412.92","5525.45","5540.92","5580.59","5622.11","5684.54"
|
| 144 |
+
"W143","11374.106","11419.08","11304.46","11312.03","11412.35","11481.37","11381.15","11365.37","11386.84","11445.25","11290.92","11301.3","11310.58"
|
| 145 |
+
"W144","1005.48","1152.75","1113.19","1035.17","1124.86","1122.3","1147.56","1215.37","1191.06","1226.87","1298.58","1187.11","1237.65"
|
| 146 |
+
"W145","8345.9","9654.6","8929.2","8681.6","9187.9","9169.4","9726.4","10199.3","10397.5","10589.6","10966.5","10340.4","10782"
|
| 147 |
+
"W146","6801.5412","6826.213","6737.471","6733.834","6810.25","6824.878","6855.421","6783.048","6832.483","6856.145","6914.958","6859.368","6934.528"
|
| 148 |
+
"W147","120.66","117.59","111.32","110.66","110.46","111.68","120.32","117.41","125.32","121.24","112.52","122.99","118.27"
|
| 149 |
+
"W148","222.63","218.63","205.95","201.16","204.31","204.6","212.82","209.47","223.37","215.24","206.88","219.05","212.55"
|
| 150 |
+
"W149","4374.969","4342.37","4303.19","4370.16","4392.91","4423.87","4362.98","4303.89","4317.18","4359.9","4426.36","4381.17","4372.66"
|
| 151 |
+
"W150","8104.576","8397.21","8157.4","7991.08","8073.41","8245.39","8235.18","8301.8","8245.5","8365.11","8466.19","8322.25","8219.59"
|
| 152 |
+
"W151","6364.0443","6391.976","6307.152","6296.818","6370.959","6382.491","6419.123","6352.659","6400.765","6420.155","6472.322","6421.251","6497.262"
|
| 153 |
+
"W152","8624.177","8443.2","8390.94","8266.24","8299.13","8283.85","8356.8","8416.68","8466.06","8541.06","8572.49","8732.67","8749.33"
|
| 154 |
+
"W153","2990.33","3154.84","2945.55","2862.16","3032.25","2990.3","3001.65","2879.11","3049.37","2997.1","3046.23","2862.51","2802.91"
|
| 155 |
+
"W154","8643.9","8597.8","8504.5","8468.7","8454","8415","8361.8","8296.2","8313.1","8300.7","8219.2","8228.2","8136.1"
|
| 156 |
+
"W155","4568.001","4591.74","4592.62","4636.82","4562.67","4572.28","4557.9","4538.75","4523.75","4521.02","4499.58","4485.14","4473.5"
|
| 157 |
+
"W156","5432.391","5451.52","5443.07","5483.69","5408.07","5413.78","5394.08","5368.37","5355.06","5351.09","5321.5","5307.96","5287.11"
|
| 158 |
+
"W157","1431.923","1435.919","1422.959","1423.154","1432.548","1438.745","1428.38","1425.897","1427.808","1431.297","1415.852","1416.006","1416.477"
|
| 159 |
+
"W158","16585.796","16545.69","16718.18","16774.2","16973.97","16859.12","16888.19","16801.33","16854.51","16677.41","16792.37","16756.94","16839.65"
|
| 160 |
+
"W159","9265.8308","9320.682","9198.628","9181.263","9248.034","9273.226","9311.038","9229.568","9282.625","9295.202","9366.068","9272.725","9316.549"
|
| 161 |
+
"W160","8060.0957","8112.114","7987.147","7969.756","8043.952","8072.951","8110.896","8034.681","8093.054","8099.825","8169.184","8075.646","8114.811"
|
| 162 |
+
"W161","12143.281","12130.04","12220.91","12262.51","12368.93","12230.26","12271.89","12244.95","12279.44","12134.14","12219.63","12187.07","12201.34"
|
| 163 |
+
"W162","4442.515","4415.65","4497.27","4511.69","4605.04","4628.86","4616.3","4556.38","4575.07","4543.27","4572.74","4569.87","4638.31"
|
| 164 |
+
"W163","3806.6","3812.9","3801.8","3787","3826","3831.2","3817.2","3802.3","3799.9","3852","3840.2","3822.6","3798.4"
|
| 165 |
+
"W164","24833.93","24489.8","24338.6","24343.5","24912.9","24582.5","24607.4","24469.4","24810.4","24637.5","24451.7","24591.2","24827.5"
|
| 166 |
+
"W165","3675.7","3699.6","3615.7","3509.8","3609.7","3718","3716.1","3649","3781","3775.7","3778.5","3711.7","3892.7"
|
| 167 |
+
"W166","2629.818","2607.91","2625.41","2639.91","2601.3","2563.83","2568.58","2593.3","2599.6","2593.1","2586.4","2663.98","2701.83"
|
| 168 |
+
"W167","2997.388","2977.87","2986.98","2990.89","2962.27","2935.63","2940.19","2958.2","2977.7","2970.67","2964.25","3035.15","3091.1"
|
| 169 |
+
"W168","3567.271","3652.93","3614.77","3584.42","3700.76","3811.42","3857.31","3783.4","3727.36","3866.92","3750.27","3560.46","3450.5"
|
| 170 |
+
"W169","1649.136","1654.32","1657.17","1671.98","1673.19","1667.09","1671.95","1684.68","1724.48","1726.97","1747.49","1765.92","1796.32"
|
| 171 |
+
"W170","3593.378","3600.1","3605.11","3626.18","3631.12","3625.32","3631.86","3645.91","3674.92","3680.2","3698.89","3719","3754.53"
|
| 172 |
+
"W171","6104.414","6112.22","6095.08","6100.62","6117.27","6120.08","6088.19","6084.55","6119.76","6132.76","6176.58","6188.09","6213"
|
| 173 |
+
"W172","3664.4298","3669.803","3665.56","3665.696","3686.83","3703.802","3713.862","3705.416","3721.62","3729.826","3737.01","3730.028","3737.606"
|
| 174 |
+
"W173","3696.3","3536.5","3504.4","3481.7","3684.6","3749.7","3684.8","4354.7","3618.1","4381.4","3985.8","3381.9","3209.2"
|
| 175 |
+
"W174","2822.634","2821.43","2817.06","2780.74","2789.52","2808.83","2796.61","2788.11","2823.77","2797.97","2834.5","2830.6","2835.47"
|
| 176 |
+
"W175","3258.4667","3263.634","3258.817","3258.677","3281.544","3297.675","3307.677","3300.658","3316.787","3324.622","3332.363","3324.696","3332.728"
|
| 177 |
+
"W176","2549.213","2475.26","2459.08","2490.54","2515.97","2480.48","2485.86","2469.92","2502.86","2485.63","2467.43","2481.02","2574.63"
|
| 178 |
+
"W177","6695.8","6817","6581.7","6559.6","6994.5","6919.5","7022.6","6730.5","6817.8","7027.7","6955.2","6793.9","6499.6"
|
| 179 |
+
"W178","1429.3","1402.7","1403.1","1382","1432","1448.2","1450.5","1444.4","1430.8","1414","1399.9","1374.3","1381.6"
|
| 180 |
+
"W179","2709.779","2713.39","2711.95","2713.77","2713.77","2722.18","2721.67","2737.86","2723.28","2754.11","2756.24","2751.53","2745.86"
|
| 181 |
+
"W180","7299.42","7307.94","7291.58","7297.69","7312.24","7315.34","7284.5","7280.26","7316.23","7330.77","7375.17","7388.38","7415.41"
|
| 182 |
+
"W181","8859.673","8863.5","8867.55","8852.3","8897.4","8915.51","8942.89","9002.82","9002.66","9020.31","9033.94","9050.97","9041.96"
|
| 183 |
+
"W182","1375.3","1376.8","1377.8","1382.1","1392.3","1392.1","1472.6","1476.2","1488.5","1487.9","1491.4","1494.2","1492.8"
|
| 184 |
+
"W183","6149.894","6150.11","6155.6","6138.53","6183.63","6193.33","6221.22","6264.96","6279.38","6266.2","6277.7","6299.44","6296.1"
|
| 185 |
+
"W184","4617.5212","4622.402","4618.642","4618.395","4642.408","4655.815","4667.141","4665.69","4673.901","4688.093","4691.17","4685.57","4695.057"
|
| 186 |
+
"W185","4072.7233","4076.188","4069.913","4070.042","4094.539","4108.687","4119.744","4116.913","4129.048","4145.197","4147.169","4139.36","4147.867"
|
| 187 |
+
"W186","1748.375","1751.22","1751.02","1745.18","1741.27","1748.73","1765.75","1777.61","1775.46","1783.09","1789.27","1794.1","1794.38"
|
| 188 |
+
"W187","942.8","948.4","944.3","940.1","945.4","960","936.4","939.7","939.3","934.4","920.4","913.7","937.7"
|
| 189 |
+
"W188","2235","2247","2197","2225","2192","2240","2149","2191","2159","2176","2252","2201","2227"
|
| 190 |
+
"W189","2128","2141","2165","2108","2221","2365","2180","2244","2272","2246","2038","1965","2069"
|
| 191 |
+
"W190","456.6855011","460.7204403","453.8686316","441.8651676","445.577063","453.0822772","437.6623557","439.8819108","430.3451587","426.0420738","419.6364978","411.3800853","418.5663313"
|
| 192 |
+
"W191","6635.2293","6711.3531","6879.4553","7031.4599","7280.6471","7287.6782","7199.4766","7075.0888","7495.1075","7853.2658","7738.1914","7432.3491","7384.1148"
|
| 193 |
+
"W192","1160.330918","1165.783708","1130.218848","1144.012678","1116.237318","1143.744038","1056.720308","1119.294178","1064.618758","1048.372548","1135.591008","1105.256478","1120.579848"
|
| 194 |
+
"W193","5897.9498","5780.9284","5693.2279","5734.5912","5836.8002","5884.0807","5840.0028","5820.8905","5977.0669","6176.512","6055.0658","6023.6219","6056.9425"
|
| 195 |
+
"W194","483.26481","501.30087","495.69513","505.966605","490.018459","506.016419","506.059303","488.243623","494.494998","507.554868","508.900208","491.403635","498.588915"
|
| 196 |
+
"W195","984.43264","1007.22011","977.57691","956.13304","989.18684","978.19025","972.65437","932.85298","995.85932","965.36357","978.65649","978.76934","995.1377"
|
| 197 |
+
"W196","9829.0437","9985.4137","10293.6715","10359.0839","9848.4779","9843.3371","9851.5289","10083.1847","9901.8945","10063.8348","9912.0124","9947.1883","10295.3118"
|
| 198 |
+
"W197","3349.318574","3535.511928","3464.898719","3501.769843","3412.296678","3438.179535","3346.9989","3340.722945","3385.474665","3372.953029","3373.89518","3370.010094","3361.798917"
|
| 199 |
+
"W198","2494.02866","2504.633975","2516.196395","2459.919613","2562.934537","2680.654134","2526.370987","2563.657085","2549.711739","2523.642748","2295.089991","2323.388696","2404.942722"
|
| 200 |
+
"W199","853.411584","717.379492","739.09963","715.596157","753.511419","764.419871","759.384134","791.933924","759.283655","774.301921","848.021256","804.524732","889.772763"
|
| 201 |
+
"W200","1847.36095","1741.76012","1857.16426","1737.89313","1834.76119","1945.47024","1795.47817","1859.5506","1993.599","1853.21632","1763.33719","1134.09872","1700.70876"
|
| 202 |
+
"W201","1967","2006","2007","1992","1974","1963","1958","1941","1986","1972","1970","1973","2025"
|
| 203 |
+
"W202","4203","4253","4204","4217","4166","4203","4106","4133","4145","4147","4222","4175","4252"
|
| 204 |
+
"W203","5510","5380","5370","5260","5250","5110","5260","5050","4770","4630","4340","4720","5060"
|
| 205 |
+
"W204","7304","7318.9","7362.3","7295.9","7296.4","7304.9","7339.5","7345.8","7326.6","7325.2","7360.2","7382.9","7378.7"
|
| 206 |
+
"W205","12491.95","13013.34","13209.75","12981","13537.85","13132.64","13691.5","13967.95","13867.98","13856.59","14381.13","14510.66","14226.31"
|
| 207 |
+
"W206","1050.5","1093.4","1222","1223.7","1043.1","930.1","1070.1","1341","966.5","1052.4","938.6","1193.9","1217.3"
|
| 208 |
+
"W207","11744.3","11322.8","11200.7","12733.5","14568.8","20567.8","16736.3","13883.6","13380.5","13299.4","15151.2","14482.4","14773.8"
|
| 209 |
+
"W208","29018.3","25380.9","24903.8","27184.9","23801.6","22076","21381.2","24022.8","28176.7","30520","23750.7","25196.6","29393"
|
| 210 |
+
"W209","4185.45","4185.45","4185.45","4185.45","4185.45","3944.18","3944.18","3944.18","3944.18","3944.18","3944.18","3944.18","3944.18"
|
| 211 |
+
"W210","8204.21","8496.79","8690.18","8796.65","9308.62","8889.88","9285.78","9353.08","9523.59","9788.87","9638.59","9552.7","9535.58"
|
| 212 |
+
"W211","19138.99","19322.67","19505.88","19671.03","19823.39","19748.99","19869.76","20103.52","20304.31","20497.56","20640.15","20839.37","21017.61"
|
| 213 |
+
"W212","5461.05","5465.87","5464.51","5474.37","5898.01","5958.73","5908.25","5890.65","5889.96","5886.71","5863.1","5846.83","5831.24"
|
| 214 |
+
"W213","3232.2","3400.7","3606","3686.4","3553.5","3071.9","3189.4","3471.8","3135.3","3578.1","3467.9","3637.1","3466"
|
| 215 |
+
"W214","1824.5","1743.9","1794","1809.4","1827.3","1908.2","1847","1751.8","1898.1","1880.5","1978.2","1886.8","1902.4"
|
| 216 |
+
"W215","2366.43","2369.53","2390.59","2419.53","2406.42","2491.02","2478.39","2482.26","2470.39","2449.71","2447.86","2447.49","2469.36"
|
| 217 |
+
"W216","3140.68","3143.77","3165.33","3194.56","3177.71","3278.54","3265.51","3269.02","3257.17","3236.28","3234.39","3231.67","3247.35"
|
| 218 |
+
"W217","9877.09","9890.943","9890","9800","9897","9700","9000","9700","9907.7","9907","9700","9910","9900"
|
| 219 |
+
"W218","9932.821","9936.393","9000","9700","9943","9700","9940.001","9700","9948.552","9951.79","9000","9945","9000"
|
| 220 |
+
"W219","1091","1084","1077","1070","1063","1056","1092","1085","1078","1071","1064","1057","1050"
|
| 221 |
+
"W220","1820","1813","1806","1799","1792","1785","1820","1813","1806","1799","1792","1785","1778"
|
| 222 |
+
"W221","7743.8","7448.1","6802.5","6961.8","7160.7","7338.7","6925.2","7142.4","7651.6","7691.4","7589.4","7706.7","7590.4"
|
| 223 |
+
"W222","7727.1","7515.8","7002.1","7245.4","7361.2","7562.2","7879.1","8082.1","8025.9","8605","8503.5","8405.7","8763.5"
|
| 224 |
+
"W223","26994","26695","28522","27712","27511","27231","27409","27856","27639","26848","27687","27342","27494"
|
| 225 |
+
"W224","1706.741","1739.879","1741.962","1763.156","1740.79","1857.234","1746.231","2071.371","1997.189","2097.67","2111.099","2253.543","2460.929"
|
| 226 |
+
"W225","1930.9","1929.5","1958.5","1935.1","1926.6","2119.8","1982.7","2016.5","2039.6","2311.4","2792.7","2319.2","2147.7"
|
| 227 |
+
"W226","2865.2","2798.5","2576.4","2805.7","2486","2628.3","2494.4","2560.8","2487.5","2543","2524.4","2603.3","2591.2"
|
| 228 |
+
"W227","1365.7","1392.9","1389.1","1439.9","1431.2","1665.4","1561.1","1606.1","1570.6","1804.3","1935.1","1801.3","1744.1"
|
| 229 |
+
"W228","8171","8157","8180","8378","8086","8624","8408","8643","8586","8856","8908","9137","8873"
|
| 230 |
+
"W229","465.53","438.43","443.29","443.83","391.38","476.51","405.11","448.58","416.24","441.32","428.96","432.6","449.93"
|
| 231 |
+
"W230","3117.6","2928.4","2783.4","2987.1","2533.1","3403.5","3054.5","3060.2","3022.7","3426.2","4886.7","4280.2","3533.6"
|
| 232 |
+
"W231","599.6","646.8","672","706.3","755.7","899.6","917.3","1036.2","1119.2","1283","1384.5","1475.1","1504.4"
|
| 233 |
+
"W232","1880.4","1889.2","1864.5","1961.1","2009.7","2211.1","2235","2421.8","2774","3052","3018.4","3171.4","3206.1"
|
| 234 |
+
"W233","1036.6","1036.6","994","1040.9","1024.8","1199.4","1106.1","1197.6","1296","1448.2","1485.5","1600.9","1695"
|
| 235 |
+
"W234","2117","1891.1","2178.6","1910.3","1818.5","1984.8","1958.6","1953.1","2009.7","2158.7","2640.2","2550.4","2678.7"
|
| 236 |
+
"W235","434.3","451.5","493.3","532.4","535.7","647.1","661.3","710.3","763.1","827.1","937.4","1003","1010.4"
|
| 237 |
+
"W236","3265.2","3299.2","3276.9","3338.6","3299.6","3574.8","3557.3","3591.5","3690.3","3977.2","4463.9","4347","4358.4"
|
| 238 |
+
"W237","1070.7","1079","1034.4","1099.3","1055.1","1200.8","1155.9","1202.1","1206.6","1378.2","1438.1","1340.1","1262.4"
|
| 239 |
+
"W238","430","420.4","399.8","415.2","435.7","548.9","538","631.4","714.5","831.1","925.2","953","917.2"
|
| 240 |
+
"W239","3696","3792","3780","3928","3927","4330","4349","4500","4765","5057","5945","5664","5722"
|
| 241 |
+
"W240","2586.5","2560.4","2776.9","2698.6","3157.7","2938.2","3040.8","3197.8","3367.8","3726.5","3724","3796.8","3728.7"
|
| 242 |
+
"W241","1040.84","1015","1098.25","1031.9","1168.21","1271.09","1311.2","1252.43","1363.56","1496.95","1592.62","1614.64","1511.61"
|
| 243 |
+
"W242","1475.3","1468.6","1494.2","1447.2","1643.9","1534.3","1587.3","1580.3","1927.1","2321.8","1935.5","1728.8","1741.4"
|
| 244 |
+
"W243","1245","1330","1551","1550","2150","2514","2855","3230","3335","3511","3890","3812","3828"
|
| 245 |
+
"W244","806.8","841.6","883.8","914.1","988.6","1108","1183","1226.4","1258","1325.9","1434.3","1480.2","1488.3"
|
| 246 |
+
"W245","2291.2","2266.7","2332","2266.8","2538.8","2482.1","2508.7","2508.2","2705.3","2972.8","3072.2","3059","2984.6"
|
| 247 |
+
"W246","1138.7","1098","1159.8","1136.4","1476.2","1476.6","1385.7","1392.2","1435.2","1579.6","1807.3","1769.8","1686.5"
|
| 248 |
+
"W247","4410.5","4581.6","4723.5","4791.8","5421","5510.5","5612.2","5829.8","6084.9","6239.6","6659.6","6860.6","6854.3"
|
| 249 |
+
"W248","16190.4","16913.4","16913.4","17091","17091","16552.7","16552.7","17220.2","17220.2","16522.2","16522.2","17584","17584"
|
| 250 |
+
"W249","21254.72","21655.06","21655.06","22030.28","22030.28","22104.93","22104.93","21822.79","21822.79","21445.17","21445.17","19654.69","19654.69"
|
| 251 |
+
"W250","3446.7","3205.7","3207.3","3331.9","3479.1","3223.9","3228.3","3338.6","3461.4","3330.3","3232.7","3339.3","3518.5"
|
| 252 |
+
"W251","6671","6836","6803","6884","6879","6940","7042","7117","7136","7116","7164","7197","7229"
|
| 253 |
+
"W252","8557.6","8835.3","8859","8662.7","8563.5","8897.2","8930.5","8826.7","8675.2","8886.2","9003.2","8917.6","8732.1"
|
| 254 |
+
"W253","12389","12551","12538","12445","12427","12637","12671","12626","12568","12702","12742","12726","12679"
|
| 255 |
+
"W254","5953.6","5901.6","5905.5","6044.6","6050.7","5968.3","5881.8","6015.5","6050.5","6027","5981.2","6126.6","6147.7"
|
| 256 |
+
"W255","1901","1898","1895","1884","1878","1866","1830","1829","1849","1870","1872","1874","1875"
|
| 257 |
+
"W256","2932","2986","2914.5","3028.8","3291.8","2001.3","2192.4","2057.1","2318.1","2584.8","2583.4","2598.3","2704.1"
|
| 258 |
+
"W257","4017.6","4017.9","4018.2","4018.8","4024.6","4026.8","4037","4037.9","4039.9","4041.8","4044.2","4044.3","4044.2"
|
| 259 |
+
"W258","2097.4","2102.3","2087.1","2130.7","2087.5","2024.5","1983.4","1973.8","1998.2","1975.9","1931.8","1928.8","1940.1"
|
| 260 |
+
"W259","6698","6436","5556","6939","6746","5285","5924","5200","6741","6343","5893","5953","5594"
|
| 261 |
+
"W260","583","503","527","504","643","523","522","1437","464","626","475","785","966"
|
| 262 |
+
"W261","4687.3","4661.5","4573.9","4712.8","4699.2","4555.3","4629.3","4557.9","4714","4676.1","4633.5","4639.5","4603.6"
|
| 263 |
+
"W262","4382.033","4394.197","4378.668","4378.375","4378.166","4390.312","4401.788","4382.618","4382.058","4401.802","4397.108","4379.233","4378.908"
|
| 264 |
+
"W263","2296","2219.6","2214.9","2044.6","2043.2","1968","1966.2","1746.1","1724.1","1701.4","1717.7","1603.2","1617"
|
| 265 |
+
"W264","4071.7","4155","4247.5","4042","3908.2","3923.1","3563.7","3806.9","4360.7","4563.2","4354.9","4323.4","3629"
|
| 266 |
+
"W265","2057.7","2149.37","2022.19","2061.06","2031.96","1904.36","1962.29","2017.42","2020.85","1963.84","1942.14","1973.34","2009.61"
|
| 267 |
+
"W266","2072.76","2162.78","2042.96","2108.32","2022.16","1989.55","2020.05","2071.25","2013.96","1993.61","1976.23","1966.25","2030.46"
|
| 268 |
+
"W267","3031.2","2833","3121.8","3662.4","3801.4","3635.9","3732.4","2989.6","4445.3","4025.2","4599.9","4366.5","3613.2"
|
| 269 |
+
"W268","1306.27","1382.07","1271.27","1314.49","1308.47","1194.08","1259.07","1298.64","1304.74","1285.92","1251.62","1308.33","1369.04"
|
| 270 |
+
"W269","1258.78","1335.44","1229.92","1305.02","1251.08","1192.61","1232.56","1280.13","1217.91","1237.3","1200.59","1223","1325.46"
|
| 271 |
+
"W270","7727.9","8061.2","7905.6","7710.1","7937.5","7954.8","8244.7","8336.4","8431.2","8311.7","8421.9","8301.7","8375.9"
|
| 272 |
+
"W271","7661.3","7754.9","7634.3","7495.5","7846.8","7899.6","8028.8","8126.1","8295.7","8169","8102.7","7990.4","8058.8"
|
| 273 |
+
"W272","974.11","971.93","949.5","952.55","956.53","955.04","928.78","934.69","937.57","942.2","943.45","949.25","951.63"
|
| 274 |
+
"W273","6542.2","6602.2","6631.8","6704.4","6867.5","6809.7","6853.3","6828.4","6750.5","6774.3","6767.8","6564.6","6614.5"
|
| 275 |
+
"W274","751.43","767.3","750.92","746.57","723.49","710.28","703.22","718.78","716.11","677.92","690.52","665.01","640.57"
|
| 276 |
+
"W275","8139.8","8273.4","8130.4","8033","7710.8","7969.4","7874.9","7911.2","7960.5","7563.1","7756.4","7432.5","7050"
|
| 277 |
+
"W276","11409.5","11408.2","11485.6","11820.4","11735","12147","11899.3","11303.7","11336.7","11585.8","11662.6","11191.2","10895.7"
|
| 278 |
+
"W277","3257.323","3271.85","3284.499","3248.314","3162.48","3345.109","3283.453","3260.614","3232.998","3257.269","3299.522","3306.657","3294.326"
|
| 279 |
+
"W278","1261.984","1332.9","1226.2","1267.01","1266.63","1153.61","1217.88","1256.22","1260.13","1249.54","1214.15","1271.59","1333.46"
|
| 280 |
+
"W279","1212.443","1284.81","1181.98","1256.58","1204.17","1146.31","1184.67","1231.88","1164.92","1195.17","1157.35","1182.41","1284.83"
|
| 281 |
+
"W280","5447.979","5462.14","5487.29","5483.53","5478.69","5471.28","5473.97","5487.77","5448.53","5428.96","5440.01","5462.1","5471.9"
|
| 282 |
+
"W281","1756","1849","1859","1774","1799","1895","1793","1730","1815","1919","1963","1977","2010"
|
| 283 |
+
"W282","2144.366551","2155.275445","2134.710461","2155.964119","2136.584281","2111.357748","2151.224229","2102.505934","2097.694159","2081.464003","2116.373998","2109.410461","2135.181639"
|
| 284 |
+
"W283","338.554154","335.62152","339.293614","334.767926","346.838264","348.526527","321.768721","345.408096","330.985181","329.278867","317.836472","335.474349","303.970852"
|
| 285 |
+
"W284","572.83","578.84","601.86","577.06","575.18","583.88","583.42","586.89","590.86","592.64","588.47","514.43","504.2"
|
| 286 |
+
"W285","692.124","687.527","691.085","705.06","706.426","697.053","691.771","695.644","716.192","722.715","721.569","721.882","739.725"
|
| 287 |
+
"W286","10549","10683.7","10413.7","10294.3","10494.1","10588","10359.7","10236.7","10403.6","10394.8","10697.7","10778.3","10831.8"
|
| 288 |
+
"W287","2187.16","2205.63","2214.48","2228.7","2312.59","2368.47","2384.82","2353.87","2374.86","2351.83","2331.47","2340.08","2351.46"
|
| 289 |
+
"W288","7742.4","7742.4","7747.4","7750.3","7712.9","7875.2","7871.1","7867.6","7867.8","7865.8","7865.3","7841.8","7779.9"
|
| 290 |
+
"W289","9790","9800","10160","9960","9920","9710","10070","9970","10000","9770","9840","9840","9750"
|
| 291 |
+
"W290","15593.6","15096.7","13956.7","14349.8","14700.6","15026.8","14956.2","15354.3","15842.4","16424.9","16261.7","16238.3","16523.1"
|
| 292 |
+
"W291","7931","7981.7","7898.8","8095.8","7630.4","8375.1","8041","8141.3","7431.5","7448.9","7899","7716.7","7367"
|
| 293 |
+
"W292","17618.9","17397.7","17308.2","17341.1","16890.5","17910.6","17573.8","17610.4","16916.2","16970.8","17256","16901","15827.8"
|
| 294 |
+
"W293","1873","1748","3351","2824","2482","1905","1900","1817","1985","1709","1949","1641","1941"
|
| 295 |
+
"W294","1255.8","1480.5","1196.1","1618.8","1318.5","1427.8","1197.9","1301.9","996","1153.8","1343.5","1270.2","1210.1"
|
| 296 |
+
"W295","3090","3218","3071","3054","3124","3185","3205","3158","3235","3157","3002","3736","3258"
|
| 297 |
+
"W296","3427","2655","2957","3214","2609","1843","2775","3629","3045","2894","2500","3242","3062"
|
| 298 |
+
"W297","5520","4651","4787","5396","4341","3130","4231","4702","5425","4859","4622","5740","5070"
|
| 299 |
+
"W298","3806","2828","3516","3576","2979","1842","3514","4137","3570","3348","3245","4174","3695"
|
| 300 |
+
"W299","3952","3296","3338","3522","3138","2128","2955","3950","3260","3003","2981","3182","3275"
|
| 301 |
+
"W300","2374","1834","1881","1979","1699","1155","1765","2300","1820","1777","1643","1919","1892"
|
| 302 |
+
"W301","3392","2854","2689","2867","2815","2185","2861","2947","2624","2433","2365","2874","2649"
|
| 303 |
+
"W302","3121","2733","2775","3159","2695","2770","3158","2702","3445","3021","3534","4221","3385"
|
| 304 |
+
"W303","5797","4859","5145","5401","4446","3132","4343","5185","5489","5312","5005","6408","5480"
|
| 305 |
+
"W304","3253","2595","2754","2963","2634","1609","2817","3632","3136","2790","2432","3129","3024"
|
| 306 |
+
"W305","3358","2770","2791","2943","2574","1892","2428","3243","2828","2730","2560","3425","2957"
|
| 307 |
+
"W306","2558","2048","2187","2435","1885","1225","2385","3028","2388","2236","2135","2696","2497"
|
| 308 |
+
"W307","3460","3044","3400","3706","2874","2274","2800","3050","3473","2985","3099","3948","3311"
|
| 309 |
+
"W308","2573","2124","2185","2420","1932","1351","1891","2117","2349","2289","2047","2592","2279"
|
| 310 |
+
"W309","7378","6631","6561","7425","7372","6258","8446","7995","6836","6070","5230","6259","6478"
|
| 311 |
+
"W310","5183","4647","5227","5734","5549","4025","5253","6645","5251","4651","4383","5649","5316"
|
| 312 |
+
"W311","5535","4926","4816","5557","5197","3914","4856","5834","5014","4424","3970","4932","4835"
|
| 313 |
+
"W312","4895","4096","4309","4967","5689","3706","4645","4642","4373","3965","3795","4215","4198"
|
| 314 |
+
"W313","2115","2203","2102","2090","2138","2180","2194","2162","2214","2161","2055","2557","2230"
|
| 315 |
+
"W314","9441","8986","8449","9450","9509","7357","8471","9510","8440","8063","7471","8021","8301"
|
| 316 |
+
"W315","5316","5008","5037","5095","5226","3920","4911","5467","4717","4589","4346","4605","4745"
|
| 317 |
+
"W316","3210","2787","2867","3217","3157","2388","2739","3447","3038","2883","2938","3259","3113"
|
| 318 |
+
"W317","4508","3991","3886","4226","4297","3077","3922","4642","4072","3787","3349","4066","3983"
|
| 319 |
+
"W318","3655","3808","3633","3613","3695","3768","3792","3736","3827","3735","3552","4420","3854"
|
| 320 |
+
"W319","2547","2678","2381","2733","2894","1965","2326","3097","2363","2247","2099","2524","2466"
|
| 321 |
+
"W320","2572","2584","2776","3036","3647","2357","2695","3086","2379","1989","1862","2062","2276"
|
| 322 |
+
"W321","6172","5932","5665","6110","5342","4145","5273","5132","6994","6018","5745","6651","6108"
|
| 323 |
+
"W322","2626","2370","2401","2830","3284","2103","2352","2865","2483","2381","2317","2530","2515"
|
| 324 |
+
"W323","5324","4761","4912","5730","6273","4394","5408","6709","5429","5251","4661","5976","5605"
|
| 325 |
+
"W324","3057","2539","2419","2741","2917","2209","2798","3546","3183","3140","2856","3481","3241"
|
| 326 |
+
"W325","3215","3336","3183","3166","3238","3301","3322","3274","3354","3273","3112","3872","3377"
|
| 327 |
+
"W326","3852","3618","3667","4098","3958","2932","3769","4767","4198","3692","3344","4281","2884"
|
| 328 |
+
"W327","9791","8592","8763","9250","7963","6586","9213","11666","10414","9639","9268","11945","10586"
|
| 329 |
+
"W328","2826","2377","2613","2854","2688","1744","2439","3290","2774","2447","2203","2916","2726"
|
| 330 |
+
"W329","11786","9960","9820","10747","10786","7966","9968","12266","10587","8995","8982","11725","10511"
|
| 331 |
+
"W330","3899","3130","3356","3606","3534","2475","3407","4577","4188","3678","3283","4192","3984"
|
| 332 |
+
"W331","6871","5873","6170","6691","6293","4403","6148","7417","6375","5700","5761","6663","7417"
|
| 333 |
+
"W332","3191","3095","2805","3275","3230","2471","3448","4534","3545","3233","3179","3157","3530"
|
| 334 |
+
"W333","2774","2724","2885","3325","3629","2345","3088","3348","2915","2738","2573","3229","2961"
|
| 335 |
+
"W334","4019","3939","4008","4178","4154","3076","3917","4989","4219","4007","3832","4543","4318"
|
| 336 |
+
"W335","5003","4375","4539","4653","5501","3930","4734","5515","5429","4309","3932","4650","4767"
|
| 337 |
+
"W336","1984","1637","2053","2085","2375","1537","2031","2317","1744","1468","1250","1730","1702"
|
| 338 |
+
"W337","1682","2188","1939","2411","2429","1734","1835","2171","1857","1650","1562","2022","1350"
|
| 339 |
+
"W338","2838","2690","2470","3041","3116","2190","2571","3022","2481","2196","2106","2391","2614"
|
| 340 |
+
"W339","4385","3907","3606","3890","3564","2431","3703","4398","3699","3437","3386","4085","3801"
|
| 341 |
+
"W340","7431","6881","7184","7855","7520","6190","7526","7912","7010","6089","6089","6444","6709"
|
| 342 |
+
"W341","5896","5259","5643","6241","6242","4539","5038","6269","5306","4688","4423","5340","5205"
|
| 343 |
+
"W342","1717.1","1304.4","1335.5","1379.2","1276.9","1007.2","1323","1575.6","1307","1264.5","1118.6","1349.1","1323"
|
| 344 |
+
"W343","4274","3596","3596","4510","4182","2859","3579","4723","4013","3409","3143","4375","3933"
|
| 345 |
+
"W344","3730","3182","3165","3463","3098","1955","2964","4003","3550","3348","3014","3879","3559"
|
| 346 |
+
"W345","3807","3360","3266","3406","3122","2212","2982","3988","3389","3160","3126","3584","3988"
|
| 347 |
+
"W346","7012","5532","5818","5845","5240","3442","5592","6704","5512","5418","5004","5866","5701"
|
| 348 |
+
"W347","4217","3343","3621","3938","3702","2575","3277","4485","3790","3487","3158","4196","3823"
|
| 349 |
+
"W348","4952","4054","3637","4411","4431","2762","3708","4300","5288","4673","3860","5166","4657"
|
| 350 |
+
"W349","3449","3593","3428","3409","3487","3556","3578","3526","3612","3525","3352","4171","3637"
|
| 351 |
+
"W350","6407","5330","5444","5744","4513","3424","6361","6498","5367","5326","5042","6151","5677"
|
| 352 |
+
"W351","2083","1704","1622","1626","1306","1010","1719","2189","1661","1753","1591","2057","1850"
|
| 353 |
+
"W352","1339","1395","1331","1324","1354","1381","1389","1369","1402","1369","1302","1619","1412"
|
| 354 |
+
"W353","2197","1711","1661","1535","1444","953","1526","1840","1722","1540","1534","1671","1661"
|
| 355 |
+
"W354","3513","3203","2907","3031","2335","1945","3178","4175","3651","3512","3573","4511","3884"
|
| 356 |
+
"W355","2777","2893","2760","2745","2807","2862","2880","2838","2908","2838","2698","3357","2928"
|
| 357 |
+
"W356","4387","3772","3660","3878","3131","2405","4081","4811","4170","4017","3824","4661","4297"
|
| 358 |
+
"W357","3938","3363","3249","3515","3097","2275","3421","4166","3807","3434","3303","4024","3747"
|
| 359 |
+
"W358","6907","5711","5496","5862","5380","3660","5159","6787","5596","5445","5345","6156","5866"
|
| 360 |
+
"W359","4458","5098","4518","4973","3973","2613","3476","4213","3386","3627","3299","3743","3965"
|
dataset/m4/Weekly-train.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/m4/Yearly-test.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset/m4/Yearly-train.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6aa668d4f04654c0bf5f07b266a51c556f0df2faf4709f852653755396a249e4
|
| 3 |
+
size 25355736
|
dataset/m4/submission-Naive2.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f9610c625d6dee654ec458dcd07fb39373d065efc33b4f17d94c50b2aef00ef9
|
| 3 |
+
size 23409576
|
dataset/m4/test.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:134249cd801d51507af9de3c77b4c789da4de96cb5a2d2fe2988cac138820d32
|
| 3 |
+
size 20778257
|
dataset/m4/training.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1799e5334656df0e7ff68ea016f555aa23e5460890e5f92a59f907deba5a7914
|
| 3 |
+
size 270896295
|
dataset/poly/polymarket_data_processed_Crypto_test.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8197ef536a3f9ef9399dc3205d39ca5cd3ac3c38bb5df2033ef74198402f2b91
|
| 3 |
+
size 34744628
|
dataset/poly/polymarket_data_processed_Election_test.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2ccfdbf65e9fc4f173daa3cfdc1ddf771a064319beb82c723c014b30ff0a1a3
|
| 3 |
+
size 30018418
|
dataset/poly/polymarket_data_processed_Other_test.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:200fb59b5a8d640d9e27e5fe88a2e323f157ab905e67f5209f95b38234873fe7
|
| 3 |
+
size 70988203
|
dataset/poly/polymarket_data_processed_Politics_test.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a7468321988e461dfdb449526db0c67274e77180c78d9d999cb7a075c683fc01
|
| 3 |
+
size 186831387
|
dataset/poly/polymarket_data_processed_Sports_test.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4574cd57806a46ba5b8576a0ed5467fa81c6ed82a86188d255ef3ddd6bbf9dc5
|
| 3 |
+
size 27514781
|
dataset/poly/polymarket_data_processed_dev.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:260d436493fe1b4b69b1d5f35808c1ce90088cf5a51bc7a6433b96039baef23f
|
| 3 |
+
size 456387925
|
dataset/poly/polymarket_data_processed_test.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fca43646308d0b5c098f5ac529b9b1bb5703f334c26c327e2c3696de32602625
|
| 3 |
+
size 420577190
|
dataset/poly/polymarket_data_processed_train.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a08508154ae529fad7a4cbb174ad47d6702e99bae04cf18b7310c52c6ff3f4c4
|
| 3 |
+
size 2609542664
|
exp/__init__.py
ADDED
|
File without changes
|
exp/exp_anomaly_detection.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from data_provider.data_factory import data_provider
|
| 2 |
+
from exp.exp_basic import Exp_Basic
|
| 3 |
+
from utils.tools import EarlyStopping, adjust_learning_rate, adjustment
|
| 4 |
+
from sklearn.metrics import precision_recall_fscore_support
|
| 5 |
+
from sklearn.metrics import accuracy_score
|
| 6 |
+
import torch.multiprocessing
|
| 7 |
+
|
| 8 |
+
torch.multiprocessing.set_sharing_strategy('file_system')
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch import optim
|
| 12 |
+
import os
|
| 13 |
+
import time
|
| 14 |
+
import warnings
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
warnings.filterwarnings('ignore')
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Exp_Anomaly_Detection(Exp_Basic):
|
| 21 |
+
def __init__(self, args):
|
| 22 |
+
super(Exp_Anomaly_Detection, self).__init__(args)
|
| 23 |
+
|
| 24 |
+
def _build_model(self):
|
| 25 |
+
model = self.model_dict[self.args.model].Model(self.args).float()
|
| 26 |
+
|
| 27 |
+
if self.args.use_multi_gpu and self.args.use_gpu:
|
| 28 |
+
model = nn.DataParallel(model, device_ids=self.args.device_ids)
|
| 29 |
+
return model
|
| 30 |
+
|
| 31 |
+
def _get_data(self, flag):
|
| 32 |
+
data_set, data_loader = data_provider(self.args, flag)
|
| 33 |
+
return data_set, data_loader
|
| 34 |
+
|
| 35 |
+
def _select_optimizer(self):
|
| 36 |
+
model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
|
| 37 |
+
return model_optim
|
| 38 |
+
|
| 39 |
+
def _select_criterion(self):
|
| 40 |
+
criterion = nn.MSELoss()
|
| 41 |
+
return criterion
|
| 42 |
+
|
| 43 |
+
def vali(self, vali_data, vali_loader, criterion):
|
| 44 |
+
total_loss = []
|
| 45 |
+
self.model.eval()
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
for i, (batch_x, _) in enumerate(vali_loader):
|
| 48 |
+
batch_x = batch_x.float().to(self.device)
|
| 49 |
+
|
| 50 |
+
outputs = self.model(batch_x, None, None, None)
|
| 51 |
+
|
| 52 |
+
f_dim = -1 if self.args.features == 'MS' else 0
|
| 53 |
+
outputs = outputs[:, :, f_dim:]
|
| 54 |
+
pred = outputs.detach().cpu()
|
| 55 |
+
true = batch_x.detach().cpu()
|
| 56 |
+
|
| 57 |
+
loss = criterion(pred, true)
|
| 58 |
+
total_loss.append(loss)
|
| 59 |
+
total_loss = np.average(total_loss)
|
| 60 |
+
self.model.train()
|
| 61 |
+
return total_loss
|
| 62 |
+
|
| 63 |
+
def train(self, setting):
|
| 64 |
+
train_data, train_loader = self._get_data(flag='train')
|
| 65 |
+
vali_data, vali_loader = self._get_data(flag='val')
|
| 66 |
+
test_data, test_loader = self._get_data(flag='test')
|
| 67 |
+
|
| 68 |
+
path = os.path.join(self.args.checkpoints, setting)
|
| 69 |
+
if not os.path.exists(path):
|
| 70 |
+
os.makedirs(path)
|
| 71 |
+
|
| 72 |
+
time_now = time.time()
|
| 73 |
+
|
| 74 |
+
train_steps = len(train_loader)
|
| 75 |
+
early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
|
| 76 |
+
|
| 77 |
+
model_optim = self._select_optimizer()
|
| 78 |
+
criterion = self._select_criterion()
|
| 79 |
+
|
| 80 |
+
for epoch in range(self.args.train_epochs):
|
| 81 |
+
iter_count = 0
|
| 82 |
+
train_loss = []
|
| 83 |
+
|
| 84 |
+
self.model.train()
|
| 85 |
+
epoch_time = time.time()
|
| 86 |
+
for i, (batch_x, batch_y) in enumerate(train_loader):
|
| 87 |
+
iter_count += 1
|
| 88 |
+
model_optim.zero_grad()
|
| 89 |
+
|
| 90 |
+
batch_x = batch_x.float().to(self.device)
|
| 91 |
+
|
| 92 |
+
outputs = self.model(batch_x, None, None, None)
|
| 93 |
+
|
| 94 |
+
f_dim = -1 if self.args.features == 'MS' else 0
|
| 95 |
+
outputs = outputs[:, :, f_dim:]
|
| 96 |
+
loss = criterion(outputs, batch_x)
|
| 97 |
+
train_loss.append(loss.item())
|
| 98 |
+
|
| 99 |
+
if (i + 1) % 100 == 0:
|
| 100 |
+
print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
|
| 101 |
+
speed = (time.time() - time_now) / iter_count
|
| 102 |
+
left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
|
| 103 |
+
print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
|
| 104 |
+
iter_count = 0
|
| 105 |
+
time_now = time.time()
|
| 106 |
+
|
| 107 |
+
loss.backward()
|
| 108 |
+
model_optim.step()
|
| 109 |
+
|
| 110 |
+
print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
|
| 111 |
+
train_loss = np.average(train_loss)
|
| 112 |
+
vali_loss = self.vali(vali_data, vali_loader, criterion)
|
| 113 |
+
test_loss = self.vali(test_data, test_loader, criterion)
|
| 114 |
+
|
| 115 |
+
print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
|
| 116 |
+
epoch + 1, train_steps, train_loss, vali_loss, test_loss))
|
| 117 |
+
early_stopping(vali_loss, self.model, path)
|
| 118 |
+
if early_stopping.early_stop:
|
| 119 |
+
print("Early stopping")
|
| 120 |
+
break
|
| 121 |
+
adjust_learning_rate(model_optim, epoch + 1, self.args)
|
| 122 |
+
|
| 123 |
+
best_model_path = path + '/' + 'checkpoint.pth'
|
| 124 |
+
self.model.load_state_dict(torch.load(best_model_path))
|
| 125 |
+
|
| 126 |
+
return self.model
|
| 127 |
+
|
| 128 |
+
def test(self, setting, test=0):
|
| 129 |
+
test_data, test_loader = self._get_data(flag='test')
|
| 130 |
+
train_data, train_loader = self._get_data(flag='train')
|
| 131 |
+
if test:
|
| 132 |
+
print('loading model')
|
| 133 |
+
self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))
|
| 134 |
+
|
| 135 |
+
attens_energy = []
|
| 136 |
+
folder_path = './test_results/' + setting + '/'
|
| 137 |
+
if not os.path.exists(folder_path):
|
| 138 |
+
os.makedirs(folder_path)
|
| 139 |
+
|
| 140 |
+
self.model.eval()
|
| 141 |
+
self.anomaly_criterion = nn.MSELoss(reduce=False)
|
| 142 |
+
|
| 143 |
+
# (1) stastic on the train set
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
for i, (batch_x, batch_y) in enumerate(train_loader):
|
| 146 |
+
batch_x = batch_x.float().to(self.device)
|
| 147 |
+
# reconstruction
|
| 148 |
+
outputs = self.model(batch_x, None, None, None)
|
| 149 |
+
# criterion
|
| 150 |
+
score = torch.mean(self.anomaly_criterion(batch_x, outputs), dim=-1)
|
| 151 |
+
score = score.detach().cpu().numpy()
|
| 152 |
+
attens_energy.append(score)
|
| 153 |
+
|
| 154 |
+
attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
|
| 155 |
+
train_energy = np.array(attens_energy)
|
| 156 |
+
|
| 157 |
+
# (2) find the threshold
|
| 158 |
+
attens_energy = []
|
| 159 |
+
test_labels = []
|
| 160 |
+
for i, (batch_x, batch_y) in enumerate(test_loader):
|
| 161 |
+
batch_x = batch_x.float().to(self.device)
|
| 162 |
+
# reconstruction
|
| 163 |
+
outputs = self.model(batch_x, None, None, None)
|
| 164 |
+
# criterion
|
| 165 |
+
score = torch.mean(self.anomaly_criterion(batch_x, outputs), dim=-1)
|
| 166 |
+
score = score.detach().cpu().numpy()
|
| 167 |
+
attens_energy.append(score)
|
| 168 |
+
test_labels.append(batch_y)
|
| 169 |
+
|
| 170 |
+
attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
|
| 171 |
+
test_energy = np.array(attens_energy)
|
| 172 |
+
combined_energy = np.concatenate([train_energy, test_energy], axis=0)
|
| 173 |
+
threshold = np.percentile(combined_energy, 100 - self.args.anomaly_ratio)
|
| 174 |
+
print("Threshold :", threshold)
|
| 175 |
+
|
| 176 |
+
# (3) evaluation on the test set
|
| 177 |
+
pred = (test_energy > threshold).astype(int)
|
| 178 |
+
test_labels = np.concatenate(test_labels, axis=0).reshape(-1)
|
| 179 |
+
test_labels = np.array(test_labels)
|
| 180 |
+
gt = test_labels.astype(int)
|
| 181 |
+
|
| 182 |
+
print("pred: ", pred.shape)
|
| 183 |
+
print("gt: ", gt.shape)
|
| 184 |
+
|
| 185 |
+
# (4) detection adjustment
|
| 186 |
+
gt, pred = adjustment(gt, pred)
|
| 187 |
+
|
| 188 |
+
pred = np.array(pred)
|
| 189 |
+
gt = np.array(gt)
|
| 190 |
+
print("pred: ", pred.shape)
|
| 191 |
+
print("gt: ", gt.shape)
|
| 192 |
+
|
| 193 |
+
accuracy = accuracy_score(gt, pred)
|
| 194 |
+
precision, recall, f_score, support = precision_recall_fscore_support(gt, pred, average='binary')
|
| 195 |
+
print("Accuracy : {:0.4f}, Precision : {:0.4f}, Recall : {:0.4f}, F-score : {:0.4f} ".format(
|
| 196 |
+
accuracy, precision,
|
| 197 |
+
recall, f_score))
|
| 198 |
+
|
| 199 |
+
f = open("result_anomaly_detection.txt", 'a')
|
| 200 |
+
f.write(setting + " \n")
|
| 201 |
+
f.write("Accuracy : {:0.4f}, Precision : {:0.4f}, Recall : {:0.4f}, F-score : {:0.4f} ".format(
|
| 202 |
+
accuracy, precision,
|
| 203 |
+
recall, f_score))
|
| 204 |
+
f.write('\n')
|
| 205 |
+
f.write('\n')
|
| 206 |
+
f.close()
|
| 207 |
+
return
|
exp/exp_basic.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from models import Autoformer, Transformer, TimesNet, Nonstationary_Transformer, DLinear, FEDformer, \
|
| 4 |
+
Informer, LightTS, Reformer, ETSformer, Pyraformer, PatchTST, MICN, Crossformer, FiLM, iTransformer, \
|
| 5 |
+
Koopa, TiDE, FreTS, TimeMixer, TSMixer, SegRNN, MambaSimple, TemporalFusionTransformer, SCINet, PAttn, TimeXer, \
|
| 6 |
+
WPMixer, MultiPatchFormer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Exp_Basic(object):
|
| 10 |
+
def __init__(self, args):
|
| 11 |
+
self.args = args
|
| 12 |
+
self.model_dict = {
|
| 13 |
+
'TimesNet': TimesNet,
|
| 14 |
+
'Autoformer': Autoformer,
|
| 15 |
+
'Transformer': Transformer,
|
| 16 |
+
'Nonstationary_Transformer': Nonstationary_Transformer,
|
| 17 |
+
'DLinear': DLinear,
|
| 18 |
+
'FEDformer': FEDformer,
|
| 19 |
+
'Informer': Informer,
|
| 20 |
+
'LightTS': LightTS,
|
| 21 |
+
'Reformer': Reformer,
|
| 22 |
+
'ETSformer': ETSformer,
|
| 23 |
+
'PatchTST': PatchTST,
|
| 24 |
+
'Pyraformer': Pyraformer,
|
| 25 |
+
'MICN': MICN,
|
| 26 |
+
'Crossformer': Crossformer,
|
| 27 |
+
'FiLM': FiLM,
|
| 28 |
+
'iTransformer': iTransformer,
|
| 29 |
+
'Koopa': Koopa,
|
| 30 |
+
'TiDE': TiDE,
|
| 31 |
+
'FreTS': FreTS,
|
| 32 |
+
'MambaSimple': MambaSimple,
|
| 33 |
+
'TimeMixer': TimeMixer,
|
| 34 |
+
'TSMixer': TSMixer,
|
| 35 |
+
'SegRNN': SegRNN,
|
| 36 |
+
'TemporalFusionTransformer': TemporalFusionTransformer,
|
| 37 |
+
"SCINet": SCINet,
|
| 38 |
+
'PAttn': PAttn,
|
| 39 |
+
'TimeXer': TimeXer,
|
| 40 |
+
'WPMixer': WPMixer,
|
| 41 |
+
'MultiPatchFormer': MultiPatchFormer
|
| 42 |
+
}
|
| 43 |
+
if args.model == 'Mamba':
|
| 44 |
+
print('Please make sure you have successfully installed mamba_ssm')
|
| 45 |
+
from models import Mamba
|
| 46 |
+
self.model_dict['Mamba'] = Mamba
|
| 47 |
+
|
| 48 |
+
self.device = self._acquire_device()
|
| 49 |
+
self.model = self._build_model().to(self.device)
|
| 50 |
+
|
| 51 |
+
def _build_model(self):
|
| 52 |
+
raise NotImplementedError
|
| 53 |
+
return None
|
| 54 |
+
|
| 55 |
+
def _acquire_device(self):
|
| 56 |
+
if self.args.use_gpu and self.args.gpu_type == 'cuda':
|
| 57 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(
|
| 58 |
+
self.args.gpu) if not self.args.use_multi_gpu else self.args.devices
|
| 59 |
+
device = torch.device('cuda:{}'.format(self.args.gpu))
|
| 60 |
+
print('Use GPU: cuda:{}'.format(self.args.gpu))
|
| 61 |
+
elif self.args.use_gpu and self.args.gpu_type == 'mps':
|
| 62 |
+
device = torch.device('mps')
|
| 63 |
+
print('Use GPU: mps')
|
| 64 |
+
else:
|
| 65 |
+
device = torch.device('cpu')
|
| 66 |
+
print('Use CPU')
|
| 67 |
+
return device
|
| 68 |
+
|
| 69 |
+
def _get_data(self):
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
def vali(self):
|
| 73 |
+
pass
|
| 74 |
+
|
| 75 |
+
def train(self):
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
def test(self):
|
| 79 |
+
pass
|
exp/exp_classification.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from data_provider.data_factory import data_provider
|
| 2 |
+
from exp.exp_basic import Exp_Basic
|
| 3 |
+
from utils.tools import EarlyStopping, adjust_learning_rate, cal_accuracy
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch import optim
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
import warnings
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pdb
|
| 12 |
+
|
| 13 |
+
warnings.filterwarnings('ignore')
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Exp_Classification(Exp_Basic):
|
| 17 |
+
def __init__(self, args):
|
| 18 |
+
super(Exp_Classification, self).__init__(args)
|
| 19 |
+
|
| 20 |
+
def _build_model(self):
|
| 21 |
+
# model input depends on data
|
| 22 |
+
train_data, train_loader = self._get_data(flag='TRAIN')
|
| 23 |
+
test_data, test_loader = self._get_data(flag='TEST')
|
| 24 |
+
self.args.seq_len = max(train_data.max_seq_len, test_data.max_seq_len)
|
| 25 |
+
self.args.pred_len = 0
|
| 26 |
+
self.args.enc_in = train_data.feature_df.shape[1]
|
| 27 |
+
self.args.num_class = len(train_data.class_names)
|
| 28 |
+
# model init
|
| 29 |
+
model = self.model_dict[self.args.model].Model(self.args).float()
|
| 30 |
+
if self.args.use_multi_gpu and self.args.use_gpu:
|
| 31 |
+
model = nn.DataParallel(model, device_ids=self.args.device_ids)
|
| 32 |
+
return model
|
| 33 |
+
|
| 34 |
+
def _get_data(self, flag):
|
| 35 |
+
data_set, data_loader = data_provider(self.args, flag)
|
| 36 |
+
return data_set, data_loader
|
| 37 |
+
|
| 38 |
+
def _select_optimizer(self):
|
| 39 |
+
# model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
|
| 40 |
+
model_optim = optim.RAdam(self.model.parameters(), lr=self.args.learning_rate)
|
| 41 |
+
return model_optim
|
| 42 |
+
|
| 43 |
+
def _select_criterion(self):
|
| 44 |
+
criterion = nn.CrossEntropyLoss()
|
| 45 |
+
return criterion
|
| 46 |
+
|
| 47 |
+
def vali(self, vali_data, vali_loader, criterion):
|
| 48 |
+
total_loss = []
|
| 49 |
+
preds = []
|
| 50 |
+
trues = []
|
| 51 |
+
self.model.eval()
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
for i, (batch_x, label, padding_mask) in enumerate(vali_loader):
|
| 54 |
+
batch_x = batch_x.float().to(self.device)
|
| 55 |
+
padding_mask = padding_mask.float().to(self.device)
|
| 56 |
+
label = label.to(self.device)
|
| 57 |
+
|
| 58 |
+
outputs = self.model(batch_x, padding_mask, None, None)
|
| 59 |
+
|
| 60 |
+
pred = outputs.detach().cpu()
|
| 61 |
+
loss = criterion(pred, label.long().squeeze().cpu())
|
| 62 |
+
total_loss.append(loss)
|
| 63 |
+
|
| 64 |
+
preds.append(outputs.detach())
|
| 65 |
+
trues.append(label)
|
| 66 |
+
|
| 67 |
+
total_loss = np.average(total_loss)
|
| 68 |
+
|
| 69 |
+
preds = torch.cat(preds, 0)
|
| 70 |
+
trues = torch.cat(trues, 0)
|
| 71 |
+
probs = torch.nn.functional.softmax(preds) # (total_samples, num_classes) est. prob. for each class and sample
|
| 72 |
+
predictions = torch.argmax(probs, dim=1).cpu().numpy() # (total_samples,) int class index for each sample
|
| 73 |
+
trues = trues.flatten().cpu().numpy()
|
| 74 |
+
accuracy = cal_accuracy(predictions, trues)
|
| 75 |
+
|
| 76 |
+
self.model.train()
|
| 77 |
+
return total_loss, accuracy
|
| 78 |
+
|
| 79 |
+
def train(self, setting):
|
| 80 |
+
train_data, train_loader = self._get_data(flag='TRAIN')
|
| 81 |
+
vali_data, vali_loader = self._get_data(flag='TEST')
|
| 82 |
+
test_data, test_loader = self._get_data(flag='TEST')
|
| 83 |
+
|
| 84 |
+
path = os.path.join(self.args.checkpoints, setting)
|
| 85 |
+
if not os.path.exists(path):
|
| 86 |
+
os.makedirs(path)
|
| 87 |
+
|
| 88 |
+
time_now = time.time()
|
| 89 |
+
|
| 90 |
+
train_steps = len(train_loader)
|
| 91 |
+
early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
|
| 92 |
+
|
| 93 |
+
model_optim = self._select_optimizer()
|
| 94 |
+
criterion = self._select_criterion()
|
| 95 |
+
|
| 96 |
+
for epoch in range(self.args.train_epochs):
|
| 97 |
+
iter_count = 0
|
| 98 |
+
train_loss = []
|
| 99 |
+
|
| 100 |
+
self.model.train()
|
| 101 |
+
epoch_time = time.time()
|
| 102 |
+
|
| 103 |
+
for i, (batch_x, label, padding_mask) in enumerate(train_loader):
|
| 104 |
+
iter_count += 1
|
| 105 |
+
model_optim.zero_grad()
|
| 106 |
+
|
| 107 |
+
batch_x = batch_x.float().to(self.device)
|
| 108 |
+
padding_mask = padding_mask.float().to(self.device)
|
| 109 |
+
label = label.to(self.device)
|
| 110 |
+
|
| 111 |
+
outputs = self.model(batch_x, padding_mask, None, None)
|
| 112 |
+
loss = criterion(outputs, label.long().squeeze(-1))
|
| 113 |
+
train_loss.append(loss.item())
|
| 114 |
+
|
| 115 |
+
if (i + 1) % 100 == 0:
|
| 116 |
+
print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
|
| 117 |
+
speed = (time.time() - time_now) / iter_count
|
| 118 |
+
left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
|
| 119 |
+
print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
|
| 120 |
+
iter_count = 0
|
| 121 |
+
time_now = time.time()
|
| 122 |
+
|
| 123 |
+
loss.backward()
|
| 124 |
+
nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=4.0)
|
| 125 |
+
model_optim.step()
|
| 126 |
+
|
| 127 |
+
print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
|
| 128 |
+
train_loss = np.average(train_loss)
|
| 129 |
+
vali_loss, val_accuracy = self.vali(vali_data, vali_loader, criterion)
|
| 130 |
+
test_loss, test_accuracy = self.vali(test_data, test_loader, criterion)
|
| 131 |
+
|
| 132 |
+
print(
|
| 133 |
+
"Epoch: {0}, Steps: {1} | Train Loss: {2:.3f} Vali Loss: {3:.3f} Vali Acc: {4:.3f} Test Loss: {5:.3f} Test Acc: {6:.3f}"
|
| 134 |
+
.format(epoch + 1, train_steps, train_loss, vali_loss, val_accuracy, test_loss, test_accuracy))
|
| 135 |
+
early_stopping(-val_accuracy, self.model, path)
|
| 136 |
+
if early_stopping.early_stop:
|
| 137 |
+
print("Early stopping")
|
| 138 |
+
break
|
| 139 |
+
|
| 140 |
+
best_model_path = path + '/' + 'checkpoint.pth'
|
| 141 |
+
self.model.load_state_dict(torch.load(best_model_path))
|
| 142 |
+
|
| 143 |
+
return self.model
|
| 144 |
+
|
| 145 |
+
def test(self, setting, test=0):
|
| 146 |
+
test_data, test_loader = self._get_data(flag='TEST')
|
| 147 |
+
if test:
|
| 148 |
+
print('loading model')
|
| 149 |
+
self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))
|
| 150 |
+
|
| 151 |
+
preds = []
|
| 152 |
+
trues = []
|
| 153 |
+
folder_path = './test_results/' + setting + '/'
|
| 154 |
+
if not os.path.exists(folder_path):
|
| 155 |
+
os.makedirs(folder_path)
|
| 156 |
+
|
| 157 |
+
self.model.eval()
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
for i, (batch_x, label, padding_mask) in enumerate(test_loader):
|
| 160 |
+
batch_x = batch_x.float().to(self.device)
|
| 161 |
+
padding_mask = padding_mask.float().to(self.device)
|
| 162 |
+
label = label.to(self.device)
|
| 163 |
+
|
| 164 |
+
outputs = self.model(batch_x, padding_mask, None, None)
|
| 165 |
+
|
| 166 |
+
preds.append(outputs.detach())
|
| 167 |
+
trues.append(label)
|
| 168 |
+
|
| 169 |
+
preds = torch.cat(preds, 0)
|
| 170 |
+
trues = torch.cat(trues, 0)
|
| 171 |
+
print('test shape:', preds.shape, trues.shape)
|
| 172 |
+
|
| 173 |
+
probs = torch.nn.functional.softmax(preds) # (total_samples, num_classes) est. prob. for each class and sample
|
| 174 |
+
predictions = torch.argmax(probs, dim=1).cpu().numpy() # (total_samples,) int class index for each sample
|
| 175 |
+
trues = trues.flatten().cpu().numpy()
|
| 176 |
+
accuracy = cal_accuracy(predictions, trues)
|
| 177 |
+
|
| 178 |
+
# result save
|
| 179 |
+
folder_path = './results/' + setting + '/'
|
| 180 |
+
if not os.path.exists(folder_path):
|
| 181 |
+
os.makedirs(folder_path)
|
| 182 |
+
|
| 183 |
+
print('accuracy:{}'.format(accuracy))
|
| 184 |
+
file_name='result_classification.txt'
|
| 185 |
+
f = open(os.path.join(folder_path,file_name), 'a')
|
| 186 |
+
f.write(setting + " \n")
|
| 187 |
+
f.write('accuracy:{}'.format(accuracy))
|
| 188 |
+
f.write('\n')
|
| 189 |
+
f.write('\n')
|
| 190 |
+
f.close()
|
| 191 |
+
return
|
exp/exp_imputation.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from data_provider.data_factory import data_provider
|
| 2 |
+
from exp.exp_basic import Exp_Basic
|
| 3 |
+
from utils.tools import EarlyStopping, adjust_learning_rate, visual
|
| 4 |
+
from utils.metrics import metric
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch import optim
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
import warnings
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
warnings.filterwarnings('ignore')
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Exp_Imputation(Exp_Basic):
|
| 17 |
+
def __init__(self, args):
|
| 18 |
+
super(Exp_Imputation, self).__init__(args)
|
| 19 |
+
|
| 20 |
+
def _build_model(self):
|
| 21 |
+
model = self.model_dict[self.args.model].Model(self.args).float()
|
| 22 |
+
|
| 23 |
+
if self.args.use_multi_gpu and self.args.use_gpu:
|
| 24 |
+
model = nn.DataParallel(model, device_ids=self.args.device_ids)
|
| 25 |
+
return model
|
| 26 |
+
|
| 27 |
+
def _get_data(self, flag):
|
| 28 |
+
data_set, data_loader = data_provider(self.args, flag)
|
| 29 |
+
return data_set, data_loader
|
| 30 |
+
|
| 31 |
+
def _select_optimizer(self):
|
| 32 |
+
model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
|
| 33 |
+
return model_optim
|
| 34 |
+
|
| 35 |
+
def _select_criterion(self):
|
| 36 |
+
criterion = nn.MSELoss()
|
| 37 |
+
return criterion
|
| 38 |
+
|
| 39 |
+
def vali(self, vali_data, vali_loader, criterion):
|
| 40 |
+
total_loss = []
|
| 41 |
+
self.model.eval()
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
|
| 44 |
+
batch_x = batch_x.float().to(self.device)
|
| 45 |
+
batch_x_mark = batch_x_mark.float().to(self.device)
|
| 46 |
+
|
| 47 |
+
# random mask
|
| 48 |
+
B, T, N = batch_x.shape
|
| 49 |
+
"""
|
| 50 |
+
B = batch size
|
| 51 |
+
T = seq len
|
| 52 |
+
N = number of features
|
| 53 |
+
"""
|
| 54 |
+
mask = torch.rand((B, T, N)).to(self.device)
|
| 55 |
+
mask[mask <= self.args.mask_rate] = 0 # masked
|
| 56 |
+
mask[mask > self.args.mask_rate] = 1 # remained
|
| 57 |
+
inp = batch_x.masked_fill(mask == 0, 0)
|
| 58 |
+
|
| 59 |
+
outputs = self.model(inp, batch_x_mark, None, None, mask)
|
| 60 |
+
|
| 61 |
+
f_dim = -1 if self.args.features == 'MS' else 0
|
| 62 |
+
outputs = outputs[:, :, f_dim:]
|
| 63 |
+
|
| 64 |
+
# add support for MS
|
| 65 |
+
batch_x = batch_x[:, :, f_dim:]
|
| 66 |
+
mask = mask[:, :, f_dim:]
|
| 67 |
+
|
| 68 |
+
pred = outputs.detach().cpu()
|
| 69 |
+
true = batch_x.detach().cpu()
|
| 70 |
+
mask = mask.detach().cpu()
|
| 71 |
+
|
| 72 |
+
loss = criterion(pred[mask == 0], true[mask == 0])
|
| 73 |
+
total_loss.append(loss)
|
| 74 |
+
total_loss = np.average(total_loss)
|
| 75 |
+
self.model.train()
|
| 76 |
+
return total_loss
|
| 77 |
+
|
| 78 |
+
def train(self, setting):
|
| 79 |
+
train_data, train_loader = self._get_data(flag='train')
|
| 80 |
+
vali_data, vali_loader = self._get_data(flag='val')
|
| 81 |
+
test_data, test_loader = self._get_data(flag='test')
|
| 82 |
+
|
| 83 |
+
path = os.path.join(self.args.checkpoints, setting)
|
| 84 |
+
if not os.path.exists(path):
|
| 85 |
+
os.makedirs(path)
|
| 86 |
+
|
| 87 |
+
time_now = time.time()
|
| 88 |
+
|
| 89 |
+
train_steps = len(train_loader)
|
| 90 |
+
early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
|
| 91 |
+
|
| 92 |
+
model_optim = self._select_optimizer()
|
| 93 |
+
criterion = self._select_criterion()
|
| 94 |
+
|
| 95 |
+
for epoch in range(self.args.train_epochs):
|
| 96 |
+
iter_count = 0
|
| 97 |
+
train_loss = []
|
| 98 |
+
|
| 99 |
+
self.model.train()
|
| 100 |
+
epoch_time = time.time()
|
| 101 |
+
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
|
| 102 |
+
iter_count += 1
|
| 103 |
+
model_optim.zero_grad()
|
| 104 |
+
|
| 105 |
+
batch_x = batch_x.float().to(self.device)
|
| 106 |
+
batch_x_mark = batch_x_mark.float().to(self.device)
|
| 107 |
+
|
| 108 |
+
# random mask
|
| 109 |
+
B, T, N = batch_x.shape
|
| 110 |
+
mask = torch.rand((B, T, N)).to(self.device)
|
| 111 |
+
mask[mask <= self.args.mask_rate] = 0 # masked
|
| 112 |
+
mask[mask > self.args.mask_rate] = 1 # remained
|
| 113 |
+
inp = batch_x.masked_fill(mask == 0, 0)
|
| 114 |
+
|
| 115 |
+
outputs = self.model(inp, batch_x_mark, None, None, mask)
|
| 116 |
+
|
| 117 |
+
f_dim = -1 if self.args.features == 'MS' else 0
|
| 118 |
+
outputs = outputs[:, :, f_dim:]
|
| 119 |
+
|
| 120 |
+
# add support for MS
|
| 121 |
+
batch_x = batch_x[:, :, f_dim:]
|
| 122 |
+
mask = mask[:, :, f_dim:]
|
| 123 |
+
|
| 124 |
+
loss = criterion(outputs[mask == 0], batch_x[mask == 0])
|
| 125 |
+
train_loss.append(loss.item())
|
| 126 |
+
|
| 127 |
+
if (i + 1) % 100 == 0:
|
| 128 |
+
print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
|
| 129 |
+
speed = (time.time() - time_now) / iter_count
|
| 130 |
+
left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
|
| 131 |
+
print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
|
| 132 |
+
iter_count = 0
|
| 133 |
+
time_now = time.time()
|
| 134 |
+
|
| 135 |
+
loss.backward()
|
| 136 |
+
model_optim.step()
|
| 137 |
+
|
| 138 |
+
print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
|
| 139 |
+
train_loss = np.average(train_loss)
|
| 140 |
+
vali_loss = self.vali(vali_data, vali_loader, criterion)
|
| 141 |
+
test_loss = self.vali(test_data, test_loader, criterion)
|
| 142 |
+
|
| 143 |
+
print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
|
| 144 |
+
epoch + 1, train_steps, train_loss, vali_loss, test_loss))
|
| 145 |
+
early_stopping(vali_loss, self.model, path)
|
| 146 |
+
if early_stopping.early_stop:
|
| 147 |
+
print("Early stopping")
|
| 148 |
+
break
|
| 149 |
+
adjust_learning_rate(model_optim, epoch + 1, self.args)
|
| 150 |
+
|
| 151 |
+
best_model_path = path + '/' + 'checkpoint.pth'
|
| 152 |
+
self.model.load_state_dict(torch.load(best_model_path))
|
| 153 |
+
|
| 154 |
+
return self.model
|
| 155 |
+
|
| 156 |
+
def test(self, setting, test=0):
|
| 157 |
+
test_data, test_loader = self._get_data(flag='test')
|
| 158 |
+
if test:
|
| 159 |
+
print('loading model')
|
| 160 |
+
self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))
|
| 161 |
+
|
| 162 |
+
preds = []
|
| 163 |
+
trues = []
|
| 164 |
+
masks = []
|
| 165 |
+
folder_path = './test_results/' + setting + '/'
|
| 166 |
+
if not os.path.exists(folder_path):
|
| 167 |
+
os.makedirs(folder_path)
|
| 168 |
+
|
| 169 |
+
self.model.eval()
|
| 170 |
+
with torch.no_grad():
|
| 171 |
+
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
|
| 172 |
+
batch_x = batch_x.float().to(self.device)
|
| 173 |
+
batch_x_mark = batch_x_mark.float().to(self.device)
|
| 174 |
+
|
| 175 |
+
# random mask
|
| 176 |
+
B, T, N = batch_x.shape
|
| 177 |
+
mask = torch.rand((B, T, N)).to(self.device)
|
| 178 |
+
mask[mask <= self.args.mask_rate] = 0 # masked
|
| 179 |
+
mask[mask > self.args.mask_rate] = 1 # remained
|
| 180 |
+
inp = batch_x.masked_fill(mask == 0, 0)
|
| 181 |
+
|
| 182 |
+
# imputation
|
| 183 |
+
outputs = self.model(inp, batch_x_mark, None, None, mask)
|
| 184 |
+
|
| 185 |
+
# eval
|
| 186 |
+
f_dim = -1 if self.args.features == 'MS' else 0
|
| 187 |
+
outputs = outputs[:, :, f_dim:]
|
| 188 |
+
|
| 189 |
+
# add support for MS
|
| 190 |
+
batch_x = batch_x[:, :, f_dim:]
|
| 191 |
+
mask = mask[:, :, f_dim:]
|
| 192 |
+
|
| 193 |
+
outputs = outputs.detach().cpu().numpy()
|
| 194 |
+
pred = outputs
|
| 195 |
+
true = batch_x.detach().cpu().numpy()
|
| 196 |
+
preds.append(pred)
|
| 197 |
+
trues.append(true)
|
| 198 |
+
masks.append(mask.detach().cpu())
|
| 199 |
+
|
| 200 |
+
if i % 20 == 0:
|
| 201 |
+
filled = true[0, :, -1].copy()
|
| 202 |
+
filled = filled * mask[0, :, -1].detach().cpu().numpy() + \
|
| 203 |
+
pred[0, :, -1] * (1 - mask[0, :, -1].detach().cpu().numpy())
|
| 204 |
+
visual(true[0, :, -1], filled, os.path.join(folder_path, str(i) + '.pdf'))
|
| 205 |
+
|
| 206 |
+
preds = np.concatenate(preds, 0)
|
| 207 |
+
trues = np.concatenate(trues, 0)
|
| 208 |
+
masks = np.concatenate(masks, 0)
|
| 209 |
+
print('test shape:', preds.shape, trues.shape)
|
| 210 |
+
|
| 211 |
+
# result save
|
| 212 |
+
folder_path = './results/' + setting + '/'
|
| 213 |
+
if not os.path.exists(folder_path):
|
| 214 |
+
os.makedirs(folder_path)
|
| 215 |
+
|
| 216 |
+
mae, mse, rmse, mape, mspe = metric(preds[masks == 0], trues[masks == 0])
|
| 217 |
+
print('mse:{}, mae:{}'.format(mse, mae))
|
| 218 |
+
f = open("result_imputation.txt", 'a')
|
| 219 |
+
f.write(setting + " \n")
|
| 220 |
+
f.write('mse:{}, mae:{}'.format(mse, mae))
|
| 221 |
+
f.write('\n')
|
| 222 |
+
f.write('\n')
|
| 223 |
+
f.close()
|
| 224 |
+
|
| 225 |
+
np.save(folder_path + 'metrics.npy', np.array([mae, mse, rmse, mape, mspe]))
|
| 226 |
+
np.save(folder_path + 'pred.npy', preds)
|
| 227 |
+
np.save(folder_path + 'true.npy', trues)
|
| 228 |
+
return
|
exp/exp_long_term_forecasting.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from data_provider.data_factory import data_provider
|
| 2 |
+
from exp.exp_basic import Exp_Basic
|
| 3 |
+
from utils.tools import EarlyStopping, adjust_learning_rate, visual
|
| 4 |
+
from utils.metrics import metric
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch import optim
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
import warnings
|
| 11 |
+
import numpy as np
|
| 12 |
+
from utils.dtw_metric import dtw, accelerated_dtw
|
| 13 |
+
from utils.augmentation import run_augmentation, run_augmentation_single
|
| 14 |
+
|
| 15 |
+
warnings.filterwarnings('ignore')
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Exp_Long_Term_Forecast(Exp_Basic):
|
| 19 |
+
def __init__(self, args):
|
| 20 |
+
super(Exp_Long_Term_Forecast, self).__init__(args)
|
| 21 |
+
|
| 22 |
+
def _build_model(self):
|
| 23 |
+
model = self.model_dict[self.args.model].Model(self.args).float()
|
| 24 |
+
|
| 25 |
+
if self.args.use_multi_gpu and self.args.use_gpu:
|
| 26 |
+
model = nn.DataParallel(model, device_ids=self.args.device_ids)
|
| 27 |
+
return model
|
| 28 |
+
|
| 29 |
+
def _get_data(self, flag):
|
| 30 |
+
data_set, data_loader = data_provider(self.args, flag)
|
| 31 |
+
return data_set, data_loader
|
| 32 |
+
|
| 33 |
+
def _select_optimizer(self):
|
| 34 |
+
model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
|
| 35 |
+
return model_optim
|
| 36 |
+
|
| 37 |
+
def _select_criterion(self):
|
| 38 |
+
criterion = nn.MSELoss()
|
| 39 |
+
return criterion
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def vali(self, vali_data, vali_loader, criterion):
|
| 43 |
+
total_loss = []
|
| 44 |
+
self.model.eval()
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
|
| 47 |
+
batch_x = batch_x.float().to(self.device)
|
| 48 |
+
batch_y = batch_y.float()
|
| 49 |
+
|
| 50 |
+
batch_x_mark = batch_x_mark.float().to(self.device)
|
| 51 |
+
batch_y_mark = batch_y_mark.float().to(self.device)
|
| 52 |
+
|
| 53 |
+
# decoder input
|
| 54 |
+
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
|
| 55 |
+
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
|
| 56 |
+
# encoder - decoder
|
| 57 |
+
if self.args.use_amp:
|
| 58 |
+
with torch.cuda.amp.autocast():
|
| 59 |
+
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
| 60 |
+
else:
|
| 61 |
+
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
| 62 |
+
f_dim = -1 if self.args.features == 'MS' else 0
|
| 63 |
+
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
| 64 |
+
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
|
| 65 |
+
|
| 66 |
+
pred = outputs.detach().cpu()
|
| 67 |
+
true = batch_y.detach().cpu()
|
| 68 |
+
|
| 69 |
+
loss = criterion(pred, true)
|
| 70 |
+
|
| 71 |
+
total_loss.append(loss)
|
| 72 |
+
total_loss = np.average(total_loss)
|
| 73 |
+
self.model.train()
|
| 74 |
+
return total_loss
|
| 75 |
+
|
| 76 |
+
def train(self, setting):
|
| 77 |
+
train_data, train_loader = self._get_data(flag='train')
|
| 78 |
+
vali_data, vali_loader = self._get_data(flag='val')
|
| 79 |
+
test_data, test_loader = self._get_data(flag='test')
|
| 80 |
+
|
| 81 |
+
path = os.path.join(self.args.checkpoints, setting)
|
| 82 |
+
if not os.path.exists(path):
|
| 83 |
+
os.makedirs(path)
|
| 84 |
+
|
| 85 |
+
time_now = time.time()
|
| 86 |
+
|
| 87 |
+
train_steps = len(train_loader)
|
| 88 |
+
early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
|
| 89 |
+
|
| 90 |
+
model_optim = self._select_optimizer()
|
| 91 |
+
criterion = self._select_criterion()
|
| 92 |
+
|
| 93 |
+
if self.args.use_amp:
|
| 94 |
+
scaler = torch.cuda.amp.GradScaler()
|
| 95 |
+
|
| 96 |
+
for epoch in range(self.args.train_epochs):
|
| 97 |
+
iter_count = 0
|
| 98 |
+
train_loss = []
|
| 99 |
+
|
| 100 |
+
self.model.train()
|
| 101 |
+
epoch_time = time.time()
|
| 102 |
+
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
|
| 103 |
+
iter_count += 1
|
| 104 |
+
model_optim.zero_grad()
|
| 105 |
+
batch_x = batch_x.float().to(self.device)
|
| 106 |
+
batch_y = batch_y.float().to(self.device)
|
| 107 |
+
batch_x_mark = batch_x_mark.float().to(self.device)
|
| 108 |
+
batch_y_mark = batch_y_mark.float().to(self.device)
|
| 109 |
+
|
| 110 |
+
# decoder input
|
| 111 |
+
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
|
| 112 |
+
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
|
| 113 |
+
|
| 114 |
+
# encoder - decoder
|
| 115 |
+
if self.args.use_amp:
|
| 116 |
+
with torch.cuda.amp.autocast():
|
| 117 |
+
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
| 118 |
+
|
| 119 |
+
f_dim = -1 if self.args.features == 'MS' else 0
|
| 120 |
+
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
| 121 |
+
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
|
| 122 |
+
loss = criterion(outputs, batch_y)
|
| 123 |
+
train_loss.append(loss.item())
|
| 124 |
+
else:
|
| 125 |
+
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
| 126 |
+
|
| 127 |
+
f_dim = -1 if self.args.features == 'MS' else 0
|
| 128 |
+
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
| 129 |
+
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
|
| 130 |
+
loss = criterion(outputs, batch_y)
|
| 131 |
+
train_loss.append(loss.item())
|
| 132 |
+
|
| 133 |
+
if (i + 1) % 100 == 0:
|
| 134 |
+
print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
|
| 135 |
+
speed = (time.time() - time_now) / iter_count
|
| 136 |
+
left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
|
| 137 |
+
print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
|
| 138 |
+
iter_count = 0
|
| 139 |
+
time_now = time.time()
|
| 140 |
+
|
| 141 |
+
if self.args.use_amp:
|
| 142 |
+
scaler.scale(loss).backward()
|
| 143 |
+
scaler.step(model_optim)
|
| 144 |
+
scaler.update()
|
| 145 |
+
else:
|
| 146 |
+
loss.backward()
|
| 147 |
+
model_optim.step()
|
| 148 |
+
|
| 149 |
+
print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
|
| 150 |
+
train_loss = np.average(train_loss)
|
| 151 |
+
vali_loss = self.vali(vali_data, vali_loader, criterion)
|
| 152 |
+
test_loss = self.vali(test_data, test_loader, criterion)
|
| 153 |
+
|
| 154 |
+
print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
|
| 155 |
+
epoch + 1, train_steps, train_loss, vali_loss, test_loss))
|
| 156 |
+
early_stopping(vali_loss, self.model, path)
|
| 157 |
+
if early_stopping.early_stop:
|
| 158 |
+
print("Early stopping")
|
| 159 |
+
break
|
| 160 |
+
|
| 161 |
+
adjust_learning_rate(model_optim, epoch + 1, self.args)
|
| 162 |
+
|
| 163 |
+
best_model_path = path + '/' + 'checkpoint.pth'
|
| 164 |
+
self.model.load_state_dict(torch.load(best_model_path))
|
| 165 |
+
|
| 166 |
+
return self.model
|
| 167 |
+
|
| 168 |
+
def test(self, setting, test=0):
|
| 169 |
+
test_data, test_loader = self._get_data(flag='test')
|
| 170 |
+
if test:
|
| 171 |
+
print('loading model')
|
| 172 |
+
self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))
|
| 173 |
+
|
| 174 |
+
preds = []
|
| 175 |
+
trues = []
|
| 176 |
+
folder_path = './test_results/' + setting + '/'
|
| 177 |
+
if not os.path.exists(folder_path):
|
| 178 |
+
os.makedirs(folder_path)
|
| 179 |
+
|
| 180 |
+
self.model.eval()
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
|
| 183 |
+
batch_x = batch_x.float().to(self.device)
|
| 184 |
+
batch_y = batch_y.float().to(self.device)
|
| 185 |
+
|
| 186 |
+
batch_x_mark = batch_x_mark.float().to(self.device)
|
| 187 |
+
batch_y_mark = batch_y_mark.float().to(self.device)
|
| 188 |
+
|
| 189 |
+
# decoder input
|
| 190 |
+
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
|
| 191 |
+
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
|
| 192 |
+
# encoder - decoder
|
| 193 |
+
if self.args.use_amp:
|
| 194 |
+
with torch.cuda.amp.autocast():
|
| 195 |
+
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
| 196 |
+
else:
|
| 197 |
+
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
| 198 |
+
|
| 199 |
+
f_dim = -1 if self.args.features == 'MS' else 0
|
| 200 |
+
outputs = outputs[:, -self.args.pred_len:, :]
|
| 201 |
+
batch_y = batch_y[:, -self.args.pred_len:, :].to(self.device)
|
| 202 |
+
outputs = outputs.detach().cpu().numpy()
|
| 203 |
+
batch_y = batch_y.detach().cpu().numpy()
|
| 204 |
+
if test_data.scale and self.args.inverse:
|
| 205 |
+
shape = batch_y.shape
|
| 206 |
+
if outputs.shape[-1] != batch_y.shape[-1]:
|
| 207 |
+
outputs = np.tile(outputs, [1, 1, int(batch_y.shape[-1] / outputs.shape[-1])])
|
| 208 |
+
outputs = test_data.inverse_transform(outputs.reshape(shape[0] * shape[1], -1)).reshape(shape)
|
| 209 |
+
batch_y = test_data.inverse_transform(batch_y.reshape(shape[0] * shape[1], -1)).reshape(shape)
|
| 210 |
+
|
| 211 |
+
outputs = outputs[:, :, f_dim:]
|
| 212 |
+
batch_y = batch_y[:, :, f_dim:]
|
| 213 |
+
|
| 214 |
+
pred = outputs
|
| 215 |
+
true = batch_y
|
| 216 |
+
|
| 217 |
+
preds.append(pred)
|
| 218 |
+
trues.append(true)
|
| 219 |
+
if i % 20 == 0:
|
| 220 |
+
input = batch_x.detach().cpu().numpy()
|
| 221 |
+
if test_data.scale and self.args.inverse:
|
| 222 |
+
shape = input.shape
|
| 223 |
+
input = test_data.inverse_transform(input.reshape(shape[0] * shape[1], -1)).reshape(shape)
|
| 224 |
+
gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0)
|
| 225 |
+
pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0)
|
| 226 |
+
visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf'))
|
| 227 |
+
|
| 228 |
+
preds = np.concatenate(preds, axis=0)
|
| 229 |
+
trues = np.concatenate(trues, axis=0)
|
| 230 |
+
print('test shape:', preds.shape, trues.shape)
|
| 231 |
+
preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
|
| 232 |
+
trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
|
| 233 |
+
print('test shape:', preds.shape, trues.shape)
|
| 234 |
+
|
| 235 |
+
# result save
|
| 236 |
+
folder_path = './results/' + setting + '/'
|
| 237 |
+
if not os.path.exists(folder_path):
|
| 238 |
+
os.makedirs(folder_path)
|
| 239 |
+
|
| 240 |
+
# dtw calculation
|
| 241 |
+
if self.args.use_dtw:
|
| 242 |
+
dtw_list = []
|
| 243 |
+
manhattan_distance = lambda x, y: np.abs(x - y)
|
| 244 |
+
for i in range(preds.shape[0]):
|
| 245 |
+
x = preds[i].reshape(-1, 1)
|
| 246 |
+
y = trues[i].reshape(-1, 1)
|
| 247 |
+
if i % 100 == 0:
|
| 248 |
+
print("calculating dtw iter:", i)
|
| 249 |
+
d, _, _, _ = accelerated_dtw(x, y, dist=manhattan_distance)
|
| 250 |
+
dtw_list.append(d)
|
| 251 |
+
dtw = np.array(dtw_list).mean()
|
| 252 |
+
else:
|
| 253 |
+
dtw = 'Not calculated'
|
| 254 |
+
|
| 255 |
+
mae, mse, rmse, mape, mspe = metric(preds, trues)
|
| 256 |
+
print('mse:{}, mae:{}, dtw:{}'.format(mse, mae, dtw))
|
| 257 |
+
f = open("result_long_term_forecast.txt", 'a')
|
| 258 |
+
f.write(setting + " \n")
|
| 259 |
+
f.write('mse:{}, mae:{}, dtw:{}'.format(mse, mae, dtw))
|
| 260 |
+
f.write('\n')
|
| 261 |
+
f.write('\n')
|
| 262 |
+
f.close()
|
| 263 |
+
|
| 264 |
+
np.save(folder_path + 'metrics.npy', np.array([mae, mse, rmse, mape, mspe]))
|
| 265 |
+
np.save(folder_path + 'pred.npy', preds)
|
| 266 |
+
np.save(folder_path + 'true.npy', trues)
|
| 267 |
+
|
| 268 |
+
return
|
exp/exp_short_term_forecasting.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from data_provider.data_factory import data_provider
|
| 2 |
+
from data_provider.m4 import M4Meta
|
| 3 |
+
from exp.exp_basic import Exp_Basic
|
| 4 |
+
from utils.tools import EarlyStopping, adjust_learning_rate, visual
|
| 5 |
+
from utils.losses import mape_loss, mase_loss, smape_loss
|
| 6 |
+
from utils.m4_summary import M4Summary
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torch import optim
|
| 11 |
+
import os
|
| 12 |
+
import time
|
| 13 |
+
import warnings
|
| 14 |
+
import numpy as np
|
| 15 |
+
import pandas as pd
|
| 16 |
+
|
| 17 |
+
warnings.filterwarnings('ignore')
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def mse(output, label):
|
| 21 |
+
mse_value = (output - label) ** 2
|
| 22 |
+
return np.mean(mse_value)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def rmse(output, label):
|
| 26 |
+
mse_value = mse(output=output, label=label)
|
| 27 |
+
return np.sqrt(mse_value)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def mae(output, label):
|
| 31 |
+
mae_value = np.abs(output - label)
|
| 32 |
+
return np.mean(mae_value)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def mape(output, label):
|
| 36 |
+
# 避免除以零
|
| 37 |
+
mask = label != 0
|
| 38 |
+
mape_value = np.abs((output[mask] - label[mask]) / label[mask])
|
| 39 |
+
return np.mean(mape_value) * 100
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Exp_Short_Term_Forecast(Exp_Basic):
|
| 43 |
+
def __init__(self, args):
|
| 44 |
+
super(Exp_Short_Term_Forecast, self).__init__(args)
|
| 45 |
+
|
| 46 |
+
def _build_model(self):
|
| 47 |
+
if self.args.data == 'm4':
|
| 48 |
+
self.args.pred_len = M4Meta.horizons_map[self.args.seasonal_patterns]
|
| 49 |
+
self.args.seq_len = 2 * self.args.pred_len
|
| 50 |
+
self.args.label_len = self.args.pred_len
|
| 51 |
+
self.args.frequency_map = M4Meta.frequency_map[self.args.seasonal_patterns]
|
| 52 |
+
elif self.args.data == 'kalshi':
|
| 53 |
+
# Poly 数据集使用命令行传入的参数,不需要特殊处理
|
| 54 |
+
self.args.frequency_map = 1
|
| 55 |
+
elif self.args.data == 'poly':
|
| 56 |
+
# Poly 数据集使用命令行传入的参数,不需要特殊处理
|
| 57 |
+
self.args.frequency_map = 1
|
| 58 |
+
|
| 59 |
+
model = self.model_dict[self.args.model].Model(self.args).float()
|
| 60 |
+
|
| 61 |
+
if self.args.use_multi_gpu and self.args.use_gpu:
|
| 62 |
+
model = nn.DataParallel(model, device_ids=self.args.device_ids)
|
| 63 |
+
return model
|
| 64 |
+
|
| 65 |
+
def _get_data(self, flag):
|
| 66 |
+
data_set, data_loader = data_provider(self.args, flag)
|
| 67 |
+
return data_set, data_loader
|
| 68 |
+
|
| 69 |
+
def _select_optimizer(self):
|
| 70 |
+
model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
|
| 71 |
+
return model_optim
|
| 72 |
+
|
| 73 |
+
def _select_criterion(self, loss_name='MSE'):
|
| 74 |
+
if loss_name == 'MSE':
|
| 75 |
+
return nn.MSELoss()
|
| 76 |
+
elif loss_name == 'MAPE':
|
| 77 |
+
return mape_loss()
|
| 78 |
+
elif loss_name == 'MASE':
|
| 79 |
+
return mase_loss()
|
| 80 |
+
elif loss_name == 'SMAPE':
|
| 81 |
+
return smape_loss()
|
| 82 |
+
|
| 83 |
+
def _prepare_data_from_timeseries(self, timeseries):
|
| 84 |
+
"""
|
| 85 |
+
从 timeseries 中提取 x 和 y
|
| 86 |
+
timeseries 中每个元素长度 >= seq_len + 1
|
| 87 |
+
x: 最后 seq_len 个点(不含 after)
|
| 88 |
+
y: 最后 pred_len 个点(after)
|
| 89 |
+
"""
|
| 90 |
+
x_list = []
|
| 91 |
+
y_list = []
|
| 92 |
+
|
| 93 |
+
for ts in timeseries:
|
| 94 |
+
# 取最后 seq_len+1 个点
|
| 95 |
+
# x = ts[-(seq_len+1):-1], y = ts[-pred_len:]
|
| 96 |
+
x_list.append(ts[-(self.args.seq_len + 1):-1])
|
| 97 |
+
y_list.append(ts[-self.args.pred_len:])
|
| 98 |
+
|
| 99 |
+
x = torch.tensor(x_list, dtype=torch.float32).to(self.device)
|
| 100 |
+
x = x.unsqueeze(-1) # (B, seq_len, 1)
|
| 101 |
+
y = np.array(y_list) # (B, pred_len)
|
| 102 |
+
|
| 103 |
+
return x, y
|
| 104 |
+
|
| 105 |
+
def train(self, setting):
|
| 106 |
+
train_data, train_loader = self._get_data(flag='train')
|
| 107 |
+
vali_data, vali_loader = self._get_data(flag='val')
|
| 108 |
+
|
| 109 |
+
path = os.path.join(self.args.checkpoints, setting)
|
| 110 |
+
if not os.path.exists(path):
|
| 111 |
+
os.makedirs(path)
|
| 112 |
+
|
| 113 |
+
time_now = time.time()
|
| 114 |
+
|
| 115 |
+
train_steps = len(train_loader)
|
| 116 |
+
early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
|
| 117 |
+
|
| 118 |
+
model_optim = self._select_optimizer()
|
| 119 |
+
criterion = self._select_criterion(self.args.loss)
|
| 120 |
+
mse_loss = nn.MSELoss()
|
| 121 |
+
|
| 122 |
+
for epoch in range(self.args.train_epochs):
|
| 123 |
+
iter_count = 0
|
| 124 |
+
train_loss = []
|
| 125 |
+
|
| 126 |
+
self.model.train()
|
| 127 |
+
epoch_time = time.time()
|
| 128 |
+
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
|
| 129 |
+
iter_count += 1
|
| 130 |
+
model_optim.zero_grad()
|
| 131 |
+
|
| 132 |
+
batch_x = batch_x.float().to(self.device) # (B, seq_len, 1)
|
| 133 |
+
batch_y = batch_y.float().to(self.device) # (B, label_len + pred_len, 1)
|
| 134 |
+
batch_y_mark = batch_y_mark.float().to(self.device)
|
| 135 |
+
|
| 136 |
+
# decoder input: [label_len 个真实值, pred_len 个零]
|
| 137 |
+
dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float().to(self.device)
|
| 138 |
+
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
|
| 139 |
+
|
| 140 |
+
outputs = self.model(batch_x, None, dec_inp, None)
|
| 141 |
+
|
| 142 |
+
f_dim = -1 if self.args.features == 'MS' else 0
|
| 143 |
+
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
| 144 |
+
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
|
| 145 |
+
|
| 146 |
+
# 使用简单的 MSE loss
|
| 147 |
+
if self.args.loss == 'MSE':
|
| 148 |
+
loss = mse_loss(outputs, batch_y)
|
| 149 |
+
else:
|
| 150 |
+
batch_y_mark = batch_y_mark[:, -self.args.pred_len:, f_dim:].to(self.device)
|
| 151 |
+
loss = criterion(batch_x, self.args.frequency_map, outputs, batch_y, batch_y_mark)
|
| 152 |
+
|
| 153 |
+
train_loss.append(loss.item())
|
| 154 |
+
|
| 155 |
+
if (i + 1) % 100 == 0:
|
| 156 |
+
print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
|
| 157 |
+
speed = (time.time() - time_now) / iter_count
|
| 158 |
+
left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
|
| 159 |
+
print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
|
| 160 |
+
iter_count = 0
|
| 161 |
+
time_now = time.time()
|
| 162 |
+
|
| 163 |
+
loss.backward()
|
| 164 |
+
model_optim.step()
|
| 165 |
+
|
| 166 |
+
print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
|
| 167 |
+
train_loss = np.average(train_loss)
|
| 168 |
+
vali_loss = self.vali(vali_loader)
|
| 169 |
+
print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f}".format(
|
| 170 |
+
epoch + 1, train_steps, train_loss, vali_loss))
|
| 171 |
+
|
| 172 |
+
early_stopping(vali_loss, self.model, path)
|
| 173 |
+
if early_stopping.early_stop:
|
| 174 |
+
print("Early stopping")
|
| 175 |
+
break
|
| 176 |
+
|
| 177 |
+
adjust_learning_rate(model_optim, epoch + 1, self.args)
|
| 178 |
+
|
| 179 |
+
best_model_path = path + '/' + 'checkpoint.pth'
|
| 180 |
+
self.model.load_state_dict(torch.load(best_model_path))
|
| 181 |
+
|
| 182 |
+
return self.model
|
| 183 |
+
|
| 184 |
+
def vali(self, vali_loader):
|
| 185 |
+
"""验证函数"""
|
| 186 |
+
timeseries = vali_loader.dataset.timeseries
|
| 187 |
+
x, y = self._prepare_data_from_timeseries(timeseries)
|
| 188 |
+
|
| 189 |
+
self.model.eval()
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
B, _, C = x.shape
|
| 192 |
+
|
| 193 |
+
# decoder input
|
| 194 |
+
dec_inp = torch.zeros((B, self.args.pred_len, C)).float().to(self.device)
|
| 195 |
+
# 从 x 中取最后 label_len 个点作为 decoder 的已知输入
|
| 196 |
+
dec_inp = torch.cat([x[:, -self.args.label_len:, :], dec_inp], dim=1).float()
|
| 197 |
+
|
| 198 |
+
# 分批推理,避免 OOM
|
| 199 |
+
outputs = torch.zeros((B, self.args.pred_len, C)).float()
|
| 200 |
+
batch_size = 500
|
| 201 |
+
id_list = np.arange(0, B, batch_size)
|
| 202 |
+
id_list = np.append(id_list, B)
|
| 203 |
+
|
| 204 |
+
for i in range(len(id_list) - 1):
|
| 205 |
+
start_idx, end_idx = id_list[i], id_list[i + 1]
|
| 206 |
+
outputs[start_idx:end_idx, :, :] = self.model(
|
| 207 |
+
x[start_idx:end_idx], None,
|
| 208 |
+
dec_inp[start_idx:end_idx], None
|
| 209 |
+
).detach().cpu()
|
| 210 |
+
|
| 211 |
+
f_dim = -1 if self.args.features == 'MS' else 0
|
| 212 |
+
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
| 213 |
+
|
| 214 |
+
pred = outputs.numpy()
|
| 215 |
+
true = y
|
| 216 |
+
|
| 217 |
+
loss = mse(pred, true)
|
| 218 |
+
|
| 219 |
+
self.model.train()
|
| 220 |
+
return loss
|
| 221 |
+
def test(self, setting, test=0):
|
| 222 |
+
"""测试函数"""
|
| 223 |
+
# 测试多个数据集
|
| 224 |
+
flags = ['test', 'test_Companies', 'test_Economics', 'test_Entertainment', 'test_Mentions', 'test_Politics']
|
| 225 |
+
# flags = ['test', 'test_Crypto', 'test_Politics', 'test_Election']
|
| 226 |
+
|
| 227 |
+
results = []
|
| 228 |
+
columns = []
|
| 229 |
+
|
| 230 |
+
for flag in flags:
|
| 231 |
+
try:
|
| 232 |
+
_, test_loader = self._get_data(flag=flag)
|
| 233 |
+
except Exception as e:
|
| 234 |
+
print(f"Skipping {flag}: {e}")
|
| 235 |
+
continue
|
| 236 |
+
|
| 237 |
+
timeseries = test_loader.dataset.timeseries
|
| 238 |
+
|
| 239 |
+
if len(timeseries) == 0:
|
| 240 |
+
print(f"[{flag}] No samples, skipping...")
|
| 241 |
+
continue
|
| 242 |
+
|
| 243 |
+
x, y = self._prepare_data_from_timeseries(timeseries)
|
| 244 |
+
|
| 245 |
+
if test:
|
| 246 |
+
print('Loading model...')
|
| 247 |
+
self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))
|
| 248 |
+
|
| 249 |
+
folder_path = './test_results/' + setting + '/'
|
| 250 |
+
if not os.path.exists(folder_path):
|
| 251 |
+
os.makedirs(folder_path)
|
| 252 |
+
|
| 253 |
+
self.model.eval()
|
| 254 |
+
with torch.no_grad():
|
| 255 |
+
B, _, C = x.shape
|
| 256 |
+
|
| 257 |
+
# decoder input
|
| 258 |
+
dec_inp = torch.zeros((B, self.args.pred_len, C)).float().to(self.device)
|
| 259 |
+
dec_inp = torch.cat([x[:, -self.args.label_len:, :], dec_inp], dim=1).float()
|
| 260 |
+
|
| 261 |
+
# 分批推理
|
| 262 |
+
outputs = torch.zeros((B, self.args.pred_len, C)).float().to(self.device)
|
| 263 |
+
batch_size = 500
|
| 264 |
+
id_list = np.arange(0, B, batch_size)
|
| 265 |
+
id_list = np.append(id_list, B)
|
| 266 |
+
|
| 267 |
+
for i in range(len(id_list) - 1):
|
| 268 |
+
start_idx, end_idx = id_list[i], id_list[i + 1]
|
| 269 |
+
outputs[start_idx:end_idx, :, :] = self.model(
|
| 270 |
+
x[start_idx:end_idx], None,
|
| 271 |
+
dec_inp[start_idx:end_idx], None
|
| 272 |
+
)
|
| 273 |
+
if start_idx % 1000 == 0:
|
| 274 |
+
print(f"Processed {start_idx}/{B}")
|
| 275 |
+
|
| 276 |
+
f_dim = -1 if self.args.features == 'MS' else 0
|
| 277 |
+
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
| 278 |
+
preds = outputs.detach().cpu().numpy()
|
| 279 |
+
trues = y
|
| 280 |
+
|
| 281 |
+
print(f'[{flag}] Test shape: {preds.shape}')
|
| 282 |
+
|
| 283 |
+
# 计算指标
|
| 284 |
+
rmse_val = rmse(preds, trues)
|
| 285 |
+
mae_val = mae(preds, trues)
|
| 286 |
+
|
| 287 |
+
columns.extend([f'{flag}_rmse', f'{flag}_mae'])
|
| 288 |
+
results.extend([rmse_val, mae_val])
|
| 289 |
+
|
| 290 |
+
print(f'[{flag}] RMSE: {rmse_val:.6f}, MAE: {mae_val:.6f}')
|
| 291 |
+
|
| 292 |
+
# 保存结果
|
| 293 |
+
folder_path = f'./{self.args.data}_results/' + self.args.model + '/'
|
| 294 |
+
if not os.path.exists(folder_path):
|
| 295 |
+
os.makedirs(folder_path)
|
| 296 |
+
|
| 297 |
+
df = pd.DataFrame([results], columns=columns)
|
| 298 |
+
result_path = os.path.join(folder_path, f'{self.args.model}_results.csv')
|
| 299 |
+
df.to_csv(result_path, index=False)
|
| 300 |
+
print(f'Results saved to {result_path}')
|
| 301 |
+
|
| 302 |
+
return results
|
kalshi_results/Autoformer/Autoformer_results.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
test_rmse,test_mae,test_Companies_rmse,test_Companies_mae,test_Economics_rmse,test_Economics_mae,test_Entertainment_rmse,test_Entertainment_mae,test_Mentions_rmse,test_Mentions_mae,test_Politics_rmse,test_Politics_mae
|
| 2 |
+
0.40856921686443876,0.3217217668927483,0.364304113186295,0.28937346579139905,0.5096567834520302,0.3970817213714233,0.40597254578512976,0.3254350355656028,0.3871389816162235,0.32037402264013803,0.40295968229454693,0.3157070307474669
|
kalshi_results/DLinear/DLinear_results.csv
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
test_rmse,test_mae,test_Companies_rmse,test_Companies_mae,test_Economics_rmse,test_Economics_mae,test_Entertainment_rmse,test_Entertainment_mae,test_Mentions_rmse,test_Mentions_mae,test_Politics_rmse,test_Politics_mae
|
| 2 |
+
0.3891896238849517,0.30230516746317077,0.351860807826818,0.27699875265589174,0.48147644981667514,0.37157798860447333,0.3897480740571339,0.30954725156348983,0.3710907844730615,0.3075220013497144,0.3831593945812524,0.29365767848840196
|
layers/AutoCorrelation.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
import math
|
| 7 |
+
from math import sqrt
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AutoCorrelation(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
AutoCorrelation Mechanism with the following two phases:
|
| 14 |
+
(1) period-based dependencies discovery
|
| 15 |
+
(2) time delay aggregation
|
| 16 |
+
This block can replace the self-attention family mechanism seamlessly.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):
|
| 20 |
+
super(AutoCorrelation, self).__init__()
|
| 21 |
+
self.factor = factor
|
| 22 |
+
self.scale = scale
|
| 23 |
+
self.mask_flag = mask_flag
|
| 24 |
+
self.output_attention = output_attention
|
| 25 |
+
self.dropout = nn.Dropout(attention_dropout)
|
| 26 |
+
|
| 27 |
+
def time_delay_agg_training(self, values, corr):
|
| 28 |
+
"""
|
| 29 |
+
SpeedUp version of Autocorrelation (a batch-normalization style design)
|
| 30 |
+
This is for the training phase.
|
| 31 |
+
"""
|
| 32 |
+
head = values.shape[1]
|
| 33 |
+
channel = values.shape[2]
|
| 34 |
+
length = values.shape[3]
|
| 35 |
+
# find top k
|
| 36 |
+
top_k = int(self.factor * math.log(length))
|
| 37 |
+
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
|
| 38 |
+
index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
|
| 39 |
+
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
|
| 40 |
+
# update corr
|
| 41 |
+
tmp_corr = torch.softmax(weights, dim=-1)
|
| 42 |
+
# aggregation
|
| 43 |
+
tmp_values = values
|
| 44 |
+
delays_agg = torch.zeros_like(values).float()
|
| 45 |
+
for i in range(top_k):
|
| 46 |
+
pattern = torch.roll(tmp_values, -int(index[i]), -1)
|
| 47 |
+
delays_agg = delays_agg + pattern * \
|
| 48 |
+
(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
|
| 49 |
+
return delays_agg
|
| 50 |
+
|
| 51 |
+
def time_delay_agg_inference(self, values, corr):
|
| 52 |
+
"""
|
| 53 |
+
SpeedUp version of Autocorrelation (a batch-normalization style design)
|
| 54 |
+
This is for the inference phase.
|
| 55 |
+
"""
|
| 56 |
+
batch = values.shape[0]
|
| 57 |
+
head = values.shape[1]
|
| 58 |
+
channel = values.shape[2]
|
| 59 |
+
length = values.shape[3]
|
| 60 |
+
# index init
|
| 61 |
+
init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).to(values.device)
|
| 62 |
+
# find top k
|
| 63 |
+
top_k = int(self.factor * math.log(length))
|
| 64 |
+
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
|
| 65 |
+
weights, delay = torch.topk(mean_value, top_k, dim=-1)
|
| 66 |
+
# update corr
|
| 67 |
+
tmp_corr = torch.softmax(weights, dim=-1)
|
| 68 |
+
# aggregation
|
| 69 |
+
tmp_values = values.repeat(1, 1, 1, 2)
|
| 70 |
+
delays_agg = torch.zeros_like(values).float()
|
| 71 |
+
for i in range(top_k):
|
| 72 |
+
tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
|
| 73 |
+
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
|
| 74 |
+
delays_agg = delays_agg + pattern * \
|
| 75 |
+
(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
|
| 76 |
+
return delays_agg
|
| 77 |
+
|
| 78 |
+
def time_delay_agg_full(self, values, corr):
|
| 79 |
+
"""
|
| 80 |
+
Standard version of Autocorrelation
|
| 81 |
+
"""
|
| 82 |
+
batch = values.shape[0]
|
| 83 |
+
head = values.shape[1]
|
| 84 |
+
channel = values.shape[2]
|
| 85 |
+
length = values.shape[3]
|
| 86 |
+
# index init
|
| 87 |
+
init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).to(values.device)
|
| 88 |
+
# find top k
|
| 89 |
+
top_k = int(self.factor * math.log(length))
|
| 90 |
+
weights, delay = torch.topk(corr, top_k, dim=-1)
|
| 91 |
+
# update corr
|
| 92 |
+
tmp_corr = torch.softmax(weights, dim=-1)
|
| 93 |
+
# aggregation
|
| 94 |
+
tmp_values = values.repeat(1, 1, 1, 2)
|
| 95 |
+
delays_agg = torch.zeros_like(values).float()
|
| 96 |
+
for i in range(top_k):
|
| 97 |
+
tmp_delay = init_index + delay[..., i].unsqueeze(-1)
|
| 98 |
+
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
|
| 99 |
+
delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
|
| 100 |
+
return delays_agg
|
| 101 |
+
|
| 102 |
+
def forward(self, queries, keys, values, attn_mask):
|
| 103 |
+
B, L, H, E = queries.shape
|
| 104 |
+
_, S, _, D = values.shape
|
| 105 |
+
if L > S:
|
| 106 |
+
zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
|
| 107 |
+
values = torch.cat([values, zeros], dim=1)
|
| 108 |
+
keys = torch.cat([keys, zeros], dim=1)
|
| 109 |
+
else:
|
| 110 |
+
values = values[:, :L, :, :]
|
| 111 |
+
keys = keys[:, :L, :, :]
|
| 112 |
+
|
| 113 |
+
# period-based dependencies
|
| 114 |
+
q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
|
| 115 |
+
k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
|
| 116 |
+
res = q_fft * torch.conj(k_fft)
|
| 117 |
+
corr = torch.fft.irfft(res, dim=-1)
|
| 118 |
+
|
| 119 |
+
# time delay agg
|
| 120 |
+
if self.training:
|
| 121 |
+
V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
|
| 122 |
+
else:
|
| 123 |
+
V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
|
| 124 |
+
|
| 125 |
+
if self.output_attention:
|
| 126 |
+
return (V.contiguous(), corr.permute(0, 3, 1, 2))
|
| 127 |
+
else:
|
| 128 |
+
return (V.contiguous(), None)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class AutoCorrelationLayer(nn.Module):
|
| 132 |
+
def __init__(self, correlation, d_model, n_heads, d_keys=None,
|
| 133 |
+
d_values=None):
|
| 134 |
+
super(AutoCorrelationLayer, self).__init__()
|
| 135 |
+
|
| 136 |
+
d_keys = d_keys or (d_model // n_heads)
|
| 137 |
+
d_values = d_values or (d_model // n_heads)
|
| 138 |
+
|
| 139 |
+
self.inner_correlation = correlation
|
| 140 |
+
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
|
| 141 |
+
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
|
| 142 |
+
self.value_projection = nn.Linear(d_model, d_values * n_heads)
|
| 143 |
+
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
| 144 |
+
self.n_heads = n_heads
|
| 145 |
+
|
| 146 |
+
def forward(self, queries, keys, values, attn_mask):
|
| 147 |
+
B, L, _ = queries.shape
|
| 148 |
+
_, S, _ = keys.shape
|
| 149 |
+
H = self.n_heads
|
| 150 |
+
|
| 151 |
+
queries = self.query_projection(queries).view(B, L, H, -1)
|
| 152 |
+
keys = self.key_projection(keys).view(B, S, H, -1)
|
| 153 |
+
values = self.value_projection(values).view(B, S, H, -1)
|
| 154 |
+
|
| 155 |
+
out, attn = self.inner_correlation(
|
| 156 |
+
queries,
|
| 157 |
+
keys,
|
| 158 |
+
values,
|
| 159 |
+
attn_mask
|
| 160 |
+
)
|
| 161 |
+
out = out.view(B, L, -1)
|
| 162 |
+
|
| 163 |
+
return self.out_projection(out), attn
|
layers/Autoformer_EncDec.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class my_Layernorm(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Special designed layernorm for the seasonal part
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, channels):
|
| 12 |
+
super(my_Layernorm, self).__init__()
|
| 13 |
+
self.layernorm = nn.LayerNorm(channels)
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
x_hat = self.layernorm(x)
|
| 17 |
+
bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
|
| 18 |
+
return x_hat - bias
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class moving_avg(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
Moving average block to highlight the trend of time series
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, kernel_size, stride):
|
| 27 |
+
super(moving_avg, self).__init__()
|
| 28 |
+
self.kernel_size = kernel_size
|
| 29 |
+
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
# padding on the both ends of time series
|
| 33 |
+
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
| 34 |
+
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
| 35 |
+
x = torch.cat([front, x, end], dim=1)
|
| 36 |
+
x = self.avg(x.permute(0, 2, 1))
|
| 37 |
+
x = x.permute(0, 2, 1)
|
| 38 |
+
return x
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class series_decomp(nn.Module):
|
| 42 |
+
"""
|
| 43 |
+
Series decomposition block
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, kernel_size):
|
| 47 |
+
super(series_decomp, self).__init__()
|
| 48 |
+
self.moving_avg = moving_avg(kernel_size, stride=1)
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
moving_mean = self.moving_avg(x)
|
| 52 |
+
res = x - moving_mean
|
| 53 |
+
return res, moving_mean
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class series_decomp_multi(nn.Module):
|
| 57 |
+
"""
|
| 58 |
+
Multiple Series decomposition block from FEDformer
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(self, kernel_size):
|
| 62 |
+
super(series_decomp_multi, self).__init__()
|
| 63 |
+
self.kernel_size = kernel_size
|
| 64 |
+
self.series_decomp = [series_decomp(kernel) for kernel in kernel_size]
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
moving_mean = []
|
| 68 |
+
res = []
|
| 69 |
+
for func in self.series_decomp:
|
| 70 |
+
sea, moving_avg = func(x)
|
| 71 |
+
moving_mean.append(moving_avg)
|
| 72 |
+
res.append(sea)
|
| 73 |
+
|
| 74 |
+
sea = sum(res) / len(res)
|
| 75 |
+
moving_mean = sum(moving_mean) / len(moving_mean)
|
| 76 |
+
return sea, moving_mean
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class EncoderLayer(nn.Module):
|
| 80 |
+
"""
|
| 81 |
+
Autoformer encoder layer with the progressive decomposition architecture
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"):
|
| 85 |
+
super(EncoderLayer, self).__init__()
|
| 86 |
+
d_ff = d_ff or 4 * d_model
|
| 87 |
+
self.attention = attention
|
| 88 |
+
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
|
| 89 |
+
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
|
| 90 |
+
self.decomp1 = series_decomp(moving_avg)
|
| 91 |
+
self.decomp2 = series_decomp(moving_avg)
|
| 92 |
+
self.dropout = nn.Dropout(dropout)
|
| 93 |
+
self.activation = F.relu if activation == "relu" else F.gelu
|
| 94 |
+
|
| 95 |
+
def forward(self, x, attn_mask=None):
|
| 96 |
+
new_x, attn = self.attention(
|
| 97 |
+
x, x, x,
|
| 98 |
+
attn_mask=attn_mask
|
| 99 |
+
)
|
| 100 |
+
x = x + self.dropout(new_x)
|
| 101 |
+
x, _ = self.decomp1(x)
|
| 102 |
+
y = x
|
| 103 |
+
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
| 104 |
+
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
| 105 |
+
res, _ = self.decomp2(x + y)
|
| 106 |
+
return res, attn
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class Encoder(nn.Module):
|
| 110 |
+
"""
|
| 111 |
+
Autoformer encoder
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
|
| 115 |
+
super(Encoder, self).__init__()
|
| 116 |
+
self.attn_layers = nn.ModuleList(attn_layers)
|
| 117 |
+
self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
|
| 118 |
+
self.norm = norm_layer
|
| 119 |
+
|
| 120 |
+
def forward(self, x, attn_mask=None):
|
| 121 |
+
attns = []
|
| 122 |
+
if self.conv_layers is not None:
|
| 123 |
+
for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
|
| 124 |
+
x, attn = attn_layer(x, attn_mask=attn_mask)
|
| 125 |
+
x = conv_layer(x)
|
| 126 |
+
attns.append(attn)
|
| 127 |
+
x, attn = self.attn_layers[-1](x)
|
| 128 |
+
attns.append(attn)
|
| 129 |
+
else:
|
| 130 |
+
for attn_layer in self.attn_layers:
|
| 131 |
+
x, attn = attn_layer(x, attn_mask=attn_mask)
|
| 132 |
+
attns.append(attn)
|
| 133 |
+
|
| 134 |
+
if self.norm is not None:
|
| 135 |
+
x = self.norm(x)
|
| 136 |
+
|
| 137 |
+
return x, attns
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class DecoderLayer(nn.Module):
|
| 141 |
+
"""
|
| 142 |
+
Autoformer decoder layer with the progressive decomposition architecture
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None,
|
| 146 |
+
moving_avg=25, dropout=0.1, activation="relu"):
|
| 147 |
+
super(DecoderLayer, self).__init__()
|
| 148 |
+
d_ff = d_ff or 4 * d_model
|
| 149 |
+
self.self_attention = self_attention
|
| 150 |
+
self.cross_attention = cross_attention
|
| 151 |
+
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
|
| 152 |
+
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
|
| 153 |
+
self.decomp1 = series_decomp(moving_avg)
|
| 154 |
+
self.decomp2 = series_decomp(moving_avg)
|
| 155 |
+
self.decomp3 = series_decomp(moving_avg)
|
| 156 |
+
self.dropout = nn.Dropout(dropout)
|
| 157 |
+
self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1,
|
| 158 |
+
padding_mode='circular', bias=False)
|
| 159 |
+
self.activation = F.relu if activation == "relu" else F.gelu
|
| 160 |
+
|
| 161 |
+
def forward(self, x, cross, x_mask=None, cross_mask=None):
|
| 162 |
+
x = x + self.dropout(self.self_attention(
|
| 163 |
+
x, x, x,
|
| 164 |
+
attn_mask=x_mask
|
| 165 |
+
)[0])
|
| 166 |
+
x, trend1 = self.decomp1(x)
|
| 167 |
+
x = x + self.dropout(self.cross_attention(
|
| 168 |
+
x, cross, cross,
|
| 169 |
+
attn_mask=cross_mask
|
| 170 |
+
)[0])
|
| 171 |
+
x, trend2 = self.decomp2(x)
|
| 172 |
+
y = x
|
| 173 |
+
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
| 174 |
+
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
| 175 |
+
x, trend3 = self.decomp3(x + y)
|
| 176 |
+
|
| 177 |
+
residual_trend = trend1 + trend2 + trend3
|
| 178 |
+
residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2)
|
| 179 |
+
return x, residual_trend
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class Decoder(nn.Module):
|
| 183 |
+
"""
|
| 184 |
+
Autoformer encoder
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
def __init__(self, layers, norm_layer=None, projection=None):
|
| 188 |
+
super(Decoder, self).__init__()
|
| 189 |
+
self.layers = nn.ModuleList(layers)
|
| 190 |
+
self.norm = norm_layer
|
| 191 |
+
self.projection = projection
|
| 192 |
+
|
| 193 |
+
def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
|
| 194 |
+
for layer in self.layers:
|
| 195 |
+
x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
|
| 196 |
+
trend = trend + residual_trend
|
| 197 |
+
|
| 198 |
+
if self.norm is not None:
|
| 199 |
+
x = self.norm(x)
|
| 200 |
+
|
| 201 |
+
if self.projection is not None:
|
| 202 |
+
x = self.projection(x)
|
| 203 |
+
return x, trend
|