Upload 17 files
Browse files- .gitattributes +3 -0
- README.md +93 -0
- augmentation/aug.py +248 -0
- dataset_loader/datasetloader.py +431 -0
- decompositions/decomposition.py +28 -0
- figures/main_result.png +3 -0
- figures/main_result2.png +3 -0
- figures/overview.png +3 -0
- main_run/train.py +290 -0
- models/DLinear.py +91 -0
- requirements.txt +10 -0
- run_main.py +150 -0
- scripts/etth1.sh +335 -0
- scripts/etth2.sh +341 -0
- scripts/ili.sh +342 -0
- scripts/weather.sh +278 -0
- utils/metrics.py +49 -0
- utils/tools.py +102 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
figures/main_result.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
figures/main_result2.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
figures/overview.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Wave-Mask/Mix: Exploring Wavelet-Based Augmentations for Time Series Forecasting <a href="https://www.arxiv.org/abs/2408.10951" title="Read the paper on arXiv"><img src="https://img.shields.io/badge/arXiv-2408.10951-b31b1b.svg" margin-left="5px" height="20" align="center"></a>
|
| 2 |
+
|
| 3 |
+
<a href="https://www.arxiv.org/abs/2408.10951" style="color: #4285F4; font-size: 24px; font-weight: bold; text-decoration: none;">Paper(arXiv)</a>
|
| 4 |
+
|
| 5 |
+
The figure depicts a framework of the training stages incorporating wavelet augmentations, which involve the concatenation of
|
| 6 |
+
the look-back window and the forecasting horizon prior to transformation and augmentation. Batch sampling of the generated synthetic data is conducted according to a predefined hyperparameter called the sampling rate. These batches are subsequently used to split the data into the look-back window and target horizon, after which they are concatenated with the original data. Wavelet augmentations are Wavelet Masking (WaveMask) and Wavelet Mixing (WaveMix). These techniques utilize the discrete wavelet transform (DWT) to obtain wavelet coefficients (both approximation and detail coefficients) by breaking down the signal and adjusting these coefficients, in line with modifying frequency components across different time scales.
|
| 7 |
+
|
| 8 |
+
WaveMask selectively eliminates specific wavelet coefficients at each decomposition level, thereby introducing variability in the augmented data. Conversely, WaveMix exchanges wavelet coefficients from two distinct instances of the dataset, thereby enhancing the diversity of the augmented data.
|
| 9 |
+
|
| 10 |
+
To the best of our knowledge, this is the first study to conduct extensive experiments on multivariate time series using Discrete Wavelet Transform as an augmentation technique.
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
<div align=center>
|
| 15 |
+
<img src="./figures/overview.png" alt="Overview" width="700" style="margin-bottom: 40px; margin-top: 40px;"/>
|
| 16 |
+
</div>
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
Tables present comparisons between our methods and baselines in terms of the metrics Mean Squared Error (MSE) and Mean Absolute Error (MAE). The best result is indicated in bold, while the second most favorable outcome is underlined.
|
| 21 |
+
|
| 22 |
+
<div align=center>
|
| 23 |
+
<img src="./figures/main_result.png" alt="Main Results" width="1000" height="500" style="margin-bottom: 40px; margin-top: 40px;"/>
|
| 24 |
+
</div>
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
<div align=center>
|
| 29 |
+
<img src="./figures/main_result2.png" alt="Main Results" width="1000" height="500" style="margin-bottom: 40px; margin-top: 40px;"/>
|
| 30 |
+
</div>
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
## Dataset
|
| 34 |
+
|
| 35 |
+
You can obtain all datasets under https://drive.google.com/drive/folders/1ZOYpTUa82_jCcxIdTmyr0LXQfvaM9vIy. All of them are ready for training.
|
| 36 |
+
|
| 37 |
+
```
|
| 38 |
+
mkdir dataset
|
| 39 |
+
```
|
| 40 |
+
Please place all of them within the ```./dataset ``` directory.
|
| 41 |
+
|
| 42 |
+
## Quick Start
|
| 43 |
+
|
| 44 |
+
Clone the project
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
git clone https://github.com/jafarbakhshaliyev/Wave-Augs.git
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
Go to the project directory
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
cd Wave-Augs
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
Install dependencies
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
pip install -r requirements.txt
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
Train:
|
| 63 |
+
|
| 64 |
+
```bash
|
| 65 |
+
sh scripts/etth1.sh
|
| 66 |
+
sh scripts/etth2.sh
|
| 67 |
+
sh scripts/weather.sh
|
| 68 |
+
sh scripts/ili.sh
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
You can change ```percentage``` to down-sample training dataset for ablation study.
|
| 72 |
+
|
| 73 |
+
## Citation
|
| 74 |
+
|
| 75 |
+
If you find the code useful, please cite our paper:
|
| 76 |
+
|
| 77 |
+
```
|
| 78 |
+
@misc{waveaug2024,
|
| 79 |
+
title={Wave-Mask/Mix: Exploring Wavelet-Based Augmentations for Time Series Forecasting},
|
| 80 |
+
author={Dona Arabi and Jafar Bakhshaliyev and Ayse Coskuner and Kiran Madhusudhanan and Kami Serdar Uckardes},
|
| 81 |
+
year={2024},
|
| 82 |
+
eprint={2408.10951},
|
| 83 |
+
archivePrefix={arXiv},
|
| 84 |
+
primaryClass={cs.LG},
|
| 85 |
+
url={https://arxiv.org/abs/2408.10951},
|
| 86 |
+
}
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
Please remember to cite all the datasets and compared methods if you use them in your experiments.
|
| 90 |
+
|
| 91 |
+
## Acknowledgements
|
| 92 |
+
|
| 93 |
+
We would like to express our gratitude to [Zheng et al.](https://arxiv.org/abs/2205.13504) for providing datasets used in this project. Additionally, we also acknowledge [Chen et al.](https://arxiv.org/abs/2302.09292) and [Zhang et al.](https://arxiv.org/abs/2303.14254) for their code frameworks, which served as the foundation for our codebase.
|
augmentation/aug.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# The freq_mask, freq_mix methods are adapted from the following sources:
|
| 3 |
+
# Chen, M., Xu, Z., Zeng, A., & Xu, Q. (2023). "FrAug: Frequency Domain Augmentation for Time Series Forecasting".
|
| 4 |
+
# arXiv preprint arXiv:2302.09292.
|
| 5 |
+
#
|
| 6 |
+
# The emd_aug and mix_aug for STAug method are adapted from the following source:
|
| 7 |
+
# https://github.com/xiyuanzh/STAug/tree/main
|
| 8 |
+
# =============================================================================
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import pywt
|
| 13 |
+
from pytorch_wavelets import DWT1DForward, DWT1DInverse
|
| 14 |
+
from typing import List, Tuple
|
| 15 |
+
|
| 16 |
+
class augmentation():
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
A class for data augmentation techniques used for Time Series Forecasting.
|
| 20 |
+
|
| 21 |
+
Attributes:
|
| 22 |
+
None
|
| 23 |
+
|
| 24 |
+
Methods:
|
| 25 |
+
- freq_mask: Apply frequency masking to input data.
|
| 26 |
+
- freq_mix: Mix two input signals in the frequency domain.
|
| 27 |
+
- wave_mask: Apply wavelet-based masking to input data.
|
| 28 |
+
- wave_mix: Mix two input signals using wavelet transformation.
|
| 29 |
+
- emd_aug: Apply empirical mode decomposition (EMD) based augmentation.
|
| 30 |
+
- mix_aug: Mix two batches of data with a random interpolation factor.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self):
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
@staticmethod
|
| 37 |
+
def freq_mask(x: torch.Tensor, y: torch.Tensor, rate: float = 0.5, dim: int = 1) -> torch.Tensor:
|
| 38 |
+
"""
|
| 39 |
+
Apply frequency masking to input data.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
- x (torch.Tensor): Look-back window.
|
| 43 |
+
- y (torch.Tensor): Target horizon.
|
| 44 |
+
- rate (float): Mask rate.
|
| 45 |
+
- dim (int): Dimension along to concatenate and apply Fourier Transform.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
- torch.Tensor: Masked synthetic data tensor.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
xy = torch.cat([x,y],dim=1)
|
| 52 |
+
xy_f = torch.fft.rfft(xy,dim=dim)
|
| 53 |
+
m = torch.FloatTensor(xy_f.shape).uniform_() < rate
|
| 54 |
+
freal = xy_f.real.masked_fill(m,0)
|
| 55 |
+
fimag = xy_f.imag.masked_fill(m,0)
|
| 56 |
+
xy_f = torch.complex(freal,fimag)
|
| 57 |
+
xy = torch.fft.irfft(xy_f,dim=dim)
|
| 58 |
+
return xy
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def freq_mix(x: torch.Tensor, y: torch.Tensor, rate: float = 0.5, dim: int = 1) -> torch.Tensor:
|
| 62 |
+
"""
|
| 63 |
+
Mix two input signals in the frequency domain.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
- x (torch.Tensor): Look-back window.
|
| 67 |
+
- y (torch.Tensor): Target horizon.
|
| 68 |
+
- rate (float): Mix rate.
|
| 69 |
+
- dim (int): Dimension along to concatenate and apply Fourier Transform.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
- torch.Tensor: Mixed synthetic data tensor.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
xy = torch.cat([x,y],dim=dim)
|
| 76 |
+
xy_f = torch.fft.rfft(xy,dim=dim)
|
| 77 |
+
|
| 78 |
+
m = torch.FloatTensor(xy_f.shape).uniform_() < rate
|
| 79 |
+
amp = abs(xy_f)
|
| 80 |
+
_,index = amp.sort(dim=dim, descending=True)
|
| 81 |
+
dominant_mask = index > 2
|
| 82 |
+
m = torch.bitwise_and(m, dominant_mask)
|
| 83 |
+
freal = xy_f.real.masked_fill(m,0)
|
| 84 |
+
fimag = xy_f.imag.masked_fill(m,0)
|
| 85 |
+
|
| 86 |
+
b_idx = np.arange(x.shape[0])
|
| 87 |
+
np.random.shuffle(b_idx)
|
| 88 |
+
x2, y2 = x[b_idx], y[b_idx]
|
| 89 |
+
xy2 = torch.cat([x2,y2],dim=dim)
|
| 90 |
+
xy2_f = torch.fft.rfft(xy2,dim=dim)
|
| 91 |
+
|
| 92 |
+
m = torch.bitwise_not(m)
|
| 93 |
+
freal2 = xy2_f.real.masked_fill(m,0)
|
| 94 |
+
fimag2 = xy2_f.imag.masked_fill(m,0)
|
| 95 |
+
|
| 96 |
+
freal += freal2
|
| 97 |
+
fimag += fimag2
|
| 98 |
+
|
| 99 |
+
xy_f = torch.complex(freal,fimag)
|
| 100 |
+
|
| 101 |
+
xy = torch.fft.irfft(xy_f,dim=dim)
|
| 102 |
+
|
| 103 |
+
return xy
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@staticmethod
|
| 107 |
+
def wave_mask(x: torch.Tensor, y: torch.Tensor, rates: List[float], wavelet: str = 'db1', level: int = 2, dim: int= 1) -> torch.Tensor:
|
| 108 |
+
"""
|
| 109 |
+
Apply wavelet-based masking to input data.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
- x (torch.Tensor): Look-back window.
|
| 113 |
+
- y (torch.Tensor): Target horizon.
|
| 114 |
+
- rates (list of floats): List of mask rates for each wavelet level.
|
| 115 |
+
- wavelet (str): Type of wavelet to use.
|
| 116 |
+
- level (int): Number of decomposition levels.
|
| 117 |
+
- dim (int): Dimension along which to concatenate and apply DWT.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
- torch.Tensor: Masked synthetic data tensor.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
xy = torch.cat([x, y], dim=1) # Concatenate along the time dimension
|
| 124 |
+
rate_tensor = torch.tensor(rates, device=x.device) # Convert rates to tensor
|
| 125 |
+
|
| 126 |
+
# Permute dimensions to match expected input: (batch_size, num_features, seq_len)
|
| 127 |
+
xy = xy.permute(0, 2, 1).to(x.device)
|
| 128 |
+
|
| 129 |
+
# Initialize & perform the DWT
|
| 130 |
+
dwt = DWT1DForward(J=level, wave=wavelet, mode='symmetric').to(x.device)
|
| 131 |
+
cA, cDs = dwt(xy)
|
| 132 |
+
|
| 133 |
+
mask = torch.rand_like(cA).to(cA.device) < rate_tensor[0]
|
| 134 |
+
cA = cA.masked_fill(mask, 0)
|
| 135 |
+
|
| 136 |
+
# Apply masking to detail coefficients
|
| 137 |
+
masked_cDs = []
|
| 138 |
+
for i, cD in enumerate(cDs, 1):
|
| 139 |
+
mask_cD = torch.rand_like(cD).to(cD.device) < rate_tensor[i] # Create mask
|
| 140 |
+
cD = cD.masked_fill(mask_cD, 0)
|
| 141 |
+
masked_cDs.append(cD)
|
| 142 |
+
|
| 143 |
+
# Initialize the inverse DWT & reconstruct the signal
|
| 144 |
+
idwt = DWT1DInverse(wave=wavelet, mode='symmetric').to(x.device)
|
| 145 |
+
reconstructed = idwt((cA, masked_cDs))
|
| 146 |
+
|
| 147 |
+
# Permute back to original shape: (batch_size, seq_len, num_features)
|
| 148 |
+
reconstructed = reconstructed.permute(0, 2, 1)
|
| 149 |
+
|
| 150 |
+
return reconstructed
|
| 151 |
+
|
| 152 |
+
@staticmethod
|
| 153 |
+
def wave_mix(x: torch.Tensor, y: torch.Tensor, rates: List[float], wavelet: str = 'db1', level: int = 2, dim: int = 1) -> torch.Tensor:
|
| 154 |
+
"""
|
| 155 |
+
Mix two input signals using wavelet transformation.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
- x (torch.Tensor): Look-back window.
|
| 159 |
+
- y (torch.Tensor): Target horizon.
|
| 160 |
+
- rates (list of floats): List of mix rates for each wavelet level.
|
| 161 |
+
- wavelet (str): Type of wavelet to use.
|
| 162 |
+
- level (int): Number of decomposition levels.
|
| 163 |
+
- dim (int): Dimension along which to concatenate and apply DWT.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
- torch.Tensor: Mixed synthetic data tensor.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
xy = torch.cat([x, y], dim=1) # Concatenate along the time dimension
|
| 170 |
+
batch_size, _, _ = xy.shape
|
| 171 |
+
rate_tensor = torch.tensor(rates, device=x.device) # Convert rates to tensor
|
| 172 |
+
|
| 173 |
+
# Permute dimensions to match expected input: (batch_size, num_features, seq_len)
|
| 174 |
+
xy = xy.permute(0, 2, 1).to(x.device)
|
| 175 |
+
|
| 176 |
+
# Shuffle the batch for mixing
|
| 177 |
+
b_idx = torch.randperm(batch_size)
|
| 178 |
+
xy2 = xy[b_idx]
|
| 179 |
+
|
| 180 |
+
# Initialize & perform the DWT on the both signals
|
| 181 |
+
dwt = DWT1DForward(J=level, wave=wavelet, mode='symmetric').to(x.device)
|
| 182 |
+
cA1, cDs1 = dwt(xy)
|
| 183 |
+
cA2, cDs2 = dwt(xy2)
|
| 184 |
+
|
| 185 |
+
# Mix the approximation coefficients
|
| 186 |
+
mask = torch.rand_like(cA1).to(cA1.device) < rate_tensor[0] # Create mask
|
| 187 |
+
cA_mixed = cA1.masked_fill(mask, 0) + cA2.masked_fill(~mask, 0)
|
| 188 |
+
|
| 189 |
+
# Mix the coefficients
|
| 190 |
+
mixed_cDs = []
|
| 191 |
+
for i, (cD1, cD2) in enumerate(zip(cDs1, cDs2), 1):
|
| 192 |
+
mask = torch.rand_like(cD1).to(cD1.device) < rate_tensor[i] # Create mask
|
| 193 |
+
cD_mixed = cD1.masked_fill(mask, 0) + cD2.masked_fill(~mask, 0)
|
| 194 |
+
mixed_cDs.append(cD_mixed)
|
| 195 |
+
|
| 196 |
+
# Initialize the inverse DWT & reconstruct the signal
|
| 197 |
+
idwt = DWT1DInverse(wave=wavelet, mode='symmetric').to(x.device)
|
| 198 |
+
reconstructed = idwt((cA_mixed, mixed_cDs))
|
| 199 |
+
|
| 200 |
+
# Permute back to original shape: (batch_size, seq_len, num_features)
|
| 201 |
+
reconstructed = reconstructed.permute(0, 2, 1)
|
| 202 |
+
|
| 203 |
+
return reconstructed
|
| 204 |
+
|
| 205 |
+
# StAug: frequency-domain augmentation
|
| 206 |
+
def emd_aug(self, x: torch.Tensor) -> torch.Tensor:
|
| 207 |
+
"""
|
| 208 |
+
Apply augmentation on empirical mode decomposition (EMD).
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
- x (torch.Tensor): Input tensor.
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
- torch.Tensor: Augmented tensor.
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
b,n_imf,t,c = x.size()
|
| 218 |
+
inp = x.permute(0,2,1,3).reshape(b,t,n_imf*c) #b,t,n_imf,c -> b,t,n_imf*c
|
| 219 |
+
if(torch.rand(1) >= 0.5):
|
| 220 |
+
w = 2 * torch.rand((b,1,n_imf*c)).cuda()
|
| 221 |
+
else:
|
| 222 |
+
w = torch.ones((b,1,n_imf*c)).cuda()
|
| 223 |
+
w_exp = w.expand(-1,t,-1) #b,t,n_imf*c
|
| 224 |
+
out = w_exp * inp
|
| 225 |
+
out = out.reshape(b,t,n_imf,c).sum(dim=2) #b,t,c
|
| 226 |
+
|
| 227 |
+
return out
|
| 228 |
+
|
| 229 |
+
# StAug: time-domain augmentation
|
| 230 |
+
def mix_aug(self, batch_x: np.ndarray, batch_y: np.ndarray, lambd: float = 0.5) -> Tuple[np.ndarray, np.ndarray]:
|
| 231 |
+
"""
|
| 232 |
+
Mix two batches of data with a random interpolation factor.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
- batch_x (numpy.ndarray): Input batch 1.
|
| 236 |
+
- batch_y (numpy.ndarray): Input batch 2.
|
| 237 |
+
- lambd (float): Beta distribution parameter for interpolation.
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
- numpy.ndarray: Mixed augmented batches.
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
inds2 = np.random.permutation(len(batch_x))
|
| 244 |
+
lam = np.random.beta(lambd, lambd)
|
| 245 |
+
batch_x = lam * batch_x[inds2] + (1-lam) * batch_x
|
| 246 |
+
batch_y = lam * batch_y[inds2] + (1-lam) * batch_y
|
| 247 |
+
|
| 248 |
+
return batch_x, batch_y
|
dataset_loader/datasetloader.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# The code is originated from
|
| 3 |
+
# Chen, M., Xu, Z., Zeng, A., & Xu, Q. (2023). "FrAug: Frequency Domain Augmentation for Time Series Forecasting".
|
| 4 |
+
# arXiv preprint arXiv:2302.09292.
|
| 5 |
+
# =============================================================================
|
| 6 |
+
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
import os
|
| 11 |
+
from sklearn.preprocessing import StandardScaler
|
| 12 |
+
from decompositions.decomposition import emd_augment
|
| 13 |
+
|
| 14 |
+
class Dataset_ETT_hour(Dataset):
|
| 15 |
+
def __init__(self, root_path, flag='train', size=None,
|
| 16 |
+
features='S', data_path='ETTh1.csv',
|
| 17 |
+
target='OT', scale=True, freq='h', n_imf = 500, percentage = 100, params = None):
|
| 18 |
+
# size [seq_len, label_len, pred_len]
|
| 19 |
+
# info
|
| 20 |
+
if size == None:
|
| 21 |
+
self.seq_len = 24 * 4 * 4
|
| 22 |
+
self.label_len = 24 * 4
|
| 23 |
+
self.pred_len = 24 * 4
|
| 24 |
+
else:
|
| 25 |
+
self.seq_len = size[0]
|
| 26 |
+
self.label_len = size[1]
|
| 27 |
+
self.pred_len = size[2]
|
| 28 |
+
# init
|
| 29 |
+
assert flag in ['train', 'test', 'val']
|
| 30 |
+
type_map = {'train': 0, 'val': 1, 'test': 2}
|
| 31 |
+
self.set_type = type_map[flag]
|
| 32 |
+
|
| 33 |
+
self.features = features
|
| 34 |
+
self.target = target
|
| 35 |
+
self.scale = scale
|
| 36 |
+
self.freq = freq
|
| 37 |
+
self.n_imf = n_imf
|
| 38 |
+
self.percentage = percentage
|
| 39 |
+
|
| 40 |
+
self.root_path = root_path
|
| 41 |
+
self.data_path = data_path
|
| 42 |
+
self.params = params
|
| 43 |
+
self.__read_data__()
|
| 44 |
+
|
| 45 |
+
def __read_data__(self):
|
| 46 |
+
self.scaler = StandardScaler()
|
| 47 |
+
df_raw = pd.read_csv(os.path.join(self.root_path,
|
| 48 |
+
self.data_path))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len]
|
| 52 |
+
border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]
|
| 53 |
+
border1 = border1s[self.set_type]
|
| 54 |
+
border2 = border2s[self.set_type]
|
| 55 |
+
|
| 56 |
+
if self.features == 'M' or self.features == 'MS':
|
| 57 |
+
cols_data = df_raw.columns[1:]
|
| 58 |
+
df_data = df_raw[cols_data]
|
| 59 |
+
elif self.features == 'S':
|
| 60 |
+
df_data = df_raw[[self.target]]
|
| 61 |
+
|
| 62 |
+
if self.scale:
|
| 63 |
+
train_data = df_data[border1s[0]:border2s[0]]
|
| 64 |
+
train_length = int((self.percentage / 100) * len(train_data))
|
| 65 |
+
train_data = train_data[-train_length:]
|
| 66 |
+
self.scaler.fit(train_data.values)
|
| 67 |
+
data = self.scaler.transform(df_data.values)
|
| 68 |
+
else:
|
| 69 |
+
data = df_data.values
|
| 70 |
+
|
| 71 |
+
if self.set_type == 0 and self.params.aug_type == 5:
|
| 72 |
+
self.aug_data = emd_augment(data[border1:border2][-len(train_data):], self.seq_len+self.pred_len, n_IMF = self.n_imf)
|
| 73 |
+
else:
|
| 74 |
+
self.aug_data = np.zeros_like(data[border1:border2])
|
| 75 |
+
|
| 76 |
+
if self.set_type == 0:
|
| 77 |
+
self.data_x = data[border1:border2][-len(train_data):]
|
| 78 |
+
self.data_y = data[border1:border2][-len(train_data):]
|
| 79 |
+
else:
|
| 80 |
+
self.data_x = data[border1:border2]
|
| 81 |
+
self.data_y = data[border1:border2]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def __getitem__(self, index):
|
| 85 |
+
s_begin = index
|
| 86 |
+
s_end = s_begin + self.seq_len
|
| 87 |
+
r_begin = s_end - self.label_len
|
| 88 |
+
r_end = r_begin + self.label_len + self.pred_len
|
| 89 |
+
|
| 90 |
+
seq_x = self.data_x[s_begin:s_end]
|
| 91 |
+
seq_y = self.data_y[r_begin:r_end]
|
| 92 |
+
|
| 93 |
+
if self.params.aug_type == 5:
|
| 94 |
+
aug_data = self.aug_data[s_begin]
|
| 95 |
+
else:
|
| 96 |
+
aug_data = np.array([])
|
| 97 |
+
|
| 98 |
+
return seq_x, seq_y, aug_data
|
| 99 |
+
|
| 100 |
+
def __len__(self):
|
| 101 |
+
return len(self.data_x) - self.seq_len - self.pred_len + 1
|
| 102 |
+
|
| 103 |
+
def inverse_transform(self, data):
|
| 104 |
+
return self.scaler.inverse_transform(data)
|
| 105 |
+
|
| 106 |
+
class Dataset_ETT_minute(Dataset):
|
| 107 |
+
def __init__(self, root_path, flag='train', size=None,
|
| 108 |
+
features='S', data_path='ETTm1.csv',
|
| 109 |
+
target='OT', scale=True, freq='t', n_imf = 500, percentage = 100, params = None):
|
| 110 |
+
# size [seq_len, label_len, pred_len]
|
| 111 |
+
# info
|
| 112 |
+
if size == None:
|
| 113 |
+
self.seq_len = 24 * 4 * 4
|
| 114 |
+
self.label_len = 24 * 4
|
| 115 |
+
self.pred_len = 24 * 4
|
| 116 |
+
else:
|
| 117 |
+
self.seq_len = size[0]
|
| 118 |
+
self.label_len = size[1]
|
| 119 |
+
self.pred_len = size[2]
|
| 120 |
+
# init
|
| 121 |
+
assert flag in ['train', 'test', 'val']
|
| 122 |
+
type_map = {'train': 0, 'val': 1, 'test': 2}
|
| 123 |
+
self.set_type = type_map[flag]
|
| 124 |
+
|
| 125 |
+
self.features = features
|
| 126 |
+
self.target = target
|
| 127 |
+
self.scale = scale
|
| 128 |
+
self.freq = freq
|
| 129 |
+
self.n_imf = n_imf
|
| 130 |
+
|
| 131 |
+
self.root_path = root_path
|
| 132 |
+
self.data_path = data_path
|
| 133 |
+
self.percentage = percentage
|
| 134 |
+
self.params = params
|
| 135 |
+
self.__read_data__()
|
| 136 |
+
|
| 137 |
+
def __read_data__(self):
|
| 138 |
+
self.scaler = StandardScaler()
|
| 139 |
+
df_raw = pd.read_csv(os.path.join(self.root_path,
|
| 140 |
+
self.data_path))
|
| 141 |
+
|
| 142 |
+
border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len]
|
| 143 |
+
border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4]
|
| 144 |
+
border1 = border1s[self.set_type]
|
| 145 |
+
border2 = border2s[self.set_type]
|
| 146 |
+
|
| 147 |
+
if self.features == 'M' or self.features == 'MS':
|
| 148 |
+
cols_data = df_raw.columns[1:]
|
| 149 |
+
df_data = df_raw[cols_data]
|
| 150 |
+
elif self.features == 'S':
|
| 151 |
+
df_data = df_raw[[self.target]]
|
| 152 |
+
|
| 153 |
+
if self.scale:
|
| 154 |
+
train_data = df_data[border1s[0]:border2s[0]]
|
| 155 |
+
train_length = int((self.percentage / 100) * len(train_data))
|
| 156 |
+
train_data = train_data[-train_length:]
|
| 157 |
+
self.scaler.fit(train_data.values)
|
| 158 |
+
data = self.scaler.transform(df_data.values)
|
| 159 |
+
else:
|
| 160 |
+
data = df_data.values
|
| 161 |
+
|
| 162 |
+
if self.set_type == 0 and self.params.aug_type == 5:
|
| 163 |
+
self.aug_data = emd_augment(data[border1:border2][-len(train_data):], self.seq_len+self.pred_len, n_IMF = self.n_imf)
|
| 164 |
+
else:
|
| 165 |
+
self.aug_data = np.zeros_like(data[border1:border2])
|
| 166 |
+
|
| 167 |
+
if self.set_type == 0:
|
| 168 |
+
self.data_x = data[border1:border2][-len(train_data):]
|
| 169 |
+
self.data_y = data[border1:border2][-len(train_data):]
|
| 170 |
+
else:
|
| 171 |
+
self.data_x = data[border1:border2]
|
| 172 |
+
self.data_y = data[border1:border2]
|
| 173 |
+
|
| 174 |
+
def __getitem__(self, index):
|
| 175 |
+
s_begin = index
|
| 176 |
+
s_end = s_begin + self.seq_len
|
| 177 |
+
r_begin = s_end - self.label_len
|
| 178 |
+
r_end = r_begin + self.label_len + self.pred_len
|
| 179 |
+
|
| 180 |
+
seq_x = self.data_x[s_begin:s_end]
|
| 181 |
+
seq_y = self.data_y[r_begin:r_end]
|
| 182 |
+
|
| 183 |
+
if self.params.aug_type == 5:
|
| 184 |
+
aug_data = self.aug_data[s_begin]
|
| 185 |
+
else:
|
| 186 |
+
aug_data = np.array([])
|
| 187 |
+
|
| 188 |
+
return seq_x, seq_y, aug_data
|
| 189 |
+
|
| 190 |
+
def __len__(self):
|
| 191 |
+
return len(self.data_x) - self.seq_len - self.pred_len + 1
|
| 192 |
+
|
| 193 |
+
def inverse_transform(self, data):
|
| 194 |
+
return self.scaler.inverse_transform(data)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class Dataset_Custom(Dataset):
|
| 198 |
+
def __init__(self, root_path, flag='train', size=None,
|
| 199 |
+
features='S', data_path='ETTh1.csv', scale = True,
|
| 200 |
+
target='OT', freq='h', n_imf = 500, percentage = 100, params=None):
|
| 201 |
+
# size [seq_len, label_len, pred_len]
|
| 202 |
+
# info
|
| 203 |
+
|
| 204 |
+
self.seq_len = size[0]
|
| 205 |
+
self.label_len = size[1]
|
| 206 |
+
self.pred_len = size[2]
|
| 207 |
+
# init
|
| 208 |
+
assert flag in ['train', 'test', 'val']
|
| 209 |
+
type_map = {'train': 0, 'val': 1, 'test': 2}
|
| 210 |
+
self.set_type = type_map[flag]
|
| 211 |
+
|
| 212 |
+
self.features = features
|
| 213 |
+
self.target = target
|
| 214 |
+
self.freq = freq
|
| 215 |
+
self.scale = scale
|
| 216 |
+
self.n_imf = n_imf
|
| 217 |
+
self.percentage = percentage
|
| 218 |
+
self.root_path = root_path
|
| 219 |
+
self.data_path = data_path
|
| 220 |
+
self.params = params
|
| 221 |
+
|
| 222 |
+
self.__read_data__()
|
| 223 |
+
|
| 224 |
+
def __read_data__(self):
|
| 225 |
+
|
| 226 |
+
self.scaler = StandardScaler()
|
| 227 |
+
df_raw = pd.read_csv(os.path.join(self.root_path,
|
| 228 |
+
self.data_path))
|
| 229 |
+
'''
|
| 230 |
+
df_raw.columns: ['date', ...(other features), target feature]
|
| 231 |
+
'''
|
| 232 |
+
cols = list(df_raw.columns)
|
| 233 |
+
cols.remove(self.target)
|
| 234 |
+
cols.remove('date')
|
| 235 |
+
df_raw = df_raw[['date'] + cols + [self.target]]
|
| 236 |
+
|
| 237 |
+
num_train = int(len(df_raw) * 0.7)
|
| 238 |
+
num_test = int(len(df_raw) * 0.2)
|
| 239 |
+
num_vali = len(df_raw) - num_train - num_test
|
| 240 |
+
border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
|
| 241 |
+
border2s = [num_train, num_train + num_vali, len(df_raw)]
|
| 242 |
+
border1 = border1s[self.set_type]
|
| 243 |
+
border2 = border2s[self.set_type]
|
| 244 |
+
|
| 245 |
+
if self.features == 'M' or self.features == 'MS':
|
| 246 |
+
cols_data = df_raw.columns[1:]
|
| 247 |
+
df_data = df_raw[cols_data]
|
| 248 |
+
elif self.features == 'S':
|
| 249 |
+
df_data = df_raw[[self.target]]
|
| 250 |
+
|
| 251 |
+
if self.scale:
|
| 252 |
+
train_data = df_data[border1s[0]:border2s[0]]
|
| 253 |
+
train_length = int((self.percentage / 100) * len(train_data))
|
| 254 |
+
train_data = train_data[-train_length:]
|
| 255 |
+
self.scaler.fit(train_data.values)
|
| 256 |
+
data = self.scaler.transform(df_data.values)
|
| 257 |
+
else:
|
| 258 |
+
data = df_data.values
|
| 259 |
+
|
| 260 |
+
if self.set_type == 0 and self.params.aug_type == 5:
|
| 261 |
+
self.aug_data = emd_augment(data[border1:border2][-len(train_data):], self.seq_len+self.pred_len, n_IMF = self.n_imf)
|
| 262 |
+
|
| 263 |
+
else:
|
| 264 |
+
self.aug_data = np.zeros_like(data[border1:border2])
|
| 265 |
+
|
| 266 |
+
if self.set_type == 0:
|
| 267 |
+
self.data_x = data[border1:border2][-len(train_data):]
|
| 268 |
+
self.data_y = data[border1:border2][-len(train_data):]
|
| 269 |
+
else:
|
| 270 |
+
self.data_x = data[border1:border2]
|
| 271 |
+
self.data_y = data[border1:border2]
|
| 272 |
+
|
| 273 |
+
def __getitem__(self, index):
|
| 274 |
+
s_begin = index # 0
|
| 275 |
+
s_end = s_begin + self.seq_len
|
| 276 |
+
r_begin = s_end - self.label_len
|
| 277 |
+
r_end = r_begin + self.label_len + self.pred_len
|
| 278 |
+
|
| 279 |
+
seq_x = self.data_x[s_begin:s_end]
|
| 280 |
+
seq_y = self.data_y[r_begin:r_end]
|
| 281 |
+
|
| 282 |
+
if self.params.aug_type == 5:
|
| 283 |
+
aug_data = self.aug_data[s_begin]
|
| 284 |
+
else:
|
| 285 |
+
aug_data = np.array([])
|
| 286 |
+
|
| 287 |
+
return seq_x, seq_y, aug_data
|
| 288 |
+
|
| 289 |
+
def __len__(self):
|
| 290 |
+
return len(self.data_x) - self.seq_len - self.pred_len + 1
|
| 291 |
+
|
| 292 |
+
def inverse_transform(self, data):
|
| 293 |
+
return self.scaler.inverse_transform(data)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class Dataset_Pred(Dataset):
|
| 297 |
+
|
| 298 |
+
def __init__(self, root_path, flag='pred', size=None,
|
| 299 |
+
features='S', data_path='ETTh1.csv',
|
| 300 |
+
target='OT', scale=True, inverse=False, freq='15min', cols=None):
|
| 301 |
+
# size [seq_len, label_len, pred_len]
|
| 302 |
+
# info
|
| 303 |
+
|
| 304 |
+
self.seq_len = size[0]
|
| 305 |
+
self.label_len = size[1]
|
| 306 |
+
self.pred_len = size[2]
|
| 307 |
+
# init
|
| 308 |
+
assert flag in ['pred']
|
| 309 |
+
|
| 310 |
+
self.features = features
|
| 311 |
+
self.target = target
|
| 312 |
+
self.scale = scale
|
| 313 |
+
self.inverse = inverse
|
| 314 |
+
self.freq = freq
|
| 315 |
+
self.cols = cols
|
| 316 |
+
self.root_path = root_path
|
| 317 |
+
self.data_path = data_path
|
| 318 |
+
self.__read_data__()
|
| 319 |
+
|
| 320 |
+
def __read_data__(self):
|
| 321 |
+
self.scaler = StandardScaler()
|
| 322 |
+
df_raw = pd.read_csv(os.path.join(self.root_path,
|
| 323 |
+
self.data_path))
|
| 324 |
+
'''
|
| 325 |
+
df_raw.columns: ['date', ...(other features), target feature]
|
| 326 |
+
'''
|
| 327 |
+
if self.cols:
|
| 328 |
+
cols = self.cols.copy()
|
| 329 |
+
cols.remove(self.target)
|
| 330 |
+
else:
|
| 331 |
+
cols = list(df_raw.columns)
|
| 332 |
+
cols.remove(self.target)
|
| 333 |
+
cols.remove('date')
|
| 334 |
+
df_raw = df_raw[['date'] + cols + [self.target]]
|
| 335 |
+
border1 = len(df_raw) - self.seq_len
|
| 336 |
+
border2 = len(df_raw)
|
| 337 |
+
|
| 338 |
+
if self.features == 'M' or self.features == 'MS':
|
| 339 |
+
cols_data = df_raw.columns[1:]
|
| 340 |
+
df_data = df_raw[cols_data]
|
| 341 |
+
elif self.features == 'S':
|
| 342 |
+
df_data = df_raw[[self.target]]
|
| 343 |
+
|
| 344 |
+
if self.scale:
|
| 345 |
+
self.scaler.fit(df_data.values)
|
| 346 |
+
data = self.scaler.transform(df_data.values)
|
| 347 |
+
else:
|
| 348 |
+
data = df_data.values
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
self.data_x = data[border1:border2]
|
| 352 |
+
if self.inverse:
|
| 353 |
+
self.data_y = df_data.values[border1:border2]
|
| 354 |
+
else:
|
| 355 |
+
self.data_y = data[border1:border2]
|
| 356 |
+
|
| 357 |
+
def __getitem__(self, index):
|
| 358 |
+
s_begin = index
|
| 359 |
+
s_end = s_begin + self.seq_len
|
| 360 |
+
r_begin = s_end - self.label_len
|
| 361 |
+
r_end = r_begin + self.label_len + self.pred_len
|
| 362 |
+
|
| 363 |
+
seq_x = self.data_x[s_begin:s_end]
|
| 364 |
+
if self.inverse:
|
| 365 |
+
seq_y = self.data_x[r_begin:r_begin + self.label_len]
|
| 366 |
+
else:
|
| 367 |
+
seq_y = self.data_y[r_begin:r_begin + self.label_len]
|
| 368 |
+
|
| 369 |
+
aug_data = np.array([])
|
| 370 |
+
|
| 371 |
+
return seq_x, seq_y, aug_data
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def __len__(self):
|
| 375 |
+
return len(self.data_x) - self.seq_len + 1
|
| 376 |
+
|
| 377 |
+
def inverse_transform(self, data):
|
| 378 |
+
return self.scaler.inverse_transform(data)
|
| 379 |
+
|
| 380 |
+
data_dict = {
|
| 381 |
+
'ETTh1': Dataset_ETT_hour,
|
| 382 |
+
'ETTh2': Dataset_ETT_hour,
|
| 383 |
+
'ETTm1': Dataset_ETT_minute,
|
| 384 |
+
'ETTm2': Dataset_ETT_minute,
|
| 385 |
+
'custom': Dataset_Custom,
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def data_provider(args, flag):
|
| 390 |
+
Data = data_dict[args.data]
|
| 391 |
+
|
| 392 |
+
if flag == 'test':
|
| 393 |
+
shuffle_flag = False
|
| 394 |
+
drop_last = True
|
| 395 |
+
batch_size = args.batch_size
|
| 396 |
+
freq = args.freq
|
| 397 |
+
nIMF = args.nIMF
|
| 398 |
+
|
| 399 |
+
elif flag == 'pred':
|
| 400 |
+
shuffle_flag = False
|
| 401 |
+
drop_last = False
|
| 402 |
+
batch_size = 1
|
| 403 |
+
freq = args.freq
|
| 404 |
+
Data = Dataset_Pred
|
| 405 |
+
else:
|
| 406 |
+
shuffle_flag = True
|
| 407 |
+
drop_last = True
|
| 408 |
+
batch_size = args.batch_size
|
| 409 |
+
freq = args.freq
|
| 410 |
+
nIMF = args.nIMF
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
data_set = Data(
|
| 414 |
+
root_path=args.root_path,
|
| 415 |
+
data_path=args.data_path,
|
| 416 |
+
flag=flag,
|
| 417 |
+
size=[args.seq_len, args.label_len, args.pred_len],
|
| 418 |
+
features=args.features,
|
| 419 |
+
target=args.target,
|
| 420 |
+
freq=freq,
|
| 421 |
+
n_imf = nIMF,
|
| 422 |
+
percentage = args.percentage,
|
| 423 |
+
params = args
|
| 424 |
+
)
|
| 425 |
+
data_loader = DataLoader(
|
| 426 |
+
data_set,
|
| 427 |
+
batch_size=batch_size,
|
| 428 |
+
shuffle=shuffle_flag,
|
| 429 |
+
num_workers=args.num_workers,
|
| 430 |
+
drop_last=drop_last)
|
| 431 |
+
return data_set, data_loader
|
decompositions/decomposition.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# The code is originated from
|
| 3 |
+
# https://github.com/xiyuanzh/STAug/tree/main
|
| 4 |
+
# =============================================================================
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PyEMD import EMD
|
| 8 |
+
|
| 9 |
+
def emd_augment(data, sequence_length, n_IMF = 500):
|
| 10 |
+
n_imf, channel_num = n_IMF, data.shape[1]
|
| 11 |
+
emd_data = np.zeros((n_imf,data.shape[0],channel_num))
|
| 12 |
+
max_imf = 0
|
| 13 |
+
for ci in range(channel_num):
|
| 14 |
+
s = data[:, ci]
|
| 15 |
+
IMF = EMD().emd(s)
|
| 16 |
+
r_s = np.zeros((n_imf, data.shape[0]))
|
| 17 |
+
if len(IMF) > max_imf:
|
| 18 |
+
max_imf = len(IMF)
|
| 19 |
+
for i in range(len(IMF)):
|
| 20 |
+
r_s[i] = IMF[len(IMF)-1-i]
|
| 21 |
+
if(len(IMF)==0): r_s[0] = s
|
| 22 |
+
emd_data[:,:,ci] = r_s
|
| 23 |
+
if max_imf < n_imf:
|
| 24 |
+
emd_data = emd_data[:max_imf,:,:]
|
| 25 |
+
train_data_new = np.zeros((len(data)-sequence_length+1,max_imf,sequence_length,channel_num))
|
| 26 |
+
for i in range(len(data)-sequence_length+1):
|
| 27 |
+
train_data_new[i] = emd_data[:,i:i+sequence_length,:]
|
| 28 |
+
return train_data_new
|
figures/main_result.png
ADDED
|
Git LFS Details
|
figures/main_result2.png
ADDED
|
Git LFS Details
|
figures/overview.png
ADDED
|
Git LFS Details
|
main_run/train.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# The code is originated from
|
| 3 |
+
# Chen, M., Xu, Z., Zeng, A., & Xu, Q. (2023). "FrAug: Frequency Domain Augmentation for Time Series Forecasting".
|
| 4 |
+
# arXiv preprint arXiv:2302.09292.
|
| 5 |
+
# =============================================================================
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from dataset_loader.datasetloader import data_provider
|
| 13 |
+
from models import DLinear
|
| 14 |
+
from utils.tools import EarlyStopping, adjust_learning_rate, visual, test_params_flop
|
| 15 |
+
from utils.metrics import metric
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from torch import optim
|
| 20 |
+
import time
|
| 21 |
+
import warnings
|
| 22 |
+
import matplotlib.pyplot as plt
|
| 23 |
+
from augmentation.aug import augmentation
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Exp_Basic(object):
|
| 27 |
+
def __init__(self, args):
|
| 28 |
+
self.args = args
|
| 29 |
+
self.device = self._acquire_device()
|
| 30 |
+
self.model = self._build_model().to(self.device)
|
| 31 |
+
|
| 32 |
+
def _build_model(self):
|
| 33 |
+
raise NotImplementedError
|
| 34 |
+
return None
|
| 35 |
+
|
| 36 |
+
def _acquire_device(self):
|
| 37 |
+
if self.args.use_gpu:
|
| 38 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(
|
| 39 |
+
self.args.gpu) if not self.args.use_multi_gpu else self.args.devices
|
| 40 |
+
device = torch.device('cuda:{}'.format(self.args.gpu))
|
| 41 |
+
print('Use GPU: cuda:{}'.format(self.args.gpu))
|
| 42 |
+
else:
|
| 43 |
+
device = torch.device('cpu')
|
| 44 |
+
print('Use CPU')
|
| 45 |
+
return device
|
| 46 |
+
|
| 47 |
+
def _get_data(self):
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
def vali(self):
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
def train(self):
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
def test(self):
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
warnings.filterwarnings('ignore')
|
| 60 |
+
|
| 61 |
+
TYPES = {0: 'None',
|
| 62 |
+
1: 'Freq-Mask',
|
| 63 |
+
2: 'Freq-Mix',
|
| 64 |
+
3: 'Wave-Mask',
|
| 65 |
+
4: 'Wave-Mix',
|
| 66 |
+
5: 'StAug'}
|
| 67 |
+
|
| 68 |
+
class Exp_Main(Exp_Basic):
|
| 69 |
+
|
| 70 |
+
def __init__(self, args):
|
| 71 |
+
super(Exp_Main, self).__init__(args)
|
| 72 |
+
|
| 73 |
+
def _build_model(self):
|
| 74 |
+
|
| 75 |
+
model_dict = {
|
| 76 |
+
'DLinear': DLinear
|
| 77 |
+
}
|
| 78 |
+
model = model_dict[self.args.model].Model(self.args).float()
|
| 79 |
+
if self.args.use_multi_gpu and self.args.use_gpu:
|
| 80 |
+
model = nn.DataParallel(model, device_ids=self.args.device_ids)
|
| 81 |
+
return model
|
| 82 |
+
|
| 83 |
+
def _get_data(self, flag):
|
| 84 |
+
data_set, data_loader = data_provider(self.args, flag)
|
| 85 |
+
return data_set, data_loader
|
| 86 |
+
|
| 87 |
+
def _select_optimizer(self):
|
| 88 |
+
model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
|
| 89 |
+
return model_optim
|
| 90 |
+
|
| 91 |
+
def _select_criterion(self):
|
| 92 |
+
criterion = nn.MSELoss()
|
| 93 |
+
return criterion
|
| 94 |
+
|
| 95 |
+
def vali(self, vali_data, vali_loader, criterion):
|
| 96 |
+
|
| 97 |
+
total_loss = []
|
| 98 |
+
self.model.eval()
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
for i, (batch_x, batch_y, _) in enumerate(vali_loader):
|
| 101 |
+
batch_x = batch_x.float().to(self.device)
|
| 102 |
+
batch_y = batch_y.float()
|
| 103 |
+
|
| 104 |
+
outputs = self.model(batch_x)
|
| 105 |
+
|
| 106 |
+
f_dim = -1 if self.args.features == 'MS' else 0
|
| 107 |
+
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
| 108 |
+
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
|
| 109 |
+
|
| 110 |
+
pred = outputs.detach().cpu()
|
| 111 |
+
true = batch_y.detach().cpu()
|
| 112 |
+
|
| 113 |
+
loss = criterion(pred, true)
|
| 114 |
+
|
| 115 |
+
total_loss.append(loss)
|
| 116 |
+
total_loss = np.average(total_loss)
|
| 117 |
+
self.model.train()
|
| 118 |
+
return total_loss
|
| 119 |
+
|
| 120 |
+
def train(self, setting):
|
| 121 |
+
train_data, train_loader = self._get_data(flag='train')
|
| 122 |
+
vali_data, vali_loader = self._get_data(flag='val')
|
| 123 |
+
test_data, test_loader = self._get_data(flag='test')
|
| 124 |
+
|
| 125 |
+
path = os.path.join(self.args.checkpoints, setting)
|
| 126 |
+
if not os.path.exists(path):
|
| 127 |
+
os.makedirs(path)
|
| 128 |
+
|
| 129 |
+
train_steps = len(train_loader)
|
| 130 |
+
early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
|
| 131 |
+
time_now = time.time()
|
| 132 |
+
|
| 133 |
+
model_optim = self._select_optimizer()
|
| 134 |
+
criterion = self._select_criterion()
|
| 135 |
+
time_now = time.time()
|
| 136 |
+
|
| 137 |
+
for epoch in range(self.args.train_epochs):
|
| 138 |
+
iter_count = 0
|
| 139 |
+
train_loss = []
|
| 140 |
+
|
| 141 |
+
self.model.train()
|
| 142 |
+
epoch_time = time.time()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
for i, (batch_x, batch_y, aug_data) in enumerate(train_loader):
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
iter_count += 1
|
| 149 |
+
model_optim.zero_grad()
|
| 150 |
+
if self.args.aug_type == 5:
|
| 151 |
+
aug_data = aug_data.float().to(self.device)
|
| 152 |
+
else:
|
| 153 |
+
aug_data = None
|
| 154 |
+
|
| 155 |
+
if self.args.aug_type:
|
| 156 |
+
aug = augmentation()
|
| 157 |
+
if self.args.aug_type == 1:
|
| 158 |
+
xy = aug.freq_mask(batch_x, batch_y[:, -self.args.pred_len:, :], rate=self.args.aug_rate, dim=1)
|
| 159 |
+
batch_x2, batch_y2 = xy[:, :self.args.seq_len, :], xy[:, -self.args.label_len-self.args.pred_len:, :]
|
| 160 |
+
batch_x = torch.cat([batch_x,batch_x2],dim=0)
|
| 161 |
+
batch_y = torch.cat([batch_y,batch_y2],dim=0)
|
| 162 |
+
elif self.args.aug_type == 2:
|
| 163 |
+
xy = aug.freq_mix(batch_x, batch_y[:, -self.args.pred_len:, :], rate=self.args.aug_rate, dim=1)
|
| 164 |
+
batch_x2, batch_y2 = xy[:, :self.args.seq_len, :], xy[:, -self.args.label_len-self.args.pred_len:, :]
|
| 165 |
+
batch_x = torch.cat([batch_x,batch_x2],dim=0)
|
| 166 |
+
batch_y = torch.cat([batch_y,batch_y2],dim=0)
|
| 167 |
+
elif self.args.aug_type == 3:
|
| 168 |
+
xy = aug.wave_mask(batch_x, batch_y[:, -self.args.pred_len:, :] ,rates = self.args.rates, wavelet =self.args.wavelet, level = self.args.level, dim = 1)
|
| 169 |
+
batch_x2, batch_y2 = xy[:, :self.args.seq_len, :], xy[:, -self.args.label_len-self.args.pred_len:, :]
|
| 170 |
+
sampling_steps = int(batch_x2.shape[0] * self.args.sampling_rate)
|
| 171 |
+
indices = torch.randperm(batch_x2.shape[0])[:sampling_steps]
|
| 172 |
+
batch_x2 = batch_x2[indices,:,:]
|
| 173 |
+
batch_y2 = batch_y2[indices,:,:]
|
| 174 |
+
batch_x = torch.cat([batch_x,batch_x2],dim=0)
|
| 175 |
+
batch_y = torch.cat([batch_y,batch_y2],dim=0)
|
| 176 |
+
elif self.args.aug_type == 4:
|
| 177 |
+
batch_x = batch_x.float().to(self.device)
|
| 178 |
+
batch_y = batch_y.float().to(self.device)
|
| 179 |
+
xy = aug.wave_mix(batch_x, batch_y[:, -self.args.pred_len:, :] ,rates = self.args.rates, wavelet = self.args.wavelet, level = self.args.level, dim = 1)
|
| 180 |
+
batch_x2, batch_y2 = xy[:, :self.args.seq_len, :], xy[:, -self.args.label_len-self.args.pred_len:, :]
|
| 181 |
+
sampling_steps = int(batch_x2.shape[0] * self.args.sampling_rate)
|
| 182 |
+
indices = torch.randperm(batch_x2.shape[0])[:sampling_steps]
|
| 183 |
+
batch_x2 = batch_x2[indices,:,:]
|
| 184 |
+
batch_y2 = batch_y2[indices,:,:]
|
| 185 |
+
batch_x = torch.cat([batch_x,batch_x2],dim=0)
|
| 186 |
+
batch_y = torch.cat([batch_y,batch_y2],dim=0)
|
| 187 |
+
elif self.args.aug_type == 5:
|
| 188 |
+
batch_x = batch_x.float().to(self.device)
|
| 189 |
+
batch_y = batch_y.float().to(self.device)
|
| 190 |
+
weighted_xy = aug.emd_aug(aug_data)
|
| 191 |
+
weighted_x, weighted_y = weighted_xy[:,:self.args.seq_len,:], weighted_xy[:,-self.args.label_len-self.args.pred_len:,:]
|
| 192 |
+
batch_x, batch_y = aug.mix_aug(weighted_x, weighted_y, lambd = self.args.aug_rate)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
batch_x = batch_x.float().to(self.device)
|
| 196 |
+
batch_y = batch_y.float().to(self.device)
|
| 197 |
+
|
| 198 |
+
outputs = self.model(batch_x)
|
| 199 |
+
|
| 200 |
+
f_dim = -1 if self.args.features == 'MS' else 0
|
| 201 |
+
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
| 202 |
+
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
|
| 203 |
+
loss = criterion(outputs, batch_y)
|
| 204 |
+
train_loss.append(loss.item())
|
| 205 |
+
|
| 206 |
+
if (i + 1) % 100 == 0:
|
| 207 |
+
print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
|
| 208 |
+
speed = (time.time() - time_now) / iter_count
|
| 209 |
+
left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
|
| 210 |
+
print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
|
| 211 |
+
iter_count = 0
|
| 212 |
+
time_now = time.time()
|
| 213 |
+
|
| 214 |
+
loss.backward()
|
| 215 |
+
model_optim.step()
|
| 216 |
+
|
| 217 |
+
print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
|
| 218 |
+
train_loss = np.average(train_loss)
|
| 219 |
+
vali_loss = self.vali(vali_data, vali_loader, criterion)
|
| 220 |
+
test_loss = self.vali(test_data, test_loader, criterion)
|
| 221 |
+
|
| 222 |
+
print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
|
| 223 |
+
epoch + 1, train_steps, train_loss, vali_loss, test_loss))
|
| 224 |
+
|
| 225 |
+
early_stopping(vali_loss, self.model, path)
|
| 226 |
+
if early_stopping.early_stop:
|
| 227 |
+
print("Early stopping")
|
| 228 |
+
break
|
| 229 |
+
|
| 230 |
+
adjust_learning_rate(model_optim, epoch + 1, self.args)
|
| 231 |
+
|
| 232 |
+
best_model_path = path + '/' + 'checkpoint.pth'
|
| 233 |
+
self.model.load_state_dict(torch.load(best_model_path))
|
| 234 |
+
min_val_loss = early_stopping.get_val_loss_min()
|
| 235 |
+
|
| 236 |
+
return self.model, min_val_loss
|
| 237 |
+
|
| 238 |
+
def test(self, setting, test=1):
|
| 239 |
+
test_data, test_loader = self._get_data(flag='test')
|
| 240 |
+
|
| 241 |
+
if test:
|
| 242 |
+
print('loading model')
|
| 243 |
+
self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))
|
| 244 |
+
|
| 245 |
+
preds = []
|
| 246 |
+
trues = []
|
| 247 |
+
inputx = []
|
| 248 |
+
|
| 249 |
+
self.model.eval()
|
| 250 |
+
with torch.no_grad():
|
| 251 |
+
for i, (batch_x, batch_y, _) in enumerate(test_loader):
|
| 252 |
+
batch_x = batch_x.float().to(self.device)
|
| 253 |
+
batch_y = batch_y.float().to(self.device)
|
| 254 |
+
|
| 255 |
+
outputs = self.model(batch_x)
|
| 256 |
+
|
| 257 |
+
f_dim = -1 if self.args.features == 'MS' else 0
|
| 258 |
+
outputs = outputs[:, -self.args.pred_len:, f_dim:]
|
| 259 |
+
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
|
| 260 |
+
outputs = outputs.detach().cpu().numpy()
|
| 261 |
+
batch_y = batch_y.detach().cpu().numpy()
|
| 262 |
+
|
| 263 |
+
pred = outputs #
|
| 264 |
+
true = batch_y
|
| 265 |
+
|
| 266 |
+
preds.append(pred)
|
| 267 |
+
trues.append(true)
|
| 268 |
+
inputx.append(batch_x.detach().cpu().numpy())
|
| 269 |
+
|
| 270 |
+
if self.args.test_flop:
|
| 271 |
+
test_params_flop((batch_x.shape[1],batch_x.shape[2]))
|
| 272 |
+
exit()
|
| 273 |
+
|
| 274 |
+
preds = np.array(preds)
|
| 275 |
+
trues = np.array(trues)
|
| 276 |
+
inputx = np.array(inputx)
|
| 277 |
+
|
| 278 |
+
preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
|
| 279 |
+
trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
|
| 280 |
+
inputx = inputx.reshape(-1, inputx.shape[-2], inputx.shape[-1])
|
| 281 |
+
|
| 282 |
+
mae, mse, rmse, mape, mspe, rse, corr = metric(preds, trues)
|
| 283 |
+
print('mse:{}, mae:{}, rse:{}, corr:{}'.format(mse, mae, rse, corr))
|
| 284 |
+
f = open(self.args.des + self.args.data + ".txt", 'a')
|
| 285 |
+
f.write(" \n")
|
| 286 |
+
f.write('{} --- Pred {} -> mse:{}, mae:{}, rse:{}'.format(TYPES[self.args.aug_type], self.args.pred_len, mse, mae, rse))
|
| 287 |
+
f.write('\n')
|
| 288 |
+
f.close()
|
| 289 |
+
|
| 290 |
+
return mse, mae, rse
|
models/DLinear.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# The code is originated from
|
| 3 |
+
# Chen, M., Xu, Z., Zeng, A., & Xu, Q. (2023). "FrAug: Frequency Domain Augmentation for Time Series Forecasting".
|
| 4 |
+
# arXiv preprint arXiv:2302.09292.
|
| 5 |
+
# =============================================================================
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
class moving_avg(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Moving average block to highlight the trend of time series
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, kernel_size, stride):
|
| 15 |
+
super(moving_avg, self).__init__()
|
| 16 |
+
self.kernel_size = kernel_size
|
| 17 |
+
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
# padding on the both ends of time series
|
| 21 |
+
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
| 22 |
+
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
| 23 |
+
x = torch.cat([front, x, end], dim=1)
|
| 24 |
+
x = self.avg(x.permute(0, 2, 1))
|
| 25 |
+
x = x.permute(0, 2, 1)
|
| 26 |
+
return x
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class series_decomp(nn.Module):
|
| 30 |
+
"""
|
| 31 |
+
Series decomposition block
|
| 32 |
+
"""
|
| 33 |
+
def __init__(self, kernel_size):
|
| 34 |
+
super(series_decomp, self).__init__()
|
| 35 |
+
self.moving_avg = moving_avg(kernel_size, stride=1)
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
moving_mean = self.moving_avg(x)
|
| 39 |
+
res = x - moving_mean
|
| 40 |
+
return res, moving_mean
|
| 41 |
+
|
| 42 |
+
class Model(nn.Module):
|
| 43 |
+
"""
|
| 44 |
+
Decomposition-Linear
|
| 45 |
+
"""
|
| 46 |
+
def __init__(self, args):
|
| 47 |
+
super(Model, self).__init__()
|
| 48 |
+
self.seq_len = args.seq_len
|
| 49 |
+
self.pred_len = args.pred_len
|
| 50 |
+
|
| 51 |
+
# Decompsition Kernel Size
|
| 52 |
+
kernel_size = 25
|
| 53 |
+
self.decompsition = series_decomp(kernel_size)
|
| 54 |
+
self.individual = args.individual
|
| 55 |
+
self.channels = args.enc_in
|
| 56 |
+
|
| 57 |
+
if self.individual:
|
| 58 |
+
self.Linear_Seasonal = nn.ModuleList()
|
| 59 |
+
self.Linear_Trend = nn.ModuleList()
|
| 60 |
+
|
| 61 |
+
for i in range(self.channels):
|
| 62 |
+
self.Linear_Seasonal.append(nn.Linear(self.seq_len,self.pred_len))
|
| 63 |
+
self.Linear_Trend.append(nn.Linear(self.seq_len,self.pred_len))
|
| 64 |
+
|
| 65 |
+
# Use this two lines if you want to visualize the weights
|
| 66 |
+
# self.Linear_Seasonal[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
|
| 67 |
+
# self.Linear_Trend[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
|
| 68 |
+
else:
|
| 69 |
+
self.Linear_Seasonal = nn.Linear(self.seq_len,self.pred_len)
|
| 70 |
+
self.Linear_Trend = nn.Linear(self.seq_len,self.pred_len)
|
| 71 |
+
|
| 72 |
+
# Use this two lines if you want to visualize weights
|
| 73 |
+
# self.Linear_Seasonal.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
|
| 74 |
+
# self.Linear_Trend.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
# x: [Batch, Input length, Channel]
|
| 78 |
+
seasonal_init, trend_init = self.decompsition(x)
|
| 79 |
+
seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1)
|
| 80 |
+
if self.individual:
|
| 81 |
+
seasonal_output = torch.zeros([seasonal_init.size(0),seasonal_init.size(1),self.pred_len],dtype=seasonal_init.dtype).to(seasonal_init.device)
|
| 82 |
+
trend_output = torch.zeros([trend_init.size(0),trend_init.size(1),self.pred_len],dtype=trend_init.dtype).to(trend_init.device)
|
| 83 |
+
for i in range(self.channels):
|
| 84 |
+
seasonal_output[:,i,:] = self.Linear_Seasonal[i](seasonal_init[:,i,:])
|
| 85 |
+
trend_output[:,i,:] = self.Linear_Trend[i](trend_init[:,i,:])
|
| 86 |
+
else:
|
| 87 |
+
seasonal_output = self.Linear_Seasonal(seasonal_init)
|
| 88 |
+
trend_output = self.Linear_Trend(trend_init)
|
| 89 |
+
|
| 90 |
+
x = seasonal_output + trend_output
|
| 91 |
+
return x.permute(0,2,1) # to [Batch, Output length, Channel]
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy
|
| 2 |
+
matplotlib
|
| 3 |
+
pandas
|
| 4 |
+
scikit-learn
|
| 5 |
+
torch==1.9.0
|
| 6 |
+
einops
|
| 7 |
+
pywavelets
|
| 8 |
+
EMD-signal
|
| 9 |
+
tensorflow==2.11.0
|
| 10 |
+
pytorch_wavelets
|
run_main.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# The code is originated from
|
| 3 |
+
# Chen, M., Xu, Z., Zeng, A., & Xu, Q. (2023). "FrAug: Frequency Domain Augmentation for Time Series Forecasting".
|
| 4 |
+
# arXiv preprint arXiv:2302.09292.
|
| 5 |
+
# =============================================================================
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import os
|
| 9 |
+
import ast
|
| 10 |
+
import torch
|
| 11 |
+
from main_run.train import Exp_Main
|
| 12 |
+
import random
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
fix_seed = 2024
|
| 16 |
+
random.seed(fix_seed)
|
| 17 |
+
torch.manual_seed(fix_seed)
|
| 18 |
+
np.random.seed(fix_seed)
|
| 19 |
+
|
| 20 |
+
TYPES = {0: 'None',
|
| 21 |
+
1: 'Freq-Mask',
|
| 22 |
+
2: 'Freq-Mix',
|
| 23 |
+
3: 'Wave-Mask',
|
| 24 |
+
4: 'Wave-Mix',
|
| 25 |
+
5: 'StAug'}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
parser = argparse.ArgumentParser(description='Augmentations for Time Series Forecasting')
|
| 29 |
+
|
| 30 |
+
# basic config
|
| 31 |
+
parser.add_argument('--model', type=str, required=True, default='DLinear',
|
| 32 |
+
help='model name, options: [DLinear]')
|
| 33 |
+
|
| 34 |
+
# data loader
|
| 35 |
+
parser.add_argument('--is_training', type=int, required=True, default=1, help='status')
|
| 36 |
+
parser.add_argument('--data', type=str, required=True, default='ETTh1', help='dataset type')
|
| 37 |
+
parser.add_argument('--root_path', type=str, default='./dataset/', help='root path of the data file')
|
| 38 |
+
parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
|
| 39 |
+
parser.add_argument('--features', type=str, default='M',
|
| 40 |
+
help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
|
| 41 |
+
parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
|
| 42 |
+
parser.add_argument('--freq', type=str, default='h',
|
| 43 |
+
help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
|
| 44 |
+
parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')
|
| 45 |
+
parser.add_argument('--percentage', type=int, default=100, help='percentage of train data as a downsampling ratio')
|
| 46 |
+
parser.add_argument('--patience', type=int, default=12, help='early stopping patience')
|
| 47 |
+
|
| 48 |
+
# forecasting task
|
| 49 |
+
parser.add_argument('--seq_len', type=int, default=336, help='input sequence length')
|
| 50 |
+
parser.add_argument('--label_len', type=int, default=0, help='start token length')
|
| 51 |
+
parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# DLinear
|
| 55 |
+
parser.add_argument('--individual', action='store_true', default=False, help='DLinear: a linear layer for each variate(channel) individually')
|
| 56 |
+
parser.add_argument('--enc_in', type=int, default=7, help='encoder input size') # DLinear with --individual, use this hyperparameter as the number of channels
|
| 57 |
+
|
| 58 |
+
# optimization
|
| 59 |
+
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
|
| 60 |
+
parser.add_argument('--itr', type=int, default=2, help='experiments times')
|
| 61 |
+
parser.add_argument('--train_epochs', type=int, default=30, help='train epochs')
|
| 62 |
+
parser.add_argument('--batch_size', type=int, default=64, help='batch size of train input data')
|
| 63 |
+
parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')
|
| 64 |
+
parser.add_argument('--des', type=str, default='test', help='exp description')
|
| 65 |
+
parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')
|
| 66 |
+
parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)
|
| 67 |
+
|
| 68 |
+
# GPU
|
| 69 |
+
parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
|
| 70 |
+
parser.add_argument('--gpu', type=int, default=0, help='gpu')
|
| 71 |
+
parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
|
| 72 |
+
parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus')
|
| 73 |
+
parser.add_argument('--test_flop', action='store_true', default=False, help='See utils/tools for usage')
|
| 74 |
+
|
| 75 |
+
# Augmentation
|
| 76 |
+
parser.add_argument('--aug_type', type=int, default=0, help='0: No augmentation, 1: Frequency Masking 2: Frequency Mixing 3: Wave Masking 4: Wave Mixing 5: StAug ')
|
| 77 |
+
parser.add_argument('--aug_rate', type=float, default=0.5, help='rate for FreqMask, FreqMix, and STAug')
|
| 78 |
+
parser.add_argument('--wavelet', type=str, default='db2', help='wavelet form for DWT')
|
| 79 |
+
parser.add_argument('--level', type=int, default=2, help='level for DWT')
|
| 80 |
+
parser.add_argument('--rates', type=str, default="[0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]",
|
| 81 |
+
help='List of float rates as a string, e.g., "[0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]"')
|
| 82 |
+
parser.add_argument('--nIMF', type=int, default=500, help='number of IMFs for EMD (STAug)')
|
| 83 |
+
parser.add_argument('--sampling_rate', type=float, default=0.5, help='sampling rate for WaveMask and WaveMix')
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
args = parser.parse_args()
|
| 89 |
+
args.rates = ast.literal_eval(args.rates)
|
| 90 |
+
|
| 91 |
+
args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
|
| 92 |
+
|
| 93 |
+
if args.use_gpu and args.use_multi_gpu:
|
| 94 |
+
args.dvices = args.devices.replace(' ', '')
|
| 95 |
+
device_ids = args.devices.split(',')
|
| 96 |
+
args.device_ids = [int(id_) for id_ in device_ids]
|
| 97 |
+
args.gpu = args.device_ids[0]
|
| 98 |
+
|
| 99 |
+
print('Args in experiment:')
|
| 100 |
+
print(args)
|
| 101 |
+
|
| 102 |
+
Exp = Exp_Main
|
| 103 |
+
|
| 104 |
+
if args.is_training:
|
| 105 |
+
mse_avg, mae_avg, rse_avg = np.zeros(args.itr), np.zeros(args.itr), np.zeros(args.itr)
|
| 106 |
+
for ii in range(args.itr):
|
| 107 |
+
# setting record of experiments
|
| 108 |
+
setting = '{}_{}_ft{}_sl{}_ll{}_pl{}_{}_{}_{}_{}'.format(
|
| 109 |
+
args.model,
|
| 110 |
+
args.data,
|
| 111 |
+
args.features,
|
| 112 |
+
args.seq_len,
|
| 113 |
+
args.label_len,
|
| 114 |
+
args.pred_len,
|
| 115 |
+
args.des, args.aug_type, args.percentage, ii)
|
| 116 |
+
|
| 117 |
+
exp = Exp(args) # set experiments
|
| 118 |
+
print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
|
| 119 |
+
exp.train(setting)
|
| 120 |
+
|
| 121 |
+
print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
|
| 122 |
+
mse, mae, rse = exp.test(setting)
|
| 123 |
+
mse_avg[ii] = mse
|
| 124 |
+
mae_avg[ii] = mae
|
| 125 |
+
rse_avg[ii] = rse
|
| 126 |
+
|
| 127 |
+
f = open("result-" + args.des + args.data + ".txt", 'a')
|
| 128 |
+
f.write('\n')
|
| 129 |
+
f.write('\n')
|
| 130 |
+
f.write("-------START FROM HERE-----")
|
| 131 |
+
f.write(TYPES[args.aug_type] + " \n")
|
| 132 |
+
f.write('avg mse:{}, avg mae:{} avg rse:{} std mse:{}, std mae:{} std rse:{}'.format(mse_avg.mean(), mae_avg.mean(), rse_avg.mean(), mse_avg.std(), mae_avg.std(), rse_avg.std()))
|
| 133 |
+
f.write('\n')
|
| 134 |
+
f.write('\n')
|
| 135 |
+
f.close()
|
| 136 |
+
torch.cuda.empty_cache()
|
| 137 |
+
else:
|
| 138 |
+
ii = 0
|
| 139 |
+
setting = '{}_{}_ft{}_sl{}_ll{}_pl{}_{}_{}_{}'.format(args.model,
|
| 140 |
+
args.data,
|
| 141 |
+
args.features,
|
| 142 |
+
args.seq_len,
|
| 143 |
+
args.label_len,
|
| 144 |
+
args.pred_len,
|
| 145 |
+
args.des, args.aug_type, ii)
|
| 146 |
+
|
| 147 |
+
exp = Exp(args) # set experiments
|
| 148 |
+
print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
|
| 149 |
+
exp.test(setting, test=1)
|
| 150 |
+
torch.cuda.empty_cache()
|
scripts/etth1.sh
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
if [ ! -d "./logs" ]; then
|
| 3 |
+
mkdir ./logs
|
| 4 |
+
fi
|
| 5 |
+
|
| 6 |
+
if [ ! -d "./logs/onerun" ]; then
|
| 7 |
+
mkdir ./logs/onerun
|
| 8 |
+
fi
|
| 9 |
+
|
| 10 |
+
seq_len=336
|
| 11 |
+
percentage=100
|
| 12 |
+
model_name=DLinear
|
| 13 |
+
|
| 14 |
+
# aug 0: None
|
| 15 |
+
# aug 1: Frequency Masking
|
| 16 |
+
# aug 2: Frequency Mixing
|
| 17 |
+
# aug 3: Wave Masking
|
| 18 |
+
# aug 4: Wave Mixing
|
| 19 |
+
# aug 5: STAug
|
| 20 |
+
|
| 21 |
+
pred_lens=(96 192 336 720)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# For Aug 0: None
|
| 25 |
+
|
| 26 |
+
for pred_len in "${pred_lens[@]}"; do
|
| 27 |
+
python3 -u ./run_main.py \
|
| 28 |
+
--is_training 1 \
|
| 29 |
+
--root_path ./dataset/ \
|
| 30 |
+
--data_path ETTh1.csv \
|
| 31 |
+
--model $model_name \
|
| 32 |
+
--data ETTh1 \
|
| 33 |
+
--features M \
|
| 34 |
+
--seq_len $seq_len \
|
| 35 |
+
--pred_len $pred_len \
|
| 36 |
+
--enc_in 7 \
|
| 37 |
+
--des '100p-h1-' \
|
| 38 |
+
--percentage $percentage \
|
| 39 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 0 --aug_rate 0.0 >logs/onerun/$model_name'_'Etth1_$seq_len'_'$pred_len'_'0.0'_'$percentage'_'None.log
|
| 40 |
+
done
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# For Aug 1: Freq-Masking
|
| 44 |
+
|
| 45 |
+
python3 -u ./run_main.py \
|
| 46 |
+
--is_training 1 \
|
| 47 |
+
--root_path ./dataset/ \
|
| 48 |
+
--data_path ETTh1.csv \
|
| 49 |
+
--model $model_name \
|
| 50 |
+
--data ETTh1 \
|
| 51 |
+
--features M \
|
| 52 |
+
--seq_len $seq_len \
|
| 53 |
+
--pred_len 96 \
|
| 54 |
+
--enc_in 7 \
|
| 55 |
+
--des '100p-h1-' \
|
| 56 |
+
--percentage $percentage \
|
| 57 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'Etth1_$seq_len'_'96'_'0.4'_'$percentage'_'FreqMask.log
|
| 58 |
+
|
| 59 |
+
python3 -u ./run_main.py \
|
| 60 |
+
--is_training 1 \
|
| 61 |
+
--root_path ./dataset/ \
|
| 62 |
+
--data_path ETTh1.csv \
|
| 63 |
+
--model $model_name \
|
| 64 |
+
--data ETTh1 \
|
| 65 |
+
--features M \
|
| 66 |
+
--seq_len $seq_len \
|
| 67 |
+
--pred_len 192 \
|
| 68 |
+
--enc_in 7 \
|
| 69 |
+
--des '100p-h1-' \
|
| 70 |
+
--percentage $percentage \
|
| 71 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.5 >logs/onerun/$model_name'_'Etth1_$seq_len'_'192'_'0.4'_'$percentage'_'FreqMask.log
|
| 72 |
+
|
| 73 |
+
python3 -u ./run_main.py \
|
| 74 |
+
--is_training 1 \
|
| 75 |
+
--root_path ./dataset/ \
|
| 76 |
+
--data_path ETTh1.csv \
|
| 77 |
+
--model $model_name \
|
| 78 |
+
--data ETTh1 \
|
| 79 |
+
--features M \
|
| 80 |
+
--seq_len $seq_len \
|
| 81 |
+
--pred_len 336 \
|
| 82 |
+
--enc_in 7 \
|
| 83 |
+
--des '100p-h1-' \
|
| 84 |
+
--percentage $percentage \
|
| 85 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.5 >logs/onerun/$model_name'_'Etth1_$seq_len'_'336'_'0.4'_'$percentage'_'FreqMask.log
|
| 86 |
+
|
| 87 |
+
python3 -u ./run_main.py \
|
| 88 |
+
--is_training 1 \
|
| 89 |
+
--root_path ./dataset/ \
|
| 90 |
+
--data_path ETTh1.csv \
|
| 91 |
+
--model $model_name \
|
| 92 |
+
--data ETTh1 \
|
| 93 |
+
--features M \
|
| 94 |
+
--seq_len $seq_len \
|
| 95 |
+
--pred_len 720 \
|
| 96 |
+
--enc_in 7 \
|
| 97 |
+
--des '100p-h1-' \
|
| 98 |
+
--percentage $percentage \
|
| 99 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.4 >logs/onerun/$model_name'_'Etth1_$seq_len'_'720'_'0.4'_'$percentage'_'FreqMask.log
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# For Aug 2: Freq-Mixing
|
| 104 |
+
|
| 105 |
+
python3 -u ./run_main.py \
|
| 106 |
+
--is_training 1 \
|
| 107 |
+
--root_path ./dataset/ \
|
| 108 |
+
--data_path ETTh1.csv \
|
| 109 |
+
--model $model_name \
|
| 110 |
+
--data ETTh1 \
|
| 111 |
+
--features M \
|
| 112 |
+
--seq_len $seq_len \
|
| 113 |
+
--pred_len 96 \
|
| 114 |
+
--enc_in 7 \
|
| 115 |
+
--des '100p-h1-' \
|
| 116 |
+
--percentage $percentage \
|
| 117 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.2 >logs/onerun/$model_name'_'Etth1_$seq_len'_'96'_'0.4'_'$percentage'_'FreqMix.log
|
| 118 |
+
|
| 119 |
+
python3 -u ./run_main.py \
|
| 120 |
+
--is_training 1 \
|
| 121 |
+
--root_path ./dataset/ \
|
| 122 |
+
--data_path ETTh1.csv \
|
| 123 |
+
--model $model_name \
|
| 124 |
+
--data ETTh1 \
|
| 125 |
+
--features M \
|
| 126 |
+
--seq_len $seq_len \
|
| 127 |
+
--pred_len 192 \
|
| 128 |
+
--enc_in 7 \
|
| 129 |
+
--des '100p-h1-' \
|
| 130 |
+
--percentage $percentage \
|
| 131 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'Etth1_$seq_len'_'192'_'0.4'_'$percentage'_'FreqMix.log
|
| 132 |
+
|
| 133 |
+
python3 -u ./run_main.py \
|
| 134 |
+
--is_training 1 \
|
| 135 |
+
--root_path ./dataset/ \
|
| 136 |
+
--data_path ETTh1.csv \
|
| 137 |
+
--model $model_name \
|
| 138 |
+
--data ETTh1 \
|
| 139 |
+
--features M \
|
| 140 |
+
--seq_len $seq_len \
|
| 141 |
+
--pred_len 336 \
|
| 142 |
+
--enc_in 7 \
|
| 143 |
+
--des '100p-h1-' \
|
| 144 |
+
--percentage $percentage \
|
| 145 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'Etth1_$seq_len'_'336'_'0.4'_'$percentage'_'FreqMix.log
|
| 146 |
+
|
| 147 |
+
python3 -u ./run_main.py \
|
| 148 |
+
--is_training 1 \
|
| 149 |
+
--root_path ./dataset/ \
|
| 150 |
+
--data_path ETTh1.csv \
|
| 151 |
+
--model $model_name \
|
| 152 |
+
--data ETTh1 \
|
| 153 |
+
--features M \
|
| 154 |
+
--seq_len $seq_len \
|
| 155 |
+
--pred_len 720 \
|
| 156 |
+
--enc_in 7 \
|
| 157 |
+
--des '100p-h1-' \
|
| 158 |
+
--percentage $percentage \
|
| 159 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.6 >logs/onerun/$model_name'_'Etth1_$seq_len'_'720'_'0.4'_'$percentage'_'FreqMix.log
|
| 160 |
+
|
| 161 |
+
# For Aug 3: Wave Masking
|
| 162 |
+
|
| 163 |
+
python3 -u ./run_main.py \
|
| 164 |
+
--is_training 1 \
|
| 165 |
+
--root_path ./dataset/ \
|
| 166 |
+
--data_path ETTh1.csv \
|
| 167 |
+
--model $model_name \
|
| 168 |
+
--data ETTh1 \
|
| 169 |
+
--features M \
|
| 170 |
+
--seq_len $seq_len \
|
| 171 |
+
--pred_len 96 \
|
| 172 |
+
--enc_in 7 \
|
| 173 |
+
--des '100p-h1-' \
|
| 174 |
+
--percentage $percentage \
|
| 175 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.5, 0.3, 0.9, 0.9, 0.0, 0.5, 0.3]" --wavelet 'db2' --level 3 --sampling_rate 0.2 >logs/onerun/$model_name'_'Etth1_$seq_len'_'96'_'0.0'_'$percentage'_'WaveMask.log
|
| 176 |
+
|
| 177 |
+
python3 -u ./run_main.py \
|
| 178 |
+
--is_training 1 \
|
| 179 |
+
--root_path ./dataset/ \
|
| 180 |
+
--data_path ETTh1.csv \
|
| 181 |
+
--model $model_name \
|
| 182 |
+
--data ETTh1 \
|
| 183 |
+
--features M \
|
| 184 |
+
--seq_len $seq_len \
|
| 185 |
+
--pred_len 192 \
|
| 186 |
+
--enc_in 7 \
|
| 187 |
+
--des '100p-h1-' \
|
| 188 |
+
--percentage $percentage \
|
| 189 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.0, 1.0, 0.2, 1.0, 0.4, 0.9, 0.3]" --wavelet 'db3' --level 1 --sampling_rate 0.2 >logs/onerun/$model_name'_'Etth1_$seq_len'_'192'_'0.0'_'$percentage'_'WaveMask.log
|
| 190 |
+
|
| 191 |
+
python3 -u ./run_main.py \
|
| 192 |
+
--is_training 1 \
|
| 193 |
+
--root_path ./dataset/ \
|
| 194 |
+
--data_path ETTh1.csv \
|
| 195 |
+
--model $model_name \
|
| 196 |
+
--data ETTh1 \
|
| 197 |
+
--features M \
|
| 198 |
+
--seq_len $seq_len \
|
| 199 |
+
--pred_len 336 \
|
| 200 |
+
--enc_in 7 \
|
| 201 |
+
--des '100p-h1-' \
|
| 202 |
+
--percentage $percentage \
|
| 203 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.1, 0.9, 0.4, 0.8, 0.4, 0.9, 0.3]" --wavelet 'db25' --level 1 --sampling_rate 0.8 >logs/onerun/$model_name'_'Etth1_$seq_len'_'336'_'0.0'_'$percentage'_'WaveMask.log
|
| 204 |
+
|
| 205 |
+
python3 -u ./run_main.py \
|
| 206 |
+
--is_training 1 \
|
| 207 |
+
--root_path ./dataset/ \
|
| 208 |
+
--data_path ETTh1.csv \
|
| 209 |
+
--model $model_name \
|
| 210 |
+
--data ETTh1 \
|
| 211 |
+
--features M \
|
| 212 |
+
--seq_len $seq_len \
|
| 213 |
+
--pred_len 720 \
|
| 214 |
+
--enc_in 7 \
|
| 215 |
+
--des '100p-h1-' \
|
| 216 |
+
--percentage $percentage \
|
| 217 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.4, 0.9, 0.5, 1.0, 0.3, 0.9, 0.3]" --wavelet 'db1' --level 1 --sampling_rate 0.2 >logs/onerun/$model_name'_'Etth1_$seq_len'_'720'_'0.0'_'$percentage'_'WaveMask.log
|
| 218 |
+
|
| 219 |
+
# For Aug 4: Wave Mixing
|
| 220 |
+
|
| 221 |
+
python3 -u ./run_main.py \
|
| 222 |
+
--is_training 1 \
|
| 223 |
+
--root_path ./dataset/ \
|
| 224 |
+
--data_path ETTh1.csv \
|
| 225 |
+
--model $model_name \
|
| 226 |
+
--data ETTh1 \
|
| 227 |
+
--features M \
|
| 228 |
+
--seq_len $seq_len \
|
| 229 |
+
--pred_len 96 \
|
| 230 |
+
--enc_in 7 \
|
| 231 |
+
--des '100p-h1-' \
|
| 232 |
+
--percentage $percentage \
|
| 233 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.0, 0.9, 0.7, 0.7, 0.5, 0.7, 0.6]" --wavelet 'db3' --level 1 --sampling_rate 0.2 >logs/onerun/$model_name'_'Etth1_$seq_len'_'96'_'0.0'_'$percentage'_'WaveMix.log
|
| 234 |
+
|
| 235 |
+
python3 -u ./run_main.py \
|
| 236 |
+
--is_training 1 \
|
| 237 |
+
--root_path ./dataset/ \
|
| 238 |
+
--data_path ETTh1.csv \
|
| 239 |
+
--model $model_name \
|
| 240 |
+
--data ETTh1 \
|
| 241 |
+
--features M \
|
| 242 |
+
--seq_len $seq_len \
|
| 243 |
+
--pred_len 192 \
|
| 244 |
+
--enc_in 7 \
|
| 245 |
+
--des '100p-h1-' \
|
| 246 |
+
--percentage $percentage \
|
| 247 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[1.0, 0.4, 0.6, 0.6, 0.3, 0.8, 0.1]" --wavelet 'db3' --level 1 --sampling_rate 0.8 >logs/onerun/$model_name'_'Etth1_$seq_len'_'192'_'0.0'_'$percentage'_'WaveMix.log
|
| 248 |
+
|
| 249 |
+
python3 -u ./run_main.py \
|
| 250 |
+
--is_training 1 \
|
| 251 |
+
--root_path ./dataset/ \
|
| 252 |
+
--data_path ETTh1.csv \
|
| 253 |
+
--model $model_name \
|
| 254 |
+
--data ETTh1 \
|
| 255 |
+
--features M \
|
| 256 |
+
--seq_len $seq_len \
|
| 257 |
+
--pred_len 336 \
|
| 258 |
+
--enc_in 7 \
|
| 259 |
+
--des '100p-h1-' \
|
| 260 |
+
--percentage $percentage \
|
| 261 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.0, 0.9, 0.2, 0.2, 0.2, 0.1, 0.7]" --wavelet 'db3' --level 1 --sampling_rate 0.8 >logs/onerun/$model_name'_'Etth1_$seq_len'_'336'_'0.0'_'$percentage'_'WaveMix.log
|
| 262 |
+
|
| 263 |
+
python3 -u ./run_main.py \
|
| 264 |
+
--is_training 1 \
|
| 265 |
+
--root_path ./dataset/ \
|
| 266 |
+
--data_path ETTh1.csv \
|
| 267 |
+
--model $model_name \
|
| 268 |
+
--data ETTh1 \
|
| 269 |
+
--features M \
|
| 270 |
+
--seq_len $seq_len \
|
| 271 |
+
--pred_len 720 \
|
| 272 |
+
--enc_in 7 \
|
| 273 |
+
--des '100p-h1-' \
|
| 274 |
+
--percentage $percentage \
|
| 275 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.1, 0.9, 0.5, 0.7, 0.1, 0.2, 0.8]" --wavelet 'db5' --level 1 --sampling_rate 0.8 >logs/onerun/$model_name'_'Etth1_$seq_len'_'720'_'0.0'_'$percentage'_'WaveMix.log
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# For Aug 5: STAug
|
| 279 |
+
|
| 280 |
+
python3 -u ./run_main.py \
|
| 281 |
+
--is_training 1 \
|
| 282 |
+
--root_path ./dataset/ \
|
| 283 |
+
--data_path ETTh1.csv \
|
| 284 |
+
--model $model_name \
|
| 285 |
+
--data ETTh1 \
|
| 286 |
+
--features M \
|
| 287 |
+
--seq_len $seq_len \
|
| 288 |
+
--pred_len 96 \
|
| 289 |
+
--enc_in 7 \
|
| 290 |
+
--des '100p-h1-' \
|
| 291 |
+
--percentage $percentage \
|
| 292 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 5 --aug_rate 0.9 --nIMF 100 >logs/onerun/$model_name'_'Etth1_$seq_len'_'96'_'0.9'_'$percentage'_'StAug.log
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
python3 -u ./run_main.py \
|
| 296 |
+
--is_training 1 \
|
| 297 |
+
--root_path ./dataset/ \
|
| 298 |
+
--data_path ETTh1.csv \
|
| 299 |
+
--model $model_name \
|
| 300 |
+
--data ETTh1 \
|
| 301 |
+
--features M \
|
| 302 |
+
--seq_len $seq_len \
|
| 303 |
+
--pred_len 192 \
|
| 304 |
+
--enc_in 7 \
|
| 305 |
+
--des '100p-h1-' \
|
| 306 |
+
--percentage $percentage \
|
| 307 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 5 --aug_rate 0.9 --nIMF 900 >logs/onerun/$model_name'_'Etth1_$seq_len'_'192'_'0.9'_'$percentage'_'StAug.log
|
| 308 |
+
|
| 309 |
+
python3 -u ./run_main.py \
|
| 310 |
+
--is_training 1 \
|
| 311 |
+
--root_path ./dataset/ \
|
| 312 |
+
--data_path ETTh1.csv \
|
| 313 |
+
--model $model_name \
|
| 314 |
+
--data ETTh1 \
|
| 315 |
+
--features M \
|
| 316 |
+
--seq_len $seq_len \
|
| 317 |
+
--pred_len 336 \
|
| 318 |
+
--enc_in 7 \
|
| 319 |
+
--des '100p-h1-' \
|
| 320 |
+
--percentage $percentage \
|
| 321 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 5 --aug_rate 0.8 --nIMF 200 >logs/onerun/$model_name'_'Etth1_$seq_len'_'336'_'0.9'_'$percentage'_'StAug.log
|
| 322 |
+
|
| 323 |
+
python3 -u ./run_main.py \
|
| 324 |
+
--is_training 1 \
|
| 325 |
+
--root_path ./dataset/ \
|
| 326 |
+
--data_path ETTh1.csv \
|
| 327 |
+
--model $model_name \
|
| 328 |
+
--data ETTh1 \
|
| 329 |
+
--features M \
|
| 330 |
+
--seq_len $seq_len \
|
| 331 |
+
--pred_len 720 \
|
| 332 |
+
--enc_in 7 \
|
| 333 |
+
--des '100p-h1-' \
|
| 334 |
+
--percentage $percentage \
|
| 335 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 5 --aug_rate 0.7 --nIMF 1000 >logs/onerun/$model_name'_'Etth1_$seq_len'_'720'_'0.9'_'$percentage'_'StAug.log
|
scripts/etth2.sh
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
if [ ! -d "./logs" ]; then
|
| 3 |
+
mkdir ./logs
|
| 4 |
+
fi
|
| 5 |
+
|
| 6 |
+
if [ ! -d "./logs/onerun" ]; then
|
| 7 |
+
mkdir ./logs/onerun
|
| 8 |
+
fi
|
| 9 |
+
|
| 10 |
+
seq_len=336
|
| 11 |
+
percentage=100
|
| 12 |
+
model_name=DLinear
|
| 13 |
+
|
| 14 |
+
# aug 0: None
|
| 15 |
+
# aug 1: Frequency Masking
|
| 16 |
+
# aug 2: Frequency Mixing
|
| 17 |
+
# aug 3: Wave Masking
|
| 18 |
+
# aug 4: Wave Mixing
|
| 19 |
+
# aug 5: STAug
|
| 20 |
+
|
| 21 |
+
pred_lens=(96 192 336 720)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# For Aug 0: None
|
| 25 |
+
|
| 26 |
+
for pred_len in "${pred_lens[@]}"; do
|
| 27 |
+
python3 -u ./run_main.py \
|
| 28 |
+
--is_training 1 \
|
| 29 |
+
--root_path ./dataset/ \
|
| 30 |
+
--data_path ETTh2.csv \
|
| 31 |
+
--model $model_name \
|
| 32 |
+
--data ETTh2 \
|
| 33 |
+
--features M \
|
| 34 |
+
--seq_len $seq_len \
|
| 35 |
+
--pred_len $pred_len \
|
| 36 |
+
--enc_in 7 \
|
| 37 |
+
--des '100p-h2-' \
|
| 38 |
+
--percentage $percentage \
|
| 39 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 0 --aug_rate 0.0 >logs/onerun/$model_name'_'Etth2_$seq_len'_'$pred_len'_'0.0'_'$percentage'_'None.log
|
| 40 |
+
done
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# For Aug 1: Freq-Masking
|
| 44 |
+
|
| 45 |
+
python3 -u ./run_main.py \
|
| 46 |
+
--is_training 1 \
|
| 47 |
+
--root_path ./dataset/ \
|
| 48 |
+
--data_path ETTh2.csv \
|
| 49 |
+
--model $model_name \
|
| 50 |
+
--data ETTh2 \
|
| 51 |
+
--features M \
|
| 52 |
+
--seq_len $seq_len \
|
| 53 |
+
--pred_len 96 \
|
| 54 |
+
--enc_in 7 \
|
| 55 |
+
--des '100p-h2-' \
|
| 56 |
+
--percentage $percentage \
|
| 57 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.6 >logs/onerun/$model_name'_'Etth2_$seq_len'_'96'_'0.4'_'$percentage'_'FreqMask.log
|
| 58 |
+
|
| 59 |
+
python3 -u ./run_main.py \
|
| 60 |
+
--is_training 1 \
|
| 61 |
+
--root_path ./dataset/ \
|
| 62 |
+
--data_path ETTh2.csv \
|
| 63 |
+
--model $model_name \
|
| 64 |
+
--data ETTh2 \
|
| 65 |
+
--features M \
|
| 66 |
+
--seq_len $seq_len \
|
| 67 |
+
--pred_len 192 \
|
| 68 |
+
--enc_in 7 \
|
| 69 |
+
--des '100p-h2-' \
|
| 70 |
+
--percentage $percentage \
|
| 71 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.6 >logs/onerun/$model_name'_'Etth2_$seq_len'_'192'_'0.4'_'$percentage'_'FreqMask.log
|
| 72 |
+
|
| 73 |
+
python3 -u ./run_main.py \
|
| 74 |
+
--is_training 1 \
|
| 75 |
+
--root_path ./dataset/ \
|
| 76 |
+
--data_path ETTh2.csv \
|
| 77 |
+
--model $model_name \
|
| 78 |
+
--data ETTh2 \
|
| 79 |
+
--features M \
|
| 80 |
+
--seq_len $seq_len \
|
| 81 |
+
--pred_len 336 \
|
| 82 |
+
--enc_in 7 \
|
| 83 |
+
--des '100p-h2-' \
|
| 84 |
+
--percentage $percentage \
|
| 85 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'Etth2_$seq_len'_'336'_'0.4'_'$percentage'_'FreqMask.log
|
| 86 |
+
|
| 87 |
+
python3 -u ./run_main.py \
|
| 88 |
+
--is_training 1 \
|
| 89 |
+
--root_path ./dataset/ \
|
| 90 |
+
--data_path ETTh2.csv \
|
| 91 |
+
--model $model_name \
|
| 92 |
+
--data ETTh2 \
|
| 93 |
+
--features M \
|
| 94 |
+
--seq_len $seq_len \
|
| 95 |
+
--pred_len 720 \
|
| 96 |
+
--enc_in 7 \
|
| 97 |
+
--des '100p-h2-' \
|
| 98 |
+
--percentage $percentage \
|
| 99 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'Etth2_$seq_len'_'720'_'0.4'_'$percentage'_'FreqMask.log
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# For Aug 2: Freq-Mixing
|
| 104 |
+
|
| 105 |
+
python3 -u ./run_main.py \
|
| 106 |
+
--is_training 1 \
|
| 107 |
+
--root_path ./dataset/ \
|
| 108 |
+
--data_path ETTh2.csv \
|
| 109 |
+
--model $model_name \
|
| 110 |
+
--data ETTh2 \
|
| 111 |
+
--features M \
|
| 112 |
+
--seq_len $seq_len \
|
| 113 |
+
--pred_len 96 \
|
| 114 |
+
--enc_in 7 \
|
| 115 |
+
--des '100p-h2-' \
|
| 116 |
+
--percentage $percentage \
|
| 117 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.9 >logs/onerun/$model_name'_'Etth2_$seq_len'_'96'_'0.4'_'$percentage'_'FreqMix.log
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
python3 -u ./run_main.py \
|
| 121 |
+
--is_training 1 \
|
| 122 |
+
--root_path ./dataset/ \
|
| 123 |
+
--data_path ETTh2.csv \
|
| 124 |
+
--model $model_name \
|
| 125 |
+
--data ETTh2 \
|
| 126 |
+
--features M \
|
| 127 |
+
--seq_len $seq_len \
|
| 128 |
+
--pred_len 192 \
|
| 129 |
+
--enc_in 7 \
|
| 130 |
+
--des '100p-h2-' \
|
| 131 |
+
--percentage $percentage \
|
| 132 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.8 >logs/onerun/$model_name'_'Etth2_$seq_len'_'192'_'0.4'_'$percentage'_'FreqMix.log
|
| 133 |
+
|
| 134 |
+
python3 -u ./run_main.py \
|
| 135 |
+
--is_training 1 \
|
| 136 |
+
--root_path ./dataset/ \
|
| 137 |
+
--data_path ETTh2.csv \
|
| 138 |
+
--model $model_name \
|
| 139 |
+
--data ETTh2 \
|
| 140 |
+
--features M \
|
| 141 |
+
--seq_len $seq_len \
|
| 142 |
+
--pred_len 336 \
|
| 143 |
+
--enc_in 7 \
|
| 144 |
+
--des '100p-h2-' \
|
| 145 |
+
--percentage $percentage \
|
| 146 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.8 >logs/onerun/$model_name'_'Etth2_$seq_len'_'336'_'0.4'_'$percentage'_'FreqMix.log
|
| 147 |
+
|
| 148 |
+
python3 -u ./run_main.py \
|
| 149 |
+
--is_training 1 \
|
| 150 |
+
--root_path ./dataset/ \
|
| 151 |
+
--data_path ETTh2.csv \
|
| 152 |
+
--model $model_name \
|
| 153 |
+
--data ETTh2 \
|
| 154 |
+
--features M \
|
| 155 |
+
--seq_len $seq_len \
|
| 156 |
+
--pred_len 720 \
|
| 157 |
+
--enc_in 7 \
|
| 158 |
+
--des '100p-h2-' \
|
| 159 |
+
--percentage $percentage \
|
| 160 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'Etth2_$seq_len'_'720'_'0.4'_'$percentage'_'FreqMix.log
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# For Aug 3: Wave Masking
|
| 166 |
+
|
| 167 |
+
python3 -u ./run_main.py \
|
| 168 |
+
--is_training 1 \
|
| 169 |
+
--root_path ./dataset/ \
|
| 170 |
+
--data_path ETTh2.csv \
|
| 171 |
+
--model $model_name \
|
| 172 |
+
--data ETTh2 \
|
| 173 |
+
--features M \
|
| 174 |
+
--seq_len $seq_len \
|
| 175 |
+
--pred_len 96 \
|
| 176 |
+
--enc_in 7 \
|
| 177 |
+
--des '100p-h2-' \
|
| 178 |
+
--percentage $percentage \
|
| 179 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.4, 0.9, 0.0, 0.8, 0.3, 0.7, 0.2]" --wavelet 'db26' --level 2 --sampling_rate 0.5 >logs/onerun/$model_name'_'Etth2_$seq_len'_'96'_'0.0'_'$percentage'_'WaveMask.log
|
| 180 |
+
|
| 181 |
+
python3 -u ./run_main.py \
|
| 182 |
+
--is_training 1 \
|
| 183 |
+
--root_path ./dataset/ \
|
| 184 |
+
--data_path ETTh2.csv \
|
| 185 |
+
--model $model_name \
|
| 186 |
+
--data ETTh2 \
|
| 187 |
+
--features M \
|
| 188 |
+
--seq_len $seq_len \
|
| 189 |
+
--pred_len 192 \
|
| 190 |
+
--enc_in 7 \
|
| 191 |
+
--des '100p-h2-' \
|
| 192 |
+
--percentage $percentage \
|
| 193 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.6, 0.7, 0.5, 0.7, 0.8, 0.3, 0.5]" --wavelet 'db26' --level 2 --sampling_rate 0.8 >logs/onerun/$model_name'_'Etth2_$seq_len'_'192'_'0.0'_'$percentage'_'WaveMask.log
|
| 194 |
+
|
| 195 |
+
python3 -u ./run_main.py \
|
| 196 |
+
--is_training 1 \
|
| 197 |
+
--root_path ./dataset/ \
|
| 198 |
+
--data_path ETTh2.csv \
|
| 199 |
+
--model $model_name \
|
| 200 |
+
--data ETTh2 \
|
| 201 |
+
--features M \
|
| 202 |
+
--seq_len $seq_len \
|
| 203 |
+
--pred_len 336 \
|
| 204 |
+
--enc_in 7 \
|
| 205 |
+
--des '100p-h2-' \
|
| 206 |
+
--percentage $percentage \
|
| 207 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.2, 0.7, 0.9, 0.4, 0.4, 0.1, 1.0]" --wavelet 'db1' --level 3 --sampling_rate 0.8 >logs/onerun/$model_name'_'Etth2_$seq_len'_'336'_'0.0'_'$percentage'_'WaveMask.log
|
| 208 |
+
|
| 209 |
+
python3 -u ./run_main.py \
|
| 210 |
+
--is_training 1 \
|
| 211 |
+
--root_path ./dataset/ \
|
| 212 |
+
--data_path ETTh2.csv \
|
| 213 |
+
--model $model_name \
|
| 214 |
+
--data ETTh2 \
|
| 215 |
+
--features M \
|
| 216 |
+
--seq_len $seq_len \
|
| 217 |
+
--pred_len 720 \
|
| 218 |
+
--enc_in 7 \
|
| 219 |
+
--des '100p-h2-' \
|
| 220 |
+
--percentage $percentage \
|
| 221 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.8, 0.9, 0.4, 0.9, 0.4, 0.4, 0.5]" --wavelet 'db5' --level 4 --sampling_rate 0.2 >logs/onerun/$model_name'_'Etth2_$seq_len'_'720'_'0.0'_'$percentage'_'WaveMask.log
|
| 222 |
+
|
| 223 |
+
# For Aug 4: Wave Mixing
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
python3 -u ./run_main.py \
|
| 227 |
+
--is_training 1 \
|
| 228 |
+
--root_path ./dataset/ \
|
| 229 |
+
--data_path ETTh2.csv \
|
| 230 |
+
--model $model_name \
|
| 231 |
+
--data ETTh2 \
|
| 232 |
+
--features M \
|
| 233 |
+
--seq_len $seq_len \
|
| 234 |
+
--pred_len 96 \
|
| 235 |
+
--enc_in 7 \
|
| 236 |
+
--des '100p-h2-' \
|
| 237 |
+
--percentage $percentage \
|
| 238 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.9, 0.9, 0.1, 0.5, 0.9, 0.6, 0.2]" --wavelet 'db25' --level 2 --sampling_rate 0.2 >logs/onerun/$model_name'_'Etth2_$seq_len'_'96'_'0.0'_'$percentage'_'WaveMix.log
|
| 239 |
+
|
| 240 |
+
python3 -u ./run_main.py \
|
| 241 |
+
--is_training 1 \
|
| 242 |
+
--root_path ./dataset/ \
|
| 243 |
+
--data_path ETTh2.csv \
|
| 244 |
+
--model $model_name \
|
| 245 |
+
--data ETTh2 \
|
| 246 |
+
--features M \
|
| 247 |
+
--seq_len $seq_len \
|
| 248 |
+
--pred_len 192 \
|
| 249 |
+
--enc_in 7 \
|
| 250 |
+
--des '100p-h2-' \
|
| 251 |
+
--percentage $percentage \
|
| 252 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.9, 0.4, 0.1, 0.8, 0.9, 0.5, 0.4]" --wavelet 'db1' --level 3 --sampling_rate 0.5 >logs/onerun/$model_name'_'Etth2_$seq_len'_'192'_'0.0'_'$percentage'_'WaveMix.log
|
| 253 |
+
|
| 254 |
+
python3 -u ./run_main.py \
|
| 255 |
+
--is_training 1 \
|
| 256 |
+
--root_path ./dataset/ \
|
| 257 |
+
--data_path ETTh2.csv \
|
| 258 |
+
--model $model_name \
|
| 259 |
+
--data ETTh2 \
|
| 260 |
+
--features M \
|
| 261 |
+
--seq_len $seq_len \
|
| 262 |
+
--pred_len 336 \
|
| 263 |
+
--enc_in 7 \
|
| 264 |
+
--des '100p-h2-' \
|
| 265 |
+
--percentage $percentage \
|
| 266 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.9, 0.1, 0.2, 0.5, 1.0, 0.5, 0.5]" --wavelet 'db25' --level 3 --sampling_rate 0.8 >logs/onerun/$model_name'_'Etth2_$seq_len'_'336'_'0.0'_'$percentage'_'WaveMix.log
|
| 267 |
+
|
| 268 |
+
python3 -u ./run_main.py \
|
| 269 |
+
--is_training 1 \
|
| 270 |
+
--root_path ./dataset/ \
|
| 271 |
+
--data_path ETTh2.csv \
|
| 272 |
+
--model $model_name \
|
| 273 |
+
--data ETTh2 \
|
| 274 |
+
--features M \
|
| 275 |
+
--seq_len $seq_len \
|
| 276 |
+
--pred_len 720 \
|
| 277 |
+
--enc_in 7 \
|
| 278 |
+
--des '100p-h2-' \
|
| 279 |
+
--percentage $percentage \
|
| 280 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.5, 0.1, 0.2, 0.7, 0.4, 0.9, 0.5]" --wavelet 'db5' --level 1 --sampling_rate 1.0 >logs/onerun/$model_name'_'Etth2_$seq_len'_'720'_'0.0'_'$percentage'_'WaveMix.log
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
# For Aug 5: STAug
|
| 285 |
+
|
| 286 |
+
python3 -u ./run_main.py \
|
| 287 |
+
--is_training 1 \
|
| 288 |
+
--root_path ./dataset/ \
|
| 289 |
+
--data_path ETTh2.csv \
|
| 290 |
+
--model $model_name \
|
| 291 |
+
--data ETTh2 \
|
| 292 |
+
--features M \
|
| 293 |
+
--seq_len $seq_len \
|
| 294 |
+
--pred_len 96 \
|
| 295 |
+
--enc_in 7 \
|
| 296 |
+
--des '100p-h2-' \
|
| 297 |
+
--percentage $percentage \
|
| 298 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 5 --aug_rate 0.4 --nIMF 2000 >logs/onerun/$model_name'_'Etth2_$seq_len'_'96'_'0.9'_'$percentage'_'StAug.log
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
python3 -u ./run_main.py \
|
| 302 |
+
--is_training 1 \
|
| 303 |
+
--root_path ./dataset/ \
|
| 304 |
+
--data_path ETTh2.csv \
|
| 305 |
+
--model $model_name \
|
| 306 |
+
--data ETTh2 \
|
| 307 |
+
--features M \
|
| 308 |
+
--seq_len $seq_len \
|
| 309 |
+
--pred_len 192 \
|
| 310 |
+
--enc_in 7 \
|
| 311 |
+
--des '100p-h2-' \
|
| 312 |
+
--percentage $percentage \
|
| 313 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 5 --aug_rate 0.9 --nIMF 200 >logs/onerun/$model_name'_'Etth2_$seq_len'_'192'_'0.9'_'$percentage'_'StAug.log
|
| 314 |
+
|
| 315 |
+
python3 -u ./run_main.py \
|
| 316 |
+
--is_training 1 \
|
| 317 |
+
--root_path ./dataset/ \
|
| 318 |
+
--data_path ETTh2.csv \
|
| 319 |
+
--model $model_name \
|
| 320 |
+
--data ETTh2 \
|
| 321 |
+
--features M \
|
| 322 |
+
--seq_len $seq_len \
|
| 323 |
+
--pred_len 336 \
|
| 324 |
+
--enc_in 7 \
|
| 325 |
+
--des '100p-h2-' \
|
| 326 |
+
--percentage $percentage \
|
| 327 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 5 --aug_rate 0.6 --nIMF 100 >logs/onerun/$model_name'_'Etth2_$seq_len'_'336'_'0.9'_'$percentage'_'StAug.log
|
| 328 |
+
|
| 329 |
+
python3 -u ./run_main.py \
|
| 330 |
+
--is_training 1 \
|
| 331 |
+
--root_path ./dataset/ \
|
| 332 |
+
--data_path ETTh2.csv \
|
| 333 |
+
--model $model_name \
|
| 334 |
+
--data ETTh2 \
|
| 335 |
+
--features M \
|
| 336 |
+
--seq_len $seq_len \
|
| 337 |
+
--pred_len 720 \
|
| 338 |
+
--enc_in 7 \
|
| 339 |
+
--des '100p-h2-' \
|
| 340 |
+
--percentage $percentage \
|
| 341 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 5 --aug_rate 0.4 --nIMF 700 >logs/onerun/$model_name'_'Etth2_$seq_len'_'720'_'0.9'_'$percentage'_'StAug.log
|
scripts/ili.sh
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
if [ ! -d "./logs" ]; then
|
| 3 |
+
mkdir ./logs
|
| 4 |
+
fi
|
| 5 |
+
|
| 6 |
+
if [ ! -d "./logs/onerun" ]; then
|
| 7 |
+
mkdir ./logs/onerun
|
| 8 |
+
fi
|
| 9 |
+
|
| 10 |
+
seq_len=36
|
| 11 |
+
percentage=100
|
| 12 |
+
model_name=DLinear
|
| 13 |
+
|
| 14 |
+
# aug 0: None
|
| 15 |
+
# aug 1: Frequency Masking
|
| 16 |
+
# aug 2: Frequency Mixing
|
| 17 |
+
# aug 3: Wave Masking
|
| 18 |
+
# aug 4: Wave Mixing
|
| 19 |
+
# aug 5: STAug
|
| 20 |
+
|
| 21 |
+
pred_lens=(24 36 48 60)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# For Aug 0: None
|
| 26 |
+
|
| 27 |
+
for pred_len in "${pred_lens[@]}"; do
|
| 28 |
+
python3 -u ./run_main.py \
|
| 29 |
+
--is_training 1 \
|
| 30 |
+
--root_path ./dataset/ \
|
| 31 |
+
--data_path national_illness.csv \
|
| 32 |
+
--model $model_name \
|
| 33 |
+
--data custom \
|
| 34 |
+
--features M \
|
| 35 |
+
--seq_len $seq_len \
|
| 36 |
+
--pred_len $pred_len \
|
| 37 |
+
--enc_in 7 \
|
| 38 |
+
--des '100p-ili-' \
|
| 39 |
+
--percentage $percentage \
|
| 40 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 0 --aug_rate 0.0 >logs/onerun/$model_name'_'ill_$seq_len'_'$pred_len'_'0.0'_'$percentage'_'None.log
|
| 41 |
+
done
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# For Aug 1: Freq-Masking
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
python3 -u ./run_main.py \
|
| 48 |
+
--is_training 1 \
|
| 49 |
+
--root_path ./dataset/ \
|
| 50 |
+
--data_path national_illness.csv \
|
| 51 |
+
--model $model_name \
|
| 52 |
+
--data custom \
|
| 53 |
+
--features M \
|
| 54 |
+
--seq_len $seq_len \
|
| 55 |
+
--pred_len 24 \
|
| 56 |
+
--enc_in 7 \
|
| 57 |
+
--des '100p-ili-' \
|
| 58 |
+
--percentage $percentage \
|
| 59 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 1 --aug_rate 0.2 >logs/onerun/$model_name'_'ill_$seq_len'_'96'_'0.4'_'$percentage'_'FreqMask.log
|
| 60 |
+
|
| 61 |
+
python3 -u ./run_main.py \
|
| 62 |
+
--is_training 1 \
|
| 63 |
+
--root_path ./dataset/ \
|
| 64 |
+
--data_path national_illness.csv \
|
| 65 |
+
--model $model_name \
|
| 66 |
+
--data custom \
|
| 67 |
+
--features M \
|
| 68 |
+
--seq_len $seq_len \
|
| 69 |
+
--pred_len 36 \
|
| 70 |
+
--enc_in 7 \
|
| 71 |
+
--des '100p-ili-' \
|
| 72 |
+
--percentage $percentage \
|
| 73 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'ill_$seq_len'_'36'_'0.4'_'$percentage'_'FreqMask.log
|
| 74 |
+
|
| 75 |
+
python3 -u ./run_main.py \
|
| 76 |
+
--is_training 1 \
|
| 77 |
+
--root_path ./dataset/ \
|
| 78 |
+
--data_path national_illness.csv \
|
| 79 |
+
--model $model_name \
|
| 80 |
+
--data custom \
|
| 81 |
+
--features M \
|
| 82 |
+
--seq_len $seq_len \
|
| 83 |
+
--pred_len 48 \
|
| 84 |
+
--enc_in 7 \
|
| 85 |
+
--des '100p-ili-' \
|
| 86 |
+
--percentage $percentage \
|
| 87 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'ill_$seq_len'_'48'_'0.4'_'$percentage'_'FreqMask.log
|
| 88 |
+
|
| 89 |
+
python3 -u ./run_main.py \
|
| 90 |
+
--is_training 1 \
|
| 91 |
+
--root_path ./dataset/ \
|
| 92 |
+
--data_path national_illness.csv \
|
| 93 |
+
--model $model_name \
|
| 94 |
+
--data custom \
|
| 95 |
+
--features M \
|
| 96 |
+
--seq_len $seq_len \
|
| 97 |
+
--pred_len 60 \
|
| 98 |
+
--enc_in 7 \
|
| 99 |
+
--des '100p-ili-' \
|
| 100 |
+
--percentage $percentage \
|
| 101 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'ill_$seq_len'_'60'_'0.4'_'$percentage'_'FreqMask.log
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# For Aug 2: Freq-Mixing
|
| 106 |
+
|
| 107 |
+
python3 -u ./run_main.py \
|
| 108 |
+
--is_training 1 \
|
| 109 |
+
--root_path ./dataset/ \
|
| 110 |
+
--data_path national_illness.csv \
|
| 111 |
+
--model $model_name \
|
| 112 |
+
--data custom \
|
| 113 |
+
--features M \
|
| 114 |
+
--seq_len $seq_len \
|
| 115 |
+
--pred_len 24 \
|
| 116 |
+
--enc_in 7 \
|
| 117 |
+
--des '100p-ili-' \
|
| 118 |
+
--percentage $percentage \
|
| 119 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'ill_$seq_len'_'24'_'0.4'_'$percentage'_'FreqMix.log
|
| 120 |
+
|
| 121 |
+
python3 -u ./run_main.py \
|
| 122 |
+
--is_training 1 \
|
| 123 |
+
--root_path ./dataset/ \
|
| 124 |
+
--data_path national_illness.csv \
|
| 125 |
+
--model $model_name \
|
| 126 |
+
--data custom \
|
| 127 |
+
--features M \
|
| 128 |
+
--seq_len $seq_len \
|
| 129 |
+
--pred_len 36 \
|
| 130 |
+
--enc_in 7 \
|
| 131 |
+
--des '100p-ili-' \
|
| 132 |
+
--percentage $percentage \
|
| 133 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'ill_$seq_len'_'36'_'0.4'_'$percentage'_'FreqMix.log
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
python3 -u ./run_main.py \
|
| 137 |
+
--is_training 1 \
|
| 138 |
+
--root_path ./dataset/ \
|
| 139 |
+
--data_path national_illness.csv \
|
| 140 |
+
--model $model_name \
|
| 141 |
+
--data custom \
|
| 142 |
+
--features M \
|
| 143 |
+
--seq_len $seq_len \
|
| 144 |
+
--pred_len 48 \
|
| 145 |
+
--enc_in 7 \
|
| 146 |
+
--des '100p-ili-' \
|
| 147 |
+
--percentage $percentage \
|
| 148 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'ill_$seq_len'_'48'_'0.4'_'$percentage'_'FreqMix.log
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
python3 -u ./run_main2.py \
|
| 152 |
+
--is_training 1 \
|
| 153 |
+
--root_path ./dataset/ \
|
| 154 |
+
--data_path national_illness.csv \
|
| 155 |
+
--model $model_name \
|
| 156 |
+
--data custom \
|
| 157 |
+
--features M \
|
| 158 |
+
--seq_len $seq_len \
|
| 159 |
+
--pred_len 60 \
|
| 160 |
+
--enc_in 7 \
|
| 161 |
+
--des '100p-ili-' \
|
| 162 |
+
--percentage $percentage \
|
| 163 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'ill_$seq_len'_'60'_'0.4'_'$percentage'_'FreqMix.log
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# For Aug 3: Wave Masking
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
python3 -u ./run_main.py \
|
| 171 |
+
--is_training 1 \
|
| 172 |
+
--root_path ./dataset/ \
|
| 173 |
+
--data_path national_illness.csv \
|
| 174 |
+
--model $model_name \
|
| 175 |
+
--data custom \
|
| 176 |
+
--features M \
|
| 177 |
+
--seq_len $seq_len \
|
| 178 |
+
--pred_len 24 \
|
| 179 |
+
--enc_in 7 \
|
| 180 |
+
--des '100p-ili-' \
|
| 181 |
+
--percentage $percentage \
|
| 182 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 3 --rates "[0.4, 0.8, 0.9, 0.7, 0.9, 0.0, 0.5]" --wavelet 'db25' --level 1 --sampling_rate 0.2 >logs/onerun/$model_name'_'ill_$seq_len'_'24'_'0.0'_'$percentage'_'WaveMask.log
|
| 183 |
+
|
| 184 |
+
python3 -u ./run_main.py \
|
| 185 |
+
--is_training 1 \
|
| 186 |
+
--root_path ./dataset/ \
|
| 187 |
+
--data_path national_illness.csv \
|
| 188 |
+
--model $model_name \
|
| 189 |
+
--data custom \
|
| 190 |
+
--features M \
|
| 191 |
+
--seq_len $seq_len \
|
| 192 |
+
--pred_len 36 \
|
| 193 |
+
--enc_in 7 \
|
| 194 |
+
--des '100p-ili-' \
|
| 195 |
+
--percentage $percentage \
|
| 196 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 3 --rates "[0.6, 0.8, 0.3, 0.1, 0.9, 0.0, 0.5]" --wavelet 'db25' --level 1 --sampling_rate 0.2 >logs/onerun/$model_name'_'ill_$seq_len'_'36'_'0.0'_'$percentage'_'WaveMask.log
|
| 197 |
+
|
| 198 |
+
python3 -u ./run_main.py \
|
| 199 |
+
--is_training 1 \
|
| 200 |
+
--root_path ./dataset/ \
|
| 201 |
+
--data_path national_illness.csv \
|
| 202 |
+
--model $model_name \
|
| 203 |
+
--data custom \
|
| 204 |
+
--features M \
|
| 205 |
+
--seq_len $seq_len \
|
| 206 |
+
--pred_len 48 \
|
| 207 |
+
--enc_in 7 \
|
| 208 |
+
--des '100p-ili-' \
|
| 209 |
+
--percentage $percentage \
|
| 210 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 3 --rates "[0.2, 0.7, 1.0, 0.4, 0.4, 0.0, 0.5]" --wavelet 'db2' --level 1 --sampling_rate 0.2 >logs/onerun/$model_name'_'ill_$seq_len'_'48'_'0.0'_'$percentage'_'WaveMask.log
|
| 211 |
+
|
| 212 |
+
python3 -u ./run_main.py \
|
| 213 |
+
--is_training 1 \
|
| 214 |
+
--root_path ./dataset/ \
|
| 215 |
+
--data_path national_illness.csv \
|
| 216 |
+
--model $model_name \
|
| 217 |
+
--data custom \
|
| 218 |
+
--features M \
|
| 219 |
+
--seq_len $seq_len \
|
| 220 |
+
--pred_len 60 \
|
| 221 |
+
--enc_in 7 \
|
| 222 |
+
--des '100p-ili-' \
|
| 223 |
+
--percentage $percentage \
|
| 224 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 3 --rates "[0.2, 0.8, 0.5, 0.1, 0.9, 0.0, 0.5]" --wavelet 'db25' --level 1 --sampling_rate 0.2 >logs/onerun/$model_name'_'ill_$seq_len'_'60'_'0.0'_'$percentage'_'WaveMask.log
|
| 225 |
+
|
| 226 |
+
# For Aug 4: Wave Mixing
|
| 227 |
+
|
| 228 |
+
python3 -u ./run_main.py \
|
| 229 |
+
--is_training 1 \
|
| 230 |
+
--root_path ./dataset/ \
|
| 231 |
+
--data_path national_illness.csv \
|
| 232 |
+
--model $model_name \
|
| 233 |
+
--data custom \
|
| 234 |
+
--features M \
|
| 235 |
+
--seq_len $seq_len \
|
| 236 |
+
--pred_len 24 \
|
| 237 |
+
--enc_in 7 \
|
| 238 |
+
--des '100p-ili-' \
|
| 239 |
+
--percentage $percentage \
|
| 240 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 4 --rates "[0.1, 0.8, 1.0, 0.0, 0.5, 0.7, 0.1]" --wavelet 'db1' --level 1 --sampling_rate 0.2 >logs/onerun/$model_name'_'ill_$seq_len'_'24'_'0.0'_'$percentage'_'WaveMix.log
|
| 241 |
+
|
| 242 |
+
python3 -u ./run_main.py \
|
| 243 |
+
--is_training 1 \
|
| 244 |
+
--root_path ./dataset/ \
|
| 245 |
+
--data_path national_illness.csv \
|
| 246 |
+
--model $model_name \
|
| 247 |
+
--data custom \
|
| 248 |
+
--features M \
|
| 249 |
+
--seq_len $seq_len \
|
| 250 |
+
--pred_len 36 \
|
| 251 |
+
--enc_in 7 \
|
| 252 |
+
--des '100p-ili-' \
|
| 253 |
+
--percentage $percentage \
|
| 254 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 4 --rates "[0.1, 1.0, 0.9, 0.2, 0.1, 0.7, 0.1]" --wavelet 'db25' --level 1 --sampling_rate 0.8 >logs/onerun/$model_name'_'ill_$seq_len'_'36'_'0.0'_'$percentage'_'WaveMix.log
|
| 255 |
+
|
| 256 |
+
python3 -u ./run_main.py \
|
| 257 |
+
--is_training 1 \
|
| 258 |
+
--root_path ./dataset/ \
|
| 259 |
+
--data_path national_illness.csv \
|
| 260 |
+
--model $model_name \
|
| 261 |
+
--data custom \
|
| 262 |
+
--features M \
|
| 263 |
+
--seq_len $seq_len \
|
| 264 |
+
--pred_len 48 \
|
| 265 |
+
--enc_in 7 \
|
| 266 |
+
--des '100p-ili-' \
|
| 267 |
+
--percentage $percentage \
|
| 268 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 4 --rates "[0.1, 1.0, 0.4, 0.5, 0.1, 0.6, 0.1]" --wavelet 'db3' --level 1 --sampling_rate 1.0 >logs/onerun/$model_name'_'ill_$seq_len'_'48'_'0.0'_'$percentage'_'WaveMix.log
|
| 269 |
+
|
| 270 |
+
python3 -u ./run_main.py \
|
| 271 |
+
--is_training 1 \
|
| 272 |
+
--root_path ./dataset/ \
|
| 273 |
+
--data_path national_illness.csv \
|
| 274 |
+
--model $model_name \
|
| 275 |
+
--data custom \
|
| 276 |
+
--features M \
|
| 277 |
+
--seq_len $seq_len \
|
| 278 |
+
--pred_len 60 \
|
| 279 |
+
--enc_in 7 \
|
| 280 |
+
--des '100p-ili-' \
|
| 281 |
+
--percentage $percentage \
|
| 282 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 4 --rates "[0.1, 0.9, 0.3, 0.9, 0.5, 0.7, 0.1]" --wavelet 'db1' --level 1 --sampling_rate 0.5 >logs/onerun/$model_name'_'ill_$seq_len'_'60'_'0.0'_'$percentage'_'WaveMix.log
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
# For Aug 5: STAug
|
| 286 |
+
|
| 287 |
+
python3 -u ./run_main.py \
|
| 288 |
+
--is_training 1 \
|
| 289 |
+
--root_path ./dataset/ \
|
| 290 |
+
--data_path national_illness.csv \
|
| 291 |
+
--model $model_name \
|
| 292 |
+
--data custom \
|
| 293 |
+
--features M \
|
| 294 |
+
--seq_len $seq_len \
|
| 295 |
+
--pred_len 24 \
|
| 296 |
+
--enc_in 7 \
|
| 297 |
+
--des '100p-ili-' \
|
| 298 |
+
--percentage $percentage \
|
| 299 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 5 --aug_rate 0.7 --nIMF 200 >logs/onerun/$model_name'_'ill_$seq_len'_'24'_'0.9'_'$percentage'_'StAug.log
|
| 300 |
+
|
| 301 |
+
python3 -u ./run_main.py \
|
| 302 |
+
--is_training 1 \
|
| 303 |
+
--root_path ./dataset/ \
|
| 304 |
+
--data_path national_illness.csv \
|
| 305 |
+
--model $model_name \
|
| 306 |
+
--data custom \
|
| 307 |
+
--features M \
|
| 308 |
+
--seq_len $seq_len \
|
| 309 |
+
--pred_len 36 \
|
| 310 |
+
--enc_in 7 \
|
| 311 |
+
--des '100p-ili-' \
|
| 312 |
+
--percentage $percentage \
|
| 313 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 5 --aug_rate 0.3 --nIMF 300 >logs/onerun/$model_name'_'ill_$seq_len'_'36'_'0.9'_'$percentage'_'StAug.log
|
| 314 |
+
|
| 315 |
+
python3 -u ./run_main.py \
|
| 316 |
+
--is_training 1 \
|
| 317 |
+
--root_path ./dataset/ \
|
| 318 |
+
--data_path national_illness.csv \
|
| 319 |
+
--model $model_name \
|
| 320 |
+
--data custom \
|
| 321 |
+
--features M \
|
| 322 |
+
--seq_len $seq_len \
|
| 323 |
+
--pred_len 48 \
|
| 324 |
+
--enc_in 7 \
|
| 325 |
+
--des '100p-ili-' \
|
| 326 |
+
--percentage $percentage \
|
| 327 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 5 --aug_rate 0.9 --nIMF 300 >logs/onerun/$model_name'_'ill_$seq_len'_'48'_'0.9'_'$percentage'_'StAug.log
|
| 328 |
+
|
| 329 |
+
python3 -u ./run_main.py \
|
| 330 |
+
--is_training 1 \
|
| 331 |
+
--root_path ./dataset/ \
|
| 332 |
+
--data_path national_illness.csv \
|
| 333 |
+
--model $model_name \
|
| 334 |
+
--data custom \
|
| 335 |
+
--features M \
|
| 336 |
+
--seq_len $seq_len \
|
| 337 |
+
--pred_len 60 \
|
| 338 |
+
--enc_in 7 \
|
| 339 |
+
--des '100p-ili-' \
|
| 340 |
+
--percentage $percentage \
|
| 341 |
+
--itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 5 --aug_rate 0.7 --nIMF 1000 >logs/onerun/$model_name'_'ill_$seq_len'_'60'_'0.9'_'$percentage'_'StAug.log
|
| 342 |
+
|
scripts/weather.sh
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
if [ ! -d "./logs" ]; then
|
| 3 |
+
mkdir ./logs
|
| 4 |
+
fi
|
| 5 |
+
|
| 6 |
+
if [ ! -d "./logs/onerun" ]; then
|
| 7 |
+
mkdir ./logs/onerun
|
| 8 |
+
fi
|
| 9 |
+
|
| 10 |
+
seq_len=336
|
| 11 |
+
percentage=100
|
| 12 |
+
model_name=DLinear
|
| 13 |
+
|
| 14 |
+
# aug 0: None
|
| 15 |
+
# aug 1: Frequency Masking
|
| 16 |
+
# aug 2: Frequency Mixing
|
| 17 |
+
# aug 3: Wave Masking
|
| 18 |
+
# aug 4: Wave Mixing
|
| 19 |
+
# aug 5: STAug
|
| 20 |
+
|
| 21 |
+
pred_lens=(96 192 336 720)
|
| 22 |
+
|
| 23 |
+
# For Aug 0: None
|
| 24 |
+
|
| 25 |
+
for pred_len in "${pred_lens[@]}"; do
|
| 26 |
+
python3 -u ./run_main.py \
|
| 27 |
+
--is_training 1 \
|
| 28 |
+
--root_path ./dataset/ \
|
| 29 |
+
--data_path weather.csv \
|
| 30 |
+
--model $model_name \
|
| 31 |
+
--data custom \
|
| 32 |
+
--features M \
|
| 33 |
+
--seq_len $seq_len \
|
| 34 |
+
--pred_len $pred_len \
|
| 35 |
+
--enc_in 21 \
|
| 36 |
+
--des '100p-whr-' \
|
| 37 |
+
--percentage $percentage \
|
| 38 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 0 --aug_rate 0.0 >logs/onerun/$model_name'_'weather_$seq_len'_'$pred_len'_'0.0'_'$percentage'_'None.log
|
| 39 |
+
done
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# For Aug 1: Freq-Masking
|
| 43 |
+
|
| 44 |
+
python3 -u ./run_main.py \
|
| 45 |
+
--is_training 1 \
|
| 46 |
+
--root_path ./dataset/ \
|
| 47 |
+
--data_path weather.csv \
|
| 48 |
+
--model $model_name \
|
| 49 |
+
--data custom \
|
| 50 |
+
--features M \
|
| 51 |
+
--seq_len $seq_len \
|
| 52 |
+
--pred_len 96 \
|
| 53 |
+
--enc_in 21 \
|
| 54 |
+
--des '100p-whr-' \
|
| 55 |
+
--percentage $percentage \
|
| 56 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'weather_$seq_len'_'96'_'0.4'_'$percentage'_'FreqMask.log
|
| 57 |
+
|
| 58 |
+
python3 -u ./run_main.py \
|
| 59 |
+
--is_training 1 \
|
| 60 |
+
--root_path ./dataset/ \
|
| 61 |
+
--data_path weather.csv \
|
| 62 |
+
--model $model_name \
|
| 63 |
+
--data custom \
|
| 64 |
+
--features M \
|
| 65 |
+
--seq_len $seq_len \
|
| 66 |
+
--pred_len 192 \
|
| 67 |
+
--enc_in 21 \
|
| 68 |
+
--des '100p-whr-' \
|
| 69 |
+
--percentage $percentage \
|
| 70 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'weather_$seq_len'_'192'_'0.4'_'$percentage'_'FreqMask.log
|
| 71 |
+
|
| 72 |
+
python3 -u ./run_main.py \
|
| 73 |
+
--is_training 1 \
|
| 74 |
+
--root_path ./dataset/ \
|
| 75 |
+
--data_path weather.csv \
|
| 76 |
+
--model $model_name \
|
| 77 |
+
--data custom \
|
| 78 |
+
--features M \
|
| 79 |
+
--seq_len $seq_len \
|
| 80 |
+
--pred_len 336 \
|
| 81 |
+
--enc_in 21 \
|
| 82 |
+
--des '100p-whr-' \
|
| 83 |
+
--percentage $percentage \
|
| 84 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'weather_$seq_len'_'336'_'0.4'_'$percentage'_'FreqMask.log
|
| 85 |
+
|
| 86 |
+
python3 -u ./run_main.py \
|
| 87 |
+
--is_training 1 \
|
| 88 |
+
--root_path ./dataset/ \
|
| 89 |
+
--data_path weather.csv \
|
| 90 |
+
--model $model_name \
|
| 91 |
+
--data custom \
|
| 92 |
+
--features M \
|
| 93 |
+
--seq_len $seq_len \
|
| 94 |
+
--pred_len 720 \
|
| 95 |
+
--enc_in 21 \
|
| 96 |
+
--des '100p-whr-' \
|
| 97 |
+
--percentage $percentage \
|
| 98 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.9 >logs/onerun/$model_name'_'weather_$seq_len'_'720'_'0.4'_'$percentage'_'FreqMask.log
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# For Aug 2: Freq-Mixing
|
| 103 |
+
|
| 104 |
+
python3 -u ./run_main.py \
|
| 105 |
+
--is_training 1 \
|
| 106 |
+
--root_path ./dataset/ \
|
| 107 |
+
--data_path weather.csv \
|
| 108 |
+
--model $model_name \
|
| 109 |
+
--data custom \
|
| 110 |
+
--features M \
|
| 111 |
+
--seq_len $seq_len \
|
| 112 |
+
--pred_len 96 \
|
| 113 |
+
--enc_in 21 \
|
| 114 |
+
--des '100p-whr-' \
|
| 115 |
+
--percentage $percentage \
|
| 116 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.9 >logs/onerun/$model_name'_'weather_$seq_len'_'96'_'0.4'_'$percentage'_'FreqMix.log
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
python3 -u ./run_main.py \
|
| 120 |
+
--is_training 1 \
|
| 121 |
+
--root_path ./dataset/ \
|
| 122 |
+
--data_path weather.csv \
|
| 123 |
+
--model $model_name \
|
| 124 |
+
--data custom \
|
| 125 |
+
--features M \
|
| 126 |
+
--seq_len $seq_len \
|
| 127 |
+
--pred_len 192 \
|
| 128 |
+
--enc_in 21 \
|
| 129 |
+
--des '100p-whr-' \
|
| 130 |
+
--percentage $percentage \
|
| 131 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'weather_$seq_len'_'192'_'0.4'_'$percentage'_'FreqMix.log
|
| 132 |
+
|
| 133 |
+
python3 -u ./run_main.py \
|
| 134 |
+
--is_training 1 \
|
| 135 |
+
--root_path ./dataset/ \
|
| 136 |
+
--data_path weather.csv \
|
| 137 |
+
--model $model_name \
|
| 138 |
+
--data custom \
|
| 139 |
+
--features M \
|
| 140 |
+
--seq_len $seq_len \
|
| 141 |
+
--pred_len 336 \
|
| 142 |
+
--enc_in 21 \
|
| 143 |
+
--des '100p-whr-' \
|
| 144 |
+
--percentage $percentage \
|
| 145 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'weather_$seq_len'_'336'_'0.4'_'$percentage'_'FreqMix.log
|
| 146 |
+
|
| 147 |
+
python3 -u ./run_main.py \
|
| 148 |
+
--is_training 1 \
|
| 149 |
+
--root_path ./dataset/ \
|
| 150 |
+
--data_path weather.csv \
|
| 151 |
+
--model $model_name \
|
| 152 |
+
--data custom \
|
| 153 |
+
--features M \
|
| 154 |
+
--seq_len $seq_len \
|
| 155 |
+
--pred_len 720 \
|
| 156 |
+
--enc_in 21 \
|
| 157 |
+
--des '100p-whr-' \
|
| 158 |
+
--percentage $percentage \
|
| 159 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.9 >logs/onerun/$model_name'_'weather_$seq_len'_'720'_'0.4'_'$percentage'_'FreqMix.log
|
| 160 |
+
|
| 161 |
+
# For Aug 3: Wave Masking
|
| 162 |
+
|
| 163 |
+
python3 -u ./run_main.py \
|
| 164 |
+
--is_training 1 \
|
| 165 |
+
--root_path ./dataset/ \
|
| 166 |
+
--data_path weather.csv \
|
| 167 |
+
--model $model_name \
|
| 168 |
+
--data custom \
|
| 169 |
+
--features M \
|
| 170 |
+
--seq_len $seq_len \
|
| 171 |
+
--pred_len 96 \
|
| 172 |
+
--enc_in 21 \
|
| 173 |
+
--des '100p-whr-' \
|
| 174 |
+
--percentage $percentage \
|
| 175 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.2, 1.0, 0.4, 0.4, 0.9, 0.0, 0.5]" --wavelet 'db2' --level 2 --sampling_rate 0.5 >logs/onerun/$model_name'_'weather_$seq_len'_'96'_'0.0'_'$percentage'_'WaveMask.log
|
| 176 |
+
|
| 177 |
+
python3 -u ./run_main.py \
|
| 178 |
+
--is_training 1 \
|
| 179 |
+
--root_path ./dataset/ \
|
| 180 |
+
--data_path weather.csv \
|
| 181 |
+
--model $model_name \
|
| 182 |
+
--data custom \
|
| 183 |
+
--features M \
|
| 184 |
+
--seq_len $seq_len \
|
| 185 |
+
--pred_len 192 \
|
| 186 |
+
--enc_in 21 \
|
| 187 |
+
--des '100p-whr-' \
|
| 188 |
+
--percentage $percentage \
|
| 189 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.1, 0.7, 0.1, 0.4, 0.5, 0.1, 0.7]" --wavelet 'db2' --level 1 --sampling_rate 0.5 >logs/onerun/$model_name'_'weather_$seq_len'_'192'_'0.0'_'$percentage'_'WaveMask.log
|
| 190 |
+
|
| 191 |
+
python3 -u ./run_main.py \
|
| 192 |
+
--is_training 1 \
|
| 193 |
+
--root_path ./dataset/ \
|
| 194 |
+
--data_path weather.csv \
|
| 195 |
+
--model $model_name \
|
| 196 |
+
--data custom \
|
| 197 |
+
--features M \
|
| 198 |
+
--seq_len $seq_len \
|
| 199 |
+
--pred_len 336 \
|
| 200 |
+
--enc_in 21 \
|
| 201 |
+
--des '100p-whr-' \
|
| 202 |
+
--percentage $percentage \
|
| 203 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[1.0, 1.0, 0.0, 0.0, 0.3, 0.4, 0.2]" --wavelet 'db1' --level 1 --sampling_rate 1.0 >logs/onerun/$model_name'_'weather_$seq_len'_'336'_'0.0'_'$percentage'_'WaveMask.log
|
| 204 |
+
|
| 205 |
+
python3 -u ./run_main.py \
|
| 206 |
+
--is_training 1 \
|
| 207 |
+
--root_path ./dataset/ \
|
| 208 |
+
--data_path weather.csv \
|
| 209 |
+
--model $model_name \
|
| 210 |
+
--data custom \
|
| 211 |
+
--features M \
|
| 212 |
+
--seq_len $seq_len \
|
| 213 |
+
--pred_len 720 \
|
| 214 |
+
--enc_in 21 \
|
| 215 |
+
--des '100p-whr-' \
|
| 216 |
+
--percentage $percentage \
|
| 217 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[1.0, 0.8, 0.6, 0.0, 0.2, 0.6, 0.5]" --wavelet 'db2' --level 1 --sampling_rate 0.5 >logs/onerun/$model_name'_'weather_$seq_len'_'720'_'0.0'_'$percentage'_'WaveMask.log
|
| 218 |
+
|
| 219 |
+
# For Aug 4: Wave Mixing
|
| 220 |
+
|
| 221 |
+
python3 -u ./run_main.py \
|
| 222 |
+
--is_training 1 \
|
| 223 |
+
--root_path ./dataset/ \
|
| 224 |
+
--data_path weather.csv \
|
| 225 |
+
--model $model_name \
|
| 226 |
+
--data custom \
|
| 227 |
+
--features M \
|
| 228 |
+
--seq_len $seq_len \
|
| 229 |
+
--pred_len 96 \
|
| 230 |
+
--enc_in 21 \
|
| 231 |
+
--des '100p-whr-' \
|
| 232 |
+
--percentage $percentage \
|
| 233 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.1, 0.5, 0.1, 0.2, 0.4, 0.7, 0.1]" --wavelet 'db3' --level 1 --sampling_rate 1.0 >logs/onerun/$model_name'_'weather_$seq_len'_'96'_'0.0'_'$percentage'_'WaveMix.log
|
| 234 |
+
|
| 235 |
+
python3 -u ./run_main.py \
|
| 236 |
+
--is_training 1 \
|
| 237 |
+
--root_path ./dataset/ \
|
| 238 |
+
--data_path weather.csv \
|
| 239 |
+
--model $model_name \
|
| 240 |
+
--data custom \
|
| 241 |
+
--features M \
|
| 242 |
+
--seq_len $seq_len \
|
| 243 |
+
--pred_len 192 \
|
| 244 |
+
--enc_in 21 \
|
| 245 |
+
--des '100p-whr-' \
|
| 246 |
+
--percentage $percentage \
|
| 247 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.2, 0.7, 1.0, 0.3, 0.2, 0.3, 0.1]" --wavelet 'db3' --level 1 --sampling_rate 1.0 >logs/onerun/$model_name'_'weather_$seq_len'_'192'_'0.0'_'$percentage'_'WaveMix.log
|
| 248 |
+
|
| 249 |
+
python3 -u ./run_main.py \
|
| 250 |
+
--is_training 1 \
|
| 251 |
+
--root_path ./dataset/ \
|
| 252 |
+
--data_path weather.csv \
|
| 253 |
+
--model $model_name \
|
| 254 |
+
--data custom \
|
| 255 |
+
--features M \
|
| 256 |
+
--seq_len $seq_len \
|
| 257 |
+
--pred_len 336 \
|
| 258 |
+
--enc_in 21 \
|
| 259 |
+
--des '100p-whr-' \
|
| 260 |
+
--percentage $percentage \
|
| 261 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.8, 0.6, 0.8, 0.6, 0.1, 1.0, 0.3]" --wavelet 'db2' --level 1 --sampling_rate 1.0 >logs/onerun/$model_name'_'weather_$seq_len'_'336'_'0.0'_'$percentage'_'WaveMix.log
|
| 262 |
+
|
| 263 |
+
python3 -u ./run_main.py \
|
| 264 |
+
--is_training 1 \
|
| 265 |
+
--root_path ./dataset/ \
|
| 266 |
+
--data_path weather.csv \
|
| 267 |
+
--model $model_name \
|
| 268 |
+
--data custom \
|
| 269 |
+
--features M \
|
| 270 |
+
--seq_len $seq_len \
|
| 271 |
+
--pred_len 720 \
|
| 272 |
+
--enc_in 21 \
|
| 273 |
+
--des '100p-whr-' \
|
| 274 |
+
--percentage $percentage \
|
| 275 |
+
--itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.1, 0.1, 0.7, 0.5, 0.6, 0.5, 0.1]" --wavelet 'db1' --level 1 --sampling_rate 1.0 >logs/onerun/$model_name'_'weather_$seq_len'_'720'_'0.0'_'$percentage'_'WaveMix.log
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
|
utils/metrics.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# The code is originated from
|
| 3 |
+
# Chen, M., Xu, Z., Zeng, A., & Xu, Q. (2023). "FrAug: Frequency Domain Augmentation for Time Series Forecasting".
|
| 4 |
+
# arXiv preprint arXiv:2302.09292.
|
| 5 |
+
# =============================================================================
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
def RSE(pred, true):
|
| 10 |
+
return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2))
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def CORR(pred, true):
|
| 14 |
+
u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0)
|
| 15 |
+
d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0))
|
| 16 |
+
d += 1e-12
|
| 17 |
+
return 0.01*(u / d).mean(-1)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def MAE(pred, true):
|
| 21 |
+
return np.mean(np.abs(pred - true))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def MSE(pred, true):
|
| 25 |
+
return np.mean((pred - true) ** 2)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def RMSE(pred, true):
|
| 29 |
+
return np.sqrt(MSE(pred, true))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def MAPE(pred, true):
|
| 33 |
+
return np.mean(np.abs((pred - true) / true))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def MSPE(pred, true):
|
| 37 |
+
return np.mean(np.square((pred - true) / true))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def metric(pred, true):
|
| 41 |
+
mae = MAE(pred, true)
|
| 42 |
+
mse = MSE(pred, true)
|
| 43 |
+
rmse = RMSE(pred, true)
|
| 44 |
+
mape = MAPE(pred, true)
|
| 45 |
+
mspe = MSPE(pred, true)
|
| 46 |
+
rse = RSE(pred, true)
|
| 47 |
+
corr = CORR(pred, true)
|
| 48 |
+
|
| 49 |
+
return mae, mse, rmse, mape, mspe, rse, corr
|
utils/tools.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# The code is originated from
|
| 3 |
+
# Chen, M., Xu, Z., Zeng, A., & Xu, Q. (2023). "FrAug: Frequency Domain Augmentation for Time Series Forecasting".
|
| 4 |
+
# arXiv preprint arXiv:2302.09292.
|
| 5 |
+
# =============================================================================
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
plt.switch_backend('agg')
|
| 12 |
+
|
| 13 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
| 14 |
+
# lr = args.learning_rate * (0.2 ** (epoch // 2))
|
| 15 |
+
if args.lradj == 'type1':
|
| 16 |
+
lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))}
|
| 17 |
+
elif args.lradj == 'type2':
|
| 18 |
+
lr_adjust = {
|
| 19 |
+
2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,
|
| 20 |
+
10: 5e-7, 15: 1e-7, 20: 5e-8
|
| 21 |
+
}
|
| 22 |
+
elif args.lradj == '3':
|
| 23 |
+
lr_adjust = {epoch: args.learning_rate if epoch < 10 else args.learning_rate*0.1}
|
| 24 |
+
elif args.lradj == '4':
|
| 25 |
+
lr_adjust = {epoch: args.learning_rate if epoch < 15 else args.learning_rate*0.1}
|
| 26 |
+
elif args.lradj == '5':
|
| 27 |
+
lr_adjust = {epoch: args.learning_rate if epoch < 25 else args.learning_rate*0.1}
|
| 28 |
+
elif args.lradj == '6':
|
| 29 |
+
lr_adjust = {epoch: args.learning_rate if epoch < 5 else args.learning_rate*0.1}
|
| 30 |
+
if epoch in lr_adjust.keys():
|
| 31 |
+
lr = lr_adjust[epoch]
|
| 32 |
+
for param_group in optimizer.param_groups:
|
| 33 |
+
param_group['lr'] = lr
|
| 34 |
+
print('Updating learning rate to {}'.format(lr))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class EarlyStopping:
|
| 38 |
+
def __init__(self, patience=7, verbose=False, delta=0):
|
| 39 |
+
self.patience = patience
|
| 40 |
+
self.verbose = verbose
|
| 41 |
+
self.counter = 0
|
| 42 |
+
self.best_score = None
|
| 43 |
+
self.early_stop = False
|
| 44 |
+
self.val_loss_min = np.Inf
|
| 45 |
+
self.delta = delta
|
| 46 |
+
|
| 47 |
+
def __call__(self, val_loss, model, path):
|
| 48 |
+
score = -val_loss
|
| 49 |
+
if self.best_score is None:
|
| 50 |
+
self.best_score = score
|
| 51 |
+
self.save_checkpoint(val_loss, model, path)
|
| 52 |
+
elif score < self.best_score + self.delta:
|
| 53 |
+
self.counter += 1
|
| 54 |
+
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
| 55 |
+
if self.counter >= self.patience:
|
| 56 |
+
self.early_stop = True
|
| 57 |
+
else:
|
| 58 |
+
self.best_score = score
|
| 59 |
+
self.save_checkpoint(val_loss, model, path)
|
| 60 |
+
self.counter = 0
|
| 61 |
+
|
| 62 |
+
def save_checkpoint(self, val_loss, model, path):
|
| 63 |
+
if self.verbose:
|
| 64 |
+
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
|
| 65 |
+
torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')
|
| 66 |
+
self.val_loss_min = val_loss
|
| 67 |
+
|
| 68 |
+
def get_val_loss_min(self):
|
| 69 |
+
return self.val_loss_min
|
| 70 |
+
|
| 71 |
+
def visual(true, preds=None, name='./pic/test.pdf'):
|
| 72 |
+
"""
|
| 73 |
+
Results visualization
|
| 74 |
+
"""
|
| 75 |
+
plt.figure()
|
| 76 |
+
plt.plot(true, label='GroundTruth', linewidth=2)
|
| 77 |
+
if preds is not None:
|
| 78 |
+
plt.plot(preds, label='Prediction', linewidth=2)
|
| 79 |
+
plt.legend()
|
| 80 |
+
plt.savefig(name, bbox_inches='tight')
|
| 81 |
+
|
| 82 |
+
def test_params_flop(model,x_shape):
|
| 83 |
+
"""
|
| 84 |
+
If you want to thest former's flop, you need to give default value to inputs in model.forward(), the following code can only pass one argument to forward()
|
| 85 |
+
"""
|
| 86 |
+
model_params = 0
|
| 87 |
+
for parameter in model.parameters():
|
| 88 |
+
model_params += parameter.numel()
|
| 89 |
+
print('INFO: Trainable parameter count: {:.2f}M'.format(model_params / 1000000.0))
|
| 90 |
+
from ptflops import get_model_complexity_info
|
| 91 |
+
with torch.cuda.device(0):
|
| 92 |
+
macs, params = get_model_complexity_info(model.cuda(), x_shape, as_strings=True, print_per_layer_stat=True)
|
| 93 |
+
# print('Flops:' + flops)
|
| 94 |
+
# print('Params:' + params)
|
| 95 |
+
print('{:<30} {:<8}'.format('Computational complexity: ', macs))
|
| 96 |
+
print('{:<30} {:<8}'.format('Number of parameters: ', params))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
|