lwaekfjlk commited on
Commit
f9d3aeb
·
verified ·
1 Parent(s): 1031f83

Upload Time-Series-Library

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +16 -0
  2. Autoformer.csv +2 -0
  3. CONTRIBUTING.md +20 -0
  4. DLinear.csv +2 -0
  5. Informer.csv +2 -0
  6. LICENSE +21 -0
  7. README.md +173 -0
  8. Reformer.csv +2 -0
  9. data_provider/__init__.py +1 -0
  10. data_provider/calculate_window_len.py +127 -0
  11. data_provider/data_factory.py +88 -0
  12. data_provider/data_loader.py +1029 -0
  13. data_provider/load.py +31 -0
  14. data_provider/m4.py +141 -0
  15. data_provider/uea.py +125 -0
  16. dataset/m4/Daily-test.csv +0 -0
  17. dataset/m4/Daily-train.csv +3 -0
  18. dataset/m4/Hourly-test.csv +0 -0
  19. dataset/m4/Hourly-train.csv +0 -0
  20. dataset/m4/M4-info.csv +0 -0
  21. dataset/m4/Monthly-test.csv +0 -0
  22. dataset/m4/Monthly-train.csv +3 -0
  23. dataset/m4/Quarterly-test.csv +0 -0
  24. dataset/m4/Quarterly-train.csv +3 -0
  25. dataset/m4/Weekly-test.csv +360 -0
  26. dataset/m4/Weekly-train.csv +0 -0
  27. dataset/m4/Yearly-test.csv +0 -0
  28. dataset/m4/Yearly-train.csv +3 -0
  29. dataset/m4/submission-Naive2.csv +3 -0
  30. dataset/m4/test.npz +3 -0
  31. dataset/m4/training.npz +3 -0
  32. dataset/poly/polymarket_data_processed_Crypto_test.jsonl +3 -0
  33. dataset/poly/polymarket_data_processed_Election_test.jsonl +3 -0
  34. dataset/poly/polymarket_data_processed_Other_test.jsonl +3 -0
  35. dataset/poly/polymarket_data_processed_Politics_test.jsonl +3 -0
  36. dataset/poly/polymarket_data_processed_Sports_test.jsonl +3 -0
  37. dataset/poly/polymarket_data_processed_dev.jsonl +3 -0
  38. dataset/poly/polymarket_data_processed_test.jsonl +3 -0
  39. dataset/poly/polymarket_data_processed_train.jsonl +3 -0
  40. exp/__init__.py +0 -0
  41. exp/exp_anomaly_detection.py +207 -0
  42. exp/exp_basic.py +79 -0
  43. exp/exp_classification.py +191 -0
  44. exp/exp_imputation.py +228 -0
  45. exp/exp_long_term_forecasting.py +268 -0
  46. exp/exp_short_term_forecasting.py +302 -0
  47. kalshi_results/Autoformer/Autoformer_results.csv +2 -0
  48. kalshi_results/DLinear/DLinear_results.csv +2 -0
  49. layers/AutoCorrelation.py +163 -0
  50. layers/Autoformer_EncDec.py +203 -0
.gitattributes CHANGED
@@ -33,3 +33,19 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ dataset/m4/Daily-train.csv filter=lfs diff=lfs merge=lfs -text
37
+ dataset/m4/Monthly-train.csv filter=lfs diff=lfs merge=lfs -text
38
+ dataset/m4/Quarterly-train.csv filter=lfs diff=lfs merge=lfs -text
39
+ dataset/m4/Yearly-train.csv filter=lfs diff=lfs merge=lfs -text
40
+ dataset/m4/submission-Naive2.csv filter=lfs diff=lfs merge=lfs -text
41
+ dataset/poly/polymarket_data_processed_Crypto_test.jsonl filter=lfs diff=lfs merge=lfs -text
42
+ dataset/poly/polymarket_data_processed_Election_test.jsonl filter=lfs diff=lfs merge=lfs -text
43
+ dataset/poly/polymarket_data_processed_Other_test.jsonl filter=lfs diff=lfs merge=lfs -text
44
+ dataset/poly/polymarket_data_processed_Politics_test.jsonl filter=lfs diff=lfs merge=lfs -text
45
+ dataset/poly/polymarket_data_processed_Sports_test.jsonl filter=lfs diff=lfs merge=lfs -text
46
+ dataset/poly/polymarket_data_processed_dev.jsonl filter=lfs diff=lfs merge=lfs -text
47
+ dataset/poly/polymarket_data_processed_test.jsonl filter=lfs diff=lfs merge=lfs -text
48
+ dataset/poly/polymarket_data_processed_train.jsonl filter=lfs diff=lfs merge=lfs -text
49
+ pic/dataset.png filter=lfs diff=lfs merge=lfs -text
50
+ tutorial/conv.png filter=lfs diff=lfs merge=lfs -text
51
+ tutorial/fft.png filter=lfs diff=lfs merge=lfs -text
Autoformer.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,Politics_testrmse,Politics_testmae,Sports_testrmse,Sports_testmae,Crypto_testrmse,Crypto_testmae,Election_testrmse,Election_testmae,Other_testrmse,Other_testmae,testrmse,testmae
2
+ 0,0.39830628510682603,0.32073805520677084,0.3818132712824491,0.31051893144830756,0.3828605477887506,0.308403942223605,0.4182624238680359,0.3396595166127513,0.39639385459888354,0.3268426575370005,0.37177661055284744,0.2993196869106722
CONTRIBUTING.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Instructions for Contributing to TSlib
2
+
3
+ Sincerely thanks to all the researchers who want to use or contribute to TSlib.
4
+
5
+ Since our team may not have enough time to fix all the bugs and catch up with the latest model, your contribution is essential to this project.
6
+
7
+ ### (1) Fix Bug
8
+
9
+ You can directly propose a pull request and add detailed descriptions to the comment, such as [this pull request](https://github.com/thuml/Time-Series-Library/pull/498).
10
+
11
+ ### (2) Add a new time series model
12
+
13
+ Thanks to creative researchers, extensive great TS models are presented, which advance this community significantly. If you want to add your model to TSlib, here are some instructions:
14
+
15
+ - Propose an issue to describe your model and give a link to your paper and official code. We will discuss whether your model is suitable for this library, such as [this issue](https://github.com/thuml/Time-Series-Library/issues/346).
16
+ - Propose a pull request in a similar style as TSlib, which means adding an additional file to ./models and providing corresponding scripts for reproduction, such as [this pull request](https://github.com/thuml/Time-Series-Library/pull/446).
17
+
18
+ Note: Given that there are a lot of TS models that have been proposed, we may not have enough time to judge which model can be a remarkable supplement to the current library. Thus, we decide ONLY to add the officially published paper to our library. Peer review can be a reliable criterion.
19
+
20
+ Thanks again for your valuable contributions.
DLinear.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,Politics_testrmse,Politics_testmae,Sports_testrmse,Sports_testmae,Crypto_testrmse,Crypto_testmae,Election_testrmse,Election_testmae,Other_testrmse,Other_testmae,testrmse,testmae
2
+ 0,0.3961972827838396,0.31730998926485693,0.37965161394318636,0.3065756351649381,0.3828125705658356,0.3088032054398474,0.4141024259329349,0.33436751696330935,0.3924656681263207,0.3223706541409388,0.37120005835210995,0.29826033160423604
Informer.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,Politics_testrmse,Politics_testmae,Sports_testrmse,Sports_testmae,Crypto_testrmse,Crypto_testmae,Election_testrmse,Election_testmae,Other_testrmse,Other_testmae,testrmse,testmae
2
+ 0,0.3974158067085375,0.31855139807181687,0.381309328795267,0.30970417269120676,0.38429771883487124,0.3100599430647862,0.4150583272414749,0.33584398433823054,0.3975700532168396,0.3287726753077512,0.3721965853006681,0.2996654640008122
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 THUML @ Tsinghua University
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Time Series Library (TSLib)
2
+ TSLib is an open-source library for deep learning researchers, especially for deep time series analysis.
3
+
4
+ We provide a neat code base to evaluate advanced deep time series models or develop your model, which covers five mainstream tasks: **long- and short-term forecasting, imputation, anomaly detection, and classification.**
5
+
6
+ :triangular_flag_on_post:**News** (2024.10) We have included [[TimeXer]](https://arxiv.org/abs/2402.19072), which defined a practical forecasting paradigm: Forecasting with Exogenous Variables. Considering both practicability and computation efficiency, we believe the new forecasting paradigm defined in TimeXer can be the "right" task for future research.
7
+
8
+ :triangular_flag_on_post:**News** (2024.10) Our lab has open-sourced [[OpenLTM]](https://github.com/thuml/OpenLTM), which provides a distinct pretrain-finetuning paradigm compared to TSLib. If you are interested in Large Time Series Models, you may find this repository helpful.
9
+
10
+ :triangular_flag_on_post:**News** (2024.07) We wrote a comprehensive survey of [[Deep Time Series Models]](https://arxiv.org/abs/2407.13278) with a rigorous benchmark based on TSLib. In this paper, we summarized the design principles of current time series models supported by insightful experiments, hoping to be helpful to future research.
11
+
12
+ :triangular_flag_on_post:**News** (2024.04) Many thanks for the great work from [frecklebars](https://github.com/thuml/Time-Series-Library/pull/378). The famous sequential model [Mamba](https://arxiv.org/abs/2312.00752) has been included in our library. See [this file](https://github.com/thuml/Time-Series-Library/blob/main/models/Mamba.py), where you need to install `mamba_ssm` with pip at first.
13
+
14
+ :triangular_flag_on_post:**News** (2024.03) Given the inconsistent look-back length of various papers, we split the long-term forecasting in the leaderboard into two categories: Look-Back-96 and Look-Back-Searching. We recommend researchers read [TimeMixer](https://openreview.net/pdf?id=7oLshfEIC2), which includes both look-back length settings in experiments for scientific rigor.
15
+
16
+ :triangular_flag_on_post:**News** (2023.10) We add an implementation to [iTransformer](https://arxiv.org/abs/2310.06625), which is the state-of-the-art model for long-term forecasting. The official code and complete scripts of iTransformer can be found [here](https://github.com/thuml/iTransformer).
17
+
18
+ :triangular_flag_on_post:**News** (2023.09) We added a detailed [tutorial](https://github.com/thuml/Time-Series-Library/blob/main/tutorial/TimesNet_tutorial.ipynb) for [TimesNet](https://openreview.net/pdf?id=ju_Uqw384Oq) and this library, which is quite friendly to beginners of deep time series analysis.
19
+
20
+ :triangular_flag_on_post:**News** (2023.02) We release the TSlib as a comprehensive benchmark and code base for time series models, which is extended from our previous GitHub repository [Autoformer](https://github.com/thuml/Autoformer).
21
+
22
+ ## Leaderboard for Time Series Analysis
23
+
24
+ Till March 2024, the top three models for five different tasks are:
25
+
26
+ | Model<br>Ranking | Long-term<br>Forecasting<br>Look-Back-96 | Long-term<br/>Forecasting<br/>Look-Back-Searching | Short-term<br>Forecasting | Imputation | Classification | Anomaly<br>Detection |
27
+ | ---------------- | ----------------------------------------------------- | ----------------------------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | -------------------------------------------------- |
28
+ | 🥇 1st | [TimeXer](https://arxiv.org/abs/2402.19072) | [TimeMixer](https://openreview.net/pdf?id=7oLshfEIC2) | [TimesNet](https://arxiv.org/abs/2210.02186) | [TimesNet](https://arxiv.org/abs/2210.02186) | [TimesNet](https://arxiv.org/abs/2210.02186) | [TimesNet](https://arxiv.org/abs/2210.02186) |
29
+ | 🥈 2nd | [iTransformer](https://arxiv.org/abs/2310.06625) | [PatchTST](https://github.com/yuqinie98/PatchTST) | [Non-stationary<br/>Transformer](https://github.com/thuml/Nonstationary_Transformers) | [Non-stationary<br/>Transformer](https://github.com/thuml/Nonstationary_Transformers) | [Non-stationary<br/>Transformer](https://github.com/thuml/Nonstationary_Transformers) | [FEDformer](https://github.com/MAZiqing/FEDformer) |
30
+ | 🥉 3rd | [TimeMixer](https://openreview.net/pdf?id=7oLshfEIC2) | [DLinear](https://arxiv.org/pdf/2205.13504.pdf) | [FEDformer](https://github.com/MAZiqing/FEDformer) | [Autoformer](https://github.com/thuml/Autoformer) | [Informer](https://github.com/zhouhaoyi/Informer2020) | [Autoformer](https://github.com/thuml/Autoformer) |
31
+
32
+
33
+ **Note: We will keep updating this leaderboard.** If you have proposed advanced and awesome models, you can send us your paper/code link or raise a pull request. We will add them to this repo and update the leaderboard as soon as possible.
34
+
35
+ **Compared models of this leaderboard.** ☑ means that their codes have already been included in this repo.
36
+ - [x] **TimeXer** - TimeXer: Empowering Transformers for Time Series Forecasting with Exogenous Variables [[NeurIPS 2024]](https://arxiv.org/abs/2402.19072) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/TimeXer.py)
37
+ - [x] **TimeMixer** - TimeMixer: Decomposable Multiscale Mixing for Time Series Forecasting [[ICLR 2024]](https://openreview.net/pdf?id=7oLshfEIC2) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/TimeMixer.py).
38
+ - [x] **TSMixer** - TSMixer: An All-MLP Architecture for Time Series Forecasting [[arXiv 2023]](https://arxiv.org/pdf/2303.06053.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/TSMixer.py)
39
+ - [x] **iTransformer** - iTransformer: Inverted Transformers Are Effective for Time Series Forecasting [[ICLR 2024]](https://arxiv.org/abs/2310.06625) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/iTransformer.py).
40
+ - [x] **PatchTST** - A Time Series is Worth 64 Words: Long-term Forecasting with Transformers [[ICLR 2023]](https://openreview.net/pdf?id=Jbdc0vTOcol) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/PatchTST.py).
41
+ - [x] **TimesNet** - TimesNet: Temporal 2D-Variation Modeling for General Time Series Analysis [[ICLR 2023]](https://openreview.net/pdf?id=ju_Uqw384Oq) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/TimesNet.py).
42
+ - [x] **DLinear** - Are Transformers Effective for Time Series Forecasting? [[AAAI 2023]](https://arxiv.org/pdf/2205.13504.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/DLinear.py).
43
+ - [x] **LightTS** - Less Is More: Fast Multivariate Time Series Forecasting with Light Sampling-oriented MLP Structures [[arXiv 2022]](https://arxiv.org/abs/2207.01186) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/LightTS.py).
44
+ - [x] **ETSformer** - ETSformer: Exponential Smoothing Transformers for Time-series Forecasting [[arXiv 2022]](https://arxiv.org/abs/2202.01381) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/ETSformer.py).
45
+ - [x] **Non-stationary Transformer** - Non-stationary Transformers: Exploring the Stationarity in Time Series Forecasting [[NeurIPS 2022]](https://openreview.net/pdf?id=ucNDIDRNjjv) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Nonstationary_Transformer.py).
46
+ - [x] **FEDformer** - FEDformer: Frequency Enhanced Decomposed Transformer for Long-term Series Forecasting [[ICML 2022]](https://proceedings.mlr.press/v162/zhou22g.html) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/FEDformer.py).
47
+ - [x] **Pyraformer** - Pyraformer: Low-complexity Pyramidal Attention for Long-range Time Series Modeling and Forecasting [[ICLR 2022]](https://openreview.net/pdf?id=0EXmFzUn5I) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Pyraformer.py).
48
+ - [x] **Autoformer** - Autoformer: Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting [[NeurIPS 2021]](https://openreview.net/pdf?id=I55UqU-M11y) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Autoformer.py).
49
+ - [x] **Informer** - Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting [[AAAI 2021]](https://ojs.aaai.org/index.php/AAAI/article/view/17325/17132) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Informer.py).
50
+ - [x] **Reformer** - Reformer: The Efficient Transformer [[ICLR 2020]](https://openreview.net/forum?id=rkgNKkHtvB) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Reformer.py).
51
+ - [x] **Transformer** - Attention is All You Need [[NeurIPS 2017]](https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Transformer.py).
52
+
53
+ See our latest paper [[TimesNet]](https://arxiv.org/abs/2210.02186) for the comprehensive benchmark. We will release a real-time updated online version soon.
54
+
55
+ **Newly added baselines.** We will add them to the leaderboard after a comprehensive evaluation.
56
+ - [x] **MultiPatchFormer** - A multiscale model for multivariate time series forecasting [[Scientific Reports 2025]](https://www.nature.com/articles/s41598-024-82417-4) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/MultiPatchFormer.py)
57
+ - [x] **WPMixer** - WPMixer: Efficient Multi-Resolution Mixing for Long-Term Time Series Forecasting [[AAAI 2025]](https://arxiv.org/abs/2412.17176) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/WPMixer.py)
58
+ - [x] **PAttn** - Are Language Models Actually Useful for Time Series Forecasting? [[NeurIPS 2024]](https://arxiv.org/pdf/2406.16964) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/PAttn.py)
59
+ - [x] **Mamba** - Mamba: Linear-Time Sequence Modeling with Selective State Spaces [[arXiv 2023]](https://arxiv.org/abs/2312.00752) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Mamba.py)
60
+ - [x] **SegRNN** - SegRNN: Segment Recurrent Neural Network for Long-Term Time Series Forecasting [[arXiv 2023]](https://arxiv.org/abs/2308.11200.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/SegRNN.py).
61
+ - [x] **Koopa** - Koopa: Learning Non-stationary Time Series Dynamics with Koopman Predictors [[NeurIPS 2023]](https://arxiv.org/pdf/2305.18803.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Koopa.py).
62
+ - [x] **FreTS** - Frequency-domain MLPs are More Effective Learners in Time Series Forecasting [[NeurIPS 2023]](https://arxiv.org/pdf/2311.06184.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/FreTS.py).
63
+ - [x] **MICN** - MICN: Multi-scale Local and Global Context Modeling for Long-term Series Forecasting [[ICLR 2023]](https://openreview.net/pdf?id=zt53IDUR1U)[[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/MICN.py).
64
+ - [x] **Crossformer** - Crossformer: Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting [[ICLR 2023]](https://openreview.net/pdf?id=vSVLM2j9eie)[[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/Crossformer.py).
65
+ - [x] **TiDE** - Long-term Forecasting with TiDE: Time-series Dense Encoder [[arXiv 2023]](https://arxiv.org/pdf/2304.08424.pdf) [[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/TiDE.py).
66
+ - [x] **SCINet** - SCINet: Time Series Modeling and Forecasting with Sample Convolution and Interaction [[NeurIPS 2022]](https://openreview.net/pdf?id=AyajSjTAzmg)[[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/SCINet.py).
67
+ - [x] **FiLM** - FiLM: Frequency improved Legendre Memory Model for Long-term Time Series Forecasting [[NeurIPS 2022]](https://openreview.net/forum?id=zTQdHSQUQWc)[[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/FiLM.py).
68
+ - [x] **TFT** - Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting [[arXiv 2019]](https://arxiv.org/abs/1912.09363)[[Code]](https://github.com/thuml/Time-Series-Library/blob/main/models/TemporalFusionTransformer.py).
69
+
70
+ ## Usage
71
+
72
+ 1. Install Python 3.8. For convenience, execute the following command.
73
+
74
+ ```
75
+ pip install -r requirements.txt
76
+ ```
77
+
78
+ 2. Prepare Data. You can obtain the well pre-processed datasets from [[Google Drive]](https://drive.google.com/drive/folders/13Cg1KYOlzM5C7K8gK8NfC-F3EYxkM3D2?usp=sharing) or [[Baidu Drive]](https://pan.baidu.com/s/1r3KhGd0Q9PJIUZdfEYoymg?pwd=i9iy), Then place the downloaded data in the folder`./dataset`. Here is a summary of supported datasets.
79
+
80
+ <p align="center">
81
+ <img src=".\pic\dataset.png" height = "200" alt="" align=center />
82
+ </p>
83
+
84
+ 3. Train and evaluate model. We provide the experiment scripts for all benchmarks under the folder `./scripts/`. You can reproduce the experiment results as the following examples:
85
+
86
+ ```
87
+ # long-term forecast
88
+ bash ./scripts/long_term_forecast/ETT_script/TimesNet_ETTh1.sh
89
+ # short-term forecast
90
+ bash ./scripts/short_term_forecast/TimesNet_M4.sh
91
+ # imputation
92
+ bash ./scripts/imputation/ETT_script/TimesNet_ETTh1.sh
93
+ # anomaly detection
94
+ bash ./scripts/anomaly_detection/PSM/TimesNet.sh
95
+ # classification
96
+ bash ./scripts/classification/TimesNet.sh
97
+ ```
98
+
99
+ 4. Develop your own model.
100
+
101
+ - Add the model file to the folder `./models`. You can follow the `./models/Transformer.py`.
102
+ - Include the newly added model in the `Exp_Basic.model_dict` of `./exp/exp_basic.py`.
103
+ - Create the corresponding scripts under the folder `./scripts`.
104
+
105
+ Note:
106
+
107
+ (1) About classification: Since we include all five tasks in a unified code base, the accuracy of each subtask may fluctuate but the average performance can be reproduced (even a bit better). We have provided the reproduced checkpoints [here](https://github.com/thuml/Time-Series-Library/issues/494).
108
+
109
+ (2) About anomaly detection: Some discussion about the adjustment strategy in anomaly detection can be found [here](https://github.com/thuml/Anomaly-Transformer/issues/14). The key point is that the adjustment strategy corresponds to an event-level metric.
110
+
111
+ ## Citation
112
+
113
+ If you find this repo useful, please cite our paper.
114
+
115
+ ```
116
+ @inproceedings{wu2023timesnet,
117
+ title={TimesNet: Temporal 2D-Variation Modeling for General Time Series Analysis},
118
+ author={Haixu Wu and Tengge Hu and Yong Liu and Hang Zhou and Jianmin Wang and Mingsheng Long},
119
+ booktitle={International Conference on Learning Representations},
120
+ year={2023},
121
+ }
122
+
123
+ @article{wang2024tssurvey,
124
+ title={Deep Time Series Models: A Comprehensive Survey and Benchmark},
125
+ author={Yuxuan Wang and Haixu Wu and Jiaxiang Dong and Yong Liu and Mingsheng Long and Jianmin Wang},
126
+ booktitle={arXiv preprint arXiv:2407.13278},
127
+ year={2024},
128
+ }
129
+ ```
130
+
131
+ ## Contact
132
+ If you have any questions or suggestions, feel free to contact our maintenance team:
133
+
134
+ Current:
135
+ - Haixu Wu (Ph.D. student, wuhx23@mails.tsinghua.edu.cn)
136
+ - Yong Liu (Ph.D. student, liuyong21@mails.tsinghua.edu.cn)
137
+ - Yuxuan Wang (Ph.D. student, wangyuxu22@mails.tsinghua.edu.cn)
138
+ - Huikun Weng (Undergraduate, wenghk22@mails.tsinghua.edu.cn)
139
+
140
+ Previous:
141
+ - Tengge Hu (Master student, htg21@mails.tsinghua.edu.cn)
142
+ - Haoran Zhang (Master student, z-hr20@mails.tsinghua.edu.cn)
143
+ - Jiawei Guo (Undergraduate, guo-jw21@mails.tsinghua.edu.cn)
144
+
145
+ Or describe it in Issues.
146
+
147
+ ## Acknowledgement
148
+
149
+ This project is supported by the National Key R&D Program of China (2021YFB1715200).
150
+
151
+ This library is constructed based on the following repos:
152
+
153
+ - Forecasting: https://github.com/thuml/Autoformer.
154
+
155
+ - Anomaly Detection: https://github.com/thuml/Anomaly-Transformer.
156
+
157
+ - Classification: https://github.com/thuml/Flowformer.
158
+
159
+ All the experiment datasets are public, and we obtain them from the following links:
160
+
161
+ - Long-term Forecasting and Imputation: https://github.com/thuml/Autoformer.
162
+
163
+ - Short-term Forecasting: https://github.com/ServiceNow/N-BEATS.
164
+
165
+ - Anomaly Detection: https://github.com/thuml/Anomaly-Transformer.
166
+
167
+ - Classification: https://www.timeseriesclassification.com/.
168
+
169
+ ## All Thanks To Our Contributors
170
+
171
+ <a href="https://github.com/thuml/Time-Series-Library/graphs/contributors">
172
+ <img src="https://contrib.rocks/image?repo=thuml/Time-Series-Library" />
173
+ </a>
Reformer.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,Politics_testrmse,Politics_testmae,Sports_testrmse,Sports_testmae,Crypto_testrmse,Crypto_testmae,Election_testrmse,Election_testmae,Other_testrmse,Other_testmae,testrmse,testmae
2
+ 0,0.39759935719372996,0.3176409474228964,0.3804696526717909,0.3091429407280413,0.38329537668949676,0.30926928011712945,0.4144651556154994,0.3346732748303968,0.39416373838022656,0.3253543121596952,0.3719121830741643,0.2993834586683424
data_provider/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
data_provider/calculate_window_len.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import Counter
3
+
4
+ def analyze_window_history_lengths(filepath):
5
+ """
6
+ 统计JSONL文件中所有window_history的长度分布
7
+
8
+ Args:
9
+ filepath: JSONL文件的路径
10
+
11
+ Returns:
12
+ Counter对象,包含长度分布统计
13
+ """
14
+ length_counter = Counter()
15
+ total_records = 0
16
+ total_breakpoints = 0
17
+ total_window_histories = 0
18
+
19
+ try:
20
+ with open(filepath, 'r', encoding='utf-8') as f:
21
+ for line_num, line in enumerate(f, 1):
22
+ line = line.strip()
23
+ if not line:
24
+ continue
25
+
26
+ try:
27
+ data = json.loads(line)
28
+ total_records += 1
29
+
30
+ # 检查是否有daily_breakpoints字段
31
+ if 'daily_breakpoints' not in data:
32
+ print(f"警告: 第{line_num}行没有'daily_breakpoints'字段")
33
+ continue
34
+
35
+ daily_breakpoints = data['daily_breakpoints']
36
+
37
+ # 遍历每个breakpoint
38
+ for i, breakpoint in enumerate(daily_breakpoints):
39
+ total_breakpoints += 1
40
+
41
+ # 检查是否有window_history字段
42
+ if 'window_history' not in breakpoint:
43
+ print(f"警告: 第{line_num}行的breakpoint[{i}]没有'window_history'字段")
44
+ continue
45
+
46
+ window_history = breakpoint['window_history']
47
+
48
+ # 统计长度
49
+ if isinstance(window_history, list):
50
+ length = len(window_history)
51
+ length_counter[length] += 1
52
+ total_window_histories += 1
53
+ else:
54
+ print(f"警告: 第{line_num}行的breakpoint[{i}]的'window_history'不是列表")
55
+
56
+ except json.JSONDecodeError as e:
57
+ print(f"JSON解析错误在第{line_num}行: {e}")
58
+ continue
59
+
60
+ return length_counter, total_records, total_breakpoints, total_window_histories
61
+
62
+ except FileNotFoundError:
63
+ print(f"文件未找到: {filepath}")
64
+ return None, 0, 0, 0
65
+ except Exception as e:
66
+ print(f"发生错误: {e}")
67
+ return None, 0, 0, 0
68
+
69
+ def print_statistics(length_counter, total_records, total_breakpoints, total_window_histories):
70
+ """
71
+ 打印统计结果
72
+ """
73
+ print("\n" + "="*60)
74
+ print("统计摘要:")
75
+ print("="*60)
76
+ print(f"总JSON记录数: {total_records}")
77
+ print(f"总breakpoints数: {total_breakpoints}")
78
+ print(f"总window_history数: {total_window_histories}")
79
+
80
+ if not length_counter:
81
+ print("\n没有找到任何window_history数据")
82
+ return
83
+
84
+ print("\n" + "="*60)
85
+ print("Window History 长度分布:")
86
+ print("="*60)
87
+ print(f"{'长度':<10} {'数量':<10} {'百分比':<10} {'分布图'}")
88
+ print("-"*60)
89
+
90
+ # 按长度排序
91
+ for length in sorted(length_counter.keys()):
92
+ count = length_counter[length]
93
+ percentage = (count / total_window_histories) * 100
94
+ bar = '█' * int(percentage / 2) # 每个█代表2%
95
+ print(f"{length:<10} {count:<10} {percentage:>6.2f}% {bar}")
96
+
97
+ print("\n" + "="*60)
98
+ print("统计信息:")
99
+ print("="*60)
100
+ print(f"最小长度: {min(length_counter.keys())}")
101
+ print(f"最大长度: {max(length_counter.keys())}")
102
+
103
+ # 计算平均长度
104
+ total_length = sum(length * count for length, count in length_counter.items())
105
+ avg_length = total_length / total_window_histories
106
+ print(f"平均长度: {avg_length:.2f}")
107
+
108
+ # 计算中位数
109
+ sorted_lengths = []
110
+ for length, count in sorted(length_counter.items()):
111
+ sorted_lengths.extend([length] * count)
112
+ median_length = sorted_lengths[len(sorted_lengths) // 2]
113
+ print(f"中位数长度: {median_length}")
114
+
115
+ print("="*60)
116
+
117
+ # 使用示例
118
+ if __name__ == "__main__":
119
+ # 替换成你的JSONL文件路径
120
+ filepath = "/data/haofeiy2/social-world-model/data/splitted_polymarket/polymarket_data_processed_with_news_train_2024-11-01.jsonl"
121
+
122
+ print(f"正在分析文件: {filepath}")
123
+ length_counter, total_records, total_breakpoints, total_window_histories = \
124
+ analyze_window_history_lengths(filepath)
125
+
126
+ if length_counter is not None:
127
+ print_statistics(length_counter, total_records, total_breakpoints, total_window_histories)
data_provider/data_factory.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_M4, PSMSegLoader, \
2
+ MSLSegLoader, SMAPSegLoader, SMDSegLoader, SWATSegLoader, UEAloader, Dataset_Poly, Dataset_Kalshi
3
+ from data_provider.uea import collate_fn
4
+ from torch.utils.data import DataLoader
5
+
6
+ data_dict = {
7
+ 'ETTh1': Dataset_ETT_hour,
8
+ 'ETTh2': Dataset_ETT_hour,
9
+ 'ETTm1': Dataset_ETT_minute,
10
+ 'ETTm2': Dataset_ETT_minute,
11
+ 'custom': Dataset_Custom,
12
+ 'm4': Dataset_M4,
13
+ 'poly': Dataset_Poly,
14
+ 'kalshi': Dataset_Kalshi,
15
+ 'PSM': PSMSegLoader,
16
+ 'MSL': MSLSegLoader,
17
+ 'SMAP': SMAPSegLoader,
18
+ 'SMD': SMDSegLoader,
19
+ 'SWAT': SWATSegLoader,
20
+ 'UEA': UEAloader
21
+ }
22
+
23
+
24
+ def data_provider(args, flag):
25
+ Data = data_dict[args.data]
26
+ timeenc = 0 if args.embed != 'timeF' else 1
27
+
28
+ shuffle_flag = False if (flag == 'test' or flag == 'TEST') else True
29
+ drop_last = False
30
+ batch_size = args.batch_size
31
+ freq = args.freq
32
+
33
+ if args.task_name == 'anomaly_detection':
34
+ drop_last = False
35
+ data_set = Data(
36
+ args = args,
37
+ root_path=args.root_path,
38
+ win_size=args.seq_len,
39
+ flag=flag,
40
+ )
41
+ print(flag, len(data_set))
42
+ data_loader = DataLoader(
43
+ data_set,
44
+ batch_size=batch_size,
45
+ shuffle=shuffle_flag,
46
+ num_workers=args.num_workers,
47
+ drop_last=drop_last)
48
+ return data_set, data_loader
49
+ elif args.task_name == 'classification':
50
+ drop_last = False
51
+ data_set = Data(
52
+ args = args,
53
+ root_path=args.root_path,
54
+ flag=flag,
55
+ )
56
+
57
+ data_loader = DataLoader(
58
+ data_set,
59
+ batch_size=batch_size,
60
+ shuffle=shuffle_flag,
61
+ num_workers=args.num_workers,
62
+ drop_last=drop_last,
63
+ collate_fn=lambda x: collate_fn(x, max_len=args.seq_len)
64
+ )
65
+ return data_set, data_loader
66
+ else:
67
+ if args.data == 'm4':
68
+ drop_last = False
69
+ data_set = Data(
70
+ args = args,
71
+ root_path=args.root_path,
72
+ data_path=args.data_path,
73
+ flag=flag,
74
+ size=[args.seq_len, args.label_len, args.pred_len],
75
+ features=args.features,
76
+ target=args.target,
77
+ timeenc=timeenc,
78
+ freq=freq,
79
+ seasonal_patterns=args.seasonal_patterns
80
+ )
81
+ print(flag, len(data_set))
82
+ data_loader = DataLoader(
83
+ data_set,
84
+ batch_size=batch_size,
85
+ shuffle=shuffle_flag,
86
+ num_workers=args.num_workers,
87
+ drop_last=drop_last)
88
+ return data_set, data_loader
data_provider/data_loader.py ADDED
@@ -0,0 +1,1029 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import glob
5
+ import re
6
+ import torch
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from sklearn.preprocessing import StandardScaler
9
+ from utils.timefeatures import time_features
10
+ from data_provider.m4 import M4Dataset, M4Meta
11
+ from data_provider.uea import subsample, interpolate_missing, Normalizer
12
+ from sktime.datasets import load_from_tsfile_to_dataframe
13
+ import warnings
14
+ from utils.augmentation import run_augmentation_single
15
+ import json
16
+
17
+ warnings.filterwarnings('ignore')
18
+
19
+
20
+ class Dataset_ETT_hour(Dataset):
21
+ def __init__(self, args, root_path, flag='train', size=None,
22
+ features='S', data_path='ETTh1.csv',
23
+ target='OT', scale=True, timeenc=0, freq='h', seasonal_patterns=None):
24
+ # size [seq_len, label_len, pred_len]
25
+ self.args = args
26
+ # info
27
+ if size == None:
28
+ self.seq_len = 24 * 4 * 4
29
+ self.label_len = 24 * 4
30
+ self.pred_len = 24 * 4
31
+ else:
32
+ self.seq_len = size[0]
33
+ self.label_len = size[1]
34
+ self.pred_len = size[2]
35
+ # init
36
+ assert flag in ['train', 'test', 'val']
37
+ type_map = {'train': 0, 'val': 1, 'test': 2}
38
+ self.set_type = type_map[flag]
39
+
40
+ self.features = features
41
+ self.target = target
42
+ self.scale = scale
43
+ self.timeenc = timeenc
44
+ self.freq = freq
45
+
46
+ self.root_path = root_path
47
+ self.data_path = data_path
48
+ self.__read_data__()
49
+
50
+ def __read_data__(self):
51
+ self.scaler = StandardScaler()
52
+ df_raw = pd.read_csv(os.path.join(self.root_path,
53
+ self.data_path))
54
+
55
+ border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len]
56
+ border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]
57
+ border1 = border1s[self.set_type]
58
+ border2 = border2s[self.set_type]
59
+
60
+ if self.features == 'M' or self.features == 'MS':
61
+ cols_data = df_raw.columns[1:]
62
+ df_data = df_raw[cols_data]
63
+ elif self.features == 'S':
64
+ df_data = df_raw[[self.target]]
65
+
66
+ if self.scale:
67
+ train_data = df_data[border1s[0]:border2s[0]]
68
+ self.scaler.fit(train_data.values)
69
+ data = self.scaler.transform(df_data.values)
70
+ else:
71
+ data = df_data.values
72
+
73
+ df_stamp = df_raw[['date']][border1:border2]
74
+ df_stamp['date'] = pd.to_datetime(df_stamp.date)
75
+ if self.timeenc == 0:
76
+ df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
77
+ df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
78
+ df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
79
+ df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
80
+ data_stamp = df_stamp.drop(['date'], 1).values
81
+ elif self.timeenc == 1:
82
+ data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
83
+ data_stamp = data_stamp.transpose(1, 0)
84
+
85
+ self.data_x = data[border1:border2]
86
+ self.data_y = data[border1:border2]
87
+
88
+ if self.set_type == 0 and self.args.augmentation_ratio > 0:
89
+ self.data_x, self.data_y, augmentation_tags = run_augmentation_single(self.data_x, self.data_y, self.args)
90
+
91
+ self.data_stamp = data_stamp
92
+
93
+ def __getitem__(self, index):
94
+ s_begin = index
95
+ s_end = s_begin + self.seq_len
96
+ r_begin = s_end - self.label_len
97
+ r_end = r_begin + self.label_len + self.pred_len
98
+
99
+ seq_x = self.data_x[s_begin:s_end]
100
+ seq_y = self.data_y[r_begin:r_end]
101
+ seq_x_mark = self.data_stamp[s_begin:s_end]
102
+ seq_y_mark = self.data_stamp[r_begin:r_end]
103
+
104
+ return seq_x, seq_y, seq_x_mark, seq_y_mark
105
+
106
+ def __len__(self):
107
+ return len(self.data_x) - self.seq_len - self.pred_len + 1
108
+
109
+ def inverse_transform(self, data):
110
+ return self.scaler.inverse_transform(data)
111
+
112
+
113
+ class Dataset_ETT_minute(Dataset):
114
+ def __init__(self, args, root_path, flag='train', size=None,
115
+ features='S', data_path='ETTm1.csv',
116
+ target='OT', scale=True, timeenc=0, freq='t', seasonal_patterns=None):
117
+ # size [seq_len, label_len, pred_len]
118
+ self.args = args
119
+ # info
120
+ if size == None:
121
+ self.seq_len = 24 * 4 * 4
122
+ self.label_len = 24 * 4
123
+ self.pred_len = 24 * 4
124
+ else:
125
+ self.seq_len = size[0]
126
+ self.label_len = size[1]
127
+ self.pred_len = size[2]
128
+ # init
129
+ assert flag in ['train', 'test', 'val']
130
+ type_map = {'train': 0, 'val': 1, 'test': 2}
131
+ self.set_type = type_map[flag]
132
+
133
+ self.features = features
134
+ self.target = target
135
+ self.scale = scale
136
+ self.timeenc = timeenc
137
+ self.freq = freq
138
+
139
+ self.root_path = root_path
140
+ self.data_path = data_path
141
+ self.__read_data__()
142
+
143
+ def __read_data__(self):
144
+ self.scaler = StandardScaler()
145
+ df_raw = pd.read_csv(os.path.join(self.root_path,
146
+ self.data_path))
147
+
148
+ border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len]
149
+ border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4]
150
+ border1 = border1s[self.set_type]
151
+ border2 = border2s[self.set_type]
152
+
153
+ if self.features == 'M' or self.features == 'MS':
154
+ cols_data = df_raw.columns[1:]
155
+ df_data = df_raw[cols_data]
156
+ elif self.features == 'S':
157
+ df_data = df_raw[[self.target]]
158
+
159
+ if self.scale:
160
+ train_data = df_data[border1s[0]:border2s[0]]
161
+ self.scaler.fit(train_data.values)
162
+ data = self.scaler.transform(df_data.values)
163
+ else:
164
+ data = df_data.values
165
+
166
+ df_stamp = df_raw[['date']][border1:border2]
167
+ df_stamp['date'] = pd.to_datetime(df_stamp.date)
168
+ if self.timeenc == 0:
169
+ df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
170
+ df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
171
+ df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
172
+ df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
173
+ df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1)
174
+ df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15)
175
+ data_stamp = df_stamp.drop(['date'], 1).values
176
+ elif self.timeenc == 1:
177
+ data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
178
+ data_stamp = data_stamp.transpose(1, 0)
179
+
180
+ self.data_x = data[border1:border2]
181
+ self.data_y = data[border1:border2]
182
+
183
+ if self.set_type == 0 and self.args.augmentation_ratio > 0:
184
+ self.data_x, self.data_y, augmentation_tags = run_augmentation_single(self.data_x, self.data_y, self.args)
185
+
186
+ self.data_stamp = data_stamp
187
+
188
+ def __getitem__(self, index):
189
+ s_begin = index
190
+ s_end = s_begin + self.seq_len
191
+ r_begin = s_end - self.label_len
192
+ r_end = r_begin + self.label_len + self.pred_len
193
+
194
+ seq_x = self.data_x[s_begin:s_end]
195
+ seq_y = self.data_y[r_begin:r_end]
196
+ seq_x_mark = self.data_stamp[s_begin:s_end]
197
+ seq_y_mark = self.data_stamp[r_begin:r_end]
198
+
199
+ return seq_x, seq_y, seq_x_mark, seq_y_mark
200
+
201
+ def __len__(self):
202
+ return len(self.data_x) - self.seq_len - self.pred_len + 1
203
+
204
+ def inverse_transform(self, data):
205
+ return self.scaler.inverse_transform(data)
206
+
207
+
208
+ class Dataset_Custom(Dataset):
209
+ def __init__(self, args, root_path, flag='train', size=None,
210
+ features='S', data_path='ETTh1.csv',
211
+ target='OT', scale=True, timeenc=0, freq='h', seasonal_patterns=None):
212
+ # size [seq_len, label_len, pred_len]
213
+ self.args = args
214
+ # info
215
+ if size == None:
216
+ self.seq_len = 24 * 4 * 4
217
+ self.label_len = 24 * 4
218
+ self.pred_len = 24 * 4
219
+ else:
220
+ self.seq_len = size[0]
221
+ self.label_len = size[1]
222
+ self.pred_len = size[2]
223
+ # init
224
+ assert flag in ['train', 'test', 'val']
225
+ type_map = {'train': 0, 'val': 1, 'test': 2}
226
+ self.set_type = type_map[flag]
227
+
228
+ self.features = features
229
+ self.target = target
230
+ self.scale = scale
231
+ self.timeenc = timeenc
232
+ self.freq = freq
233
+
234
+ self.root_path = root_path
235
+ self.data_path = data_path
236
+ self.__read_data__()
237
+
238
+ def __read_data__(self):
239
+ self.scaler = StandardScaler()
240
+ df_raw = pd.read_csv(os.path.join(self.root_path,
241
+ self.data_path))
242
+
243
+ '''
244
+ df_raw.columns: ['date', ...(other features), target feature]
245
+ '''
246
+ cols = list(df_raw.columns)
247
+ cols.remove(self.target)
248
+ cols.remove('date')
249
+ df_raw = df_raw[['date'] + cols + [self.target]]
250
+ num_train = int(len(df_raw) * 0.7)
251
+ num_test = int(len(df_raw) * 0.2)
252
+ num_vali = len(df_raw) - num_train - num_test
253
+ border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
254
+ border2s = [num_train, num_train + num_vali, len(df_raw)]
255
+ border1 = border1s[self.set_type]
256
+ border2 = border2s[self.set_type]
257
+
258
+ if self.features == 'M' or self.features == 'MS':
259
+ cols_data = df_raw.columns[1:]
260
+ df_data = df_raw[cols_data]
261
+ elif self.features == 'S':
262
+ df_data = df_raw[[self.target]]
263
+
264
+ if self.scale:
265
+ train_data = df_data[border1s[0]:border2s[0]]
266
+ self.scaler.fit(train_data.values)
267
+ data = self.scaler.transform(df_data.values)
268
+ else:
269
+ data = df_data.values
270
+
271
+ df_stamp = df_raw[['date']][border1:border2]
272
+ df_stamp['date'] = pd.to_datetime(df_stamp.date)
273
+ if self.timeenc == 0:
274
+ df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
275
+ df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
276
+ df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
277
+ df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
278
+ data_stamp = df_stamp.drop(['date'], 1).values
279
+ elif self.timeenc == 1:
280
+ data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
281
+ data_stamp = data_stamp.transpose(1, 0)
282
+
283
+ self.data_x = data[border1:border2]
284
+ self.data_y = data[border1:border2]
285
+
286
+ if self.set_type == 0 and self.args.augmentation_ratio > 0:
287
+ self.data_x, self.data_y, augmentation_tags = run_augmentation_single(self.data_x, self.data_y, self.args)
288
+
289
+ self.data_stamp = data_stamp
290
+
291
+ def __getitem__(self, index):
292
+ s_begin = index
293
+ s_end = s_begin + self.seq_len
294
+ r_begin = s_end - self.label_len
295
+ r_end = r_begin + self.label_len + self.pred_len
296
+
297
+ seq_x = self.data_x[s_begin:s_end]
298
+ seq_y = self.data_y[r_begin:r_end]
299
+ seq_x_mark = self.data_stamp[s_begin:s_end]
300
+ seq_y_mark = self.data_stamp[r_begin:r_end]
301
+
302
+ return seq_x, seq_y, seq_x_mark, seq_y_mark
303
+
304
+ def __len__(self):
305
+ return len(self.data_x) - self.seq_len - self.pred_len + 1
306
+
307
+ def inverse_transform(self, data):
308
+ return self.scaler.inverse_transform(data)
309
+
310
+
311
+ class Dataset_M4(Dataset):
312
+ def __init__(self, args, root_path, flag='pred', size=None,
313
+ features='S', data_path='ETTh1.csv',
314
+ target='OT', scale=False, inverse=False, timeenc=0, freq='15min',
315
+ seasonal_patterns='Yearly'):
316
+ # size [seq_len, label_len, pred_len]
317
+ # init
318
+ self.features = features
319
+ self.target = target
320
+ self.scale = scale
321
+ self.inverse = inverse
322
+ self.timeenc = timeenc
323
+ self.root_path = root_path
324
+
325
+ self.seq_len = size[0]
326
+ self.label_len = size[1]
327
+ self.pred_len = size[2]
328
+
329
+ self.seasonal_patterns = seasonal_patterns
330
+ self.history_size = M4Meta.history_size[seasonal_patterns]
331
+ self.window_sampling_limit = int(self.history_size * self.pred_len)
332
+ self.flag = flag
333
+
334
+ self.__read_data__()
335
+
336
+ def __read_data__(self):
337
+ # M4Dataset.initialize()
338
+ if self.flag == 'train':
339
+ dataset = M4Dataset.load(training=True, dataset_file=self.root_path)
340
+ else:
341
+ dataset = M4Dataset.load(training=False, dataset_file=self.root_path)
342
+ training_values = np.array(
343
+ [v[~np.isnan(v)] for v in
344
+ dataset.values[dataset.groups == self.seasonal_patterns]]) # split different frequencies
345
+ self.ids = np.array([i for i in dataset.ids[dataset.groups == self.seasonal_patterns]])
346
+ self.timeseries = [ts for ts in training_values]
347
+ import pdb
348
+ pdb.set_trace()
349
+
350
+ def __getitem__(self, index):
351
+ insample = np.zeros((self.seq_len, 1))
352
+ insample_mask = np.zeros((self.seq_len, 1))
353
+ outsample = np.zeros((self.pred_len + self.label_len, 1))
354
+ outsample_mask = np.zeros((self.pred_len + self.label_len, 1)) # m4 dataset
355
+
356
+ sampled_timeseries = self.timeseries[index]
357
+ cut_point = np.random.randint(low=max(1, len(sampled_timeseries) - self.window_sampling_limit),
358
+ high=len(sampled_timeseries),
359
+ size=1)[0]
360
+
361
+ insample_window = sampled_timeseries[max(0, cut_point - self.seq_len):cut_point]
362
+ insample[-len(insample_window):, 0] = insample_window
363
+ insample_mask[-len(insample_window):, 0] = 1.0
364
+ outsample_window = sampled_timeseries[
365
+ max(0, cut_point - self.label_len):min(len(sampled_timeseries), cut_point + self.pred_len)]
366
+ outsample[:len(outsample_window), 0] = outsample_window
367
+ outsample_mask[:len(outsample_window), 0] = 1.0
368
+ return insample, outsample, insample_mask, outsample_mask
369
+
370
+ def __len__(self):
371
+ return len(self.timeseries)
372
+
373
+ def inverse_transform(self, data):
374
+ return self.scaler.inverse_transform(data)
375
+
376
+ def last_insample_window(self):
377
+ """
378
+ The last window of insample size of all timeseries.
379
+ This function does not support batching and does not reshuffle timeseries.
380
+
381
+ :return: Last insample window of all timeseries. Shape "timeseries, insample size"
382
+ """
383
+ insample = np.zeros((len(self.timeseries), self.seq_len))
384
+ insample_mask = np.zeros((len(self.timeseries), self.seq_len))
385
+ for i, ts in enumerate(self.timeseries):
386
+ ts_last_window = ts[-self.seq_len:]
387
+ insample[i, -len(ts):] = ts_last_window
388
+ insample_mask[i, -len(ts):] = 1.0
389
+ return insample, insample_mask
390
+
391
+
392
+
393
+
394
+ class Dataset_Poly(Dataset):
395
+ def __init__(self, args, root_path, flag='pred', size=None,
396
+ features='S', data_path='ETTh1.csv',
397
+ target='OT', scale=False, inverse=False, timeenc=0, freq='15min',
398
+ seasonal_patterns='Yearly'):
399
+ self.args = args
400
+ self.features = features
401
+ self.target = target
402
+ self.scale = scale
403
+ self.inverse = inverse
404
+ self.timeenc = timeenc
405
+ self.root_path = root_path
406
+ self.flag = flag
407
+
408
+ # 从 size 或 args 获取参数
409
+ # size = [seq_len, label_len, pred_len]
410
+ if size is not None:
411
+ self.seq_len = size[0] # 建议设为 16
412
+ self.label_len = size[1] # 建议设为 1
413
+ self.pred_len = size[2] # 建议设为 1
414
+ else:
415
+ self.seq_len = getattr(args, 'seq_len', 16)
416
+ self.label_len = getattr(args, 'label_len', 1)
417
+ self.pred_len = getattr(args, 'pred_len', 1)
418
+
419
+ self.__read_data__()
420
+
421
+ def __read_data__(self):
422
+ # 基础路径
423
+ base_path = self.root_path
424
+ category_path = self.root_path
425
+
426
+ # 文件路径映射 (flag -> (目录, 文件名))
427
+ file_map = {
428
+ 'train': (base_path, 'polymarket_data_processed_with_news_train_2025-11-01.jsonl'),
429
+ 'val': (base_path, 'polymarket_data_processed_with_news_train_2025-11-01.jsonl'),
430
+ 'test': (base_path, 'polymarket_data_processed_with_news_test_2025-11-01.jsonl'),
431
+ # 分类别测试集(不同目录)
432
+ 'test_Crypto': (category_path, 'polymarket_data_processed_with_news_test_2025-11-01_crypto.jsonl'),
433
+ 'test_Politics': (category_path, 'polymarket_data_processed_with_news_test_2025-11-01_politics.jsonl'),
434
+ 'test_Election': (category_path, 'polymarket_data_processed_with_news_test_2025-11-01_election.jsonl'),
435
+ }
436
+
437
+ self.timeseries = []
438
+
439
+ # 获取文件路径
440
+ if self.flag in file_map:
441
+ dir_path, file_name = file_map[self.flag]
442
+ file_path = os.path.join(dir_path, file_name)
443
+ else:
444
+ file_path = os.path.join(base_path, f'polymarket_data_processed_with_news_{self.flag}_2025-11-01.jsonl')
445
+
446
+ if not os.path.exists(file_path):
447
+ raise FileNotFoundError(f"Data file not found: {file_path}")
448
+
449
+ all_samples = []
450
+ skipped = 0
451
+ total = 0
452
+
453
+ with open(file_path, 'r') as fcc_file:
454
+ for line in fcc_file:
455
+ obj = json.loads(line)
456
+
457
+ if 'daily_breakpoints' not in obj:
458
+ continue
459
+
460
+ for bp in obj['daily_breakpoints']:
461
+ total += 1
462
+ window_history = bp.get('window_history', [])
463
+
464
+ if len(window_history) < self.seq_len + 1:
465
+ skipped += 1
466
+ continue
467
+
468
+ prices = [e['p'] for e in window_history]
469
+ all_samples.append(prices)
470
+
471
+ # 对 train/val 做 80/20 切分
472
+ if self.flag == 'train':
473
+ split_idx = int(len(all_samples) * 0.8)
474
+ self.timeseries = all_samples[:split_idx]
475
+ elif self.flag == 'val':
476
+ split_idx = int(len(all_samples) * 0.8)
477
+ self.timeseries = all_samples[split_idx:]
478
+ else:
479
+ self.timeseries = all_samples
480
+
481
+ print(f"[{self.flag}] Loaded {len(self.timeseries)} samples from {os.path.basename(file_path)}")
482
+ print(f"[{self.flag}] Skipped {skipped}/{total} (seq_len requirement: {self.seq_len + 1})")
483
+
484
+ def __getitem__(self, index):
485
+ sampled_timeseries = self.timeseries[index]
486
+
487
+ # ========== insample: (seq_len, 1) ==========
488
+ # 取最后 seq_len+1 个点,前 seq_len 个作为输入
489
+ # 例如:window_history 长度 17,取 [-17:-1] 共 16 个点作为输入
490
+ insample = np.zeros((self.seq_len, 1), dtype=np.float32)
491
+ insample_mask = np.zeros((self.seq_len, 1), dtype=np.float32)
492
+
493
+ # 输入:倒数第 seq_len+1 到倒数第 2 个点(不含 after)
494
+ input_prices = sampled_timeseries[-(self.seq_len + 1):-1]
495
+ insample[:, 0] = input_prices
496
+ insample_mask[:, 0] = 1.0
497
+
498
+ # ========== outsample: (label_len + pred_len, 1) ==========
499
+ # label_len=1 对应 before 点,pred_len=1 对应 after 点
500
+ outsample = np.zeros((self.label_len + self.pred_len, 1), dtype=np.float32)
501
+ outsample_mask = np.zeros((self.label_len + self.pred_len, 1), dtype=np.float32)
502
+
503
+ # outsample = [最后 label_len+pred_len 个点]
504
+ # 即 [before, after] 当 label_len=1, pred_len=1
505
+ outsample[:, 0] = sampled_timeseries[-(self.label_len + self.pred_len):]
506
+ outsample_mask[:, 0] = 1.0
507
+
508
+ return insample, outsample, insample_mask, outsample_mask
509
+
510
+ def __len__(self):
511
+ return len(self.timeseries)
512
+
513
+ def inverse_transform(self, data):
514
+ if hasattr(self, 'scaler') and self.scaler is not None:
515
+ return self.scaler.inverse_transform(data)
516
+ return data
517
+
518
+ def last_insample_window(self):
519
+ """用于推理时获取所有时间序列的最后输入窗口"""
520
+ insample = np.zeros((len(self.timeseries), self.seq_len), dtype=np.float32)
521
+ insample_mask = np.zeros((len(self.timeseries), self.seq_len), dtype=np.float32)
522
+
523
+ for i, ts in enumerate(self.timeseries):
524
+ input_prices = ts[-(self.seq_len + 1):-1]
525
+ insample[i, :] = input_prices
526
+ insample_mask[i, :] = 1.0
527
+
528
+ return insample, insample_mask
529
+
530
+
531
+ class PSMSegLoader(Dataset):
532
+ def __init__(self, args, root_path, win_size, step=1, flag="train"):
533
+ self.flag = flag
534
+ self.step = step
535
+ self.win_size = win_size
536
+ self.scaler = StandardScaler()
537
+ data = pd.read_csv(os.path.join(root_path, 'train.csv'))
538
+ data = data.values[:, 1:]
539
+ data = np.nan_to_num(data)
540
+ self.scaler.fit(data)
541
+ data = self.scaler.transform(data)
542
+ test_data = pd.read_csv(os.path.join(root_path, 'test.csv'))
543
+ test_data = test_data.values[:, 1:]
544
+ test_data = np.nan_to_num(test_data)
545
+ self.test = self.scaler.transform(test_data)
546
+ self.train = data
547
+ data_len = len(self.train)
548
+ self.val = self.train[(int)(data_len * 0.8):]
549
+ self.test_labels = pd.read_csv(os.path.join(root_path, 'test_label.csv')).values[:, 1:]
550
+ print("test:", self.test.shape)
551
+ print("train:", self.train.shape)
552
+
553
+ def __len__(self):
554
+ if self.flag == "train":
555
+ return (self.train.shape[0] - self.win_size) // self.step + 1
556
+ elif (self.flag == 'val'):
557
+ return (self.val.shape[0] - self.win_size) // self.step + 1
558
+ elif (self.flag == 'test'):
559
+ return (self.test.shape[0] - self.win_size) // self.step + 1
560
+ else:
561
+ return (self.test.shape[0] - self.win_size) // self.win_size + 1
562
+
563
+ def __getitem__(self, index):
564
+ index = index * self.step
565
+ if self.flag == "train":
566
+ return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
567
+ elif (self.flag == 'val'):
568
+ return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
569
+ elif (self.flag == 'test'):
570
+ return np.float32(self.test[index:index + self.win_size]), np.float32(
571
+ self.test_labels[index:index + self.win_size])
572
+ else:
573
+ return np.float32(self.test[
574
+ index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
575
+ self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])
576
+
577
+
578
+ class MSLSegLoader(Dataset):
579
+ def __init__(self, args, root_path, win_size, step=1, flag="train"):
580
+ self.flag = flag
581
+ self.step = step
582
+ self.win_size = win_size
583
+ self.scaler = StandardScaler()
584
+ data = np.load(os.path.join(root_path, "MSL_train.npy"))
585
+ self.scaler.fit(data)
586
+ data = self.scaler.transform(data)
587
+ test_data = np.load(os.path.join(root_path, "MSL_test.npy"))
588
+ self.test = self.scaler.transform(test_data)
589
+ self.train = data
590
+ data_len = len(self.train)
591
+ self.val = self.train[(int)(data_len * 0.8):]
592
+ self.test_labels = np.load(os.path.join(root_path, "MSL_test_label.npy"))
593
+ print("test:", self.test.shape)
594
+ print("train:", self.train.shape)
595
+
596
+ def __len__(self):
597
+ if self.flag == "train":
598
+ return (self.train.shape[0] - self.win_size) // self.step + 1
599
+ elif (self.flag == 'val'):
600
+ return (self.val.shape[0] - self.win_size) // self.step + 1
601
+ elif (self.flag == 'test'):
602
+ return (self.test.shape[0] - self.win_size) // self.step + 1
603
+ else:
604
+ return (self.test.shape[0] - self.win_size) // self.win_size + 1
605
+
606
+ def __getitem__(self, index):
607
+ index = index * self.step
608
+ if self.flag == "train":
609
+ return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
610
+ elif (self.flag == 'val'):
611
+ return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
612
+ elif (self.flag == 'test'):
613
+ return np.float32(self.test[index:index + self.win_size]), np.float32(
614
+ self.test_labels[index:index + self.win_size])
615
+ else:
616
+ return np.float32(self.test[
617
+ index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
618
+ self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])
619
+
620
+
621
+ class SMAPSegLoader(Dataset):
622
+ def __init__(self, args, root_path, win_size, step=1, flag="train"):
623
+ self.flag = flag
624
+ self.step = step
625
+ self.win_size = win_size
626
+ self.scaler = StandardScaler()
627
+ data = np.load(os.path.join(root_path, "SMAP_train.npy"))
628
+ self.scaler.fit(data)
629
+ data = self.scaler.transform(data)
630
+ test_data = np.load(os.path.join(root_path, "SMAP_test.npy"))
631
+ self.test = self.scaler.transform(test_data)
632
+ self.train = data
633
+ data_len = len(self.train)
634
+ self.val = self.train[(int)(data_len * 0.8):]
635
+ self.test_labels = np.load(os.path.join(root_path, "SMAP_test_label.npy"))
636
+ print("test:", self.test.shape)
637
+ print("train:", self.train.shape)
638
+
639
+ def __len__(self):
640
+
641
+ if self.flag == "train":
642
+ return (self.train.shape[0] - self.win_size) // self.step + 1
643
+ elif (self.flag == 'val'):
644
+ return (self.val.shape[0] - self.win_size) // self.step + 1
645
+ elif (self.flag == 'test'):
646
+ return (self.test.shape[0] - self.win_size) // self.step + 1
647
+ else:
648
+ return (self.test.shape[0] - self.win_size) // self.win_size + 1
649
+
650
+ def __getitem__(self, index):
651
+ index = index * self.step
652
+ if self.flag == "train":
653
+ return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
654
+ elif (self.flag == 'val'):
655
+ return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
656
+ elif (self.flag == 'test'):
657
+ return np.float32(self.test[index:index + self.win_size]), np.float32(
658
+ self.test_labels[index:index + self.win_size])
659
+ else:
660
+ return np.float32(self.test[
661
+ index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
662
+ self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])
663
+
664
+
665
+ class SMDSegLoader(Dataset):
666
+ def __init__(self, args, root_path, win_size, step=100, flag="train"):
667
+ self.flag = flag
668
+ self.step = step
669
+ self.win_size = win_size
670
+ self.scaler = StandardScaler()
671
+ data = np.load(os.path.join(root_path, "SMD_train.npy"))
672
+ self.scaler.fit(data)
673
+ data = self.scaler.transform(data)
674
+ test_data = np.load(os.path.join(root_path, "SMD_test.npy"))
675
+ self.test = self.scaler.transform(test_data)
676
+ self.train = data
677
+ data_len = len(self.train)
678
+ self.val = self.train[(int)(data_len * 0.8):]
679
+ self.test_labels = np.load(os.path.join(root_path, "SMD_test_label.npy"))
680
+
681
+ def __len__(self):
682
+ if self.flag == "train":
683
+ return (self.train.shape[0] - self.win_size) // self.step + 1
684
+ elif (self.flag == 'val'):
685
+ return (self.val.shape[0] - self.win_size) // self.step + 1
686
+ elif (self.flag == 'test'):
687
+ return (self.test.shape[0] - self.win_size) // self.step + 1
688
+ else:
689
+ return (self.test.shape[0] - self.win_size) // self.win_size + 1
690
+
691
+ def __getitem__(self, index):
692
+ index = index * self.step
693
+ if self.flag == "train":
694
+ return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
695
+ elif (self.flag == 'val'):
696
+ return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
697
+ elif (self.flag == 'test'):
698
+ return np.float32(self.test[index:index + self.win_size]), np.float32(
699
+ self.test_labels[index:index + self.win_size])
700
+ else:
701
+ return np.float32(self.test[
702
+ index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
703
+ self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])
704
+
705
+
706
+ class SWATSegLoader(Dataset):
707
+ def __init__(self, args, root_path, win_size, step=1, flag="train"):
708
+ self.flag = flag
709
+ self.step = step
710
+ self.win_size = win_size
711
+ self.scaler = StandardScaler()
712
+
713
+ train_data = pd.read_csv(os.path.join(root_path, 'swat_train2.csv'))
714
+ test_data = pd.read_csv(os.path.join(root_path, 'swat2.csv'))
715
+ labels = test_data.values[:, -1:]
716
+ train_data = train_data.values[:, :-1]
717
+ test_data = test_data.values[:, :-1]
718
+
719
+ self.scaler.fit(train_data)
720
+ train_data = self.scaler.transform(train_data)
721
+ test_data = self.scaler.transform(test_data)
722
+ self.train = train_data
723
+ self.test = test_data
724
+ data_len = len(self.train)
725
+ self.val = self.train[(int)(data_len * 0.8):]
726
+ self.test_labels = labels
727
+ print("test:", self.test.shape)
728
+ print("train:", self.train.shape)
729
+
730
+ def __len__(self):
731
+ """
732
+ Number of images in the object dataset.
733
+ """
734
+ if self.flag == "train":
735
+ return (self.train.shape[0] - self.win_size) // self.step + 1
736
+ elif (self.flag == 'val'):
737
+ return (self.val.shape[0] - self.win_size) // self.step + 1
738
+ elif (self.flag == 'test'):
739
+ return (self.test.shape[0] - self.win_size) // self.step + 1
740
+ else:
741
+ return (self.test.shape[0] - self.win_size) // self.win_size + 1
742
+
743
+ def __getitem__(self, index):
744
+ index = index * self.step
745
+ if self.flag == "train":
746
+ return np.float32(self.train[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
747
+ elif (self.flag == 'val'):
748
+ return np.float32(self.val[index:index + self.win_size]), np.float32(self.test_labels[0:self.win_size])
749
+ elif (self.flag == 'test'):
750
+ return np.float32(self.test[index:index + self.win_size]), np.float32(
751
+ self.test_labels[index:index + self.win_size])
752
+ else:
753
+ return np.float32(self.test[
754
+ index // self.step * self.win_size:index // self.step * self.win_size + self.win_size]), np.float32(
755
+ self.test_labels[index // self.step * self.win_size:index // self.step * self.win_size + self.win_size])
756
+
757
+
758
+ class UEAloader(Dataset):
759
+ """
760
+ Dataset class for datasets included in:
761
+ Time Series Classification Archive (www.timeseriesclassification.com)
762
+ Argument:
763
+ limit_size: float in (0, 1) for debug
764
+ Attributes:
765
+ all_df: (num_samples * seq_len, num_columns) dataframe indexed by integer indices, with multiple rows corresponding to the same index (sample).
766
+ Each row is a time step; Each column contains either metadata (e.g. timestamp) or a feature.
767
+ feature_df: (num_samples * seq_len, feat_dim) dataframe; contains the subset of columns of `all_df` which correspond to selected features
768
+ feature_names: names of columns contained in `feature_df` (same as feature_df.columns)
769
+ all_IDs: (num_samples,) series of IDs contained in `all_df`/`feature_df` (same as all_df.index.unique() )
770
+ labels_df: (num_samples, num_labels) pd.DataFrame of label(s) for each sample
771
+ max_seq_len: maximum sequence (time series) length. If None, script argument `max_seq_len` will be used.
772
+ (Moreover, script argument overrides this attribute)
773
+ """
774
+
775
+ def __init__(self, args, root_path, file_list=None, limit_size=None, flag=None):
776
+ self.args = args
777
+ self.root_path = root_path
778
+ self.flag = flag
779
+ self.all_df, self.labels_df = self.load_all(root_path, file_list=file_list, flag=flag)
780
+ self.all_IDs = self.all_df.index.unique() # all sample IDs (integer indices 0 ... num_samples-1)
781
+
782
+ if limit_size is not None:
783
+ if limit_size > 1:
784
+ limit_size = int(limit_size)
785
+ else: # interpret as proportion if in (0, 1]
786
+ limit_size = int(limit_size * len(self.all_IDs))
787
+ self.all_IDs = self.all_IDs[:limit_size]
788
+ self.all_df = self.all_df.loc[self.all_IDs]
789
+
790
+ # use all features
791
+ self.feature_names = self.all_df.columns
792
+ self.feature_df = self.all_df
793
+
794
+ # pre_process
795
+ normalizer = Normalizer()
796
+ self.feature_df = normalizer.normalize(self.feature_df)
797
+ print(len(self.all_IDs))
798
+
799
+ def load_all(self, root_path, file_list=None, flag=None):
800
+ """
801
+ Loads datasets from ts files contained in `root_path` into a dataframe, optionally choosing from `pattern`
802
+ Args:
803
+ root_path: directory containing all individual .ts files
804
+ file_list: optionally, provide a list of file paths within `root_path` to consider.
805
+ Otherwise, entire `root_path` contents will be used.
806
+ Returns:
807
+ all_df: a single (possibly concatenated) dataframe with all data corresponding to specified files
808
+ labels_df: dataframe containing label(s) for each sample
809
+ """
810
+ # Select paths for training and evaluation
811
+ if file_list is None:
812
+ data_paths = glob.glob(os.path.join(root_path, '*')) # list of all paths
813
+ else:
814
+ data_paths = [os.path.join(root_path, p) for p in file_list]
815
+ if len(data_paths) == 0:
816
+ raise Exception('No files found using: {}'.format(os.path.join(root_path, '*')))
817
+ if flag is not None:
818
+ data_paths = list(filter(lambda x: re.search(flag, x), data_paths))
819
+ input_paths = [p for p in data_paths if os.path.isfile(p) and p.endswith('.ts')]
820
+ if len(input_paths) == 0:
821
+ pattern='*.ts'
822
+ raise Exception("No .ts files found using pattern: '{}'".format(pattern))
823
+
824
+ all_df, labels_df = self.load_single(input_paths[0]) # a single file contains dataset
825
+
826
+ return all_df, labels_df
827
+
828
+ def load_single(self, filepath):
829
+ df, labels = load_from_tsfile_to_dataframe(filepath, return_separate_X_and_y=True,
830
+ replace_missing_vals_with='NaN')
831
+ labels = pd.Series(labels, dtype="category")
832
+ self.class_names = labels.cat.categories
833
+ labels_df = pd.DataFrame(labels.cat.codes,
834
+ dtype=np.int8) # int8-32 gives an error when using nn.CrossEntropyLoss
835
+
836
+ lengths = df.applymap(
837
+ lambda x: len(x)).values # (num_samples, num_dimensions) array containing the length of each series
838
+
839
+ horiz_diffs = np.abs(lengths - np.expand_dims(lengths[:, 0], -1))
840
+
841
+ if np.sum(horiz_diffs) > 0: # if any row (sample) has varying length across dimensions
842
+ df = df.applymap(subsample)
843
+
844
+ lengths = df.applymap(lambda x: len(x)).values
845
+ vert_diffs = np.abs(lengths - np.expand_dims(lengths[0, :], 0))
846
+ if np.sum(vert_diffs) > 0: # if any column (dimension) has varying length across samples
847
+ self.max_seq_len = int(np.max(lengths[:, 0]))
848
+ else:
849
+ self.max_seq_len = lengths[0, 0]
850
+
851
+ # First create a (seq_len, feat_dim) dataframe for each sample, indexed by a single integer ("ID" of the sample)
852
+ # Then concatenate into a (num_samples * seq_len, feat_dim) dataframe, with multiple rows corresponding to the
853
+ # sample index (i.e. the same scheme as all datasets in this project)
854
+
855
+ df = pd.concat((pd.DataFrame({col: df.loc[row, col] for col in df.columns}).reset_index(drop=True).set_index(
856
+ pd.Series(lengths[row, 0] * [row])) for row in range(df.shape[0])), axis=0)
857
+
858
+ # Replace NaN values
859
+ grp = df.groupby(by=df.index)
860
+ df = grp.transform(interpolate_missing)
861
+
862
+ return df, labels_df
863
+
864
+ def instance_norm(self, case):
865
+ if self.root_path.count('EthanolConcentration') > 0: # special process for numerical stability
866
+ mean = case.mean(0, keepdim=True)
867
+ case = case - mean
868
+ stdev = torch.sqrt(torch.var(case, dim=1, keepdim=True, unbiased=False) + 1e-5)
869
+ case /= stdev
870
+ return case
871
+ else:
872
+ return case
873
+
874
+ def __getitem__(self, ind):
875
+ batch_x = self.feature_df.loc[self.all_IDs[ind]].values
876
+ labels = self.labels_df.loc[self.all_IDs[ind]].values
877
+ if self.flag == "TRAIN" and self.args.augmentation_ratio > 0:
878
+ num_samples = len(self.all_IDs)
879
+ num_columns = self.feature_df.shape[1]
880
+ seq_len = int(self.feature_df.shape[0] / num_samples)
881
+ batch_x = batch_x.reshape((1, seq_len, num_columns))
882
+ batch_x, labels, augmentation_tags = run_augmentation_single(batch_x, labels, self.args)
883
+
884
+ batch_x = batch_x.reshape((1 * seq_len, num_columns))
885
+
886
+ return self.instance_norm(torch.from_numpy(batch_x)), \
887
+ torch.from_numpy(labels)
888
+
889
+ def __len__(self):
890
+ return len(self.all_IDs)
891
+
892
+
893
+ class Dataset_Kalshi(Dataset):
894
+ def __init__(self, args, root_path, flag='pred', size=None,
895
+ features='S', data_path='ETTh1.csv',
896
+ target='OT', scale=False, inverse=False, timeenc=0, freq='15min',
897
+ seasonal_patterns='Yearly'):
898
+ self.args = args
899
+ self.features = features
900
+ self.target = target
901
+ self.scale = scale
902
+ self.inverse = inverse
903
+ self.timeenc = timeenc
904
+ self.root_path = root_path
905
+ self.flag = flag
906
+
907
+ # 从 size 或 args 获取参数
908
+ # size = [seq_len, label_len, pred_len]
909
+ if size is not None:
910
+ self.seq_len = size[0] # 建议设为 16
911
+ self.label_len = size[1] # 建议设为 1
912
+ self.pred_len = size[2] # 建议设为 1
913
+ else:
914
+ self.seq_len = getattr(args, 'seq_len', 16)
915
+ self.label_len = getattr(args, 'label_len', 1)
916
+ self.pred_len = getattr(args, 'pred_len', 1)
917
+
918
+ self.__read_data__()
919
+
920
+ def __read_data__(self):
921
+ # 基础路径
922
+ base_path = self.root_path
923
+ category_path = base_path
924
+
925
+ # 文件路径映射 (flag -> (目录, 文件名))
926
+ file_map = {
927
+ 'train': (base_path, 'kalshi_data_processed_with_news_train_2025-11-01.jsonl'),
928
+ 'val': (base_path, 'kalshi_data_processed_with_news_train_2025-11-01.jsonl'),
929
+ 'test': (base_path, 'kalshi_data_processed_with_news_test_2025-11-01.jsonl'),
930
+ # 分类别测试集(不同目录)
931
+ 'test_Companies': (category_path, 'kalshi_data_processed_with_news_test_2025-11-01_companies.jsonl'),
932
+ 'test_Economics': (category_path, 'kalshi_data_processed_with_news_test_2025-11-01_economics.jsonl'),
933
+ 'test_Entertainment': (category_path, 'kalshi_data_processed_with_news_test_2025-11-01_entertainment.jsonl'),
934
+ 'test_Mentions': (category_path, 'kalshi_data_processed_with_news_test_2025-11-01_mentions.jsonl'),
935
+ 'test_Politics': (category_path, 'kalshi_data_processed_with_news_test_2025-11-01_politics.jsonl'),
936
+ }
937
+
938
+ self.timeseries = []
939
+
940
+ # 获取文件路径
941
+ if self.flag in file_map:
942
+ dir_path, file_name = file_map[self.flag]
943
+ file_path = os.path.join(dir_path, file_name)
944
+ else:
945
+ file_path = os.path.join(base_path, f'kalshi_data_processed_with_news_{self.flag}_2024-11-01.jsonl')
946
+
947
+ if not os.path.exists(file_path):
948
+ raise FileNotFoundError(f"Data file not found: {file_path}")
949
+
950
+ all_samples = []
951
+ skipped = 0
952
+ total = 0
953
+
954
+ with open(file_path, 'r') as fcc_file:
955
+ for line in fcc_file:
956
+ obj = json.loads(line)
957
+
958
+ if 'daily_breakpoints' not in obj:
959
+ continue
960
+
961
+ for bp in obj['daily_breakpoints']:
962
+ total += 1
963
+ window_history = bp.get('window_history', [])
964
+
965
+ if len(window_history) < self.seq_len + 1:
966
+ skipped += 1
967
+ continue
968
+
969
+ prices = [e['p'] for e in window_history]
970
+ all_samples.append(prices)
971
+
972
+ # 对 train/val 做 80/20 切分
973
+ if self.flag == 'train':
974
+ split_idx = int(len(all_samples) * 0.8)
975
+ self.timeseries = all_samples[:split_idx]
976
+ elif self.flag == 'val':
977
+ split_idx = int(len(all_samples) * 0.8)
978
+ self.timeseries = all_samples[split_idx:]
979
+ else:
980
+ self.timeseries = all_samples
981
+
982
+ print(f"[{self.flag}] Loaded {len(self.timeseries)} samples from {os.path.basename(file_path)}")
983
+ print(f"[{self.flag}] Skipped {skipped}/{total} (seq_len requirement: {self.seq_len + 1})")
984
+
985
+ def __getitem__(self, index):
986
+ sampled_timeseries = self.timeseries[index]
987
+
988
+ # ========== insample: (seq_len, 1) ==========
989
+ # 取最后 seq_len+1 个点,前 seq_len 个作为输入
990
+ # 例如:window_history 长度 17,取 [-17:-1] 共 16 个点作为输入
991
+ insample = np.zeros((self.seq_len, 1), dtype=np.float32)
992
+ insample_mask = np.zeros((self.seq_len, 1), dtype=np.float32)
993
+
994
+ # 输入:倒数第 seq_len+1 到倒数第 2 个点(不含 after)
995
+ input_prices = sampled_timeseries[-(self.seq_len + 1):-1]
996
+ insample[:, 0] = input_prices
997
+ insample_mask[:, 0] = 1.0
998
+
999
+ # ========== outsample: (label_len + pred_len, 1) ==========
1000
+ # label_len=1 对应 before 点,pred_len=1 对应 after 点
1001
+ outsample = np.zeros((self.label_len + self.pred_len, 1), dtype=np.float32)
1002
+ outsample_mask = np.zeros((self.label_len + self.pred_len, 1), dtype=np.float32)
1003
+
1004
+ # outsample = [最后 label_len+pred_len 个点]
1005
+ # 即 [before, after] 当 label_len=1, pred_len=1
1006
+ outsample[:, 0] = sampled_timeseries[-(self.label_len + self.pred_len):]
1007
+ outsample_mask[:, 0] = 1.0
1008
+
1009
+ return insample, outsample, insample_mask, outsample_mask
1010
+
1011
+ def __len__(self):
1012
+ return len(self.timeseries)
1013
+
1014
+ def inverse_transform(self, data):
1015
+ if hasattr(self, 'scaler') and self.scaler is not None:
1016
+ return self.scaler.inverse_transform(data)
1017
+ return data
1018
+
1019
+ def last_insample_window(self):
1020
+ """用于推理时获取所有时间序列的最后输入窗口"""
1021
+ insample = np.zeros((len(self.timeseries), self.seq_len), dtype=np.float32)
1022
+ insample_mask = np.zeros((len(self.timeseries), self.seq_len), dtype=np.float32)
1023
+
1024
+ for i, ts in enumerate(self.timeseries):
1025
+ input_prices = ts[-(self.seq_len + 1):-1]
1026
+ insample[i, :] = input_prices
1027
+ insample_mask[i, :] = 1.0
1028
+
1029
+ return insample, insample_mask
data_provider/load.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ def load_first_jsonl_record(filepath):
4
+ try:
5
+ with open(filepath, 'r', encoding='utf-8') as f:
6
+ first_line = f.readline().strip()
7
+ if first_line:
8
+ return json.loads(first_line)
9
+ else:
10
+ print("文件为空")
11
+ return None
12
+ except FileNotFoundError:
13
+ print(f"文件未找到: {filepath}")
14
+ return None
15
+ except json.JSONDecodeError as e:
16
+ print(f"JSON解析错误: {e}")
17
+ return None
18
+ except Exception as e:
19
+ print(f"发生错误: {e}")
20
+ return None
21
+
22
+ # 使用示例
23
+ if __name__ == "__main__":
24
+ filepath = "/data/haofeiy2/social-world-model/data/splitted_polymarket/polymarket_data_processed_with_news_train_2024-11-01.jsonl"
25
+
26
+ first_record = load_first_jsonl_record(filepath)
27
+ breakpoint()
28
+
29
+ if first_record:
30
+ print("第一条记录:")
31
+ print(json.dumps(first_record, ensure_ascii=False, indent=2))
data_provider/m4.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is provided for the purposes of scientific reproducibility
2
+ # under the following limited license from Element AI Inc. The code is an
3
+ # implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis
4
+ # expansion analysis for interpretable time series forecasting,
5
+ # https://arxiv.org/abs/1905.10437). The copyright to the source code is
6
+ # licensed under the Creative Commons - Attribution-NonCommercial 4.0
7
+ # International license (CC BY-NC 4.0):
8
+ # https://creativecommons.org/licenses/by-nc/4.0/. Any commercial use (whether
9
+ # for the benefit of third parties or internally in production) requires an
10
+ # explicit license. The subject-matter of the N-BEATS model and associated
11
+ # materials are the property of Element AI Inc. and may be subject to patent
12
+ # protection. No license to patents is granted hereunder (whether express or
13
+ # implied). Copyright © 2020 Element AI Inc. All rights reserved.
14
+
15
+ """
16
+ M4 Dataset
17
+ """
18
+ import logging
19
+ import os
20
+ from collections import OrderedDict
21
+ from dataclasses import dataclass
22
+ from glob import glob
23
+
24
+ import numpy as np
25
+ import pandas as pd
26
+ import patoolib
27
+ from tqdm import tqdm
28
+ import logging
29
+ import os
30
+ import pathlib
31
+ import sys
32
+ from urllib import request
33
+
34
+
35
+ def url_file_name(url: str) -> str:
36
+ """
37
+ Extract file name from url.
38
+
39
+ :param url: URL to extract file name from.
40
+ :return: File name.
41
+ """
42
+ return url.split('/')[-1] if len(url) > 0 else ''
43
+
44
+
45
+ def download(url: str, file_path: str) -> None:
46
+ """
47
+ Download a file to the given path.
48
+
49
+ :param url: URL to download
50
+ :param file_path: Where to download the content.
51
+ """
52
+
53
+ def progress(count, block_size, total_size):
54
+ progress_pct = float(count * block_size) / float(total_size) * 100.0
55
+ sys.stdout.write('\rDownloading {} to {} {:.1f}%'.format(url, file_path, progress_pct))
56
+ sys.stdout.flush()
57
+
58
+ if not os.path.isfile(file_path):
59
+ opener = request.build_opener()
60
+ opener.addheaders = [('User-agent', 'Mozilla/5.0')]
61
+ request.install_opener(opener)
62
+ pathlib.Path(os.path.dirname(file_path)).mkdir(parents=True, exist_ok=True)
63
+ f, _ = request.urlretrieve(url, file_path, progress)
64
+ sys.stdout.write('\n')
65
+ sys.stdout.flush()
66
+ file_info = os.stat(f)
67
+ logging.info(f'Successfully downloaded {os.path.basename(file_path)} {file_info.st_size} bytes.')
68
+ else:
69
+ file_info = os.stat(file_path)
70
+ logging.info(f'File already exists: {file_path} {file_info.st_size} bytes.')
71
+
72
+
73
+ @dataclass()
74
+ class M4Dataset:
75
+ ids: np.ndarray
76
+ groups: np.ndarray
77
+ frequencies: np.ndarray
78
+ horizons: np.ndarray
79
+ values: np.ndarray
80
+
81
+ @staticmethod
82
+ def load(training: bool = True, dataset_file: str = '../dataset/m4') -> 'M4Dataset':
83
+ """
84
+ Load cached dataset.
85
+
86
+ :param training: Load training part if training is True, test part otherwise.
87
+ """
88
+ info_file = os.path.join(dataset_file, 'M4-info.csv')
89
+ train_cache_file = os.path.join(dataset_file, 'training.npz')
90
+ test_cache_file = os.path.join(dataset_file, 'test.npz')
91
+ m4_info = pd.read_csv(info_file)
92
+ m4dataset = M4Dataset(ids=m4_info.M4id.values,
93
+ groups=m4_info.SP.values,
94
+ frequencies=m4_info.Frequency.values,
95
+ horizons=m4_info.Horizon.values,
96
+ values=np.load(
97
+ train_cache_file if training else test_cache_file,
98
+ allow_pickle=True))
99
+ # import pdb
100
+ # pdb.set_trace()
101
+ return m4dataset
102
+
103
+
104
+ @dataclass()
105
+ class M4Meta:
106
+ seasonal_patterns = ['Yearly', 'Quarterly', 'Monthly', 'Weekly', 'Daily', 'Hourly']
107
+ horizons = [6, 8, 18, 13, 14, 48]
108
+ frequencies = [1, 4, 12, 1, 1, 24]
109
+ horizons_map = {
110
+ 'Yearly': 6,
111
+ 'Quarterly': 8,
112
+ 'Monthly': 18,
113
+ 'Weekly': 13,
114
+ 'Daily': 14,
115
+ 'Hourly': 48
116
+ } # different predict length
117
+ frequency_map = {
118
+ 'Yearly': 1,
119
+ 'Quarterly': 4,
120
+ 'Monthly': 12,
121
+ 'Weekly': 1,
122
+ 'Daily': 1,
123
+ 'Hourly': 24
124
+ }
125
+ history_size = {
126
+ 'Yearly': 1.5,
127
+ 'Quarterly': 1.5,
128
+ 'Monthly': 1.5,
129
+ 'Weekly': 10,
130
+ 'Daily': 10,
131
+ 'Hourly': 10
132
+ } # from interpretable.gin
133
+
134
+
135
+ def load_m4_info() -> pd.DataFrame:
136
+ """
137
+ Load M4Info file.
138
+
139
+ :return: Pandas DataFrame of M4Info.
140
+ """
141
+ return pd.read_csv(INFO_FILE_PATH)
data_provider/uea.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+
6
+
7
+ def collate_fn(data, max_len=None):
8
+ """Build mini-batch tensors from a list of (X, mask) tuples. Mask input. Create
9
+ Args:
10
+ data: len(batch_size) list of tuples (X, y).
11
+ - X: torch tensor of shape (seq_length, feat_dim); variable seq_length.
12
+ - y: torch tensor of shape (num_labels,) : class indices or numerical targets
13
+ (for classification or regression, respectively). num_labels > 1 for multi-task models
14
+ max_len: global fixed sequence length. Used for architectures requiring fixed length input,
15
+ where the batch length cannot vary dynamically. Longer sequences are clipped, shorter are padded with 0s
16
+ Returns:
17
+ X: (batch_size, padded_length, feat_dim) torch tensor of masked features (input)
18
+ targets: (batch_size, padded_length, feat_dim) torch tensor of unmasked features (output)
19
+ target_masks: (batch_size, padded_length, feat_dim) boolean torch tensor
20
+ 0 indicates masked values to be predicted, 1 indicates unaffected/"active" feature values
21
+ padding_masks: (batch_size, padded_length) boolean tensor, 1 means keep vector at this position, 0 means padding
22
+ """
23
+
24
+ batch_size = len(data)
25
+ features, labels = zip(*data)
26
+
27
+ # Stack and pad features and masks (convert 2D to 3D tensors, i.e. add batch dimension)
28
+ lengths = [X.shape[0] for X in features] # original sequence length for each time series
29
+ if max_len is None:
30
+ max_len = max(lengths)
31
+
32
+ X = torch.zeros(batch_size, max_len, features[0].shape[-1]) # (batch_size, padded_length, feat_dim)
33
+ for i in range(batch_size):
34
+ end = min(lengths[i], max_len)
35
+ X[i, :end, :] = features[i][:end, :]
36
+
37
+ targets = torch.stack(labels, dim=0) # (batch_size, num_labels)
38
+
39
+ padding_masks = padding_mask(torch.tensor(lengths, dtype=torch.int16),
40
+ max_len=max_len) # (batch_size, padded_length) boolean tensor, "1" means keep
41
+
42
+ return X, targets, padding_masks
43
+
44
+
45
+ def padding_mask(lengths, max_len=None):
46
+ """
47
+ Used to mask padded positions: creates a (batch_size, max_len) boolean mask from a tensor of sequence lengths,
48
+ where 1 means keep element at this position (time step)
49
+ """
50
+ batch_size = lengths.numel()
51
+ max_len = max_len or lengths.max_val() # trick works because of overloading of 'or' operator for non-boolean types
52
+ return (torch.arange(0, max_len, device=lengths.device)
53
+ .type_as(lengths)
54
+ .repeat(batch_size, 1)
55
+ .lt(lengths.unsqueeze(1)))
56
+
57
+
58
+ class Normalizer(object):
59
+ """
60
+ Normalizes dataframe across ALL contained rows (time steps). Different from per-sample normalization.
61
+ """
62
+
63
+ def __init__(self, norm_type='standardization', mean=None, std=None, min_val=None, max_val=None):
64
+ """
65
+ Args:
66
+ norm_type: choose from:
67
+ "standardization", "minmax": normalizes dataframe across ALL contained rows (time steps)
68
+ "per_sample_std", "per_sample_minmax": normalizes each sample separately (i.e. across only its own rows)
69
+ mean, std, min_val, max_val: optional (num_feat,) Series of pre-computed values
70
+ """
71
+
72
+ self.norm_type = norm_type
73
+ self.mean = mean
74
+ self.std = std
75
+ self.min_val = min_val
76
+ self.max_val = max_val
77
+
78
+ def normalize(self, df):
79
+ """
80
+ Args:
81
+ df: input dataframe
82
+ Returns:
83
+ df: normalized dataframe
84
+ """
85
+ if self.norm_type == "standardization":
86
+ if self.mean is None:
87
+ self.mean = df.mean()
88
+ self.std = df.std()
89
+ return (df - self.mean) / (self.std + np.finfo(float).eps)
90
+
91
+ elif self.norm_type == "minmax":
92
+ if self.max_val is None:
93
+ self.max_val = df.max()
94
+ self.min_val = df.min()
95
+ return (df - self.min_val) / (self.max_val - self.min_val + np.finfo(float).eps)
96
+
97
+ elif self.norm_type == "per_sample_std":
98
+ grouped = df.groupby(by=df.index)
99
+ return (df - grouped.transform('mean')) / grouped.transform('std')
100
+
101
+ elif self.norm_type == "per_sample_minmax":
102
+ grouped = df.groupby(by=df.index)
103
+ min_vals = grouped.transform('min')
104
+ return (df - min_vals) / (grouped.transform('max') - min_vals + np.finfo(float).eps)
105
+
106
+ else:
107
+ raise (NameError(f'Normalize method "{self.norm_type}" not implemented'))
108
+
109
+
110
+ def interpolate_missing(y):
111
+ """
112
+ Replaces NaN values in pd.Series `y` using linear interpolation
113
+ """
114
+ if y.isna().any():
115
+ y = y.interpolate(method='linear', limit_direction='both')
116
+ return y
117
+
118
+
119
+ def subsample(y, limit=256, factor=2):
120
+ """
121
+ If a given Series is longer than `limit`, returns subsampled sequence by the specified integer factor
122
+ """
123
+ if len(y) > limit:
124
+ return y[::factor].reset_index(drop=True)
125
+ return y
dataset/m4/Daily-test.csv ADDED
The diff for this file is too large to render. See raw diff
 
dataset/m4/Daily-train.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78e94591c60c06309f1e544fd7b2ccba28f3616b01913dd336a1d1d98483a1ec
3
+ size 95765153
dataset/m4/Hourly-test.csv ADDED
The diff for this file is too large to render. See raw diff
 
dataset/m4/Hourly-train.csv ADDED
The diff for this file is too large to render. See raw diff
 
dataset/m4/M4-info.csv ADDED
The diff for this file is too large to render. See raw diff
 
dataset/m4/Monthly-test.csv ADDED
The diff for this file is too large to render. See raw diff
 
dataset/m4/Monthly-train.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63aaa56198b4a22279a2fec541f56921f3bed6d8e0f84e550206df9276f2e9b9
3
+ size 91655432
dataset/m4/Quarterly-test.csv ADDED
The diff for this file is too large to render. See raw diff
 
dataset/m4/Quarterly-train.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4450678cd493aa965c14bde984a5b62a66d8c3b2f7af1e53779bc91ff008ccdc
3
+ size 38788547
dataset/m4/Weekly-test.csv ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "V1","V2","V3","V4","V5","V6","V7","V8","V9","V10","V11","V12","V13","V14"
2
+ "W1","35397.16","35808.59","35808.59","36246.14","36246.14","36403.7","36403.7","36150.2","36150.2","35790.55","35790.55","34066.95","34066.95"
3
+ "W2","3608.061","3624.368","3537.64","3587.344","3662.088","3658.634","3622.148","3591.48","3638.559","3635.781","3522.328","3458.631","3354.758"
4
+ "W3","9602.4","9895.9","9915.9","9726.6","9626.5","9966.1","10008.5","9911.7","9760.7","9968.2","10089.1","10005.9","9823.3"
5
+ "W4","2336.56","2373.14","2457.51","2586","2237.7","2877.94","2497.74","3525.34","3053.33","3151.07","3425.79","3505.61","4098.69"
6
+ "W5","1499","1734","1391","1833","1637","2122","1757","2079","2477","1957","1907","1752","1330"
7
+ "W6","502","568","528","632","606","785","779","834","906","1372","1080","938","752"
8
+ "W7","3117","2579","3199","1716","4558","3001","3263","3110","5899","7148","4048","2707","2579"
9
+ "W8","3430","2870","4620","4890","7470","6960","6430","3970","3650","5730","4500","2870","3270"
10
+ "W9","8100","7481","8953","8087","9724","8993","10011","9705","8397","7560","8669","7574","7106"
11
+ "W10","5122.14","5104.68","5104.68","5104.68","5565.7","5565.7","5565.7","5565.7","5565.7","5544.94","5544.94","5544.94","5544.94"
12
+ "W11","337.19","359.52","357.62","368.22","328.77","391.31","340.06","323.19","323.3","339.94","316.45","295.99","280.07"
13
+ "W12","1092","1085","1078","1071","1064","1057","1050","1043","1092","1085","1078","1071","1064"
14
+ "W13","2373.445","2373.105","2453.974","2390.353","2324.83","2337.239","2375.971","2398.802","2350.108","2363.517","2484.652","2558.814","2653.311"
15
+ "W14","2341.516","2407.573","2435.3","2361.92","2310.927","2322.117","2408.718","2476.214","2361.203","2393.436","2552.503","2554.351","2500.651"
16
+ "W15","4231","4311.68","3848.09","4149.38","3633.49","3553.49","3743.59","4626.91","4235.48","4288.81","5601.72","5737.57","5232.13"
17
+ "W16","4231","4311.68","3848.09","4149.38","3633.49","3553.49","3743.59","4626.91","4235.48","4288.81","5601.72","5737.57","5232.13"
18
+ "W17","12027.16","12027.2","12027.24","11985.62","11985.74","12432.86","12432.99","12418.26","12418.32","12418.37","12418.43","12418.49","12323.22"
19
+ "W18","4334.57","4334.61","4334.65","4351.65","4351.77","3846.99","3954.85","3996.48","3992.55","3992.61","3992.68","3992.75","4029.96"
20
+ "W19","2754","3035","2591","2808","2324","4033","3146","3402","3677","4790","7520","4445","2960"
21
+ "W20","3942","4184","3374","4291","2657","4258","4050","3730","3311","3154","4551","4677","3944"
22
+ "W21","2203","2234","1786","2582","2838","2949","2483","2184","3809","3395","3194","2744","2054"
23
+ "W22","3780","4430","3380","4100","2710","3980","3720","3300","3460","2760","4730","3930","3180"
24
+ "W23","753.8","767.6","722.9","769.5","536.6","756.4","813","732.2","617.9","617","901.1","872.9","739.2"
25
+ "W24","6510","6109","5215","6798","4354","10041","7637","6137","7002","8575","19189","12646","6956"
26
+ "W25","1431","1679","1348","1870","2110","2492","2701","2649","2745","2676","3017","2669","2017"
27
+ "W26","2672","3061","2947","4293","3409","4660","5136","6235","5580","4585","6168","4943","3945"
28
+ "W27","2144","2144","1752","2314","2645","2643","3120","3302","3802","2762","3713","4473","2643"
29
+ "W28","7793","3341","3003","3507","2555","3167","3300","3507","8089","8507","5611","4081","3025"
30
+ "W29","851","1017","907","1075","833","1716","1339","1312","1345","1322","2142","1732","1143"
31
+ "W30","3623","4016","3075","4373","3773","5121","4129","4785","6509","7561","6183","5485","4223"
32
+ "W31","708","829","796","1026","1483","1412","1665","1696","2150","1948","1778","1156","1053"
33
+ "W32","5029","3990","5803","3980","6833","5247","6023","7372","6529","7042","6642","6607","7307"
34
+ "W33","1635.1","1470","1831.2","1860.7","2952.6","2076","1926.5","2127.6","2526.5","3130.5","2901.6","1965.4","1930.4"
35
+ "W34","1528","1485","1676","1434","2556","2017","1752","1575","1478","2618","2045","1480","1464"
36
+ "W35","1240","938","1641","1225","2143","2298","1561","1253","1257","2191","2137","1662","1565"
37
+ "W36","18469.2","14697.3","14490.1","16890.6","13307.5","11488","11021.4","13786","17773.2","20125.2","13053.1","14418.3","18561.2"
38
+ "W37","2891.8","2870","2839.7","2925.6","2858.2","2873.3","2857.3","2832.8","2818.5","2843","2863.1","2838.6","2827.4"
39
+ "W38","3050.13","2991.26","2966.63","3079.93","3027.55","3046.13","3052.31","3028.4","3042.15","3060.21","3044.19","3005.94","2999.31"
40
+ "W39","5820.04","5811.86","5764.85","5774.63","5782.51","5729.81","5736.21","5784.8","5783.06","5742.76","5732.81","5730.19","5678.13"
41
+ "W40","5816.795","5807.545","5761.31","5766.93","5777.23","5727.54","5732.42","5780.82","5776.885","5739.61","5732.305","5721.025","5674.69"
42
+ "W41","6138.53","6046.39","6014.57","6020.12","6111.92","6065.35","6090.68","6170.24","6206.4","6161.53","6092.33","6028.87","6039.88"
43
+ "W42","6145.38","6055.38","6021.96","6036.22","6123.1","6070.16","6098.74","6178.74","6219.68","6168.3","6093.4","6048.22","6047.21"
44
+ "W43","6141.955","6050.885","6018.265","6028.17","6117.51","6067.755","6094.71","6174.49","6213.04","6164.915","6092.865","6038.545","6043.545"
45
+ "W44","5813.55","5803.23","5757.77","5759.23","5771.95","5725.27","5728.63","5776.84","5770.71","5736.46","5731.8","5711.86","5671.25"
46
+ "W45","2934.68","2967.95","3071.57","3101.51","3101.51","3136.93","3189.06","3165.62","3069.02","2990.9","2965.64","2997.27","3001.98"
47
+ "W46","3156.835","3088.77","3129.65","3163.045","3114.31","3110.356","3123.937","3136.711","3158.378","3095.771","3042.689","3026.794","3062.599"
48
+ "W47","2933.38","2967.36","3070.66","3098.73","3098.73","3134.37","3186.88","3164.56","3067.65","2988.52","2963.47","2994.78","2999.43"
49
+ "W48","2964.555","2996.025","3000.705","2941.08","2935.96","2938.24","2936.72","2936.66","2882.2","2862.63","2867.64","2875.68","2871.67"
50
+ "W49","3193.28","3248.07","3406.28","3378.86","3335.78","3331.52","3375.54","3365.51","3264.59","3155.58","3087.64","3128.35","3161.7"
51
+ "W50","3194.69","3248.72","3407.29","3381.89","3338.78","3334.24","3377.85","3366.64","3266.05","3158.09","3089.9","3130.95","3164.39"
52
+ "W51","6549.4","6749.6","6672.9","6664.1","6601.3","6601.8","6483.1","6488.7","6431.8","6468.2","6418.9","6383.5","6403.6"
53
+ "W52","6980.8","7049.9","6985.1","7060.9","7021.9","7015","6921.2","6940.5","6942","6894.3","6783.9","6801.8","6773.4"
54
+ "W53","1915.5","1845.1","1867.4","2142","2309.8","2248.6","2182.6","2050.6","1986.8","1865.6","1912.3","1982.7","1933.5"
55
+ "W54","2483","2491","2474","2491","2483","2460","2473","2448","2429","2411","2434","2445","2439"
56
+ "W55","5056","4936","4551","4461","4515","4725","4863","5097","5191","5195","5328","5342","5207"
57
+ "W56","3030","3099","3193","3261","3334","3440","3538","3633","3733","3814","3877","3929","3978"
58
+ "W57","2070","1970","1830","1730","1640","1550","1500","1460","1460","1390","1370","1370","1410"
59
+ "W58","3485","3651","3681","3704","3737","3907","3747","3820","3842","3649","3414","3116","2984"
60
+ "W59","2350","2040","2030","1600","1160","2140","2780","1910","2180","2170","2010","1690","1880"
61
+ "W60","9160.47","9180.81","9257.59","9244.22","9195.97","9383.28","9480.43","9506.87","9431.19","9450.03","9630.17","9660.81","9601.98"
62
+ "W61","10128.3","10802.8","10802.8","10971.1","10971.1","10436.8","10436.8","11025.5","11025.5","10135.8","10135.8","10934.7","10934.7"
63
+ "W62","7208.8","7061.6","7061.6","7018.9","7018.9","7157.6","7157.6","7144.4","7144.4","7536.5","7536.5","7620.1","7620.1"
64
+ "W63","6062.1","6110.6","6110.6","6119.9","6119.9","6115.9","6115.9","6194.7","6194.7","6386.4","6386.4","6649.3","6649.3"
65
+ "W64","14834","12734","12772","13825","15080","12824","12879","13769","14767","13645","12832","13629","15073"
66
+ "W65","20461","17999","18027","19270","20726","18118","18144","19242","20438","19136","18135","19163","20871"
67
+ "W66","7318.7","7580.2","7605.2","7418.2","7320.8","7633.4","7663.4","7564.1","7418.4","7616","7729","7645","7464.2"
68
+ "W67","13361","13641","13485","13635","13869","13787","13846","13810","13958","13995","13987","14185","14413"
69
+ "W68","6925.91","6949.42","6997.28","7002.45","6998.8","6970.53","7019.43","7039.14","7041.96","6974.79","6967.61","6968.64","6957.49"
70
+ "W69","6930.32","6983.87","7009.79","7004","7003.5","6971.28","7036.53","7049.97","7035.15","6965.37","6967.94","6962.49","6960.41"
71
+ "W70","2612.74","2607.29","2583.84","2601.24","2597.5","2591.6","2648.13","2649.71","2654.19","2655.63","2614.58","2613.64","2614.93"
72
+ "W71","1930","1927","1927","1907","1906","1903","1903","1902","1897","1897","1896","1893","1892"
73
+ "W72","1930","1927","1927","1907","1906","1903","1903","1902","1897","1897","1896","1893","1892"
74
+ "W73","1930","1927","1927","1907","1906","1903","1903","1902","1897","1897","1896","1893","1892"
75
+ "W74","4494.5","3936.22","4013.09","3951.31","3701.88","3380.45","3652.31","3893.49","4156.82","4353.71","4927.72","5405.49","6528.25"
76
+ "W75","23442.9","23935.8","24274.4","23870.6","23235.6","23930.5","25498.8","25416.7","23773.6","25931.8","25269.5","24984.5","25674.5"
77
+ "W76","440","440","440","440","440","440","440","440","240","240","240","240","240"
78
+ "W77","3977.66","4541.82","5291.64","4690.9","4198.24","4610.46","4725.98","4681.88","3926.7","3826.87","4457.07","4650.71","4440.09"
79
+ "W78","3923.9","4525.96","5268.49","4149.73","4106.37","4300.7","4930.68","4726.57","3929.57","4190.49","4428.39","4249.56","4221.58"
80
+ "W79","1560","1720","1770","1820","1860","1930","1790","1660","1620","1600","1590","1640","1670"
81
+ "W80","1720","1760","1810","1840","1950","1820","1670","1620","1600","1580","1640","1660","1760"
82
+ "W81","23057.9","24044.6","23942.2","23970.5","23519","23811.9","24756.7","24982.2","24388.4","24379.1","25518.2","24728","25659.7"
83
+ "W82","4417.421","4434.675","4430.035","4412.757","4414.728","4419.787","4422.151","4410.932","4409.458","4417.371","4423.566","4427.224","4414.39"
84
+ "W83","6001.5","5974.5","5931.6","6028.6","6053","6064","6017.2","5972.1","5997.6","5991.1","5939.8","5939.5","5940.6"
85
+ "W84","6016.7","5923.8","5945","6050.1","6069.5","6022.3","5944.7","5985.6","5996","5978","5937.4","5937.5","5944"
86
+ "W85","17289","17312","17335","17364","17429","17501","17567","17632","17683","17718","17753","17791","17808"
87
+ "W86","17297","17320","17344","17385","17454","17522","17601","17661","17696","17731","17766","17801","17791"
88
+ "W87","1094.56","1094.52","1092.83","1091.55","1098.78","1099.02","1096.83","1093.89","1089.57","1099.3","1098.45","1095.05","1088.79"
89
+ "W88","12154.01","12042.04","12027.43","11998.81","12089.79","12025.85","12079.76","12053.18","12179.34","12133.8","12093.14","12163.75","12117.78"
90
+ "W89","1185.62","1180.63","1201.97","1132.26","1087.99","1095.31","1151.72","1115.15","1107.33","1089.42","1084.26","1103.96","1071.63"
91
+ "W90","1828.128","1861.782","1872.236","1854.954","1852.089","1854.091","1868.669","1893.618","1859.457","1860.784","1871.368","1829.142","1818.838"
92
+ "W91","19466.9","19798.45","19924.33","19681.8","19608.88","19636.22","19838.41","20051.33","19701.9","19697.26","19797.94","19395.38","19260.01"
93
+ "W92","21837.03","22832.08","22198.61","22399.36","22658.55","22969.77","22908.66","22311.88","21853.35","22670.54","22686.31","21102.33","21032.34"
94
+ "W93","1748.563","1753.833","1740.646","1741.955","1753.649","1760.835","1747.62","1745.674","1751.342","1758.483","1747.428","1749.617","1753.035"
95
+ "W94","1802.54","1986.53","1931.14","1831.39","1948.19","1948.12","2000.27","2077.83","2068.5","2093.54","2174.32","2052.07","2108"
96
+ "W95","1621.22","1764.77","1679.71","1638.8","1728.67","1733.06","1799.69","1857.32","1899.37","1907.46","1936.32","1863.71","1912.86"
97
+ "W96","11355.272","11385.864","11303.634","11309.627","11397.101","11427.142","11459.31","11388.745","11451.752","11478.884","11545.812","11497.115","11568.007"
98
+ "W97","447.83","454.85","482.02","472.14","455.09","481.15","524.38","519.49","485.51","467.8","486.97","483.66","457.56"
99
+ "W98","584.3","590.59","610.91","597.94","585.15","610.26","651.93","646.63","617.37","598.04","617.17","615.46","586.87"
100
+ "W99","14837.74","14751.79","14883.29","15009.8","14963.45","14930.14","14847.66","14828.55","14864.72","14893.35","14811.6","14898.23","14855.39"
101
+ "W100","11474.49","11875.5","11517.02","11305.8","11423.14","11647.5","11632.38","11729.89","11679.04","11740.51","11934.32","11741.1","11634.37"
102
+ "W101","9871.498","9910.685","9815.305","9808.647","9900.756","9934.128","9974.544","9905.89","9965.28","9989.549","10064.652","10007.292","10082.468"
103
+ "W102","16004.79","15859.41","15729.3","15616.87","15749.81","15642.21","15716.89","15789.79","15961.13","15986.34","15871.65","15813.95","15743.94"
104
+ "W103","4153.94","4375.42","4063.47","3975.35","4202.22","4212.67","4171.08","4025.43","4260.1","4159.33","4272.43","3968.74","3905.78"
105
+ "W104","1026.84","1019.01","1009.31","1003.54","1007.61","1004.8","998.68","991.28","991.98","988.73","979.2","976.82","968.28"
106
+ "W105","9135.8","9155.8","9123.07","9168.48","9097.38","9109.48","9061.98","9058.14","9033.5","9064.66","9038.48","9025.31","9001.11"
107
+ "W106","2162.966","2167.806","2153.199","2154.003","2164.848","2171.359","2157.906","2154.995","2160.496","2165.439","2154.43","2155.905","2159.078"
108
+ "W107","24019.73","23983.72","24168.42","24240.09","24481.23","24419.29","24441.25","24333.51","24398.35","24222.27","24359.23","24297.26","24345.73"
109
+ "W108","15926.84","16078.017","15963.238","15970.428","16056.811","16086.342","16092.328","16042.099","16027.739","16096.719","16115.75","15955.488","15951.666"
110
+ "W109","14171.645","14318.713","14198.442","14206.103","14300.423","14334.406","14340.354","14293.978","14288.988","14354.158","14370.531","14207.81","14198.367"
111
+ "W110","7331.84","7307.69","7396.85","7438.91","7519.81","7592.29","7571.98","7464.36","7484.2","7484.94","7528.27","7483.09","7522.26"
112
+ "W111","1080.57","1080.53","1078.63","1077.26","1084.46","1084.89","1082.71","1079.78","1076.06","1085.53","1084.67","1081.1","1074.27"
113
+ "W112","9609.45","9520.42","9512.01","9471.4","9562.31","9488.15","9556.86","9534.86","9647.41","9594.13","9576.58","9662.18","9644.83"
114
+ "W113","754.7","766.1","766.27","714.66","682.99","700.3","763.74","747.68","742.31","717.41","712.44","734.81","726.45"
115
+ "W114","10463.22","10370.99","10322.48","10331.1","10275.58","10318.71","10216.97","10282.47","10238.62","10071.88","10200.06","10227.22","10233.33"
116
+ "W115","11217.92","11137.09","11088.75","11045.76","10958.57","11019.01","10980.71","11030.15","10980.93","10789.29","10912.5","10962.03","10959.78"
117
+ "W116","14539.1","14894.1","13984.77","13907.03","14293.59","14685.26","15017.67","14100.07","14408.13","14676.51","15205.38","13875.72","14094.91"
118
+ "W117","7037.56","7043.25","7026.12","7062.94","7075.16","7057.07","7064.06","7097.6","7249.93","7267.89","7328.08","7388.03","7480.86"
119
+ "W118","1747.852","1753.13","1739.954","1741.265","1752.962","1760.145","1746.934","1744.992","1750.66","1757.801","1746.75","1748.939","1752.358"
120
+ "W119","1029.75","1180.41","1140.58","1060.38","1154.44","1152.64","1175.8","1244.19","1225.38","1262.37","1332.13","1221.9","1270.41"
121
+ "W120","855.09","989.28","916.28","889.25","943.99","943.1","996.81","1044.71","1069.8","1090.56","1126.05","1064.67","1106.98"
122
+ "W121","10465.971","10496.016","10403.031","10399.53","10497.08","10528.68","10569.283","10488.464","10554.103","10585.971","10651.968","10589.396","10672.134"
123
+ "W122","329.99","321.94","332.45","333.74","338.72","353.4","396.27","383.57","366.9","359.23","363.53","378.46","369.19"
124
+ "W123","440.2","431.45","435.3","432.77","440.92","454.74","497.32","484.31","473.21","461.88","466.64","483.21","471.99"
125
+ "W124","8434.6","8404.06","8370.62","8440.35","8445.77","8485.14","8424.83","8351.47","8365.51","8411.94","8472.83","8434.49","8421.44"
126
+ "W125","10927.21","11218.64","10974.46","10771.82","10862.93","11054.22","11031.79","11089.91","11069.27","11163.08","11300.69","11152.85","11055.06"
127
+ "W126","9622.511","9655.61","9565.969","9555.495","9652.503","9680.166","9726.8","9653.317","9717.552","9744.777","9804.685","9745.947","9829.99"
128
+ "W127","11173.39","10918.46","10850.02","10756.78","10815.1","10764.33","10842.66","10886.6","10968.92","11026.69","11039.92","11213.69","11323.96"
129
+ "W128","3659.91","3836.54","3603.72","3518.12","3731.7","3682.25","3703.91","3552.16","3731.15","3699.87","3741.75","3541.9","3452.87"
130
+ "W129","1007.32","1000.05","990.76","985.07","988.6","986.32","981.23","974.06","974.39","971.47","961.91","960.25","951.77"
131
+ "W130","8142.17","8164.91","8155.02","8197.46","8121.84","8135.96","8115.75","8106.23","8078.34","8105.2","8077.74","8059.49","8032.97"
132
+ "W131","2161.865","2166.713","2152.117","2152.923","2163.772","2170.279","2156.83","2153.923","2159.431","2164.374","2153.369","2154.844","2158.018"
133
+ "W132","22735.69","22695.8","22873.78","22912.73","23157.6","23052.45","23109.41","23066.29","23133.89","22943.61","23070.07","23056.38","23135.75"
134
+ "W133","13883.352","13943.084","13817.27","13799.658","13890.442","13929.041","13978.179","13895.258","13956.526","13983.295","14057.238","13958.295","14011.606"
135
+ "W134","12132.819","12188.302","12057.06","12039.798","12138.491","12181.638","12230.64","12151.594","12222.102","12245.022","12316.353","12215.006","12262.678"
136
+ "W135","6190.89","6166.87","6248.29","6256.87","6346.31","6377.59","6382.05","6333.99","6350.53","6326.36","6362.01","6363.97","6432.69"
137
+ "W136","6432.1","6492.9","6523.6","6596.4","6759.9","6701.7","6745.7","6721.2","6644","6667.8","6661.7","6458.5","6508.5"
138
+ "W137","7126.057","7071.44","7078.15","7037.05","7071.02","7029.9","7096.12","7087.92","7166.37","7130.38","7131.41","7203.06","7162.08"
139
+ "W138","387.13","396.14","404.7","363.68","322.02","328.5","392.13","382.78","364.21","339.84","334.59","363.64","337.18"
140
+ "W139","7833.402","7763.08","7697.07","7691.19","7674.28","7754.88","7648.39","7689.17","7639.02","7478.78","7613.66","7563.24","7531.5"
141
+ "W140","8220.532","8159.22","8101.77","8054.87","7996.3","8083.38","8040.52","8071.95","8003.23","7818.62","7948.25","7926.88","7868.68"
142
+ "W141","10971.829","11241.17","10370","10322.61","10592.83","10873.84","11160.36","10316.67","10680.77","10809.59","11455.11","10315.26","10644.41"
143
+ "W142","5388.424","5388.93","5368.95","5390.96","5401.97","5389.98","5392.11","5412.92","5525.45","5540.92","5580.59","5622.11","5684.54"
144
+ "W143","11374.106","11419.08","11304.46","11312.03","11412.35","11481.37","11381.15","11365.37","11386.84","11445.25","11290.92","11301.3","11310.58"
145
+ "W144","1005.48","1152.75","1113.19","1035.17","1124.86","1122.3","1147.56","1215.37","1191.06","1226.87","1298.58","1187.11","1237.65"
146
+ "W145","8345.9","9654.6","8929.2","8681.6","9187.9","9169.4","9726.4","10199.3","10397.5","10589.6","10966.5","10340.4","10782"
147
+ "W146","6801.5412","6826.213","6737.471","6733.834","6810.25","6824.878","6855.421","6783.048","6832.483","6856.145","6914.958","6859.368","6934.528"
148
+ "W147","120.66","117.59","111.32","110.66","110.46","111.68","120.32","117.41","125.32","121.24","112.52","122.99","118.27"
149
+ "W148","222.63","218.63","205.95","201.16","204.31","204.6","212.82","209.47","223.37","215.24","206.88","219.05","212.55"
150
+ "W149","4374.969","4342.37","4303.19","4370.16","4392.91","4423.87","4362.98","4303.89","4317.18","4359.9","4426.36","4381.17","4372.66"
151
+ "W150","8104.576","8397.21","8157.4","7991.08","8073.41","8245.39","8235.18","8301.8","8245.5","8365.11","8466.19","8322.25","8219.59"
152
+ "W151","6364.0443","6391.976","6307.152","6296.818","6370.959","6382.491","6419.123","6352.659","6400.765","6420.155","6472.322","6421.251","6497.262"
153
+ "W152","8624.177","8443.2","8390.94","8266.24","8299.13","8283.85","8356.8","8416.68","8466.06","8541.06","8572.49","8732.67","8749.33"
154
+ "W153","2990.33","3154.84","2945.55","2862.16","3032.25","2990.3","3001.65","2879.11","3049.37","2997.1","3046.23","2862.51","2802.91"
155
+ "W154","8643.9","8597.8","8504.5","8468.7","8454","8415","8361.8","8296.2","8313.1","8300.7","8219.2","8228.2","8136.1"
156
+ "W155","4568.001","4591.74","4592.62","4636.82","4562.67","4572.28","4557.9","4538.75","4523.75","4521.02","4499.58","4485.14","4473.5"
157
+ "W156","5432.391","5451.52","5443.07","5483.69","5408.07","5413.78","5394.08","5368.37","5355.06","5351.09","5321.5","5307.96","5287.11"
158
+ "W157","1431.923","1435.919","1422.959","1423.154","1432.548","1438.745","1428.38","1425.897","1427.808","1431.297","1415.852","1416.006","1416.477"
159
+ "W158","16585.796","16545.69","16718.18","16774.2","16973.97","16859.12","16888.19","16801.33","16854.51","16677.41","16792.37","16756.94","16839.65"
160
+ "W159","9265.8308","9320.682","9198.628","9181.263","9248.034","9273.226","9311.038","9229.568","9282.625","9295.202","9366.068","9272.725","9316.549"
161
+ "W160","8060.0957","8112.114","7987.147","7969.756","8043.952","8072.951","8110.896","8034.681","8093.054","8099.825","8169.184","8075.646","8114.811"
162
+ "W161","12143.281","12130.04","12220.91","12262.51","12368.93","12230.26","12271.89","12244.95","12279.44","12134.14","12219.63","12187.07","12201.34"
163
+ "W162","4442.515","4415.65","4497.27","4511.69","4605.04","4628.86","4616.3","4556.38","4575.07","4543.27","4572.74","4569.87","4638.31"
164
+ "W163","3806.6","3812.9","3801.8","3787","3826","3831.2","3817.2","3802.3","3799.9","3852","3840.2","3822.6","3798.4"
165
+ "W164","24833.93","24489.8","24338.6","24343.5","24912.9","24582.5","24607.4","24469.4","24810.4","24637.5","24451.7","24591.2","24827.5"
166
+ "W165","3675.7","3699.6","3615.7","3509.8","3609.7","3718","3716.1","3649","3781","3775.7","3778.5","3711.7","3892.7"
167
+ "W166","2629.818","2607.91","2625.41","2639.91","2601.3","2563.83","2568.58","2593.3","2599.6","2593.1","2586.4","2663.98","2701.83"
168
+ "W167","2997.388","2977.87","2986.98","2990.89","2962.27","2935.63","2940.19","2958.2","2977.7","2970.67","2964.25","3035.15","3091.1"
169
+ "W168","3567.271","3652.93","3614.77","3584.42","3700.76","3811.42","3857.31","3783.4","3727.36","3866.92","3750.27","3560.46","3450.5"
170
+ "W169","1649.136","1654.32","1657.17","1671.98","1673.19","1667.09","1671.95","1684.68","1724.48","1726.97","1747.49","1765.92","1796.32"
171
+ "W170","3593.378","3600.1","3605.11","3626.18","3631.12","3625.32","3631.86","3645.91","3674.92","3680.2","3698.89","3719","3754.53"
172
+ "W171","6104.414","6112.22","6095.08","6100.62","6117.27","6120.08","6088.19","6084.55","6119.76","6132.76","6176.58","6188.09","6213"
173
+ "W172","3664.4298","3669.803","3665.56","3665.696","3686.83","3703.802","3713.862","3705.416","3721.62","3729.826","3737.01","3730.028","3737.606"
174
+ "W173","3696.3","3536.5","3504.4","3481.7","3684.6","3749.7","3684.8","4354.7","3618.1","4381.4","3985.8","3381.9","3209.2"
175
+ "W174","2822.634","2821.43","2817.06","2780.74","2789.52","2808.83","2796.61","2788.11","2823.77","2797.97","2834.5","2830.6","2835.47"
176
+ "W175","3258.4667","3263.634","3258.817","3258.677","3281.544","3297.675","3307.677","3300.658","3316.787","3324.622","3332.363","3324.696","3332.728"
177
+ "W176","2549.213","2475.26","2459.08","2490.54","2515.97","2480.48","2485.86","2469.92","2502.86","2485.63","2467.43","2481.02","2574.63"
178
+ "W177","6695.8","6817","6581.7","6559.6","6994.5","6919.5","7022.6","6730.5","6817.8","7027.7","6955.2","6793.9","6499.6"
179
+ "W178","1429.3","1402.7","1403.1","1382","1432","1448.2","1450.5","1444.4","1430.8","1414","1399.9","1374.3","1381.6"
180
+ "W179","2709.779","2713.39","2711.95","2713.77","2713.77","2722.18","2721.67","2737.86","2723.28","2754.11","2756.24","2751.53","2745.86"
181
+ "W180","7299.42","7307.94","7291.58","7297.69","7312.24","7315.34","7284.5","7280.26","7316.23","7330.77","7375.17","7388.38","7415.41"
182
+ "W181","8859.673","8863.5","8867.55","8852.3","8897.4","8915.51","8942.89","9002.82","9002.66","9020.31","9033.94","9050.97","9041.96"
183
+ "W182","1375.3","1376.8","1377.8","1382.1","1392.3","1392.1","1472.6","1476.2","1488.5","1487.9","1491.4","1494.2","1492.8"
184
+ "W183","6149.894","6150.11","6155.6","6138.53","6183.63","6193.33","6221.22","6264.96","6279.38","6266.2","6277.7","6299.44","6296.1"
185
+ "W184","4617.5212","4622.402","4618.642","4618.395","4642.408","4655.815","4667.141","4665.69","4673.901","4688.093","4691.17","4685.57","4695.057"
186
+ "W185","4072.7233","4076.188","4069.913","4070.042","4094.539","4108.687","4119.744","4116.913","4129.048","4145.197","4147.169","4139.36","4147.867"
187
+ "W186","1748.375","1751.22","1751.02","1745.18","1741.27","1748.73","1765.75","1777.61","1775.46","1783.09","1789.27","1794.1","1794.38"
188
+ "W187","942.8","948.4","944.3","940.1","945.4","960","936.4","939.7","939.3","934.4","920.4","913.7","937.7"
189
+ "W188","2235","2247","2197","2225","2192","2240","2149","2191","2159","2176","2252","2201","2227"
190
+ "W189","2128","2141","2165","2108","2221","2365","2180","2244","2272","2246","2038","1965","2069"
191
+ "W190","456.6855011","460.7204403","453.8686316","441.8651676","445.577063","453.0822772","437.6623557","439.8819108","430.3451587","426.0420738","419.6364978","411.3800853","418.5663313"
192
+ "W191","6635.2293","6711.3531","6879.4553","7031.4599","7280.6471","7287.6782","7199.4766","7075.0888","7495.1075","7853.2658","7738.1914","7432.3491","7384.1148"
193
+ "W192","1160.330918","1165.783708","1130.218848","1144.012678","1116.237318","1143.744038","1056.720308","1119.294178","1064.618758","1048.372548","1135.591008","1105.256478","1120.579848"
194
+ "W193","5897.9498","5780.9284","5693.2279","5734.5912","5836.8002","5884.0807","5840.0028","5820.8905","5977.0669","6176.512","6055.0658","6023.6219","6056.9425"
195
+ "W194","483.26481","501.30087","495.69513","505.966605","490.018459","506.016419","506.059303","488.243623","494.494998","507.554868","508.900208","491.403635","498.588915"
196
+ "W195","984.43264","1007.22011","977.57691","956.13304","989.18684","978.19025","972.65437","932.85298","995.85932","965.36357","978.65649","978.76934","995.1377"
197
+ "W196","9829.0437","9985.4137","10293.6715","10359.0839","9848.4779","9843.3371","9851.5289","10083.1847","9901.8945","10063.8348","9912.0124","9947.1883","10295.3118"
198
+ "W197","3349.318574","3535.511928","3464.898719","3501.769843","3412.296678","3438.179535","3346.9989","3340.722945","3385.474665","3372.953029","3373.89518","3370.010094","3361.798917"
199
+ "W198","2494.02866","2504.633975","2516.196395","2459.919613","2562.934537","2680.654134","2526.370987","2563.657085","2549.711739","2523.642748","2295.089991","2323.388696","2404.942722"
200
+ "W199","853.411584","717.379492","739.09963","715.596157","753.511419","764.419871","759.384134","791.933924","759.283655","774.301921","848.021256","804.524732","889.772763"
201
+ "W200","1847.36095","1741.76012","1857.16426","1737.89313","1834.76119","1945.47024","1795.47817","1859.5506","1993.599","1853.21632","1763.33719","1134.09872","1700.70876"
202
+ "W201","1967","2006","2007","1992","1974","1963","1958","1941","1986","1972","1970","1973","2025"
203
+ "W202","4203","4253","4204","4217","4166","4203","4106","4133","4145","4147","4222","4175","4252"
204
+ "W203","5510","5380","5370","5260","5250","5110","5260","5050","4770","4630","4340","4720","5060"
205
+ "W204","7304","7318.9","7362.3","7295.9","7296.4","7304.9","7339.5","7345.8","7326.6","7325.2","7360.2","7382.9","7378.7"
206
+ "W205","12491.95","13013.34","13209.75","12981","13537.85","13132.64","13691.5","13967.95","13867.98","13856.59","14381.13","14510.66","14226.31"
207
+ "W206","1050.5","1093.4","1222","1223.7","1043.1","930.1","1070.1","1341","966.5","1052.4","938.6","1193.9","1217.3"
208
+ "W207","11744.3","11322.8","11200.7","12733.5","14568.8","20567.8","16736.3","13883.6","13380.5","13299.4","15151.2","14482.4","14773.8"
209
+ "W208","29018.3","25380.9","24903.8","27184.9","23801.6","22076","21381.2","24022.8","28176.7","30520","23750.7","25196.6","29393"
210
+ "W209","4185.45","4185.45","4185.45","4185.45","4185.45","3944.18","3944.18","3944.18","3944.18","3944.18","3944.18","3944.18","3944.18"
211
+ "W210","8204.21","8496.79","8690.18","8796.65","9308.62","8889.88","9285.78","9353.08","9523.59","9788.87","9638.59","9552.7","9535.58"
212
+ "W211","19138.99","19322.67","19505.88","19671.03","19823.39","19748.99","19869.76","20103.52","20304.31","20497.56","20640.15","20839.37","21017.61"
213
+ "W212","5461.05","5465.87","5464.51","5474.37","5898.01","5958.73","5908.25","5890.65","5889.96","5886.71","5863.1","5846.83","5831.24"
214
+ "W213","3232.2","3400.7","3606","3686.4","3553.5","3071.9","3189.4","3471.8","3135.3","3578.1","3467.9","3637.1","3466"
215
+ "W214","1824.5","1743.9","1794","1809.4","1827.3","1908.2","1847","1751.8","1898.1","1880.5","1978.2","1886.8","1902.4"
216
+ "W215","2366.43","2369.53","2390.59","2419.53","2406.42","2491.02","2478.39","2482.26","2470.39","2449.71","2447.86","2447.49","2469.36"
217
+ "W216","3140.68","3143.77","3165.33","3194.56","3177.71","3278.54","3265.51","3269.02","3257.17","3236.28","3234.39","3231.67","3247.35"
218
+ "W217","9877.09","9890.943","9890","9800","9897","9700","9000","9700","9907.7","9907","9700","9910","9900"
219
+ "W218","9932.821","9936.393","9000","9700","9943","9700","9940.001","9700","9948.552","9951.79","9000","9945","9000"
220
+ "W219","1091","1084","1077","1070","1063","1056","1092","1085","1078","1071","1064","1057","1050"
221
+ "W220","1820","1813","1806","1799","1792","1785","1820","1813","1806","1799","1792","1785","1778"
222
+ "W221","7743.8","7448.1","6802.5","6961.8","7160.7","7338.7","6925.2","7142.4","7651.6","7691.4","7589.4","7706.7","7590.4"
223
+ "W222","7727.1","7515.8","7002.1","7245.4","7361.2","7562.2","7879.1","8082.1","8025.9","8605","8503.5","8405.7","8763.5"
224
+ "W223","26994","26695","28522","27712","27511","27231","27409","27856","27639","26848","27687","27342","27494"
225
+ "W224","1706.741","1739.879","1741.962","1763.156","1740.79","1857.234","1746.231","2071.371","1997.189","2097.67","2111.099","2253.543","2460.929"
226
+ "W225","1930.9","1929.5","1958.5","1935.1","1926.6","2119.8","1982.7","2016.5","2039.6","2311.4","2792.7","2319.2","2147.7"
227
+ "W226","2865.2","2798.5","2576.4","2805.7","2486","2628.3","2494.4","2560.8","2487.5","2543","2524.4","2603.3","2591.2"
228
+ "W227","1365.7","1392.9","1389.1","1439.9","1431.2","1665.4","1561.1","1606.1","1570.6","1804.3","1935.1","1801.3","1744.1"
229
+ "W228","8171","8157","8180","8378","8086","8624","8408","8643","8586","8856","8908","9137","8873"
230
+ "W229","465.53","438.43","443.29","443.83","391.38","476.51","405.11","448.58","416.24","441.32","428.96","432.6","449.93"
231
+ "W230","3117.6","2928.4","2783.4","2987.1","2533.1","3403.5","3054.5","3060.2","3022.7","3426.2","4886.7","4280.2","3533.6"
232
+ "W231","599.6","646.8","672","706.3","755.7","899.6","917.3","1036.2","1119.2","1283","1384.5","1475.1","1504.4"
233
+ "W232","1880.4","1889.2","1864.5","1961.1","2009.7","2211.1","2235","2421.8","2774","3052","3018.4","3171.4","3206.1"
234
+ "W233","1036.6","1036.6","994","1040.9","1024.8","1199.4","1106.1","1197.6","1296","1448.2","1485.5","1600.9","1695"
235
+ "W234","2117","1891.1","2178.6","1910.3","1818.5","1984.8","1958.6","1953.1","2009.7","2158.7","2640.2","2550.4","2678.7"
236
+ "W235","434.3","451.5","493.3","532.4","535.7","647.1","661.3","710.3","763.1","827.1","937.4","1003","1010.4"
237
+ "W236","3265.2","3299.2","3276.9","3338.6","3299.6","3574.8","3557.3","3591.5","3690.3","3977.2","4463.9","4347","4358.4"
238
+ "W237","1070.7","1079","1034.4","1099.3","1055.1","1200.8","1155.9","1202.1","1206.6","1378.2","1438.1","1340.1","1262.4"
239
+ "W238","430","420.4","399.8","415.2","435.7","548.9","538","631.4","714.5","831.1","925.2","953","917.2"
240
+ "W239","3696","3792","3780","3928","3927","4330","4349","4500","4765","5057","5945","5664","5722"
241
+ "W240","2586.5","2560.4","2776.9","2698.6","3157.7","2938.2","3040.8","3197.8","3367.8","3726.5","3724","3796.8","3728.7"
242
+ "W241","1040.84","1015","1098.25","1031.9","1168.21","1271.09","1311.2","1252.43","1363.56","1496.95","1592.62","1614.64","1511.61"
243
+ "W242","1475.3","1468.6","1494.2","1447.2","1643.9","1534.3","1587.3","1580.3","1927.1","2321.8","1935.5","1728.8","1741.4"
244
+ "W243","1245","1330","1551","1550","2150","2514","2855","3230","3335","3511","3890","3812","3828"
245
+ "W244","806.8","841.6","883.8","914.1","988.6","1108","1183","1226.4","1258","1325.9","1434.3","1480.2","1488.3"
246
+ "W245","2291.2","2266.7","2332","2266.8","2538.8","2482.1","2508.7","2508.2","2705.3","2972.8","3072.2","3059","2984.6"
247
+ "W246","1138.7","1098","1159.8","1136.4","1476.2","1476.6","1385.7","1392.2","1435.2","1579.6","1807.3","1769.8","1686.5"
248
+ "W247","4410.5","4581.6","4723.5","4791.8","5421","5510.5","5612.2","5829.8","6084.9","6239.6","6659.6","6860.6","6854.3"
249
+ "W248","16190.4","16913.4","16913.4","17091","17091","16552.7","16552.7","17220.2","17220.2","16522.2","16522.2","17584","17584"
250
+ "W249","21254.72","21655.06","21655.06","22030.28","22030.28","22104.93","22104.93","21822.79","21822.79","21445.17","21445.17","19654.69","19654.69"
251
+ "W250","3446.7","3205.7","3207.3","3331.9","3479.1","3223.9","3228.3","3338.6","3461.4","3330.3","3232.7","3339.3","3518.5"
252
+ "W251","6671","6836","6803","6884","6879","6940","7042","7117","7136","7116","7164","7197","7229"
253
+ "W252","8557.6","8835.3","8859","8662.7","8563.5","8897.2","8930.5","8826.7","8675.2","8886.2","9003.2","8917.6","8732.1"
254
+ "W253","12389","12551","12538","12445","12427","12637","12671","12626","12568","12702","12742","12726","12679"
255
+ "W254","5953.6","5901.6","5905.5","6044.6","6050.7","5968.3","5881.8","6015.5","6050.5","6027","5981.2","6126.6","6147.7"
256
+ "W255","1901","1898","1895","1884","1878","1866","1830","1829","1849","1870","1872","1874","1875"
257
+ "W256","2932","2986","2914.5","3028.8","3291.8","2001.3","2192.4","2057.1","2318.1","2584.8","2583.4","2598.3","2704.1"
258
+ "W257","4017.6","4017.9","4018.2","4018.8","4024.6","4026.8","4037","4037.9","4039.9","4041.8","4044.2","4044.3","4044.2"
259
+ "W258","2097.4","2102.3","2087.1","2130.7","2087.5","2024.5","1983.4","1973.8","1998.2","1975.9","1931.8","1928.8","1940.1"
260
+ "W259","6698","6436","5556","6939","6746","5285","5924","5200","6741","6343","5893","5953","5594"
261
+ "W260","583","503","527","504","643","523","522","1437","464","626","475","785","966"
262
+ "W261","4687.3","4661.5","4573.9","4712.8","4699.2","4555.3","4629.3","4557.9","4714","4676.1","4633.5","4639.5","4603.6"
263
+ "W262","4382.033","4394.197","4378.668","4378.375","4378.166","4390.312","4401.788","4382.618","4382.058","4401.802","4397.108","4379.233","4378.908"
264
+ "W263","2296","2219.6","2214.9","2044.6","2043.2","1968","1966.2","1746.1","1724.1","1701.4","1717.7","1603.2","1617"
265
+ "W264","4071.7","4155","4247.5","4042","3908.2","3923.1","3563.7","3806.9","4360.7","4563.2","4354.9","4323.4","3629"
266
+ "W265","2057.7","2149.37","2022.19","2061.06","2031.96","1904.36","1962.29","2017.42","2020.85","1963.84","1942.14","1973.34","2009.61"
267
+ "W266","2072.76","2162.78","2042.96","2108.32","2022.16","1989.55","2020.05","2071.25","2013.96","1993.61","1976.23","1966.25","2030.46"
268
+ "W267","3031.2","2833","3121.8","3662.4","3801.4","3635.9","3732.4","2989.6","4445.3","4025.2","4599.9","4366.5","3613.2"
269
+ "W268","1306.27","1382.07","1271.27","1314.49","1308.47","1194.08","1259.07","1298.64","1304.74","1285.92","1251.62","1308.33","1369.04"
270
+ "W269","1258.78","1335.44","1229.92","1305.02","1251.08","1192.61","1232.56","1280.13","1217.91","1237.3","1200.59","1223","1325.46"
271
+ "W270","7727.9","8061.2","7905.6","7710.1","7937.5","7954.8","8244.7","8336.4","8431.2","8311.7","8421.9","8301.7","8375.9"
272
+ "W271","7661.3","7754.9","7634.3","7495.5","7846.8","7899.6","8028.8","8126.1","8295.7","8169","8102.7","7990.4","8058.8"
273
+ "W272","974.11","971.93","949.5","952.55","956.53","955.04","928.78","934.69","937.57","942.2","943.45","949.25","951.63"
274
+ "W273","6542.2","6602.2","6631.8","6704.4","6867.5","6809.7","6853.3","6828.4","6750.5","6774.3","6767.8","6564.6","6614.5"
275
+ "W274","751.43","767.3","750.92","746.57","723.49","710.28","703.22","718.78","716.11","677.92","690.52","665.01","640.57"
276
+ "W275","8139.8","8273.4","8130.4","8033","7710.8","7969.4","7874.9","7911.2","7960.5","7563.1","7756.4","7432.5","7050"
277
+ "W276","11409.5","11408.2","11485.6","11820.4","11735","12147","11899.3","11303.7","11336.7","11585.8","11662.6","11191.2","10895.7"
278
+ "W277","3257.323","3271.85","3284.499","3248.314","3162.48","3345.109","3283.453","3260.614","3232.998","3257.269","3299.522","3306.657","3294.326"
279
+ "W278","1261.984","1332.9","1226.2","1267.01","1266.63","1153.61","1217.88","1256.22","1260.13","1249.54","1214.15","1271.59","1333.46"
280
+ "W279","1212.443","1284.81","1181.98","1256.58","1204.17","1146.31","1184.67","1231.88","1164.92","1195.17","1157.35","1182.41","1284.83"
281
+ "W280","5447.979","5462.14","5487.29","5483.53","5478.69","5471.28","5473.97","5487.77","5448.53","5428.96","5440.01","5462.1","5471.9"
282
+ "W281","1756","1849","1859","1774","1799","1895","1793","1730","1815","1919","1963","1977","2010"
283
+ "W282","2144.366551","2155.275445","2134.710461","2155.964119","2136.584281","2111.357748","2151.224229","2102.505934","2097.694159","2081.464003","2116.373998","2109.410461","2135.181639"
284
+ "W283","338.554154","335.62152","339.293614","334.767926","346.838264","348.526527","321.768721","345.408096","330.985181","329.278867","317.836472","335.474349","303.970852"
285
+ "W284","572.83","578.84","601.86","577.06","575.18","583.88","583.42","586.89","590.86","592.64","588.47","514.43","504.2"
286
+ "W285","692.124","687.527","691.085","705.06","706.426","697.053","691.771","695.644","716.192","722.715","721.569","721.882","739.725"
287
+ "W286","10549","10683.7","10413.7","10294.3","10494.1","10588","10359.7","10236.7","10403.6","10394.8","10697.7","10778.3","10831.8"
288
+ "W287","2187.16","2205.63","2214.48","2228.7","2312.59","2368.47","2384.82","2353.87","2374.86","2351.83","2331.47","2340.08","2351.46"
289
+ "W288","7742.4","7742.4","7747.4","7750.3","7712.9","7875.2","7871.1","7867.6","7867.8","7865.8","7865.3","7841.8","7779.9"
290
+ "W289","9790","9800","10160","9960","9920","9710","10070","9970","10000","9770","9840","9840","9750"
291
+ "W290","15593.6","15096.7","13956.7","14349.8","14700.6","15026.8","14956.2","15354.3","15842.4","16424.9","16261.7","16238.3","16523.1"
292
+ "W291","7931","7981.7","7898.8","8095.8","7630.4","8375.1","8041","8141.3","7431.5","7448.9","7899","7716.7","7367"
293
+ "W292","17618.9","17397.7","17308.2","17341.1","16890.5","17910.6","17573.8","17610.4","16916.2","16970.8","17256","16901","15827.8"
294
+ "W293","1873","1748","3351","2824","2482","1905","1900","1817","1985","1709","1949","1641","1941"
295
+ "W294","1255.8","1480.5","1196.1","1618.8","1318.5","1427.8","1197.9","1301.9","996","1153.8","1343.5","1270.2","1210.1"
296
+ "W295","3090","3218","3071","3054","3124","3185","3205","3158","3235","3157","3002","3736","3258"
297
+ "W296","3427","2655","2957","3214","2609","1843","2775","3629","3045","2894","2500","3242","3062"
298
+ "W297","5520","4651","4787","5396","4341","3130","4231","4702","5425","4859","4622","5740","5070"
299
+ "W298","3806","2828","3516","3576","2979","1842","3514","4137","3570","3348","3245","4174","3695"
300
+ "W299","3952","3296","3338","3522","3138","2128","2955","3950","3260","3003","2981","3182","3275"
301
+ "W300","2374","1834","1881","1979","1699","1155","1765","2300","1820","1777","1643","1919","1892"
302
+ "W301","3392","2854","2689","2867","2815","2185","2861","2947","2624","2433","2365","2874","2649"
303
+ "W302","3121","2733","2775","3159","2695","2770","3158","2702","3445","3021","3534","4221","3385"
304
+ "W303","5797","4859","5145","5401","4446","3132","4343","5185","5489","5312","5005","6408","5480"
305
+ "W304","3253","2595","2754","2963","2634","1609","2817","3632","3136","2790","2432","3129","3024"
306
+ "W305","3358","2770","2791","2943","2574","1892","2428","3243","2828","2730","2560","3425","2957"
307
+ "W306","2558","2048","2187","2435","1885","1225","2385","3028","2388","2236","2135","2696","2497"
308
+ "W307","3460","3044","3400","3706","2874","2274","2800","3050","3473","2985","3099","3948","3311"
309
+ "W308","2573","2124","2185","2420","1932","1351","1891","2117","2349","2289","2047","2592","2279"
310
+ "W309","7378","6631","6561","7425","7372","6258","8446","7995","6836","6070","5230","6259","6478"
311
+ "W310","5183","4647","5227","5734","5549","4025","5253","6645","5251","4651","4383","5649","5316"
312
+ "W311","5535","4926","4816","5557","5197","3914","4856","5834","5014","4424","3970","4932","4835"
313
+ "W312","4895","4096","4309","4967","5689","3706","4645","4642","4373","3965","3795","4215","4198"
314
+ "W313","2115","2203","2102","2090","2138","2180","2194","2162","2214","2161","2055","2557","2230"
315
+ "W314","9441","8986","8449","9450","9509","7357","8471","9510","8440","8063","7471","8021","8301"
316
+ "W315","5316","5008","5037","5095","5226","3920","4911","5467","4717","4589","4346","4605","4745"
317
+ "W316","3210","2787","2867","3217","3157","2388","2739","3447","3038","2883","2938","3259","3113"
318
+ "W317","4508","3991","3886","4226","4297","3077","3922","4642","4072","3787","3349","4066","3983"
319
+ "W318","3655","3808","3633","3613","3695","3768","3792","3736","3827","3735","3552","4420","3854"
320
+ "W319","2547","2678","2381","2733","2894","1965","2326","3097","2363","2247","2099","2524","2466"
321
+ "W320","2572","2584","2776","3036","3647","2357","2695","3086","2379","1989","1862","2062","2276"
322
+ "W321","6172","5932","5665","6110","5342","4145","5273","5132","6994","6018","5745","6651","6108"
323
+ "W322","2626","2370","2401","2830","3284","2103","2352","2865","2483","2381","2317","2530","2515"
324
+ "W323","5324","4761","4912","5730","6273","4394","5408","6709","5429","5251","4661","5976","5605"
325
+ "W324","3057","2539","2419","2741","2917","2209","2798","3546","3183","3140","2856","3481","3241"
326
+ "W325","3215","3336","3183","3166","3238","3301","3322","3274","3354","3273","3112","3872","3377"
327
+ "W326","3852","3618","3667","4098","3958","2932","3769","4767","4198","3692","3344","4281","2884"
328
+ "W327","9791","8592","8763","9250","7963","6586","9213","11666","10414","9639","9268","11945","10586"
329
+ "W328","2826","2377","2613","2854","2688","1744","2439","3290","2774","2447","2203","2916","2726"
330
+ "W329","11786","9960","9820","10747","10786","7966","9968","12266","10587","8995","8982","11725","10511"
331
+ "W330","3899","3130","3356","3606","3534","2475","3407","4577","4188","3678","3283","4192","3984"
332
+ "W331","6871","5873","6170","6691","6293","4403","6148","7417","6375","5700","5761","6663","7417"
333
+ "W332","3191","3095","2805","3275","3230","2471","3448","4534","3545","3233","3179","3157","3530"
334
+ "W333","2774","2724","2885","3325","3629","2345","3088","3348","2915","2738","2573","3229","2961"
335
+ "W334","4019","3939","4008","4178","4154","3076","3917","4989","4219","4007","3832","4543","4318"
336
+ "W335","5003","4375","4539","4653","5501","3930","4734","5515","5429","4309","3932","4650","4767"
337
+ "W336","1984","1637","2053","2085","2375","1537","2031","2317","1744","1468","1250","1730","1702"
338
+ "W337","1682","2188","1939","2411","2429","1734","1835","2171","1857","1650","1562","2022","1350"
339
+ "W338","2838","2690","2470","3041","3116","2190","2571","3022","2481","2196","2106","2391","2614"
340
+ "W339","4385","3907","3606","3890","3564","2431","3703","4398","3699","3437","3386","4085","3801"
341
+ "W340","7431","6881","7184","7855","7520","6190","7526","7912","7010","6089","6089","6444","6709"
342
+ "W341","5896","5259","5643","6241","6242","4539","5038","6269","5306","4688","4423","5340","5205"
343
+ "W342","1717.1","1304.4","1335.5","1379.2","1276.9","1007.2","1323","1575.6","1307","1264.5","1118.6","1349.1","1323"
344
+ "W343","4274","3596","3596","4510","4182","2859","3579","4723","4013","3409","3143","4375","3933"
345
+ "W344","3730","3182","3165","3463","3098","1955","2964","4003","3550","3348","3014","3879","3559"
346
+ "W345","3807","3360","3266","3406","3122","2212","2982","3988","3389","3160","3126","3584","3988"
347
+ "W346","7012","5532","5818","5845","5240","3442","5592","6704","5512","5418","5004","5866","5701"
348
+ "W347","4217","3343","3621","3938","3702","2575","3277","4485","3790","3487","3158","4196","3823"
349
+ "W348","4952","4054","3637","4411","4431","2762","3708","4300","5288","4673","3860","5166","4657"
350
+ "W349","3449","3593","3428","3409","3487","3556","3578","3526","3612","3525","3352","4171","3637"
351
+ "W350","6407","5330","5444","5744","4513","3424","6361","6498","5367","5326","5042","6151","5677"
352
+ "W351","2083","1704","1622","1626","1306","1010","1719","2189","1661","1753","1591","2057","1850"
353
+ "W352","1339","1395","1331","1324","1354","1381","1389","1369","1402","1369","1302","1619","1412"
354
+ "W353","2197","1711","1661","1535","1444","953","1526","1840","1722","1540","1534","1671","1661"
355
+ "W354","3513","3203","2907","3031","2335","1945","3178","4175","3651","3512","3573","4511","3884"
356
+ "W355","2777","2893","2760","2745","2807","2862","2880","2838","2908","2838","2698","3357","2928"
357
+ "W356","4387","3772","3660","3878","3131","2405","4081","4811","4170","4017","3824","4661","4297"
358
+ "W357","3938","3363","3249","3515","3097","2275","3421","4166","3807","3434","3303","4024","3747"
359
+ "W358","6907","5711","5496","5862","5380","3660","5159","6787","5596","5445","5345","6156","5866"
360
+ "W359","4458","5098","4518","4973","3973","2613","3476","4213","3386","3627","3299","3743","3965"
dataset/m4/Weekly-train.csv ADDED
The diff for this file is too large to render. See raw diff
 
dataset/m4/Yearly-test.csv ADDED
The diff for this file is too large to render. See raw diff
 
dataset/m4/Yearly-train.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6aa668d4f04654c0bf5f07b266a51c556f0df2faf4709f852653755396a249e4
3
+ size 25355736
dataset/m4/submission-Naive2.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9610c625d6dee654ec458dcd07fb39373d065efc33b4f17d94c50b2aef00ef9
3
+ size 23409576
dataset/m4/test.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:134249cd801d51507af9de3c77b4c789da4de96cb5a2d2fe2988cac138820d32
3
+ size 20778257
dataset/m4/training.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1799e5334656df0e7ff68ea016f555aa23e5460890e5f92a59f907deba5a7914
3
+ size 270896295
dataset/poly/polymarket_data_processed_Crypto_test.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8197ef536a3f9ef9399dc3205d39ca5cd3ac3c38bb5df2033ef74198402f2b91
3
+ size 34744628
dataset/poly/polymarket_data_processed_Election_test.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2ccfdbf65e9fc4f173daa3cfdc1ddf771a064319beb82c723c014b30ff0a1a3
3
+ size 30018418
dataset/poly/polymarket_data_processed_Other_test.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:200fb59b5a8d640d9e27e5fe88a2e323f157ab905e67f5209f95b38234873fe7
3
+ size 70988203
dataset/poly/polymarket_data_processed_Politics_test.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7468321988e461dfdb449526db0c67274e77180c78d9d999cb7a075c683fc01
3
+ size 186831387
dataset/poly/polymarket_data_processed_Sports_test.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4574cd57806a46ba5b8576a0ed5467fa81c6ed82a86188d255ef3ddd6bbf9dc5
3
+ size 27514781
dataset/poly/polymarket_data_processed_dev.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:260d436493fe1b4b69b1d5f35808c1ce90088cf5a51bc7a6433b96039baef23f
3
+ size 456387925
dataset/poly/polymarket_data_processed_test.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fca43646308d0b5c098f5ac529b9b1bb5703f334c26c327e2c3696de32602625
3
+ size 420577190
dataset/poly/polymarket_data_processed_train.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a08508154ae529fad7a4cbb174ad47d6702e99bae04cf18b7310c52c6ff3f4c4
3
+ size 2609542664
exp/__init__.py ADDED
File without changes
exp/exp_anomaly_detection.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_provider.data_factory import data_provider
2
+ from exp.exp_basic import Exp_Basic
3
+ from utils.tools import EarlyStopping, adjust_learning_rate, adjustment
4
+ from sklearn.metrics import precision_recall_fscore_support
5
+ from sklearn.metrics import accuracy_score
6
+ import torch.multiprocessing
7
+
8
+ torch.multiprocessing.set_sharing_strategy('file_system')
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import optim
12
+ import os
13
+ import time
14
+ import warnings
15
+ import numpy as np
16
+
17
+ warnings.filterwarnings('ignore')
18
+
19
+
20
+ class Exp_Anomaly_Detection(Exp_Basic):
21
+ def __init__(self, args):
22
+ super(Exp_Anomaly_Detection, self).__init__(args)
23
+
24
+ def _build_model(self):
25
+ model = self.model_dict[self.args.model].Model(self.args).float()
26
+
27
+ if self.args.use_multi_gpu and self.args.use_gpu:
28
+ model = nn.DataParallel(model, device_ids=self.args.device_ids)
29
+ return model
30
+
31
+ def _get_data(self, flag):
32
+ data_set, data_loader = data_provider(self.args, flag)
33
+ return data_set, data_loader
34
+
35
+ def _select_optimizer(self):
36
+ model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
37
+ return model_optim
38
+
39
+ def _select_criterion(self):
40
+ criterion = nn.MSELoss()
41
+ return criterion
42
+
43
+ def vali(self, vali_data, vali_loader, criterion):
44
+ total_loss = []
45
+ self.model.eval()
46
+ with torch.no_grad():
47
+ for i, (batch_x, _) in enumerate(vali_loader):
48
+ batch_x = batch_x.float().to(self.device)
49
+
50
+ outputs = self.model(batch_x, None, None, None)
51
+
52
+ f_dim = -1 if self.args.features == 'MS' else 0
53
+ outputs = outputs[:, :, f_dim:]
54
+ pred = outputs.detach().cpu()
55
+ true = batch_x.detach().cpu()
56
+
57
+ loss = criterion(pred, true)
58
+ total_loss.append(loss)
59
+ total_loss = np.average(total_loss)
60
+ self.model.train()
61
+ return total_loss
62
+
63
+ def train(self, setting):
64
+ train_data, train_loader = self._get_data(flag='train')
65
+ vali_data, vali_loader = self._get_data(flag='val')
66
+ test_data, test_loader = self._get_data(flag='test')
67
+
68
+ path = os.path.join(self.args.checkpoints, setting)
69
+ if not os.path.exists(path):
70
+ os.makedirs(path)
71
+
72
+ time_now = time.time()
73
+
74
+ train_steps = len(train_loader)
75
+ early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
76
+
77
+ model_optim = self._select_optimizer()
78
+ criterion = self._select_criterion()
79
+
80
+ for epoch in range(self.args.train_epochs):
81
+ iter_count = 0
82
+ train_loss = []
83
+
84
+ self.model.train()
85
+ epoch_time = time.time()
86
+ for i, (batch_x, batch_y) in enumerate(train_loader):
87
+ iter_count += 1
88
+ model_optim.zero_grad()
89
+
90
+ batch_x = batch_x.float().to(self.device)
91
+
92
+ outputs = self.model(batch_x, None, None, None)
93
+
94
+ f_dim = -1 if self.args.features == 'MS' else 0
95
+ outputs = outputs[:, :, f_dim:]
96
+ loss = criterion(outputs, batch_x)
97
+ train_loss.append(loss.item())
98
+
99
+ if (i + 1) % 100 == 0:
100
+ print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
101
+ speed = (time.time() - time_now) / iter_count
102
+ left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
103
+ print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
104
+ iter_count = 0
105
+ time_now = time.time()
106
+
107
+ loss.backward()
108
+ model_optim.step()
109
+
110
+ print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
111
+ train_loss = np.average(train_loss)
112
+ vali_loss = self.vali(vali_data, vali_loader, criterion)
113
+ test_loss = self.vali(test_data, test_loader, criterion)
114
+
115
+ print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
116
+ epoch + 1, train_steps, train_loss, vali_loss, test_loss))
117
+ early_stopping(vali_loss, self.model, path)
118
+ if early_stopping.early_stop:
119
+ print("Early stopping")
120
+ break
121
+ adjust_learning_rate(model_optim, epoch + 1, self.args)
122
+
123
+ best_model_path = path + '/' + 'checkpoint.pth'
124
+ self.model.load_state_dict(torch.load(best_model_path))
125
+
126
+ return self.model
127
+
128
+ def test(self, setting, test=0):
129
+ test_data, test_loader = self._get_data(flag='test')
130
+ train_data, train_loader = self._get_data(flag='train')
131
+ if test:
132
+ print('loading model')
133
+ self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))
134
+
135
+ attens_energy = []
136
+ folder_path = './test_results/' + setting + '/'
137
+ if not os.path.exists(folder_path):
138
+ os.makedirs(folder_path)
139
+
140
+ self.model.eval()
141
+ self.anomaly_criterion = nn.MSELoss(reduce=False)
142
+
143
+ # (1) stastic on the train set
144
+ with torch.no_grad():
145
+ for i, (batch_x, batch_y) in enumerate(train_loader):
146
+ batch_x = batch_x.float().to(self.device)
147
+ # reconstruction
148
+ outputs = self.model(batch_x, None, None, None)
149
+ # criterion
150
+ score = torch.mean(self.anomaly_criterion(batch_x, outputs), dim=-1)
151
+ score = score.detach().cpu().numpy()
152
+ attens_energy.append(score)
153
+
154
+ attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
155
+ train_energy = np.array(attens_energy)
156
+
157
+ # (2) find the threshold
158
+ attens_energy = []
159
+ test_labels = []
160
+ for i, (batch_x, batch_y) in enumerate(test_loader):
161
+ batch_x = batch_x.float().to(self.device)
162
+ # reconstruction
163
+ outputs = self.model(batch_x, None, None, None)
164
+ # criterion
165
+ score = torch.mean(self.anomaly_criterion(batch_x, outputs), dim=-1)
166
+ score = score.detach().cpu().numpy()
167
+ attens_energy.append(score)
168
+ test_labels.append(batch_y)
169
+
170
+ attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
171
+ test_energy = np.array(attens_energy)
172
+ combined_energy = np.concatenate([train_energy, test_energy], axis=0)
173
+ threshold = np.percentile(combined_energy, 100 - self.args.anomaly_ratio)
174
+ print("Threshold :", threshold)
175
+
176
+ # (3) evaluation on the test set
177
+ pred = (test_energy > threshold).astype(int)
178
+ test_labels = np.concatenate(test_labels, axis=0).reshape(-1)
179
+ test_labels = np.array(test_labels)
180
+ gt = test_labels.astype(int)
181
+
182
+ print("pred: ", pred.shape)
183
+ print("gt: ", gt.shape)
184
+
185
+ # (4) detection adjustment
186
+ gt, pred = adjustment(gt, pred)
187
+
188
+ pred = np.array(pred)
189
+ gt = np.array(gt)
190
+ print("pred: ", pred.shape)
191
+ print("gt: ", gt.shape)
192
+
193
+ accuracy = accuracy_score(gt, pred)
194
+ precision, recall, f_score, support = precision_recall_fscore_support(gt, pred, average='binary')
195
+ print("Accuracy : {:0.4f}, Precision : {:0.4f}, Recall : {:0.4f}, F-score : {:0.4f} ".format(
196
+ accuracy, precision,
197
+ recall, f_score))
198
+
199
+ f = open("result_anomaly_detection.txt", 'a')
200
+ f.write(setting + " \n")
201
+ f.write("Accuracy : {:0.4f}, Precision : {:0.4f}, Recall : {:0.4f}, F-score : {:0.4f} ".format(
202
+ accuracy, precision,
203
+ recall, f_score))
204
+ f.write('\n')
205
+ f.write('\n')
206
+ f.close()
207
+ return
exp/exp_basic.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from models import Autoformer, Transformer, TimesNet, Nonstationary_Transformer, DLinear, FEDformer, \
4
+ Informer, LightTS, Reformer, ETSformer, Pyraformer, PatchTST, MICN, Crossformer, FiLM, iTransformer, \
5
+ Koopa, TiDE, FreTS, TimeMixer, TSMixer, SegRNN, MambaSimple, TemporalFusionTransformer, SCINet, PAttn, TimeXer, \
6
+ WPMixer, MultiPatchFormer
7
+
8
+
9
+ class Exp_Basic(object):
10
+ def __init__(self, args):
11
+ self.args = args
12
+ self.model_dict = {
13
+ 'TimesNet': TimesNet,
14
+ 'Autoformer': Autoformer,
15
+ 'Transformer': Transformer,
16
+ 'Nonstationary_Transformer': Nonstationary_Transformer,
17
+ 'DLinear': DLinear,
18
+ 'FEDformer': FEDformer,
19
+ 'Informer': Informer,
20
+ 'LightTS': LightTS,
21
+ 'Reformer': Reformer,
22
+ 'ETSformer': ETSformer,
23
+ 'PatchTST': PatchTST,
24
+ 'Pyraformer': Pyraformer,
25
+ 'MICN': MICN,
26
+ 'Crossformer': Crossformer,
27
+ 'FiLM': FiLM,
28
+ 'iTransformer': iTransformer,
29
+ 'Koopa': Koopa,
30
+ 'TiDE': TiDE,
31
+ 'FreTS': FreTS,
32
+ 'MambaSimple': MambaSimple,
33
+ 'TimeMixer': TimeMixer,
34
+ 'TSMixer': TSMixer,
35
+ 'SegRNN': SegRNN,
36
+ 'TemporalFusionTransformer': TemporalFusionTransformer,
37
+ "SCINet": SCINet,
38
+ 'PAttn': PAttn,
39
+ 'TimeXer': TimeXer,
40
+ 'WPMixer': WPMixer,
41
+ 'MultiPatchFormer': MultiPatchFormer
42
+ }
43
+ if args.model == 'Mamba':
44
+ print('Please make sure you have successfully installed mamba_ssm')
45
+ from models import Mamba
46
+ self.model_dict['Mamba'] = Mamba
47
+
48
+ self.device = self._acquire_device()
49
+ self.model = self._build_model().to(self.device)
50
+
51
+ def _build_model(self):
52
+ raise NotImplementedError
53
+ return None
54
+
55
+ def _acquire_device(self):
56
+ if self.args.use_gpu and self.args.gpu_type == 'cuda':
57
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(
58
+ self.args.gpu) if not self.args.use_multi_gpu else self.args.devices
59
+ device = torch.device('cuda:{}'.format(self.args.gpu))
60
+ print('Use GPU: cuda:{}'.format(self.args.gpu))
61
+ elif self.args.use_gpu and self.args.gpu_type == 'mps':
62
+ device = torch.device('mps')
63
+ print('Use GPU: mps')
64
+ else:
65
+ device = torch.device('cpu')
66
+ print('Use CPU')
67
+ return device
68
+
69
+ def _get_data(self):
70
+ pass
71
+
72
+ def vali(self):
73
+ pass
74
+
75
+ def train(self):
76
+ pass
77
+
78
+ def test(self):
79
+ pass
exp/exp_classification.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_provider.data_factory import data_provider
2
+ from exp.exp_basic import Exp_Basic
3
+ from utils.tools import EarlyStopping, adjust_learning_rate, cal_accuracy
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import optim
7
+ import os
8
+ import time
9
+ import warnings
10
+ import numpy as np
11
+ import pdb
12
+
13
+ warnings.filterwarnings('ignore')
14
+
15
+
16
+ class Exp_Classification(Exp_Basic):
17
+ def __init__(self, args):
18
+ super(Exp_Classification, self).__init__(args)
19
+
20
+ def _build_model(self):
21
+ # model input depends on data
22
+ train_data, train_loader = self._get_data(flag='TRAIN')
23
+ test_data, test_loader = self._get_data(flag='TEST')
24
+ self.args.seq_len = max(train_data.max_seq_len, test_data.max_seq_len)
25
+ self.args.pred_len = 0
26
+ self.args.enc_in = train_data.feature_df.shape[1]
27
+ self.args.num_class = len(train_data.class_names)
28
+ # model init
29
+ model = self.model_dict[self.args.model].Model(self.args).float()
30
+ if self.args.use_multi_gpu and self.args.use_gpu:
31
+ model = nn.DataParallel(model, device_ids=self.args.device_ids)
32
+ return model
33
+
34
+ def _get_data(self, flag):
35
+ data_set, data_loader = data_provider(self.args, flag)
36
+ return data_set, data_loader
37
+
38
+ def _select_optimizer(self):
39
+ # model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
40
+ model_optim = optim.RAdam(self.model.parameters(), lr=self.args.learning_rate)
41
+ return model_optim
42
+
43
+ def _select_criterion(self):
44
+ criterion = nn.CrossEntropyLoss()
45
+ return criterion
46
+
47
+ def vali(self, vali_data, vali_loader, criterion):
48
+ total_loss = []
49
+ preds = []
50
+ trues = []
51
+ self.model.eval()
52
+ with torch.no_grad():
53
+ for i, (batch_x, label, padding_mask) in enumerate(vali_loader):
54
+ batch_x = batch_x.float().to(self.device)
55
+ padding_mask = padding_mask.float().to(self.device)
56
+ label = label.to(self.device)
57
+
58
+ outputs = self.model(batch_x, padding_mask, None, None)
59
+
60
+ pred = outputs.detach().cpu()
61
+ loss = criterion(pred, label.long().squeeze().cpu())
62
+ total_loss.append(loss)
63
+
64
+ preds.append(outputs.detach())
65
+ trues.append(label)
66
+
67
+ total_loss = np.average(total_loss)
68
+
69
+ preds = torch.cat(preds, 0)
70
+ trues = torch.cat(trues, 0)
71
+ probs = torch.nn.functional.softmax(preds) # (total_samples, num_classes) est. prob. for each class and sample
72
+ predictions = torch.argmax(probs, dim=1).cpu().numpy() # (total_samples,) int class index for each sample
73
+ trues = trues.flatten().cpu().numpy()
74
+ accuracy = cal_accuracy(predictions, trues)
75
+
76
+ self.model.train()
77
+ return total_loss, accuracy
78
+
79
+ def train(self, setting):
80
+ train_data, train_loader = self._get_data(flag='TRAIN')
81
+ vali_data, vali_loader = self._get_data(flag='TEST')
82
+ test_data, test_loader = self._get_data(flag='TEST')
83
+
84
+ path = os.path.join(self.args.checkpoints, setting)
85
+ if not os.path.exists(path):
86
+ os.makedirs(path)
87
+
88
+ time_now = time.time()
89
+
90
+ train_steps = len(train_loader)
91
+ early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
92
+
93
+ model_optim = self._select_optimizer()
94
+ criterion = self._select_criterion()
95
+
96
+ for epoch in range(self.args.train_epochs):
97
+ iter_count = 0
98
+ train_loss = []
99
+
100
+ self.model.train()
101
+ epoch_time = time.time()
102
+
103
+ for i, (batch_x, label, padding_mask) in enumerate(train_loader):
104
+ iter_count += 1
105
+ model_optim.zero_grad()
106
+
107
+ batch_x = batch_x.float().to(self.device)
108
+ padding_mask = padding_mask.float().to(self.device)
109
+ label = label.to(self.device)
110
+
111
+ outputs = self.model(batch_x, padding_mask, None, None)
112
+ loss = criterion(outputs, label.long().squeeze(-1))
113
+ train_loss.append(loss.item())
114
+
115
+ if (i + 1) % 100 == 0:
116
+ print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
117
+ speed = (time.time() - time_now) / iter_count
118
+ left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
119
+ print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
120
+ iter_count = 0
121
+ time_now = time.time()
122
+
123
+ loss.backward()
124
+ nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=4.0)
125
+ model_optim.step()
126
+
127
+ print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
128
+ train_loss = np.average(train_loss)
129
+ vali_loss, val_accuracy = self.vali(vali_data, vali_loader, criterion)
130
+ test_loss, test_accuracy = self.vali(test_data, test_loader, criterion)
131
+
132
+ print(
133
+ "Epoch: {0}, Steps: {1} | Train Loss: {2:.3f} Vali Loss: {3:.3f} Vali Acc: {4:.3f} Test Loss: {5:.3f} Test Acc: {6:.3f}"
134
+ .format(epoch + 1, train_steps, train_loss, vali_loss, val_accuracy, test_loss, test_accuracy))
135
+ early_stopping(-val_accuracy, self.model, path)
136
+ if early_stopping.early_stop:
137
+ print("Early stopping")
138
+ break
139
+
140
+ best_model_path = path + '/' + 'checkpoint.pth'
141
+ self.model.load_state_dict(torch.load(best_model_path))
142
+
143
+ return self.model
144
+
145
+ def test(self, setting, test=0):
146
+ test_data, test_loader = self._get_data(flag='TEST')
147
+ if test:
148
+ print('loading model')
149
+ self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))
150
+
151
+ preds = []
152
+ trues = []
153
+ folder_path = './test_results/' + setting + '/'
154
+ if not os.path.exists(folder_path):
155
+ os.makedirs(folder_path)
156
+
157
+ self.model.eval()
158
+ with torch.no_grad():
159
+ for i, (batch_x, label, padding_mask) in enumerate(test_loader):
160
+ batch_x = batch_x.float().to(self.device)
161
+ padding_mask = padding_mask.float().to(self.device)
162
+ label = label.to(self.device)
163
+
164
+ outputs = self.model(batch_x, padding_mask, None, None)
165
+
166
+ preds.append(outputs.detach())
167
+ trues.append(label)
168
+
169
+ preds = torch.cat(preds, 0)
170
+ trues = torch.cat(trues, 0)
171
+ print('test shape:', preds.shape, trues.shape)
172
+
173
+ probs = torch.nn.functional.softmax(preds) # (total_samples, num_classes) est. prob. for each class and sample
174
+ predictions = torch.argmax(probs, dim=1).cpu().numpy() # (total_samples,) int class index for each sample
175
+ trues = trues.flatten().cpu().numpy()
176
+ accuracy = cal_accuracy(predictions, trues)
177
+
178
+ # result save
179
+ folder_path = './results/' + setting + '/'
180
+ if not os.path.exists(folder_path):
181
+ os.makedirs(folder_path)
182
+
183
+ print('accuracy:{}'.format(accuracy))
184
+ file_name='result_classification.txt'
185
+ f = open(os.path.join(folder_path,file_name), 'a')
186
+ f.write(setting + " \n")
187
+ f.write('accuracy:{}'.format(accuracy))
188
+ f.write('\n')
189
+ f.write('\n')
190
+ f.close()
191
+ return
exp/exp_imputation.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_provider.data_factory import data_provider
2
+ from exp.exp_basic import Exp_Basic
3
+ from utils.tools import EarlyStopping, adjust_learning_rate, visual
4
+ from utils.metrics import metric
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import optim
8
+ import os
9
+ import time
10
+ import warnings
11
+ import numpy as np
12
+
13
+ warnings.filterwarnings('ignore')
14
+
15
+
16
+ class Exp_Imputation(Exp_Basic):
17
+ def __init__(self, args):
18
+ super(Exp_Imputation, self).__init__(args)
19
+
20
+ def _build_model(self):
21
+ model = self.model_dict[self.args.model].Model(self.args).float()
22
+
23
+ if self.args.use_multi_gpu and self.args.use_gpu:
24
+ model = nn.DataParallel(model, device_ids=self.args.device_ids)
25
+ return model
26
+
27
+ def _get_data(self, flag):
28
+ data_set, data_loader = data_provider(self.args, flag)
29
+ return data_set, data_loader
30
+
31
+ def _select_optimizer(self):
32
+ model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
33
+ return model_optim
34
+
35
+ def _select_criterion(self):
36
+ criterion = nn.MSELoss()
37
+ return criterion
38
+
39
+ def vali(self, vali_data, vali_loader, criterion):
40
+ total_loss = []
41
+ self.model.eval()
42
+ with torch.no_grad():
43
+ for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
44
+ batch_x = batch_x.float().to(self.device)
45
+ batch_x_mark = batch_x_mark.float().to(self.device)
46
+
47
+ # random mask
48
+ B, T, N = batch_x.shape
49
+ """
50
+ B = batch size
51
+ T = seq len
52
+ N = number of features
53
+ """
54
+ mask = torch.rand((B, T, N)).to(self.device)
55
+ mask[mask <= self.args.mask_rate] = 0 # masked
56
+ mask[mask > self.args.mask_rate] = 1 # remained
57
+ inp = batch_x.masked_fill(mask == 0, 0)
58
+
59
+ outputs = self.model(inp, batch_x_mark, None, None, mask)
60
+
61
+ f_dim = -1 if self.args.features == 'MS' else 0
62
+ outputs = outputs[:, :, f_dim:]
63
+
64
+ # add support for MS
65
+ batch_x = batch_x[:, :, f_dim:]
66
+ mask = mask[:, :, f_dim:]
67
+
68
+ pred = outputs.detach().cpu()
69
+ true = batch_x.detach().cpu()
70
+ mask = mask.detach().cpu()
71
+
72
+ loss = criterion(pred[mask == 0], true[mask == 0])
73
+ total_loss.append(loss)
74
+ total_loss = np.average(total_loss)
75
+ self.model.train()
76
+ return total_loss
77
+
78
+ def train(self, setting):
79
+ train_data, train_loader = self._get_data(flag='train')
80
+ vali_data, vali_loader = self._get_data(flag='val')
81
+ test_data, test_loader = self._get_data(flag='test')
82
+
83
+ path = os.path.join(self.args.checkpoints, setting)
84
+ if not os.path.exists(path):
85
+ os.makedirs(path)
86
+
87
+ time_now = time.time()
88
+
89
+ train_steps = len(train_loader)
90
+ early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
91
+
92
+ model_optim = self._select_optimizer()
93
+ criterion = self._select_criterion()
94
+
95
+ for epoch in range(self.args.train_epochs):
96
+ iter_count = 0
97
+ train_loss = []
98
+
99
+ self.model.train()
100
+ epoch_time = time.time()
101
+ for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
102
+ iter_count += 1
103
+ model_optim.zero_grad()
104
+
105
+ batch_x = batch_x.float().to(self.device)
106
+ batch_x_mark = batch_x_mark.float().to(self.device)
107
+
108
+ # random mask
109
+ B, T, N = batch_x.shape
110
+ mask = torch.rand((B, T, N)).to(self.device)
111
+ mask[mask <= self.args.mask_rate] = 0 # masked
112
+ mask[mask > self.args.mask_rate] = 1 # remained
113
+ inp = batch_x.masked_fill(mask == 0, 0)
114
+
115
+ outputs = self.model(inp, batch_x_mark, None, None, mask)
116
+
117
+ f_dim = -1 if self.args.features == 'MS' else 0
118
+ outputs = outputs[:, :, f_dim:]
119
+
120
+ # add support for MS
121
+ batch_x = batch_x[:, :, f_dim:]
122
+ mask = mask[:, :, f_dim:]
123
+
124
+ loss = criterion(outputs[mask == 0], batch_x[mask == 0])
125
+ train_loss.append(loss.item())
126
+
127
+ if (i + 1) % 100 == 0:
128
+ print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
129
+ speed = (time.time() - time_now) / iter_count
130
+ left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
131
+ print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
132
+ iter_count = 0
133
+ time_now = time.time()
134
+
135
+ loss.backward()
136
+ model_optim.step()
137
+
138
+ print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
139
+ train_loss = np.average(train_loss)
140
+ vali_loss = self.vali(vali_data, vali_loader, criterion)
141
+ test_loss = self.vali(test_data, test_loader, criterion)
142
+
143
+ print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
144
+ epoch + 1, train_steps, train_loss, vali_loss, test_loss))
145
+ early_stopping(vali_loss, self.model, path)
146
+ if early_stopping.early_stop:
147
+ print("Early stopping")
148
+ break
149
+ adjust_learning_rate(model_optim, epoch + 1, self.args)
150
+
151
+ best_model_path = path + '/' + 'checkpoint.pth'
152
+ self.model.load_state_dict(torch.load(best_model_path))
153
+
154
+ return self.model
155
+
156
+ def test(self, setting, test=0):
157
+ test_data, test_loader = self._get_data(flag='test')
158
+ if test:
159
+ print('loading model')
160
+ self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))
161
+
162
+ preds = []
163
+ trues = []
164
+ masks = []
165
+ folder_path = './test_results/' + setting + '/'
166
+ if not os.path.exists(folder_path):
167
+ os.makedirs(folder_path)
168
+
169
+ self.model.eval()
170
+ with torch.no_grad():
171
+ for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
172
+ batch_x = batch_x.float().to(self.device)
173
+ batch_x_mark = batch_x_mark.float().to(self.device)
174
+
175
+ # random mask
176
+ B, T, N = batch_x.shape
177
+ mask = torch.rand((B, T, N)).to(self.device)
178
+ mask[mask <= self.args.mask_rate] = 0 # masked
179
+ mask[mask > self.args.mask_rate] = 1 # remained
180
+ inp = batch_x.masked_fill(mask == 0, 0)
181
+
182
+ # imputation
183
+ outputs = self.model(inp, batch_x_mark, None, None, mask)
184
+
185
+ # eval
186
+ f_dim = -1 if self.args.features == 'MS' else 0
187
+ outputs = outputs[:, :, f_dim:]
188
+
189
+ # add support for MS
190
+ batch_x = batch_x[:, :, f_dim:]
191
+ mask = mask[:, :, f_dim:]
192
+
193
+ outputs = outputs.detach().cpu().numpy()
194
+ pred = outputs
195
+ true = batch_x.detach().cpu().numpy()
196
+ preds.append(pred)
197
+ trues.append(true)
198
+ masks.append(mask.detach().cpu())
199
+
200
+ if i % 20 == 0:
201
+ filled = true[0, :, -1].copy()
202
+ filled = filled * mask[0, :, -1].detach().cpu().numpy() + \
203
+ pred[0, :, -1] * (1 - mask[0, :, -1].detach().cpu().numpy())
204
+ visual(true[0, :, -1], filled, os.path.join(folder_path, str(i) + '.pdf'))
205
+
206
+ preds = np.concatenate(preds, 0)
207
+ trues = np.concatenate(trues, 0)
208
+ masks = np.concatenate(masks, 0)
209
+ print('test shape:', preds.shape, trues.shape)
210
+
211
+ # result save
212
+ folder_path = './results/' + setting + '/'
213
+ if not os.path.exists(folder_path):
214
+ os.makedirs(folder_path)
215
+
216
+ mae, mse, rmse, mape, mspe = metric(preds[masks == 0], trues[masks == 0])
217
+ print('mse:{}, mae:{}'.format(mse, mae))
218
+ f = open("result_imputation.txt", 'a')
219
+ f.write(setting + " \n")
220
+ f.write('mse:{}, mae:{}'.format(mse, mae))
221
+ f.write('\n')
222
+ f.write('\n')
223
+ f.close()
224
+
225
+ np.save(folder_path + 'metrics.npy', np.array([mae, mse, rmse, mape, mspe]))
226
+ np.save(folder_path + 'pred.npy', preds)
227
+ np.save(folder_path + 'true.npy', trues)
228
+ return
exp/exp_long_term_forecasting.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_provider.data_factory import data_provider
2
+ from exp.exp_basic import Exp_Basic
3
+ from utils.tools import EarlyStopping, adjust_learning_rate, visual
4
+ from utils.metrics import metric
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import optim
8
+ import os
9
+ import time
10
+ import warnings
11
+ import numpy as np
12
+ from utils.dtw_metric import dtw, accelerated_dtw
13
+ from utils.augmentation import run_augmentation, run_augmentation_single
14
+
15
+ warnings.filterwarnings('ignore')
16
+
17
+
18
+ class Exp_Long_Term_Forecast(Exp_Basic):
19
+ def __init__(self, args):
20
+ super(Exp_Long_Term_Forecast, self).__init__(args)
21
+
22
+ def _build_model(self):
23
+ model = self.model_dict[self.args.model].Model(self.args).float()
24
+
25
+ if self.args.use_multi_gpu and self.args.use_gpu:
26
+ model = nn.DataParallel(model, device_ids=self.args.device_ids)
27
+ return model
28
+
29
+ def _get_data(self, flag):
30
+ data_set, data_loader = data_provider(self.args, flag)
31
+ return data_set, data_loader
32
+
33
+ def _select_optimizer(self):
34
+ model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
35
+ return model_optim
36
+
37
+ def _select_criterion(self):
38
+ criterion = nn.MSELoss()
39
+ return criterion
40
+
41
+
42
+ def vali(self, vali_data, vali_loader, criterion):
43
+ total_loss = []
44
+ self.model.eval()
45
+ with torch.no_grad():
46
+ for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
47
+ batch_x = batch_x.float().to(self.device)
48
+ batch_y = batch_y.float()
49
+
50
+ batch_x_mark = batch_x_mark.float().to(self.device)
51
+ batch_y_mark = batch_y_mark.float().to(self.device)
52
+
53
+ # decoder input
54
+ dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
55
+ dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
56
+ # encoder - decoder
57
+ if self.args.use_amp:
58
+ with torch.cuda.amp.autocast():
59
+ outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
60
+ else:
61
+ outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
62
+ f_dim = -1 if self.args.features == 'MS' else 0
63
+ outputs = outputs[:, -self.args.pred_len:, f_dim:]
64
+ batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
65
+
66
+ pred = outputs.detach().cpu()
67
+ true = batch_y.detach().cpu()
68
+
69
+ loss = criterion(pred, true)
70
+
71
+ total_loss.append(loss)
72
+ total_loss = np.average(total_loss)
73
+ self.model.train()
74
+ return total_loss
75
+
76
+ def train(self, setting):
77
+ train_data, train_loader = self._get_data(flag='train')
78
+ vali_data, vali_loader = self._get_data(flag='val')
79
+ test_data, test_loader = self._get_data(flag='test')
80
+
81
+ path = os.path.join(self.args.checkpoints, setting)
82
+ if not os.path.exists(path):
83
+ os.makedirs(path)
84
+
85
+ time_now = time.time()
86
+
87
+ train_steps = len(train_loader)
88
+ early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
89
+
90
+ model_optim = self._select_optimizer()
91
+ criterion = self._select_criterion()
92
+
93
+ if self.args.use_amp:
94
+ scaler = torch.cuda.amp.GradScaler()
95
+
96
+ for epoch in range(self.args.train_epochs):
97
+ iter_count = 0
98
+ train_loss = []
99
+
100
+ self.model.train()
101
+ epoch_time = time.time()
102
+ for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
103
+ iter_count += 1
104
+ model_optim.zero_grad()
105
+ batch_x = batch_x.float().to(self.device)
106
+ batch_y = batch_y.float().to(self.device)
107
+ batch_x_mark = batch_x_mark.float().to(self.device)
108
+ batch_y_mark = batch_y_mark.float().to(self.device)
109
+
110
+ # decoder input
111
+ dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
112
+ dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
113
+
114
+ # encoder - decoder
115
+ if self.args.use_amp:
116
+ with torch.cuda.amp.autocast():
117
+ outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
118
+
119
+ f_dim = -1 if self.args.features == 'MS' else 0
120
+ outputs = outputs[:, -self.args.pred_len:, f_dim:]
121
+ batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
122
+ loss = criterion(outputs, batch_y)
123
+ train_loss.append(loss.item())
124
+ else:
125
+ outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
126
+
127
+ f_dim = -1 if self.args.features == 'MS' else 0
128
+ outputs = outputs[:, -self.args.pred_len:, f_dim:]
129
+ batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
130
+ loss = criterion(outputs, batch_y)
131
+ train_loss.append(loss.item())
132
+
133
+ if (i + 1) % 100 == 0:
134
+ print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
135
+ speed = (time.time() - time_now) / iter_count
136
+ left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
137
+ print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
138
+ iter_count = 0
139
+ time_now = time.time()
140
+
141
+ if self.args.use_amp:
142
+ scaler.scale(loss).backward()
143
+ scaler.step(model_optim)
144
+ scaler.update()
145
+ else:
146
+ loss.backward()
147
+ model_optim.step()
148
+
149
+ print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
150
+ train_loss = np.average(train_loss)
151
+ vali_loss = self.vali(vali_data, vali_loader, criterion)
152
+ test_loss = self.vali(test_data, test_loader, criterion)
153
+
154
+ print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
155
+ epoch + 1, train_steps, train_loss, vali_loss, test_loss))
156
+ early_stopping(vali_loss, self.model, path)
157
+ if early_stopping.early_stop:
158
+ print("Early stopping")
159
+ break
160
+
161
+ adjust_learning_rate(model_optim, epoch + 1, self.args)
162
+
163
+ best_model_path = path + '/' + 'checkpoint.pth'
164
+ self.model.load_state_dict(torch.load(best_model_path))
165
+
166
+ return self.model
167
+
168
+ def test(self, setting, test=0):
169
+ test_data, test_loader = self._get_data(flag='test')
170
+ if test:
171
+ print('loading model')
172
+ self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))
173
+
174
+ preds = []
175
+ trues = []
176
+ folder_path = './test_results/' + setting + '/'
177
+ if not os.path.exists(folder_path):
178
+ os.makedirs(folder_path)
179
+
180
+ self.model.eval()
181
+ with torch.no_grad():
182
+ for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
183
+ batch_x = batch_x.float().to(self.device)
184
+ batch_y = batch_y.float().to(self.device)
185
+
186
+ batch_x_mark = batch_x_mark.float().to(self.device)
187
+ batch_y_mark = batch_y_mark.float().to(self.device)
188
+
189
+ # decoder input
190
+ dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
191
+ dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
192
+ # encoder - decoder
193
+ if self.args.use_amp:
194
+ with torch.cuda.amp.autocast():
195
+ outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
196
+ else:
197
+ outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
198
+
199
+ f_dim = -1 if self.args.features == 'MS' else 0
200
+ outputs = outputs[:, -self.args.pred_len:, :]
201
+ batch_y = batch_y[:, -self.args.pred_len:, :].to(self.device)
202
+ outputs = outputs.detach().cpu().numpy()
203
+ batch_y = batch_y.detach().cpu().numpy()
204
+ if test_data.scale and self.args.inverse:
205
+ shape = batch_y.shape
206
+ if outputs.shape[-1] != batch_y.shape[-1]:
207
+ outputs = np.tile(outputs, [1, 1, int(batch_y.shape[-1] / outputs.shape[-1])])
208
+ outputs = test_data.inverse_transform(outputs.reshape(shape[0] * shape[1], -1)).reshape(shape)
209
+ batch_y = test_data.inverse_transform(batch_y.reshape(shape[0] * shape[1], -1)).reshape(shape)
210
+
211
+ outputs = outputs[:, :, f_dim:]
212
+ batch_y = batch_y[:, :, f_dim:]
213
+
214
+ pred = outputs
215
+ true = batch_y
216
+
217
+ preds.append(pred)
218
+ trues.append(true)
219
+ if i % 20 == 0:
220
+ input = batch_x.detach().cpu().numpy()
221
+ if test_data.scale and self.args.inverse:
222
+ shape = input.shape
223
+ input = test_data.inverse_transform(input.reshape(shape[0] * shape[1], -1)).reshape(shape)
224
+ gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0)
225
+ pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0)
226
+ visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf'))
227
+
228
+ preds = np.concatenate(preds, axis=0)
229
+ trues = np.concatenate(trues, axis=0)
230
+ print('test shape:', preds.shape, trues.shape)
231
+ preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
232
+ trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
233
+ print('test shape:', preds.shape, trues.shape)
234
+
235
+ # result save
236
+ folder_path = './results/' + setting + '/'
237
+ if not os.path.exists(folder_path):
238
+ os.makedirs(folder_path)
239
+
240
+ # dtw calculation
241
+ if self.args.use_dtw:
242
+ dtw_list = []
243
+ manhattan_distance = lambda x, y: np.abs(x - y)
244
+ for i in range(preds.shape[0]):
245
+ x = preds[i].reshape(-1, 1)
246
+ y = trues[i].reshape(-1, 1)
247
+ if i % 100 == 0:
248
+ print("calculating dtw iter:", i)
249
+ d, _, _, _ = accelerated_dtw(x, y, dist=manhattan_distance)
250
+ dtw_list.append(d)
251
+ dtw = np.array(dtw_list).mean()
252
+ else:
253
+ dtw = 'Not calculated'
254
+
255
+ mae, mse, rmse, mape, mspe = metric(preds, trues)
256
+ print('mse:{}, mae:{}, dtw:{}'.format(mse, mae, dtw))
257
+ f = open("result_long_term_forecast.txt", 'a')
258
+ f.write(setting + " \n")
259
+ f.write('mse:{}, mae:{}, dtw:{}'.format(mse, mae, dtw))
260
+ f.write('\n')
261
+ f.write('\n')
262
+ f.close()
263
+
264
+ np.save(folder_path + 'metrics.npy', np.array([mae, mse, rmse, mape, mspe]))
265
+ np.save(folder_path + 'pred.npy', preds)
266
+ np.save(folder_path + 'true.npy', trues)
267
+
268
+ return
exp/exp_short_term_forecasting.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_provider.data_factory import data_provider
2
+ from data_provider.m4 import M4Meta
3
+ from exp.exp_basic import Exp_Basic
4
+ from utils.tools import EarlyStopping, adjust_learning_rate, visual
5
+ from utils.losses import mape_loss, mase_loss, smape_loss
6
+ from utils.m4_summary import M4Summary
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch import optim
11
+ import os
12
+ import time
13
+ import warnings
14
+ import numpy as np
15
+ import pandas as pd
16
+
17
+ warnings.filterwarnings('ignore')
18
+
19
+
20
+ def mse(output, label):
21
+ mse_value = (output - label) ** 2
22
+ return np.mean(mse_value)
23
+
24
+
25
+ def rmse(output, label):
26
+ mse_value = mse(output=output, label=label)
27
+ return np.sqrt(mse_value)
28
+
29
+
30
+ def mae(output, label):
31
+ mae_value = np.abs(output - label)
32
+ return np.mean(mae_value)
33
+
34
+
35
+ def mape(output, label):
36
+ # 避免除以零
37
+ mask = label != 0
38
+ mape_value = np.abs((output[mask] - label[mask]) / label[mask])
39
+ return np.mean(mape_value) * 100
40
+
41
+
42
+ class Exp_Short_Term_Forecast(Exp_Basic):
43
+ def __init__(self, args):
44
+ super(Exp_Short_Term_Forecast, self).__init__(args)
45
+
46
+ def _build_model(self):
47
+ if self.args.data == 'm4':
48
+ self.args.pred_len = M4Meta.horizons_map[self.args.seasonal_patterns]
49
+ self.args.seq_len = 2 * self.args.pred_len
50
+ self.args.label_len = self.args.pred_len
51
+ self.args.frequency_map = M4Meta.frequency_map[self.args.seasonal_patterns]
52
+ elif self.args.data == 'kalshi':
53
+ # Poly 数据集使用命令行传入的参数,不需要特殊处理
54
+ self.args.frequency_map = 1
55
+ elif self.args.data == 'poly':
56
+ # Poly 数据集使用命令行传入的参数,不需要特殊处理
57
+ self.args.frequency_map = 1
58
+
59
+ model = self.model_dict[self.args.model].Model(self.args).float()
60
+
61
+ if self.args.use_multi_gpu and self.args.use_gpu:
62
+ model = nn.DataParallel(model, device_ids=self.args.device_ids)
63
+ return model
64
+
65
+ def _get_data(self, flag):
66
+ data_set, data_loader = data_provider(self.args, flag)
67
+ return data_set, data_loader
68
+
69
+ def _select_optimizer(self):
70
+ model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
71
+ return model_optim
72
+
73
+ def _select_criterion(self, loss_name='MSE'):
74
+ if loss_name == 'MSE':
75
+ return nn.MSELoss()
76
+ elif loss_name == 'MAPE':
77
+ return mape_loss()
78
+ elif loss_name == 'MASE':
79
+ return mase_loss()
80
+ elif loss_name == 'SMAPE':
81
+ return smape_loss()
82
+
83
+ def _prepare_data_from_timeseries(self, timeseries):
84
+ """
85
+ 从 timeseries 中提取 x 和 y
86
+ timeseries 中每个元素长度 >= seq_len + 1
87
+ x: 最后 seq_len 个点(不含 after)
88
+ y: 最后 pred_len 个点(after)
89
+ """
90
+ x_list = []
91
+ y_list = []
92
+
93
+ for ts in timeseries:
94
+ # 取最后 seq_len+1 个点
95
+ # x = ts[-(seq_len+1):-1], y = ts[-pred_len:]
96
+ x_list.append(ts[-(self.args.seq_len + 1):-1])
97
+ y_list.append(ts[-self.args.pred_len:])
98
+
99
+ x = torch.tensor(x_list, dtype=torch.float32).to(self.device)
100
+ x = x.unsqueeze(-1) # (B, seq_len, 1)
101
+ y = np.array(y_list) # (B, pred_len)
102
+
103
+ return x, y
104
+
105
+ def train(self, setting):
106
+ train_data, train_loader = self._get_data(flag='train')
107
+ vali_data, vali_loader = self._get_data(flag='val')
108
+
109
+ path = os.path.join(self.args.checkpoints, setting)
110
+ if not os.path.exists(path):
111
+ os.makedirs(path)
112
+
113
+ time_now = time.time()
114
+
115
+ train_steps = len(train_loader)
116
+ early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
117
+
118
+ model_optim = self._select_optimizer()
119
+ criterion = self._select_criterion(self.args.loss)
120
+ mse_loss = nn.MSELoss()
121
+
122
+ for epoch in range(self.args.train_epochs):
123
+ iter_count = 0
124
+ train_loss = []
125
+
126
+ self.model.train()
127
+ epoch_time = time.time()
128
+ for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
129
+ iter_count += 1
130
+ model_optim.zero_grad()
131
+
132
+ batch_x = batch_x.float().to(self.device) # (B, seq_len, 1)
133
+ batch_y = batch_y.float().to(self.device) # (B, label_len + pred_len, 1)
134
+ batch_y_mark = batch_y_mark.float().to(self.device)
135
+
136
+ # decoder input: [label_len 个真实值, pred_len 个零]
137
+ dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float().to(self.device)
138
+ dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
139
+
140
+ outputs = self.model(batch_x, None, dec_inp, None)
141
+
142
+ f_dim = -1 if self.args.features == 'MS' else 0
143
+ outputs = outputs[:, -self.args.pred_len:, f_dim:]
144
+ batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
145
+
146
+ # 使用简单的 MSE loss
147
+ if self.args.loss == 'MSE':
148
+ loss = mse_loss(outputs, batch_y)
149
+ else:
150
+ batch_y_mark = batch_y_mark[:, -self.args.pred_len:, f_dim:].to(self.device)
151
+ loss = criterion(batch_x, self.args.frequency_map, outputs, batch_y, batch_y_mark)
152
+
153
+ train_loss.append(loss.item())
154
+
155
+ if (i + 1) % 100 == 0:
156
+ print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
157
+ speed = (time.time() - time_now) / iter_count
158
+ left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
159
+ print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
160
+ iter_count = 0
161
+ time_now = time.time()
162
+
163
+ loss.backward()
164
+ model_optim.step()
165
+
166
+ print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
167
+ train_loss = np.average(train_loss)
168
+ vali_loss = self.vali(vali_loader)
169
+ print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f}".format(
170
+ epoch + 1, train_steps, train_loss, vali_loss))
171
+
172
+ early_stopping(vali_loss, self.model, path)
173
+ if early_stopping.early_stop:
174
+ print("Early stopping")
175
+ break
176
+
177
+ adjust_learning_rate(model_optim, epoch + 1, self.args)
178
+
179
+ best_model_path = path + '/' + 'checkpoint.pth'
180
+ self.model.load_state_dict(torch.load(best_model_path))
181
+
182
+ return self.model
183
+
184
+ def vali(self, vali_loader):
185
+ """验证函数"""
186
+ timeseries = vali_loader.dataset.timeseries
187
+ x, y = self._prepare_data_from_timeseries(timeseries)
188
+
189
+ self.model.eval()
190
+ with torch.no_grad():
191
+ B, _, C = x.shape
192
+
193
+ # decoder input
194
+ dec_inp = torch.zeros((B, self.args.pred_len, C)).float().to(self.device)
195
+ # 从 x 中取最后 label_len 个点作为 decoder 的已知输入
196
+ dec_inp = torch.cat([x[:, -self.args.label_len:, :], dec_inp], dim=1).float()
197
+
198
+ # 分批推理,避免 OOM
199
+ outputs = torch.zeros((B, self.args.pred_len, C)).float()
200
+ batch_size = 500
201
+ id_list = np.arange(0, B, batch_size)
202
+ id_list = np.append(id_list, B)
203
+
204
+ for i in range(len(id_list) - 1):
205
+ start_idx, end_idx = id_list[i], id_list[i + 1]
206
+ outputs[start_idx:end_idx, :, :] = self.model(
207
+ x[start_idx:end_idx], None,
208
+ dec_inp[start_idx:end_idx], None
209
+ ).detach().cpu()
210
+
211
+ f_dim = -1 if self.args.features == 'MS' else 0
212
+ outputs = outputs[:, -self.args.pred_len:, f_dim:]
213
+
214
+ pred = outputs.numpy()
215
+ true = y
216
+
217
+ loss = mse(pred, true)
218
+
219
+ self.model.train()
220
+ return loss
221
+ def test(self, setting, test=0):
222
+ """测试函数"""
223
+ # 测试多个数据集
224
+ flags = ['test', 'test_Companies', 'test_Economics', 'test_Entertainment', 'test_Mentions', 'test_Politics']
225
+ # flags = ['test', 'test_Crypto', 'test_Politics', 'test_Election']
226
+
227
+ results = []
228
+ columns = []
229
+
230
+ for flag in flags:
231
+ try:
232
+ _, test_loader = self._get_data(flag=flag)
233
+ except Exception as e:
234
+ print(f"Skipping {flag}: {e}")
235
+ continue
236
+
237
+ timeseries = test_loader.dataset.timeseries
238
+
239
+ if len(timeseries) == 0:
240
+ print(f"[{flag}] No samples, skipping...")
241
+ continue
242
+
243
+ x, y = self._prepare_data_from_timeseries(timeseries)
244
+
245
+ if test:
246
+ print('Loading model...')
247
+ self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))
248
+
249
+ folder_path = './test_results/' + setting + '/'
250
+ if not os.path.exists(folder_path):
251
+ os.makedirs(folder_path)
252
+
253
+ self.model.eval()
254
+ with torch.no_grad():
255
+ B, _, C = x.shape
256
+
257
+ # decoder input
258
+ dec_inp = torch.zeros((B, self.args.pred_len, C)).float().to(self.device)
259
+ dec_inp = torch.cat([x[:, -self.args.label_len:, :], dec_inp], dim=1).float()
260
+
261
+ # 分批推理
262
+ outputs = torch.zeros((B, self.args.pred_len, C)).float().to(self.device)
263
+ batch_size = 500
264
+ id_list = np.arange(0, B, batch_size)
265
+ id_list = np.append(id_list, B)
266
+
267
+ for i in range(len(id_list) - 1):
268
+ start_idx, end_idx = id_list[i], id_list[i + 1]
269
+ outputs[start_idx:end_idx, :, :] = self.model(
270
+ x[start_idx:end_idx], None,
271
+ dec_inp[start_idx:end_idx], None
272
+ )
273
+ if start_idx % 1000 == 0:
274
+ print(f"Processed {start_idx}/{B}")
275
+
276
+ f_dim = -1 if self.args.features == 'MS' else 0
277
+ outputs = outputs[:, -self.args.pred_len:, f_dim:]
278
+ preds = outputs.detach().cpu().numpy()
279
+ trues = y
280
+
281
+ print(f'[{flag}] Test shape: {preds.shape}')
282
+
283
+ # 计算指标
284
+ rmse_val = rmse(preds, trues)
285
+ mae_val = mae(preds, trues)
286
+
287
+ columns.extend([f'{flag}_rmse', f'{flag}_mae'])
288
+ results.extend([rmse_val, mae_val])
289
+
290
+ print(f'[{flag}] RMSE: {rmse_val:.6f}, MAE: {mae_val:.6f}')
291
+
292
+ # 保存结果
293
+ folder_path = f'./{self.args.data}_results/' + self.args.model + '/'
294
+ if not os.path.exists(folder_path):
295
+ os.makedirs(folder_path)
296
+
297
+ df = pd.DataFrame([results], columns=columns)
298
+ result_path = os.path.join(folder_path, f'{self.args.model}_results.csv')
299
+ df.to_csv(result_path, index=False)
300
+ print(f'Results saved to {result_path}')
301
+
302
+ return results
kalshi_results/Autoformer/Autoformer_results.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ test_rmse,test_mae,test_Companies_rmse,test_Companies_mae,test_Economics_rmse,test_Economics_mae,test_Entertainment_rmse,test_Entertainment_mae,test_Mentions_rmse,test_Mentions_mae,test_Politics_rmse,test_Politics_mae
2
+ 0.40856921686443876,0.3217217668927483,0.364304113186295,0.28937346579139905,0.5096567834520302,0.3970817213714233,0.40597254578512976,0.3254350355656028,0.3871389816162235,0.32037402264013803,0.40295968229454693,0.3157070307474669
kalshi_results/DLinear/DLinear_results.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ test_rmse,test_mae,test_Companies_rmse,test_Companies_mae,test_Economics_rmse,test_Economics_mae,test_Entertainment_rmse,test_Entertainment_mae,test_Mentions_rmse,test_Mentions_mae,test_Politics_rmse,test_Politics_mae
2
+ 0.3891896238849517,0.30230516746317077,0.351860807826818,0.27699875265589174,0.48147644981667514,0.37157798860447333,0.3897480740571339,0.30954725156348983,0.3710907844730615,0.3075220013497144,0.3831593945812524,0.29365767848840196
layers/AutoCorrelation.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import math
7
+ from math import sqrt
8
+ import os
9
+
10
+
11
+ class AutoCorrelation(nn.Module):
12
+ """
13
+ AutoCorrelation Mechanism with the following two phases:
14
+ (1) period-based dependencies discovery
15
+ (2) time delay aggregation
16
+ This block can replace the self-attention family mechanism seamlessly.
17
+ """
18
+
19
+ def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):
20
+ super(AutoCorrelation, self).__init__()
21
+ self.factor = factor
22
+ self.scale = scale
23
+ self.mask_flag = mask_flag
24
+ self.output_attention = output_attention
25
+ self.dropout = nn.Dropout(attention_dropout)
26
+
27
+ def time_delay_agg_training(self, values, corr):
28
+ """
29
+ SpeedUp version of Autocorrelation (a batch-normalization style design)
30
+ This is for the training phase.
31
+ """
32
+ head = values.shape[1]
33
+ channel = values.shape[2]
34
+ length = values.shape[3]
35
+ # find top k
36
+ top_k = int(self.factor * math.log(length))
37
+ mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
38
+ index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
39
+ weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
40
+ # update corr
41
+ tmp_corr = torch.softmax(weights, dim=-1)
42
+ # aggregation
43
+ tmp_values = values
44
+ delays_agg = torch.zeros_like(values).float()
45
+ for i in range(top_k):
46
+ pattern = torch.roll(tmp_values, -int(index[i]), -1)
47
+ delays_agg = delays_agg + pattern * \
48
+ (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
49
+ return delays_agg
50
+
51
+ def time_delay_agg_inference(self, values, corr):
52
+ """
53
+ SpeedUp version of Autocorrelation (a batch-normalization style design)
54
+ This is for the inference phase.
55
+ """
56
+ batch = values.shape[0]
57
+ head = values.shape[1]
58
+ channel = values.shape[2]
59
+ length = values.shape[3]
60
+ # index init
61
+ init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).to(values.device)
62
+ # find top k
63
+ top_k = int(self.factor * math.log(length))
64
+ mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
65
+ weights, delay = torch.topk(mean_value, top_k, dim=-1)
66
+ # update corr
67
+ tmp_corr = torch.softmax(weights, dim=-1)
68
+ # aggregation
69
+ tmp_values = values.repeat(1, 1, 1, 2)
70
+ delays_agg = torch.zeros_like(values).float()
71
+ for i in range(top_k):
72
+ tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
73
+ pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
74
+ delays_agg = delays_agg + pattern * \
75
+ (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
76
+ return delays_agg
77
+
78
+ def time_delay_agg_full(self, values, corr):
79
+ """
80
+ Standard version of Autocorrelation
81
+ """
82
+ batch = values.shape[0]
83
+ head = values.shape[1]
84
+ channel = values.shape[2]
85
+ length = values.shape[3]
86
+ # index init
87
+ init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).to(values.device)
88
+ # find top k
89
+ top_k = int(self.factor * math.log(length))
90
+ weights, delay = torch.topk(corr, top_k, dim=-1)
91
+ # update corr
92
+ tmp_corr = torch.softmax(weights, dim=-1)
93
+ # aggregation
94
+ tmp_values = values.repeat(1, 1, 1, 2)
95
+ delays_agg = torch.zeros_like(values).float()
96
+ for i in range(top_k):
97
+ tmp_delay = init_index + delay[..., i].unsqueeze(-1)
98
+ pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
99
+ delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
100
+ return delays_agg
101
+
102
+ def forward(self, queries, keys, values, attn_mask):
103
+ B, L, H, E = queries.shape
104
+ _, S, _, D = values.shape
105
+ if L > S:
106
+ zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
107
+ values = torch.cat([values, zeros], dim=1)
108
+ keys = torch.cat([keys, zeros], dim=1)
109
+ else:
110
+ values = values[:, :L, :, :]
111
+ keys = keys[:, :L, :, :]
112
+
113
+ # period-based dependencies
114
+ q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
115
+ k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
116
+ res = q_fft * torch.conj(k_fft)
117
+ corr = torch.fft.irfft(res, dim=-1)
118
+
119
+ # time delay agg
120
+ if self.training:
121
+ V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
122
+ else:
123
+ V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
124
+
125
+ if self.output_attention:
126
+ return (V.contiguous(), corr.permute(0, 3, 1, 2))
127
+ else:
128
+ return (V.contiguous(), None)
129
+
130
+
131
+ class AutoCorrelationLayer(nn.Module):
132
+ def __init__(self, correlation, d_model, n_heads, d_keys=None,
133
+ d_values=None):
134
+ super(AutoCorrelationLayer, self).__init__()
135
+
136
+ d_keys = d_keys or (d_model // n_heads)
137
+ d_values = d_values or (d_model // n_heads)
138
+
139
+ self.inner_correlation = correlation
140
+ self.query_projection = nn.Linear(d_model, d_keys * n_heads)
141
+ self.key_projection = nn.Linear(d_model, d_keys * n_heads)
142
+ self.value_projection = nn.Linear(d_model, d_values * n_heads)
143
+ self.out_projection = nn.Linear(d_values * n_heads, d_model)
144
+ self.n_heads = n_heads
145
+
146
+ def forward(self, queries, keys, values, attn_mask):
147
+ B, L, _ = queries.shape
148
+ _, S, _ = keys.shape
149
+ H = self.n_heads
150
+
151
+ queries = self.query_projection(queries).view(B, L, H, -1)
152
+ keys = self.key_projection(keys).view(B, S, H, -1)
153
+ values = self.value_projection(values).view(B, S, H, -1)
154
+
155
+ out, attn = self.inner_correlation(
156
+ queries,
157
+ keys,
158
+ values,
159
+ attn_mask
160
+ )
161
+ out = out.view(B, L, -1)
162
+
163
+ return self.out_projection(out), attn
layers/Autoformer_EncDec.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class my_Layernorm(nn.Module):
7
+ """
8
+ Special designed layernorm for the seasonal part
9
+ """
10
+
11
+ def __init__(self, channels):
12
+ super(my_Layernorm, self).__init__()
13
+ self.layernorm = nn.LayerNorm(channels)
14
+
15
+ def forward(self, x):
16
+ x_hat = self.layernorm(x)
17
+ bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
18
+ return x_hat - bias
19
+
20
+
21
+ class moving_avg(nn.Module):
22
+ """
23
+ Moving average block to highlight the trend of time series
24
+ """
25
+
26
+ def __init__(self, kernel_size, stride):
27
+ super(moving_avg, self).__init__()
28
+ self.kernel_size = kernel_size
29
+ self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
30
+
31
+ def forward(self, x):
32
+ # padding on the both ends of time series
33
+ front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
34
+ end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
35
+ x = torch.cat([front, x, end], dim=1)
36
+ x = self.avg(x.permute(0, 2, 1))
37
+ x = x.permute(0, 2, 1)
38
+ return x
39
+
40
+
41
+ class series_decomp(nn.Module):
42
+ """
43
+ Series decomposition block
44
+ """
45
+
46
+ def __init__(self, kernel_size):
47
+ super(series_decomp, self).__init__()
48
+ self.moving_avg = moving_avg(kernel_size, stride=1)
49
+
50
+ def forward(self, x):
51
+ moving_mean = self.moving_avg(x)
52
+ res = x - moving_mean
53
+ return res, moving_mean
54
+
55
+
56
+ class series_decomp_multi(nn.Module):
57
+ """
58
+ Multiple Series decomposition block from FEDformer
59
+ """
60
+
61
+ def __init__(self, kernel_size):
62
+ super(series_decomp_multi, self).__init__()
63
+ self.kernel_size = kernel_size
64
+ self.series_decomp = [series_decomp(kernel) for kernel in kernel_size]
65
+
66
+ def forward(self, x):
67
+ moving_mean = []
68
+ res = []
69
+ for func in self.series_decomp:
70
+ sea, moving_avg = func(x)
71
+ moving_mean.append(moving_avg)
72
+ res.append(sea)
73
+
74
+ sea = sum(res) / len(res)
75
+ moving_mean = sum(moving_mean) / len(moving_mean)
76
+ return sea, moving_mean
77
+
78
+
79
+ class EncoderLayer(nn.Module):
80
+ """
81
+ Autoformer encoder layer with the progressive decomposition architecture
82
+ """
83
+
84
+ def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"):
85
+ super(EncoderLayer, self).__init__()
86
+ d_ff = d_ff or 4 * d_model
87
+ self.attention = attention
88
+ self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
89
+ self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
90
+ self.decomp1 = series_decomp(moving_avg)
91
+ self.decomp2 = series_decomp(moving_avg)
92
+ self.dropout = nn.Dropout(dropout)
93
+ self.activation = F.relu if activation == "relu" else F.gelu
94
+
95
+ def forward(self, x, attn_mask=None):
96
+ new_x, attn = self.attention(
97
+ x, x, x,
98
+ attn_mask=attn_mask
99
+ )
100
+ x = x + self.dropout(new_x)
101
+ x, _ = self.decomp1(x)
102
+ y = x
103
+ y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
104
+ y = self.dropout(self.conv2(y).transpose(-1, 1))
105
+ res, _ = self.decomp2(x + y)
106
+ return res, attn
107
+
108
+
109
+ class Encoder(nn.Module):
110
+ """
111
+ Autoformer encoder
112
+ """
113
+
114
+ def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
115
+ super(Encoder, self).__init__()
116
+ self.attn_layers = nn.ModuleList(attn_layers)
117
+ self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
118
+ self.norm = norm_layer
119
+
120
+ def forward(self, x, attn_mask=None):
121
+ attns = []
122
+ if self.conv_layers is not None:
123
+ for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
124
+ x, attn = attn_layer(x, attn_mask=attn_mask)
125
+ x = conv_layer(x)
126
+ attns.append(attn)
127
+ x, attn = self.attn_layers[-1](x)
128
+ attns.append(attn)
129
+ else:
130
+ for attn_layer in self.attn_layers:
131
+ x, attn = attn_layer(x, attn_mask=attn_mask)
132
+ attns.append(attn)
133
+
134
+ if self.norm is not None:
135
+ x = self.norm(x)
136
+
137
+ return x, attns
138
+
139
+
140
+ class DecoderLayer(nn.Module):
141
+ """
142
+ Autoformer decoder layer with the progressive decomposition architecture
143
+ """
144
+
145
+ def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None,
146
+ moving_avg=25, dropout=0.1, activation="relu"):
147
+ super(DecoderLayer, self).__init__()
148
+ d_ff = d_ff or 4 * d_model
149
+ self.self_attention = self_attention
150
+ self.cross_attention = cross_attention
151
+ self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False)
152
+ self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False)
153
+ self.decomp1 = series_decomp(moving_avg)
154
+ self.decomp2 = series_decomp(moving_avg)
155
+ self.decomp3 = series_decomp(moving_avg)
156
+ self.dropout = nn.Dropout(dropout)
157
+ self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1,
158
+ padding_mode='circular', bias=False)
159
+ self.activation = F.relu if activation == "relu" else F.gelu
160
+
161
+ def forward(self, x, cross, x_mask=None, cross_mask=None):
162
+ x = x + self.dropout(self.self_attention(
163
+ x, x, x,
164
+ attn_mask=x_mask
165
+ )[0])
166
+ x, trend1 = self.decomp1(x)
167
+ x = x + self.dropout(self.cross_attention(
168
+ x, cross, cross,
169
+ attn_mask=cross_mask
170
+ )[0])
171
+ x, trend2 = self.decomp2(x)
172
+ y = x
173
+ y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
174
+ y = self.dropout(self.conv2(y).transpose(-1, 1))
175
+ x, trend3 = self.decomp3(x + y)
176
+
177
+ residual_trend = trend1 + trend2 + trend3
178
+ residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2)
179
+ return x, residual_trend
180
+
181
+
182
+ class Decoder(nn.Module):
183
+ """
184
+ Autoformer encoder
185
+ """
186
+
187
+ def __init__(self, layers, norm_layer=None, projection=None):
188
+ super(Decoder, self).__init__()
189
+ self.layers = nn.ModuleList(layers)
190
+ self.norm = norm_layer
191
+ self.projection = projection
192
+
193
+ def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
194
+ for layer in self.layers:
195
+ x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
196
+ trend = trend + residual_trend
197
+
198
+ if self.norm is not None:
199
+ x = self.norm(x)
200
+
201
+ if self.projection is not None:
202
+ x = self.projection(x)
203
+ return x, trend