bakhshaliyev commited on
Commit
8a37944
·
verified ·
1 Parent(s): 589ffd0

Upload 17 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures/main_result.png filter=lfs diff=lfs merge=lfs -text
37
+ figures/main_result2.png filter=lfs diff=lfs merge=lfs -text
38
+ figures/overview.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Wave-Mask/Mix: Exploring Wavelet-Based Augmentations for Time Series Forecasting <a href="https://www.arxiv.org/abs/2408.10951" title="Read the paper on arXiv"><img src="https://img.shields.io/badge/arXiv-2408.10951-b31b1b.svg" margin-left="5px" height="20" align="center"></a>
2
+
3
+ <a href="https://www.arxiv.org/abs/2408.10951" style="color: #4285F4; font-size: 24px; font-weight: bold; text-decoration: none;">Paper(arXiv)</a>
4
+
5
+ The figure depicts a framework of the training stages incorporating wavelet augmentations, which involve the concatenation of
6
+ the look-back window and the forecasting horizon prior to transformation and augmentation. Batch sampling of the generated synthetic data is conducted according to a predefined hyperparameter called the sampling rate. These batches are subsequently used to split the data into the look-back window and target horizon, after which they are concatenated with the original data. Wavelet augmentations are Wavelet Masking (WaveMask) and Wavelet Mixing (WaveMix). These techniques utilize the discrete wavelet transform (DWT) to obtain wavelet coefficients (both approximation and detail coefficients) by breaking down the signal and adjusting these coefficients, in line with modifying frequency components across different time scales.
7
+
8
+ WaveMask selectively eliminates specific wavelet coefficients at each decomposition level, thereby introducing variability in the augmented data. Conversely, WaveMix exchanges wavelet coefficients from two distinct instances of the dataset, thereby enhancing the diversity of the augmented data.
9
+
10
+ To the best of our knowledge, this is the first study to conduct extensive experiments on multivariate time series using Discrete Wavelet Transform as an augmentation technique.
11
+
12
+
13
+
14
+ <div align=center>
15
+ <img src="./figures/overview.png" alt="Overview" width="700" style="margin-bottom: 40px; margin-top: 40px;"/>
16
+ </div>
17
+
18
+
19
+
20
+ Tables present comparisons between our methods and baselines in terms of the metrics Mean Squared Error (MSE) and Mean Absolute Error (MAE). The best result is indicated in bold, while the second most favorable outcome is underlined.
21
+
22
+ <div align=center>
23
+ <img src="./figures/main_result.png" alt="Main Results" width="1000" height="500" style="margin-bottom: 40px; margin-top: 40px;"/>
24
+ </div>
25
+
26
+
27
+
28
+ <div align=center>
29
+ <img src="./figures/main_result2.png" alt="Main Results" width="1000" height="500" style="margin-bottom: 40px; margin-top: 40px;"/>
30
+ </div>
31
+
32
+
33
+ ## Dataset
34
+
35
+ You can obtain all datasets under https://drive.google.com/drive/folders/1ZOYpTUa82_jCcxIdTmyr0LXQfvaM9vIy. All of them are ready for training.
36
+
37
+ ```
38
+ mkdir dataset
39
+ ```
40
+ Please place all of them within the ```./dataset ``` directory.
41
+
42
+ ## Quick Start
43
+
44
+ Clone the project
45
+
46
+ ```bash
47
+ git clone https://github.com/jafarbakhshaliyev/Wave-Augs.git
48
+ ```
49
+
50
+ Go to the project directory
51
+
52
+ ```bash
53
+ cd Wave-Augs
54
+ ```
55
+
56
+ Install dependencies
57
+
58
+ ```bash
59
+ pip install -r requirements.txt
60
+ ```
61
+
62
+ Train:
63
+
64
+ ```bash
65
+ sh scripts/etth1.sh
66
+ sh scripts/etth2.sh
67
+ sh scripts/weather.sh
68
+ sh scripts/ili.sh
69
+ ```
70
+
71
+ You can change ```percentage``` to down-sample training dataset for ablation study.
72
+
73
+ ## Citation
74
+
75
+ If you find the code useful, please cite our paper:
76
+
77
+ ```
78
+ @misc{waveaug2024,
79
+ title={Wave-Mask/Mix: Exploring Wavelet-Based Augmentations for Time Series Forecasting},
80
+ author={Dona Arabi and Jafar Bakhshaliyev and Ayse Coskuner and Kiran Madhusudhanan and Kami Serdar Uckardes},
81
+ year={2024},
82
+ eprint={2408.10951},
83
+ archivePrefix={arXiv},
84
+ primaryClass={cs.LG},
85
+ url={https://arxiv.org/abs/2408.10951},
86
+ }
87
+ ```
88
+
89
+ Please remember to cite all the datasets and compared methods if you use them in your experiments.
90
+
91
+ ## Acknowledgements
92
+
93
+ We would like to express our gratitude to [Zheng et al.](https://arxiv.org/abs/2205.13504) for providing datasets used in this project. Additionally, we also acknowledge [Chen et al.](https://arxiv.org/abs/2302.09292) and [Zhang et al.](https://arxiv.org/abs/2303.14254) for their code frameworks, which served as the foundation for our codebase.
augmentation/aug.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # The freq_mask, freq_mix methods are adapted from the following sources:
3
+ # Chen, M., Xu, Z., Zeng, A., & Xu, Q. (2023). "FrAug: Frequency Domain Augmentation for Time Series Forecasting".
4
+ # arXiv preprint arXiv:2302.09292.
5
+ #
6
+ # The emd_aug and mix_aug for STAug method are adapted from the following source:
7
+ # https://github.com/xiyuanzh/STAug/tree/main
8
+ # =============================================================================
9
+
10
+ import numpy as np
11
+ import torch
12
+ import pywt
13
+ from pytorch_wavelets import DWT1DForward, DWT1DInverse
14
+ from typing import List, Tuple
15
+
16
+ class augmentation():
17
+
18
+ """
19
+ A class for data augmentation techniques used for Time Series Forecasting.
20
+
21
+ Attributes:
22
+ None
23
+
24
+ Methods:
25
+ - freq_mask: Apply frequency masking to input data.
26
+ - freq_mix: Mix two input signals in the frequency domain.
27
+ - wave_mask: Apply wavelet-based masking to input data.
28
+ - wave_mix: Mix two input signals using wavelet transformation.
29
+ - emd_aug: Apply empirical mode decomposition (EMD) based augmentation.
30
+ - mix_aug: Mix two batches of data with a random interpolation factor.
31
+ """
32
+
33
+ def __init__(self):
34
+ pass
35
+
36
+ @staticmethod
37
+ def freq_mask(x: torch.Tensor, y: torch.Tensor, rate: float = 0.5, dim: int = 1) -> torch.Tensor:
38
+ """
39
+ Apply frequency masking to input data.
40
+
41
+ Args:
42
+ - x (torch.Tensor): Look-back window.
43
+ - y (torch.Tensor): Target horizon.
44
+ - rate (float): Mask rate.
45
+ - dim (int): Dimension along to concatenate and apply Fourier Transform.
46
+
47
+ Returns:
48
+ - torch.Tensor: Masked synthetic data tensor.
49
+ """
50
+
51
+ xy = torch.cat([x,y],dim=1)
52
+ xy_f = torch.fft.rfft(xy,dim=dim)
53
+ m = torch.FloatTensor(xy_f.shape).uniform_() < rate
54
+ freal = xy_f.real.masked_fill(m,0)
55
+ fimag = xy_f.imag.masked_fill(m,0)
56
+ xy_f = torch.complex(freal,fimag)
57
+ xy = torch.fft.irfft(xy_f,dim=dim)
58
+ return xy
59
+
60
+ @staticmethod
61
+ def freq_mix(x: torch.Tensor, y: torch.Tensor, rate: float = 0.5, dim: int = 1) -> torch.Tensor:
62
+ """
63
+ Mix two input signals in the frequency domain.
64
+
65
+ Args:
66
+ - x (torch.Tensor): Look-back window.
67
+ - y (torch.Tensor): Target horizon.
68
+ - rate (float): Mix rate.
69
+ - dim (int): Dimension along to concatenate and apply Fourier Transform.
70
+
71
+ Returns:
72
+ - torch.Tensor: Mixed synthetic data tensor.
73
+ """
74
+
75
+ xy = torch.cat([x,y],dim=dim)
76
+ xy_f = torch.fft.rfft(xy,dim=dim)
77
+
78
+ m = torch.FloatTensor(xy_f.shape).uniform_() < rate
79
+ amp = abs(xy_f)
80
+ _,index = amp.sort(dim=dim, descending=True)
81
+ dominant_mask = index > 2
82
+ m = torch.bitwise_and(m, dominant_mask)
83
+ freal = xy_f.real.masked_fill(m,0)
84
+ fimag = xy_f.imag.masked_fill(m,0)
85
+
86
+ b_idx = np.arange(x.shape[0])
87
+ np.random.shuffle(b_idx)
88
+ x2, y2 = x[b_idx], y[b_idx]
89
+ xy2 = torch.cat([x2,y2],dim=dim)
90
+ xy2_f = torch.fft.rfft(xy2,dim=dim)
91
+
92
+ m = torch.bitwise_not(m)
93
+ freal2 = xy2_f.real.masked_fill(m,0)
94
+ fimag2 = xy2_f.imag.masked_fill(m,0)
95
+
96
+ freal += freal2
97
+ fimag += fimag2
98
+
99
+ xy_f = torch.complex(freal,fimag)
100
+
101
+ xy = torch.fft.irfft(xy_f,dim=dim)
102
+
103
+ return xy
104
+
105
+
106
+ @staticmethod
107
+ def wave_mask(x: torch.Tensor, y: torch.Tensor, rates: List[float], wavelet: str = 'db1', level: int = 2, dim: int= 1) -> torch.Tensor:
108
+ """
109
+ Apply wavelet-based masking to input data.
110
+
111
+ Args:
112
+ - x (torch.Tensor): Look-back window.
113
+ - y (torch.Tensor): Target horizon.
114
+ - rates (list of floats): List of mask rates for each wavelet level.
115
+ - wavelet (str): Type of wavelet to use.
116
+ - level (int): Number of decomposition levels.
117
+ - dim (int): Dimension along which to concatenate and apply DWT.
118
+
119
+ Returns:
120
+ - torch.Tensor: Masked synthetic data tensor.
121
+ """
122
+
123
+ xy = torch.cat([x, y], dim=1) # Concatenate along the time dimension
124
+ rate_tensor = torch.tensor(rates, device=x.device) # Convert rates to tensor
125
+
126
+ # Permute dimensions to match expected input: (batch_size, num_features, seq_len)
127
+ xy = xy.permute(0, 2, 1).to(x.device)
128
+
129
+ # Initialize & perform the DWT
130
+ dwt = DWT1DForward(J=level, wave=wavelet, mode='symmetric').to(x.device)
131
+ cA, cDs = dwt(xy)
132
+
133
+ mask = torch.rand_like(cA).to(cA.device) < rate_tensor[0]
134
+ cA = cA.masked_fill(mask, 0)
135
+
136
+ # Apply masking to detail coefficients
137
+ masked_cDs = []
138
+ for i, cD in enumerate(cDs, 1):
139
+ mask_cD = torch.rand_like(cD).to(cD.device) < rate_tensor[i] # Create mask
140
+ cD = cD.masked_fill(mask_cD, 0)
141
+ masked_cDs.append(cD)
142
+
143
+ # Initialize the inverse DWT & reconstruct the signal
144
+ idwt = DWT1DInverse(wave=wavelet, mode='symmetric').to(x.device)
145
+ reconstructed = idwt((cA, masked_cDs))
146
+
147
+ # Permute back to original shape: (batch_size, seq_len, num_features)
148
+ reconstructed = reconstructed.permute(0, 2, 1)
149
+
150
+ return reconstructed
151
+
152
+ @staticmethod
153
+ def wave_mix(x: torch.Tensor, y: torch.Tensor, rates: List[float], wavelet: str = 'db1', level: int = 2, dim: int = 1) -> torch.Tensor:
154
+ """
155
+ Mix two input signals using wavelet transformation.
156
+
157
+ Args:
158
+ - x (torch.Tensor): Look-back window.
159
+ - y (torch.Tensor): Target horizon.
160
+ - rates (list of floats): List of mix rates for each wavelet level.
161
+ - wavelet (str): Type of wavelet to use.
162
+ - level (int): Number of decomposition levels.
163
+ - dim (int): Dimension along which to concatenate and apply DWT.
164
+
165
+ Returns:
166
+ - torch.Tensor: Mixed synthetic data tensor.
167
+ """
168
+
169
+ xy = torch.cat([x, y], dim=1) # Concatenate along the time dimension
170
+ batch_size, _, _ = xy.shape
171
+ rate_tensor = torch.tensor(rates, device=x.device) # Convert rates to tensor
172
+
173
+ # Permute dimensions to match expected input: (batch_size, num_features, seq_len)
174
+ xy = xy.permute(0, 2, 1).to(x.device)
175
+
176
+ # Shuffle the batch for mixing
177
+ b_idx = torch.randperm(batch_size)
178
+ xy2 = xy[b_idx]
179
+
180
+ # Initialize & perform the DWT on the both signals
181
+ dwt = DWT1DForward(J=level, wave=wavelet, mode='symmetric').to(x.device)
182
+ cA1, cDs1 = dwt(xy)
183
+ cA2, cDs2 = dwt(xy2)
184
+
185
+ # Mix the approximation coefficients
186
+ mask = torch.rand_like(cA1).to(cA1.device) < rate_tensor[0] # Create mask
187
+ cA_mixed = cA1.masked_fill(mask, 0) + cA2.masked_fill(~mask, 0)
188
+
189
+ # Mix the coefficients
190
+ mixed_cDs = []
191
+ for i, (cD1, cD2) in enumerate(zip(cDs1, cDs2), 1):
192
+ mask = torch.rand_like(cD1).to(cD1.device) < rate_tensor[i] # Create mask
193
+ cD_mixed = cD1.masked_fill(mask, 0) + cD2.masked_fill(~mask, 0)
194
+ mixed_cDs.append(cD_mixed)
195
+
196
+ # Initialize the inverse DWT & reconstruct the signal
197
+ idwt = DWT1DInverse(wave=wavelet, mode='symmetric').to(x.device)
198
+ reconstructed = idwt((cA_mixed, mixed_cDs))
199
+
200
+ # Permute back to original shape: (batch_size, seq_len, num_features)
201
+ reconstructed = reconstructed.permute(0, 2, 1)
202
+
203
+ return reconstructed
204
+
205
+ # StAug: frequency-domain augmentation
206
+ def emd_aug(self, x: torch.Tensor) -> torch.Tensor:
207
+ """
208
+ Apply augmentation on empirical mode decomposition (EMD).
209
+
210
+ Args:
211
+ - x (torch.Tensor): Input tensor.
212
+
213
+ Returns:
214
+ - torch.Tensor: Augmented tensor.
215
+ """
216
+
217
+ b,n_imf,t,c = x.size()
218
+ inp = x.permute(0,2,1,3).reshape(b,t,n_imf*c) #b,t,n_imf,c -> b,t,n_imf*c
219
+ if(torch.rand(1) >= 0.5):
220
+ w = 2 * torch.rand((b,1,n_imf*c)).cuda()
221
+ else:
222
+ w = torch.ones((b,1,n_imf*c)).cuda()
223
+ w_exp = w.expand(-1,t,-1) #b,t,n_imf*c
224
+ out = w_exp * inp
225
+ out = out.reshape(b,t,n_imf,c).sum(dim=2) #b,t,c
226
+
227
+ return out
228
+
229
+ # StAug: time-domain augmentation
230
+ def mix_aug(self, batch_x: np.ndarray, batch_y: np.ndarray, lambd: float = 0.5) -> Tuple[np.ndarray, np.ndarray]:
231
+ """
232
+ Mix two batches of data with a random interpolation factor.
233
+
234
+ Args:
235
+ - batch_x (numpy.ndarray): Input batch 1.
236
+ - batch_y (numpy.ndarray): Input batch 2.
237
+ - lambd (float): Beta distribution parameter for interpolation.
238
+
239
+ Returns:
240
+ - numpy.ndarray: Mixed augmented batches.
241
+ """
242
+
243
+ inds2 = np.random.permutation(len(batch_x))
244
+ lam = np.random.beta(lambd, lambd)
245
+ batch_x = lam * batch_x[inds2] + (1-lam) * batch_x
246
+ batch_y = lam * batch_y[inds2] + (1-lam) * batch_y
247
+
248
+ return batch_x, batch_y
dataset_loader/datasetloader.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # The code is originated from
3
+ # Chen, M., Xu, Z., Zeng, A., & Xu, Q. (2023). "FrAug: Frequency Domain Augmentation for Time Series Forecasting".
4
+ # arXiv preprint arXiv:2302.09292.
5
+ # =============================================================================
6
+
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import pandas as pd
9
+ import numpy as np
10
+ import os
11
+ from sklearn.preprocessing import StandardScaler
12
+ from decompositions.decomposition import emd_augment
13
+
14
+ class Dataset_ETT_hour(Dataset):
15
+ def __init__(self, root_path, flag='train', size=None,
16
+ features='S', data_path='ETTh1.csv',
17
+ target='OT', scale=True, freq='h', n_imf = 500, percentage = 100, params = None):
18
+ # size [seq_len, label_len, pred_len]
19
+ # info
20
+ if size == None:
21
+ self.seq_len = 24 * 4 * 4
22
+ self.label_len = 24 * 4
23
+ self.pred_len = 24 * 4
24
+ else:
25
+ self.seq_len = size[0]
26
+ self.label_len = size[1]
27
+ self.pred_len = size[2]
28
+ # init
29
+ assert flag in ['train', 'test', 'val']
30
+ type_map = {'train': 0, 'val': 1, 'test': 2}
31
+ self.set_type = type_map[flag]
32
+
33
+ self.features = features
34
+ self.target = target
35
+ self.scale = scale
36
+ self.freq = freq
37
+ self.n_imf = n_imf
38
+ self.percentage = percentage
39
+
40
+ self.root_path = root_path
41
+ self.data_path = data_path
42
+ self.params = params
43
+ self.__read_data__()
44
+
45
+ def __read_data__(self):
46
+ self.scaler = StandardScaler()
47
+ df_raw = pd.read_csv(os.path.join(self.root_path,
48
+ self.data_path))
49
+
50
+
51
+ border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len]
52
+ border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]
53
+ border1 = border1s[self.set_type]
54
+ border2 = border2s[self.set_type]
55
+
56
+ if self.features == 'M' or self.features == 'MS':
57
+ cols_data = df_raw.columns[1:]
58
+ df_data = df_raw[cols_data]
59
+ elif self.features == 'S':
60
+ df_data = df_raw[[self.target]]
61
+
62
+ if self.scale:
63
+ train_data = df_data[border1s[0]:border2s[0]]
64
+ train_length = int((self.percentage / 100) * len(train_data))
65
+ train_data = train_data[-train_length:]
66
+ self.scaler.fit(train_data.values)
67
+ data = self.scaler.transform(df_data.values)
68
+ else:
69
+ data = df_data.values
70
+
71
+ if self.set_type == 0 and self.params.aug_type == 5:
72
+ self.aug_data = emd_augment(data[border1:border2][-len(train_data):], self.seq_len+self.pred_len, n_IMF = self.n_imf)
73
+ else:
74
+ self.aug_data = np.zeros_like(data[border1:border2])
75
+
76
+ if self.set_type == 0:
77
+ self.data_x = data[border1:border2][-len(train_data):]
78
+ self.data_y = data[border1:border2][-len(train_data):]
79
+ else:
80
+ self.data_x = data[border1:border2]
81
+ self.data_y = data[border1:border2]
82
+
83
+
84
+ def __getitem__(self, index):
85
+ s_begin = index
86
+ s_end = s_begin + self.seq_len
87
+ r_begin = s_end - self.label_len
88
+ r_end = r_begin + self.label_len + self.pred_len
89
+
90
+ seq_x = self.data_x[s_begin:s_end]
91
+ seq_y = self.data_y[r_begin:r_end]
92
+
93
+ if self.params.aug_type == 5:
94
+ aug_data = self.aug_data[s_begin]
95
+ else:
96
+ aug_data = np.array([])
97
+
98
+ return seq_x, seq_y, aug_data
99
+
100
+ def __len__(self):
101
+ return len(self.data_x) - self.seq_len - self.pred_len + 1
102
+
103
+ def inverse_transform(self, data):
104
+ return self.scaler.inverse_transform(data)
105
+
106
+ class Dataset_ETT_minute(Dataset):
107
+ def __init__(self, root_path, flag='train', size=None,
108
+ features='S', data_path='ETTm1.csv',
109
+ target='OT', scale=True, freq='t', n_imf = 500, percentage = 100, params = None):
110
+ # size [seq_len, label_len, pred_len]
111
+ # info
112
+ if size == None:
113
+ self.seq_len = 24 * 4 * 4
114
+ self.label_len = 24 * 4
115
+ self.pred_len = 24 * 4
116
+ else:
117
+ self.seq_len = size[0]
118
+ self.label_len = size[1]
119
+ self.pred_len = size[2]
120
+ # init
121
+ assert flag in ['train', 'test', 'val']
122
+ type_map = {'train': 0, 'val': 1, 'test': 2}
123
+ self.set_type = type_map[flag]
124
+
125
+ self.features = features
126
+ self.target = target
127
+ self.scale = scale
128
+ self.freq = freq
129
+ self.n_imf = n_imf
130
+
131
+ self.root_path = root_path
132
+ self.data_path = data_path
133
+ self.percentage = percentage
134
+ self.params = params
135
+ self.__read_data__()
136
+
137
+ def __read_data__(self):
138
+ self.scaler = StandardScaler()
139
+ df_raw = pd.read_csv(os.path.join(self.root_path,
140
+ self.data_path))
141
+
142
+ border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len]
143
+ border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4]
144
+ border1 = border1s[self.set_type]
145
+ border2 = border2s[self.set_type]
146
+
147
+ if self.features == 'M' or self.features == 'MS':
148
+ cols_data = df_raw.columns[1:]
149
+ df_data = df_raw[cols_data]
150
+ elif self.features == 'S':
151
+ df_data = df_raw[[self.target]]
152
+
153
+ if self.scale:
154
+ train_data = df_data[border1s[0]:border2s[0]]
155
+ train_length = int((self.percentage / 100) * len(train_data))
156
+ train_data = train_data[-train_length:]
157
+ self.scaler.fit(train_data.values)
158
+ data = self.scaler.transform(df_data.values)
159
+ else:
160
+ data = df_data.values
161
+
162
+ if self.set_type == 0 and self.params.aug_type == 5:
163
+ self.aug_data = emd_augment(data[border1:border2][-len(train_data):], self.seq_len+self.pred_len, n_IMF = self.n_imf)
164
+ else:
165
+ self.aug_data = np.zeros_like(data[border1:border2])
166
+
167
+ if self.set_type == 0:
168
+ self.data_x = data[border1:border2][-len(train_data):]
169
+ self.data_y = data[border1:border2][-len(train_data):]
170
+ else:
171
+ self.data_x = data[border1:border2]
172
+ self.data_y = data[border1:border2]
173
+
174
+ def __getitem__(self, index):
175
+ s_begin = index
176
+ s_end = s_begin + self.seq_len
177
+ r_begin = s_end - self.label_len
178
+ r_end = r_begin + self.label_len + self.pred_len
179
+
180
+ seq_x = self.data_x[s_begin:s_end]
181
+ seq_y = self.data_y[r_begin:r_end]
182
+
183
+ if self.params.aug_type == 5:
184
+ aug_data = self.aug_data[s_begin]
185
+ else:
186
+ aug_data = np.array([])
187
+
188
+ return seq_x, seq_y, aug_data
189
+
190
+ def __len__(self):
191
+ return len(self.data_x) - self.seq_len - self.pred_len + 1
192
+
193
+ def inverse_transform(self, data):
194
+ return self.scaler.inverse_transform(data)
195
+
196
+
197
+ class Dataset_Custom(Dataset):
198
+ def __init__(self, root_path, flag='train', size=None,
199
+ features='S', data_path='ETTh1.csv', scale = True,
200
+ target='OT', freq='h', n_imf = 500, percentage = 100, params=None):
201
+ # size [seq_len, label_len, pred_len]
202
+ # info
203
+
204
+ self.seq_len = size[0]
205
+ self.label_len = size[1]
206
+ self.pred_len = size[2]
207
+ # init
208
+ assert flag in ['train', 'test', 'val']
209
+ type_map = {'train': 0, 'val': 1, 'test': 2}
210
+ self.set_type = type_map[flag]
211
+
212
+ self.features = features
213
+ self.target = target
214
+ self.freq = freq
215
+ self.scale = scale
216
+ self.n_imf = n_imf
217
+ self.percentage = percentage
218
+ self.root_path = root_path
219
+ self.data_path = data_path
220
+ self.params = params
221
+
222
+ self.__read_data__()
223
+
224
+ def __read_data__(self):
225
+
226
+ self.scaler = StandardScaler()
227
+ df_raw = pd.read_csv(os.path.join(self.root_path,
228
+ self.data_path))
229
+ '''
230
+ df_raw.columns: ['date', ...(other features), target feature]
231
+ '''
232
+ cols = list(df_raw.columns)
233
+ cols.remove(self.target)
234
+ cols.remove('date')
235
+ df_raw = df_raw[['date'] + cols + [self.target]]
236
+
237
+ num_train = int(len(df_raw) * 0.7)
238
+ num_test = int(len(df_raw) * 0.2)
239
+ num_vali = len(df_raw) - num_train - num_test
240
+ border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
241
+ border2s = [num_train, num_train + num_vali, len(df_raw)]
242
+ border1 = border1s[self.set_type]
243
+ border2 = border2s[self.set_type]
244
+
245
+ if self.features == 'M' or self.features == 'MS':
246
+ cols_data = df_raw.columns[1:]
247
+ df_data = df_raw[cols_data]
248
+ elif self.features == 'S':
249
+ df_data = df_raw[[self.target]]
250
+
251
+ if self.scale:
252
+ train_data = df_data[border1s[0]:border2s[0]]
253
+ train_length = int((self.percentage / 100) * len(train_data))
254
+ train_data = train_data[-train_length:]
255
+ self.scaler.fit(train_data.values)
256
+ data = self.scaler.transform(df_data.values)
257
+ else:
258
+ data = df_data.values
259
+
260
+ if self.set_type == 0 and self.params.aug_type == 5:
261
+ self.aug_data = emd_augment(data[border1:border2][-len(train_data):], self.seq_len+self.pred_len, n_IMF = self.n_imf)
262
+
263
+ else:
264
+ self.aug_data = np.zeros_like(data[border1:border2])
265
+
266
+ if self.set_type == 0:
267
+ self.data_x = data[border1:border2][-len(train_data):]
268
+ self.data_y = data[border1:border2][-len(train_data):]
269
+ else:
270
+ self.data_x = data[border1:border2]
271
+ self.data_y = data[border1:border2]
272
+
273
+ def __getitem__(self, index):
274
+ s_begin = index # 0
275
+ s_end = s_begin + self.seq_len
276
+ r_begin = s_end - self.label_len
277
+ r_end = r_begin + self.label_len + self.pred_len
278
+
279
+ seq_x = self.data_x[s_begin:s_end]
280
+ seq_y = self.data_y[r_begin:r_end]
281
+
282
+ if self.params.aug_type == 5:
283
+ aug_data = self.aug_data[s_begin]
284
+ else:
285
+ aug_data = np.array([])
286
+
287
+ return seq_x, seq_y, aug_data
288
+
289
+ def __len__(self):
290
+ return len(self.data_x) - self.seq_len - self.pred_len + 1
291
+
292
+ def inverse_transform(self, data):
293
+ return self.scaler.inverse_transform(data)
294
+
295
+
296
+ class Dataset_Pred(Dataset):
297
+
298
+ def __init__(self, root_path, flag='pred', size=None,
299
+ features='S', data_path='ETTh1.csv',
300
+ target='OT', scale=True, inverse=False, freq='15min', cols=None):
301
+ # size [seq_len, label_len, pred_len]
302
+ # info
303
+
304
+ self.seq_len = size[0]
305
+ self.label_len = size[1]
306
+ self.pred_len = size[2]
307
+ # init
308
+ assert flag in ['pred']
309
+
310
+ self.features = features
311
+ self.target = target
312
+ self.scale = scale
313
+ self.inverse = inverse
314
+ self.freq = freq
315
+ self.cols = cols
316
+ self.root_path = root_path
317
+ self.data_path = data_path
318
+ self.__read_data__()
319
+
320
+ def __read_data__(self):
321
+ self.scaler = StandardScaler()
322
+ df_raw = pd.read_csv(os.path.join(self.root_path,
323
+ self.data_path))
324
+ '''
325
+ df_raw.columns: ['date', ...(other features), target feature]
326
+ '''
327
+ if self.cols:
328
+ cols = self.cols.copy()
329
+ cols.remove(self.target)
330
+ else:
331
+ cols = list(df_raw.columns)
332
+ cols.remove(self.target)
333
+ cols.remove('date')
334
+ df_raw = df_raw[['date'] + cols + [self.target]]
335
+ border1 = len(df_raw) - self.seq_len
336
+ border2 = len(df_raw)
337
+
338
+ if self.features == 'M' or self.features == 'MS':
339
+ cols_data = df_raw.columns[1:]
340
+ df_data = df_raw[cols_data]
341
+ elif self.features == 'S':
342
+ df_data = df_raw[[self.target]]
343
+
344
+ if self.scale:
345
+ self.scaler.fit(df_data.values)
346
+ data = self.scaler.transform(df_data.values)
347
+ else:
348
+ data = df_data.values
349
+
350
+
351
+ self.data_x = data[border1:border2]
352
+ if self.inverse:
353
+ self.data_y = df_data.values[border1:border2]
354
+ else:
355
+ self.data_y = data[border1:border2]
356
+
357
+ def __getitem__(self, index):
358
+ s_begin = index
359
+ s_end = s_begin + self.seq_len
360
+ r_begin = s_end - self.label_len
361
+ r_end = r_begin + self.label_len + self.pred_len
362
+
363
+ seq_x = self.data_x[s_begin:s_end]
364
+ if self.inverse:
365
+ seq_y = self.data_x[r_begin:r_begin + self.label_len]
366
+ else:
367
+ seq_y = self.data_y[r_begin:r_begin + self.label_len]
368
+
369
+ aug_data = np.array([])
370
+
371
+ return seq_x, seq_y, aug_data
372
+
373
+
374
+ def __len__(self):
375
+ return len(self.data_x) - self.seq_len + 1
376
+
377
+ def inverse_transform(self, data):
378
+ return self.scaler.inverse_transform(data)
379
+
380
+ data_dict = {
381
+ 'ETTh1': Dataset_ETT_hour,
382
+ 'ETTh2': Dataset_ETT_hour,
383
+ 'ETTm1': Dataset_ETT_minute,
384
+ 'ETTm2': Dataset_ETT_minute,
385
+ 'custom': Dataset_Custom,
386
+ }
387
+
388
+
389
+ def data_provider(args, flag):
390
+ Data = data_dict[args.data]
391
+
392
+ if flag == 'test':
393
+ shuffle_flag = False
394
+ drop_last = True
395
+ batch_size = args.batch_size
396
+ freq = args.freq
397
+ nIMF = args.nIMF
398
+
399
+ elif flag == 'pred':
400
+ shuffle_flag = False
401
+ drop_last = False
402
+ batch_size = 1
403
+ freq = args.freq
404
+ Data = Dataset_Pred
405
+ else:
406
+ shuffle_flag = True
407
+ drop_last = True
408
+ batch_size = args.batch_size
409
+ freq = args.freq
410
+ nIMF = args.nIMF
411
+
412
+
413
+ data_set = Data(
414
+ root_path=args.root_path,
415
+ data_path=args.data_path,
416
+ flag=flag,
417
+ size=[args.seq_len, args.label_len, args.pred_len],
418
+ features=args.features,
419
+ target=args.target,
420
+ freq=freq,
421
+ n_imf = nIMF,
422
+ percentage = args.percentage,
423
+ params = args
424
+ )
425
+ data_loader = DataLoader(
426
+ data_set,
427
+ batch_size=batch_size,
428
+ shuffle=shuffle_flag,
429
+ num_workers=args.num_workers,
430
+ drop_last=drop_last)
431
+ return data_set, data_loader
decompositions/decomposition.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # The code is originated from
3
+ # https://github.com/xiyuanzh/STAug/tree/main
4
+ # =============================================================================
5
+
6
+ import numpy as np
7
+ from PyEMD import EMD
8
+
9
+ def emd_augment(data, sequence_length, n_IMF = 500):
10
+ n_imf, channel_num = n_IMF, data.shape[1]
11
+ emd_data = np.zeros((n_imf,data.shape[0],channel_num))
12
+ max_imf = 0
13
+ for ci in range(channel_num):
14
+ s = data[:, ci]
15
+ IMF = EMD().emd(s)
16
+ r_s = np.zeros((n_imf, data.shape[0]))
17
+ if len(IMF) > max_imf:
18
+ max_imf = len(IMF)
19
+ for i in range(len(IMF)):
20
+ r_s[i] = IMF[len(IMF)-1-i]
21
+ if(len(IMF)==0): r_s[0] = s
22
+ emd_data[:,:,ci] = r_s
23
+ if max_imf < n_imf:
24
+ emd_data = emd_data[:max_imf,:,:]
25
+ train_data_new = np.zeros((len(data)-sequence_length+1,max_imf,sequence_length,channel_num))
26
+ for i in range(len(data)-sequence_length+1):
27
+ train_data_new[i] = emd_data[:,i:i+sequence_length,:]
28
+ return train_data_new
figures/main_result.png ADDED

Git LFS Details

  • SHA256: 5c4d5c83bb34b4cde64d77119fd07328941286ac32ada65c2b44870c5ac281ad
  • Pointer size: 131 Bytes
  • Size of remote file: 113 kB
figures/main_result2.png ADDED

Git LFS Details

  • SHA256: 637c0c230cd2d487b6c9a2812639bf47b4bb2d24d5aaf0af0571a8b318350087
  • Pointer size: 131 Bytes
  • Size of remote file: 135 kB
figures/overview.png ADDED

Git LFS Details

  • SHA256: 78e0c25763f8034bfbebb5c00f04ea518613b3ea704a98b792b72aec0d20a9ea
  • Pointer size: 131 Bytes
  • Size of remote file: 113 kB
main_run/train.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # The code is originated from
3
+ # Chen, M., Xu, Z., Zeng, A., & Xu, Q. (2023). "FrAug: Frequency Domain Augmentation for Time Series Forecasting".
4
+ # arXiv preprint arXiv:2302.09292.
5
+ # =============================================================================
6
+
7
+ import os
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+
12
+ from dataset_loader.datasetloader import data_provider
13
+ from models import DLinear
14
+ from utils.tools import EarlyStopping, adjust_learning_rate, visual, test_params_flop
15
+ from utils.metrics import metric
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch import optim
20
+ import time
21
+ import warnings
22
+ import matplotlib.pyplot as plt
23
+ from augmentation.aug import augmentation
24
+
25
+
26
+ class Exp_Basic(object):
27
+ def __init__(self, args):
28
+ self.args = args
29
+ self.device = self._acquire_device()
30
+ self.model = self._build_model().to(self.device)
31
+
32
+ def _build_model(self):
33
+ raise NotImplementedError
34
+ return None
35
+
36
+ def _acquire_device(self):
37
+ if self.args.use_gpu:
38
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(
39
+ self.args.gpu) if not self.args.use_multi_gpu else self.args.devices
40
+ device = torch.device('cuda:{}'.format(self.args.gpu))
41
+ print('Use GPU: cuda:{}'.format(self.args.gpu))
42
+ else:
43
+ device = torch.device('cpu')
44
+ print('Use CPU')
45
+ return device
46
+
47
+ def _get_data(self):
48
+ pass
49
+
50
+ def vali(self):
51
+ pass
52
+
53
+ def train(self):
54
+ pass
55
+
56
+ def test(self):
57
+ pass
58
+
59
+ warnings.filterwarnings('ignore')
60
+
61
+ TYPES = {0: 'None',
62
+ 1: 'Freq-Mask',
63
+ 2: 'Freq-Mix',
64
+ 3: 'Wave-Mask',
65
+ 4: 'Wave-Mix',
66
+ 5: 'StAug'}
67
+
68
+ class Exp_Main(Exp_Basic):
69
+
70
+ def __init__(self, args):
71
+ super(Exp_Main, self).__init__(args)
72
+
73
+ def _build_model(self):
74
+
75
+ model_dict = {
76
+ 'DLinear': DLinear
77
+ }
78
+ model = model_dict[self.args.model].Model(self.args).float()
79
+ if self.args.use_multi_gpu and self.args.use_gpu:
80
+ model = nn.DataParallel(model, device_ids=self.args.device_ids)
81
+ return model
82
+
83
+ def _get_data(self, flag):
84
+ data_set, data_loader = data_provider(self.args, flag)
85
+ return data_set, data_loader
86
+
87
+ def _select_optimizer(self):
88
+ model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
89
+ return model_optim
90
+
91
+ def _select_criterion(self):
92
+ criterion = nn.MSELoss()
93
+ return criterion
94
+
95
+ def vali(self, vali_data, vali_loader, criterion):
96
+
97
+ total_loss = []
98
+ self.model.eval()
99
+ with torch.no_grad():
100
+ for i, (batch_x, batch_y, _) in enumerate(vali_loader):
101
+ batch_x = batch_x.float().to(self.device)
102
+ batch_y = batch_y.float()
103
+
104
+ outputs = self.model(batch_x)
105
+
106
+ f_dim = -1 if self.args.features == 'MS' else 0
107
+ outputs = outputs[:, -self.args.pred_len:, f_dim:]
108
+ batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
109
+
110
+ pred = outputs.detach().cpu()
111
+ true = batch_y.detach().cpu()
112
+
113
+ loss = criterion(pred, true)
114
+
115
+ total_loss.append(loss)
116
+ total_loss = np.average(total_loss)
117
+ self.model.train()
118
+ return total_loss
119
+
120
+ def train(self, setting):
121
+ train_data, train_loader = self._get_data(flag='train')
122
+ vali_data, vali_loader = self._get_data(flag='val')
123
+ test_data, test_loader = self._get_data(flag='test')
124
+
125
+ path = os.path.join(self.args.checkpoints, setting)
126
+ if not os.path.exists(path):
127
+ os.makedirs(path)
128
+
129
+ train_steps = len(train_loader)
130
+ early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
131
+ time_now = time.time()
132
+
133
+ model_optim = self._select_optimizer()
134
+ criterion = self._select_criterion()
135
+ time_now = time.time()
136
+
137
+ for epoch in range(self.args.train_epochs):
138
+ iter_count = 0
139
+ train_loss = []
140
+
141
+ self.model.train()
142
+ epoch_time = time.time()
143
+
144
+
145
+ for i, (batch_x, batch_y, aug_data) in enumerate(train_loader):
146
+
147
+
148
+ iter_count += 1
149
+ model_optim.zero_grad()
150
+ if self.args.aug_type == 5:
151
+ aug_data = aug_data.float().to(self.device)
152
+ else:
153
+ aug_data = None
154
+
155
+ if self.args.aug_type:
156
+ aug = augmentation()
157
+ if self.args.aug_type == 1:
158
+ xy = aug.freq_mask(batch_x, batch_y[:, -self.args.pred_len:, :], rate=self.args.aug_rate, dim=1)
159
+ batch_x2, batch_y2 = xy[:, :self.args.seq_len, :], xy[:, -self.args.label_len-self.args.pred_len:, :]
160
+ batch_x = torch.cat([batch_x,batch_x2],dim=0)
161
+ batch_y = torch.cat([batch_y,batch_y2],dim=0)
162
+ elif self.args.aug_type == 2:
163
+ xy = aug.freq_mix(batch_x, batch_y[:, -self.args.pred_len:, :], rate=self.args.aug_rate, dim=1)
164
+ batch_x2, batch_y2 = xy[:, :self.args.seq_len, :], xy[:, -self.args.label_len-self.args.pred_len:, :]
165
+ batch_x = torch.cat([batch_x,batch_x2],dim=0)
166
+ batch_y = torch.cat([batch_y,batch_y2],dim=0)
167
+ elif self.args.aug_type == 3:
168
+ xy = aug.wave_mask(batch_x, batch_y[:, -self.args.pred_len:, :] ,rates = self.args.rates, wavelet =self.args.wavelet, level = self.args.level, dim = 1)
169
+ batch_x2, batch_y2 = xy[:, :self.args.seq_len, :], xy[:, -self.args.label_len-self.args.pred_len:, :]
170
+ sampling_steps = int(batch_x2.shape[0] * self.args.sampling_rate)
171
+ indices = torch.randperm(batch_x2.shape[0])[:sampling_steps]
172
+ batch_x2 = batch_x2[indices,:,:]
173
+ batch_y2 = batch_y2[indices,:,:]
174
+ batch_x = torch.cat([batch_x,batch_x2],dim=0)
175
+ batch_y = torch.cat([batch_y,batch_y2],dim=0)
176
+ elif self.args.aug_type == 4:
177
+ batch_x = batch_x.float().to(self.device)
178
+ batch_y = batch_y.float().to(self.device)
179
+ xy = aug.wave_mix(batch_x, batch_y[:, -self.args.pred_len:, :] ,rates = self.args.rates, wavelet = self.args.wavelet, level = self.args.level, dim = 1)
180
+ batch_x2, batch_y2 = xy[:, :self.args.seq_len, :], xy[:, -self.args.label_len-self.args.pred_len:, :]
181
+ sampling_steps = int(batch_x2.shape[0] * self.args.sampling_rate)
182
+ indices = torch.randperm(batch_x2.shape[0])[:sampling_steps]
183
+ batch_x2 = batch_x2[indices,:,:]
184
+ batch_y2 = batch_y2[indices,:,:]
185
+ batch_x = torch.cat([batch_x,batch_x2],dim=0)
186
+ batch_y = torch.cat([batch_y,batch_y2],dim=0)
187
+ elif self.args.aug_type == 5:
188
+ batch_x = batch_x.float().to(self.device)
189
+ batch_y = batch_y.float().to(self.device)
190
+ weighted_xy = aug.emd_aug(aug_data)
191
+ weighted_x, weighted_y = weighted_xy[:,:self.args.seq_len,:], weighted_xy[:,-self.args.label_len-self.args.pred_len:,:]
192
+ batch_x, batch_y = aug.mix_aug(weighted_x, weighted_y, lambd = self.args.aug_rate)
193
+
194
+
195
+ batch_x = batch_x.float().to(self.device)
196
+ batch_y = batch_y.float().to(self.device)
197
+
198
+ outputs = self.model(batch_x)
199
+
200
+ f_dim = -1 if self.args.features == 'MS' else 0
201
+ outputs = outputs[:, -self.args.pred_len:, f_dim:]
202
+ batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
203
+ loss = criterion(outputs, batch_y)
204
+ train_loss.append(loss.item())
205
+
206
+ if (i + 1) % 100 == 0:
207
+ print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
208
+ speed = (time.time() - time_now) / iter_count
209
+ left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
210
+ print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
211
+ iter_count = 0
212
+ time_now = time.time()
213
+
214
+ loss.backward()
215
+ model_optim.step()
216
+
217
+ print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
218
+ train_loss = np.average(train_loss)
219
+ vali_loss = self.vali(vali_data, vali_loader, criterion)
220
+ test_loss = self.vali(test_data, test_loader, criterion)
221
+
222
+ print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
223
+ epoch + 1, train_steps, train_loss, vali_loss, test_loss))
224
+
225
+ early_stopping(vali_loss, self.model, path)
226
+ if early_stopping.early_stop:
227
+ print("Early stopping")
228
+ break
229
+
230
+ adjust_learning_rate(model_optim, epoch + 1, self.args)
231
+
232
+ best_model_path = path + '/' + 'checkpoint.pth'
233
+ self.model.load_state_dict(torch.load(best_model_path))
234
+ min_val_loss = early_stopping.get_val_loss_min()
235
+
236
+ return self.model, min_val_loss
237
+
238
+ def test(self, setting, test=1):
239
+ test_data, test_loader = self._get_data(flag='test')
240
+
241
+ if test:
242
+ print('loading model')
243
+ self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))
244
+
245
+ preds = []
246
+ trues = []
247
+ inputx = []
248
+
249
+ self.model.eval()
250
+ with torch.no_grad():
251
+ for i, (batch_x, batch_y, _) in enumerate(test_loader):
252
+ batch_x = batch_x.float().to(self.device)
253
+ batch_y = batch_y.float().to(self.device)
254
+
255
+ outputs = self.model(batch_x)
256
+
257
+ f_dim = -1 if self.args.features == 'MS' else 0
258
+ outputs = outputs[:, -self.args.pred_len:, f_dim:]
259
+ batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
260
+ outputs = outputs.detach().cpu().numpy()
261
+ batch_y = batch_y.detach().cpu().numpy()
262
+
263
+ pred = outputs #
264
+ true = batch_y
265
+
266
+ preds.append(pred)
267
+ trues.append(true)
268
+ inputx.append(batch_x.detach().cpu().numpy())
269
+
270
+ if self.args.test_flop:
271
+ test_params_flop((batch_x.shape[1],batch_x.shape[2]))
272
+ exit()
273
+
274
+ preds = np.array(preds)
275
+ trues = np.array(trues)
276
+ inputx = np.array(inputx)
277
+
278
+ preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
279
+ trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
280
+ inputx = inputx.reshape(-1, inputx.shape[-2], inputx.shape[-1])
281
+
282
+ mae, mse, rmse, mape, mspe, rse, corr = metric(preds, trues)
283
+ print('mse:{}, mae:{}, rse:{}, corr:{}'.format(mse, mae, rse, corr))
284
+ f = open(self.args.des + self.args.data + ".txt", 'a')
285
+ f.write(" \n")
286
+ f.write('{} --- Pred {} -> mse:{}, mae:{}, rse:{}'.format(TYPES[self.args.aug_type], self.args.pred_len, mse, mae, rse))
287
+ f.write('\n')
288
+ f.close()
289
+
290
+ return mse, mae, rse
models/DLinear.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # The code is originated from
3
+ # Chen, M., Xu, Z., Zeng, A., & Xu, Q. (2023). "FrAug: Frequency Domain Augmentation for Time Series Forecasting".
4
+ # arXiv preprint arXiv:2302.09292.
5
+ # =============================================================================
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ class moving_avg(nn.Module):
11
+ """
12
+ Moving average block to highlight the trend of time series
13
+ """
14
+ def __init__(self, kernel_size, stride):
15
+ super(moving_avg, self).__init__()
16
+ self.kernel_size = kernel_size
17
+ self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
18
+
19
+ def forward(self, x):
20
+ # padding on the both ends of time series
21
+ front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
22
+ end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
23
+ x = torch.cat([front, x, end], dim=1)
24
+ x = self.avg(x.permute(0, 2, 1))
25
+ x = x.permute(0, 2, 1)
26
+ return x
27
+
28
+
29
+ class series_decomp(nn.Module):
30
+ """
31
+ Series decomposition block
32
+ """
33
+ def __init__(self, kernel_size):
34
+ super(series_decomp, self).__init__()
35
+ self.moving_avg = moving_avg(kernel_size, stride=1)
36
+
37
+ def forward(self, x):
38
+ moving_mean = self.moving_avg(x)
39
+ res = x - moving_mean
40
+ return res, moving_mean
41
+
42
+ class Model(nn.Module):
43
+ """
44
+ Decomposition-Linear
45
+ """
46
+ def __init__(self, args):
47
+ super(Model, self).__init__()
48
+ self.seq_len = args.seq_len
49
+ self.pred_len = args.pred_len
50
+
51
+ # Decompsition Kernel Size
52
+ kernel_size = 25
53
+ self.decompsition = series_decomp(kernel_size)
54
+ self.individual = args.individual
55
+ self.channels = args.enc_in
56
+
57
+ if self.individual:
58
+ self.Linear_Seasonal = nn.ModuleList()
59
+ self.Linear_Trend = nn.ModuleList()
60
+
61
+ for i in range(self.channels):
62
+ self.Linear_Seasonal.append(nn.Linear(self.seq_len,self.pred_len))
63
+ self.Linear_Trend.append(nn.Linear(self.seq_len,self.pred_len))
64
+
65
+ # Use this two lines if you want to visualize the weights
66
+ # self.Linear_Seasonal[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
67
+ # self.Linear_Trend[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
68
+ else:
69
+ self.Linear_Seasonal = nn.Linear(self.seq_len,self.pred_len)
70
+ self.Linear_Trend = nn.Linear(self.seq_len,self.pred_len)
71
+
72
+ # Use this two lines if you want to visualize weights
73
+ # self.Linear_Seasonal.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
74
+ # self.Linear_Trend.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len]))
75
+
76
+ def forward(self, x):
77
+ # x: [Batch, Input length, Channel]
78
+ seasonal_init, trend_init = self.decompsition(x)
79
+ seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1)
80
+ if self.individual:
81
+ seasonal_output = torch.zeros([seasonal_init.size(0),seasonal_init.size(1),self.pred_len],dtype=seasonal_init.dtype).to(seasonal_init.device)
82
+ trend_output = torch.zeros([trend_init.size(0),trend_init.size(1),self.pred_len],dtype=trend_init.dtype).to(trend_init.device)
83
+ for i in range(self.channels):
84
+ seasonal_output[:,i,:] = self.Linear_Seasonal[i](seasonal_init[:,i,:])
85
+ trend_output[:,i,:] = self.Linear_Trend[i](trend_init[:,i,:])
86
+ else:
87
+ seasonal_output = self.Linear_Seasonal(seasonal_init)
88
+ trend_output = self.Linear_Trend(trend_init)
89
+
90
+ x = seasonal_output + trend_output
91
+ return x.permute(0,2,1) # to [Batch, Output length, Channel]
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ matplotlib
3
+ pandas
4
+ scikit-learn
5
+ torch==1.9.0
6
+ einops
7
+ pywavelets
8
+ EMD-signal
9
+ tensorflow==2.11.0
10
+ pytorch_wavelets
run_main.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # The code is originated from
3
+ # Chen, M., Xu, Z., Zeng, A., & Xu, Q. (2023). "FrAug: Frequency Domain Augmentation for Time Series Forecasting".
4
+ # arXiv preprint arXiv:2302.09292.
5
+ # =============================================================================
6
+
7
+ import argparse
8
+ import os
9
+ import ast
10
+ import torch
11
+ from main_run.train import Exp_Main
12
+ import random
13
+ import numpy as np
14
+
15
+ fix_seed = 2024
16
+ random.seed(fix_seed)
17
+ torch.manual_seed(fix_seed)
18
+ np.random.seed(fix_seed)
19
+
20
+ TYPES = {0: 'None',
21
+ 1: 'Freq-Mask',
22
+ 2: 'Freq-Mix',
23
+ 3: 'Wave-Mask',
24
+ 4: 'Wave-Mix',
25
+ 5: 'StAug'}
26
+
27
+
28
+ parser = argparse.ArgumentParser(description='Augmentations for Time Series Forecasting')
29
+
30
+ # basic config
31
+ parser.add_argument('--model', type=str, required=True, default='DLinear',
32
+ help='model name, options: [DLinear]')
33
+
34
+ # data loader
35
+ parser.add_argument('--is_training', type=int, required=True, default=1, help='status')
36
+ parser.add_argument('--data', type=str, required=True, default='ETTh1', help='dataset type')
37
+ parser.add_argument('--root_path', type=str, default='./dataset/', help='root path of the data file')
38
+ parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
39
+ parser.add_argument('--features', type=str, default='M',
40
+ help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
41
+ parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
42
+ parser.add_argument('--freq', type=str, default='h',
43
+ help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
44
+ parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')
45
+ parser.add_argument('--percentage', type=int, default=100, help='percentage of train data as a downsampling ratio')
46
+ parser.add_argument('--patience', type=int, default=12, help='early stopping patience')
47
+
48
+ # forecasting task
49
+ parser.add_argument('--seq_len', type=int, default=336, help='input sequence length')
50
+ parser.add_argument('--label_len', type=int, default=0, help='start token length')
51
+ parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')
52
+
53
+
54
+ # DLinear
55
+ parser.add_argument('--individual', action='store_true', default=False, help='DLinear: a linear layer for each variate(channel) individually')
56
+ parser.add_argument('--enc_in', type=int, default=7, help='encoder input size') # DLinear with --individual, use this hyperparameter as the number of channels
57
+
58
+ # optimization
59
+ parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
60
+ parser.add_argument('--itr', type=int, default=2, help='experiments times')
61
+ parser.add_argument('--train_epochs', type=int, default=30, help='train epochs')
62
+ parser.add_argument('--batch_size', type=int, default=64, help='batch size of train input data')
63
+ parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')
64
+ parser.add_argument('--des', type=str, default='test', help='exp description')
65
+ parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')
66
+ parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)
67
+
68
+ # GPU
69
+ parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
70
+ parser.add_argument('--gpu', type=int, default=0, help='gpu')
71
+ parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
72
+ parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus')
73
+ parser.add_argument('--test_flop', action='store_true', default=False, help='See utils/tools for usage')
74
+
75
+ # Augmentation
76
+ parser.add_argument('--aug_type', type=int, default=0, help='0: No augmentation, 1: Frequency Masking 2: Frequency Mixing 3: Wave Masking 4: Wave Mixing 5: StAug ')
77
+ parser.add_argument('--aug_rate', type=float, default=0.5, help='rate for FreqMask, FreqMix, and STAug')
78
+ parser.add_argument('--wavelet', type=str, default='db2', help='wavelet form for DWT')
79
+ parser.add_argument('--level', type=int, default=2, help='level for DWT')
80
+ parser.add_argument('--rates', type=str, default="[0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]",
81
+ help='List of float rates as a string, e.g., "[0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]"')
82
+ parser.add_argument('--nIMF', type=int, default=500, help='number of IMFs for EMD (STAug)')
83
+ parser.add_argument('--sampling_rate', type=float, default=0.5, help='sampling rate for WaveMask and WaveMix')
84
+
85
+
86
+
87
+
88
+ args = parser.parse_args()
89
+ args.rates = ast.literal_eval(args.rates)
90
+
91
+ args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
92
+
93
+ if args.use_gpu and args.use_multi_gpu:
94
+ args.dvices = args.devices.replace(' ', '')
95
+ device_ids = args.devices.split(',')
96
+ args.device_ids = [int(id_) for id_ in device_ids]
97
+ args.gpu = args.device_ids[0]
98
+
99
+ print('Args in experiment:')
100
+ print(args)
101
+
102
+ Exp = Exp_Main
103
+
104
+ if args.is_training:
105
+ mse_avg, mae_avg, rse_avg = np.zeros(args.itr), np.zeros(args.itr), np.zeros(args.itr)
106
+ for ii in range(args.itr):
107
+ # setting record of experiments
108
+ setting = '{}_{}_ft{}_sl{}_ll{}_pl{}_{}_{}_{}_{}'.format(
109
+ args.model,
110
+ args.data,
111
+ args.features,
112
+ args.seq_len,
113
+ args.label_len,
114
+ args.pred_len,
115
+ args.des, args.aug_type, args.percentage, ii)
116
+
117
+ exp = Exp(args) # set experiments
118
+ print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
119
+ exp.train(setting)
120
+
121
+ print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
122
+ mse, mae, rse = exp.test(setting)
123
+ mse_avg[ii] = mse
124
+ mae_avg[ii] = mae
125
+ rse_avg[ii] = rse
126
+
127
+ f = open("result-" + args.des + args.data + ".txt", 'a')
128
+ f.write('\n')
129
+ f.write('\n')
130
+ f.write("-------START FROM HERE-----")
131
+ f.write(TYPES[args.aug_type] + " \n")
132
+ f.write('avg mse:{}, avg mae:{} avg rse:{} std mse:{}, std mae:{} std rse:{}'.format(mse_avg.mean(), mae_avg.mean(), rse_avg.mean(), mse_avg.std(), mae_avg.std(), rse_avg.std()))
133
+ f.write('\n')
134
+ f.write('\n')
135
+ f.close()
136
+ torch.cuda.empty_cache()
137
+ else:
138
+ ii = 0
139
+ setting = '{}_{}_ft{}_sl{}_ll{}_pl{}_{}_{}_{}'.format(args.model,
140
+ args.data,
141
+ args.features,
142
+ args.seq_len,
143
+ args.label_len,
144
+ args.pred_len,
145
+ args.des, args.aug_type, ii)
146
+
147
+ exp = Exp(args) # set experiments
148
+ print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting))
149
+ exp.test(setting, test=1)
150
+ torch.cuda.empty_cache()
scripts/etth1.sh ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ if [ ! -d "./logs" ]; then
3
+ mkdir ./logs
4
+ fi
5
+
6
+ if [ ! -d "./logs/onerun" ]; then
7
+ mkdir ./logs/onerun
8
+ fi
9
+
10
+ seq_len=336
11
+ percentage=100
12
+ model_name=DLinear
13
+
14
+ # aug 0: None
15
+ # aug 1: Frequency Masking
16
+ # aug 2: Frequency Mixing
17
+ # aug 3: Wave Masking
18
+ # aug 4: Wave Mixing
19
+ # aug 5: STAug
20
+
21
+ pred_lens=(96 192 336 720)
22
+
23
+
24
+ # For Aug 0: None
25
+
26
+ for pred_len in "${pred_lens[@]}"; do
27
+ python3 -u ./run_main.py \
28
+ --is_training 1 \
29
+ --root_path ./dataset/ \
30
+ --data_path ETTh1.csv \
31
+ --model $model_name \
32
+ --data ETTh1 \
33
+ --features M \
34
+ --seq_len $seq_len \
35
+ --pred_len $pred_len \
36
+ --enc_in 7 \
37
+ --des '100p-h1-' \
38
+ --percentage $percentage \
39
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 0 --aug_rate 0.0 >logs/onerun/$model_name'_'Etth1_$seq_len'_'$pred_len'_'0.0'_'$percentage'_'None.log
40
+ done
41
+
42
+
43
+ # For Aug 1: Freq-Masking
44
+
45
+ python3 -u ./run_main.py \
46
+ --is_training 1 \
47
+ --root_path ./dataset/ \
48
+ --data_path ETTh1.csv \
49
+ --model $model_name \
50
+ --data ETTh1 \
51
+ --features M \
52
+ --seq_len $seq_len \
53
+ --pred_len 96 \
54
+ --enc_in 7 \
55
+ --des '100p-h1-' \
56
+ --percentage $percentage \
57
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'Etth1_$seq_len'_'96'_'0.4'_'$percentage'_'FreqMask.log
58
+
59
+ python3 -u ./run_main.py \
60
+ --is_training 1 \
61
+ --root_path ./dataset/ \
62
+ --data_path ETTh1.csv \
63
+ --model $model_name \
64
+ --data ETTh1 \
65
+ --features M \
66
+ --seq_len $seq_len \
67
+ --pred_len 192 \
68
+ --enc_in 7 \
69
+ --des '100p-h1-' \
70
+ --percentage $percentage \
71
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.5 >logs/onerun/$model_name'_'Etth1_$seq_len'_'192'_'0.4'_'$percentage'_'FreqMask.log
72
+
73
+ python3 -u ./run_main.py \
74
+ --is_training 1 \
75
+ --root_path ./dataset/ \
76
+ --data_path ETTh1.csv \
77
+ --model $model_name \
78
+ --data ETTh1 \
79
+ --features M \
80
+ --seq_len $seq_len \
81
+ --pred_len 336 \
82
+ --enc_in 7 \
83
+ --des '100p-h1-' \
84
+ --percentage $percentage \
85
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.5 >logs/onerun/$model_name'_'Etth1_$seq_len'_'336'_'0.4'_'$percentage'_'FreqMask.log
86
+
87
+ python3 -u ./run_main.py \
88
+ --is_training 1 \
89
+ --root_path ./dataset/ \
90
+ --data_path ETTh1.csv \
91
+ --model $model_name \
92
+ --data ETTh1 \
93
+ --features M \
94
+ --seq_len $seq_len \
95
+ --pred_len 720 \
96
+ --enc_in 7 \
97
+ --des '100p-h1-' \
98
+ --percentage $percentage \
99
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.4 >logs/onerun/$model_name'_'Etth1_$seq_len'_'720'_'0.4'_'$percentage'_'FreqMask.log
100
+
101
+
102
+
103
+ # For Aug 2: Freq-Mixing
104
+
105
+ python3 -u ./run_main.py \
106
+ --is_training 1 \
107
+ --root_path ./dataset/ \
108
+ --data_path ETTh1.csv \
109
+ --model $model_name \
110
+ --data ETTh1 \
111
+ --features M \
112
+ --seq_len $seq_len \
113
+ --pred_len 96 \
114
+ --enc_in 7 \
115
+ --des '100p-h1-' \
116
+ --percentage $percentage \
117
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.2 >logs/onerun/$model_name'_'Etth1_$seq_len'_'96'_'0.4'_'$percentage'_'FreqMix.log
118
+
119
+ python3 -u ./run_main.py \
120
+ --is_training 1 \
121
+ --root_path ./dataset/ \
122
+ --data_path ETTh1.csv \
123
+ --model $model_name \
124
+ --data ETTh1 \
125
+ --features M \
126
+ --seq_len $seq_len \
127
+ --pred_len 192 \
128
+ --enc_in 7 \
129
+ --des '100p-h1-' \
130
+ --percentage $percentage \
131
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'Etth1_$seq_len'_'192'_'0.4'_'$percentage'_'FreqMix.log
132
+
133
+ python3 -u ./run_main.py \
134
+ --is_training 1 \
135
+ --root_path ./dataset/ \
136
+ --data_path ETTh1.csv \
137
+ --model $model_name \
138
+ --data ETTh1 \
139
+ --features M \
140
+ --seq_len $seq_len \
141
+ --pred_len 336 \
142
+ --enc_in 7 \
143
+ --des '100p-h1-' \
144
+ --percentage $percentage \
145
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'Etth1_$seq_len'_'336'_'0.4'_'$percentage'_'FreqMix.log
146
+
147
+ python3 -u ./run_main.py \
148
+ --is_training 1 \
149
+ --root_path ./dataset/ \
150
+ --data_path ETTh1.csv \
151
+ --model $model_name \
152
+ --data ETTh1 \
153
+ --features M \
154
+ --seq_len $seq_len \
155
+ --pred_len 720 \
156
+ --enc_in 7 \
157
+ --des '100p-h1-' \
158
+ --percentage $percentage \
159
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.6 >logs/onerun/$model_name'_'Etth1_$seq_len'_'720'_'0.4'_'$percentage'_'FreqMix.log
160
+
161
+ # For Aug 3: Wave Masking
162
+
163
+ python3 -u ./run_main.py \
164
+ --is_training 1 \
165
+ --root_path ./dataset/ \
166
+ --data_path ETTh1.csv \
167
+ --model $model_name \
168
+ --data ETTh1 \
169
+ --features M \
170
+ --seq_len $seq_len \
171
+ --pred_len 96 \
172
+ --enc_in 7 \
173
+ --des '100p-h1-' \
174
+ --percentage $percentage \
175
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.5, 0.3, 0.9, 0.9, 0.0, 0.5, 0.3]" --wavelet 'db2' --level 3 --sampling_rate 0.2 >logs/onerun/$model_name'_'Etth1_$seq_len'_'96'_'0.0'_'$percentage'_'WaveMask.log
176
+
177
+ python3 -u ./run_main.py \
178
+ --is_training 1 \
179
+ --root_path ./dataset/ \
180
+ --data_path ETTh1.csv \
181
+ --model $model_name \
182
+ --data ETTh1 \
183
+ --features M \
184
+ --seq_len $seq_len \
185
+ --pred_len 192 \
186
+ --enc_in 7 \
187
+ --des '100p-h1-' \
188
+ --percentage $percentage \
189
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.0, 1.0, 0.2, 1.0, 0.4, 0.9, 0.3]" --wavelet 'db3' --level 1 --sampling_rate 0.2 >logs/onerun/$model_name'_'Etth1_$seq_len'_'192'_'0.0'_'$percentage'_'WaveMask.log
190
+
191
+ python3 -u ./run_main.py \
192
+ --is_training 1 \
193
+ --root_path ./dataset/ \
194
+ --data_path ETTh1.csv \
195
+ --model $model_name \
196
+ --data ETTh1 \
197
+ --features M \
198
+ --seq_len $seq_len \
199
+ --pred_len 336 \
200
+ --enc_in 7 \
201
+ --des '100p-h1-' \
202
+ --percentage $percentage \
203
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.1, 0.9, 0.4, 0.8, 0.4, 0.9, 0.3]" --wavelet 'db25' --level 1 --sampling_rate 0.8 >logs/onerun/$model_name'_'Etth1_$seq_len'_'336'_'0.0'_'$percentage'_'WaveMask.log
204
+
205
+ python3 -u ./run_main.py \
206
+ --is_training 1 \
207
+ --root_path ./dataset/ \
208
+ --data_path ETTh1.csv \
209
+ --model $model_name \
210
+ --data ETTh1 \
211
+ --features M \
212
+ --seq_len $seq_len \
213
+ --pred_len 720 \
214
+ --enc_in 7 \
215
+ --des '100p-h1-' \
216
+ --percentage $percentage \
217
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.4, 0.9, 0.5, 1.0, 0.3, 0.9, 0.3]" --wavelet 'db1' --level 1 --sampling_rate 0.2 >logs/onerun/$model_name'_'Etth1_$seq_len'_'720'_'0.0'_'$percentage'_'WaveMask.log
218
+
219
+ # For Aug 4: Wave Mixing
220
+
221
+ python3 -u ./run_main.py \
222
+ --is_training 1 \
223
+ --root_path ./dataset/ \
224
+ --data_path ETTh1.csv \
225
+ --model $model_name \
226
+ --data ETTh1 \
227
+ --features M \
228
+ --seq_len $seq_len \
229
+ --pred_len 96 \
230
+ --enc_in 7 \
231
+ --des '100p-h1-' \
232
+ --percentage $percentage \
233
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.0, 0.9, 0.7, 0.7, 0.5, 0.7, 0.6]" --wavelet 'db3' --level 1 --sampling_rate 0.2 >logs/onerun/$model_name'_'Etth1_$seq_len'_'96'_'0.0'_'$percentage'_'WaveMix.log
234
+
235
+ python3 -u ./run_main.py \
236
+ --is_training 1 \
237
+ --root_path ./dataset/ \
238
+ --data_path ETTh1.csv \
239
+ --model $model_name \
240
+ --data ETTh1 \
241
+ --features M \
242
+ --seq_len $seq_len \
243
+ --pred_len 192 \
244
+ --enc_in 7 \
245
+ --des '100p-h1-' \
246
+ --percentage $percentage \
247
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[1.0, 0.4, 0.6, 0.6, 0.3, 0.8, 0.1]" --wavelet 'db3' --level 1 --sampling_rate 0.8 >logs/onerun/$model_name'_'Etth1_$seq_len'_'192'_'0.0'_'$percentage'_'WaveMix.log
248
+
249
+ python3 -u ./run_main.py \
250
+ --is_training 1 \
251
+ --root_path ./dataset/ \
252
+ --data_path ETTh1.csv \
253
+ --model $model_name \
254
+ --data ETTh1 \
255
+ --features M \
256
+ --seq_len $seq_len \
257
+ --pred_len 336 \
258
+ --enc_in 7 \
259
+ --des '100p-h1-' \
260
+ --percentage $percentage \
261
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.0, 0.9, 0.2, 0.2, 0.2, 0.1, 0.7]" --wavelet 'db3' --level 1 --sampling_rate 0.8 >logs/onerun/$model_name'_'Etth1_$seq_len'_'336'_'0.0'_'$percentage'_'WaveMix.log
262
+
263
+ python3 -u ./run_main.py \
264
+ --is_training 1 \
265
+ --root_path ./dataset/ \
266
+ --data_path ETTh1.csv \
267
+ --model $model_name \
268
+ --data ETTh1 \
269
+ --features M \
270
+ --seq_len $seq_len \
271
+ --pred_len 720 \
272
+ --enc_in 7 \
273
+ --des '100p-h1-' \
274
+ --percentage $percentage \
275
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.1, 0.9, 0.5, 0.7, 0.1, 0.2, 0.8]" --wavelet 'db5' --level 1 --sampling_rate 0.8 >logs/onerun/$model_name'_'Etth1_$seq_len'_'720'_'0.0'_'$percentage'_'WaveMix.log
276
+
277
+
278
+ # For Aug 5: STAug
279
+
280
+ python3 -u ./run_main.py \
281
+ --is_training 1 \
282
+ --root_path ./dataset/ \
283
+ --data_path ETTh1.csv \
284
+ --model $model_name \
285
+ --data ETTh1 \
286
+ --features M \
287
+ --seq_len $seq_len \
288
+ --pred_len 96 \
289
+ --enc_in 7 \
290
+ --des '100p-h1-' \
291
+ --percentage $percentage \
292
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 5 --aug_rate 0.9 --nIMF 100 >logs/onerun/$model_name'_'Etth1_$seq_len'_'96'_'0.9'_'$percentage'_'StAug.log
293
+
294
+
295
+ python3 -u ./run_main.py \
296
+ --is_training 1 \
297
+ --root_path ./dataset/ \
298
+ --data_path ETTh1.csv \
299
+ --model $model_name \
300
+ --data ETTh1 \
301
+ --features M \
302
+ --seq_len $seq_len \
303
+ --pred_len 192 \
304
+ --enc_in 7 \
305
+ --des '100p-h1-' \
306
+ --percentage $percentage \
307
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 5 --aug_rate 0.9 --nIMF 900 >logs/onerun/$model_name'_'Etth1_$seq_len'_'192'_'0.9'_'$percentage'_'StAug.log
308
+
309
+ python3 -u ./run_main.py \
310
+ --is_training 1 \
311
+ --root_path ./dataset/ \
312
+ --data_path ETTh1.csv \
313
+ --model $model_name \
314
+ --data ETTh1 \
315
+ --features M \
316
+ --seq_len $seq_len \
317
+ --pred_len 336 \
318
+ --enc_in 7 \
319
+ --des '100p-h1-' \
320
+ --percentage $percentage \
321
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 5 --aug_rate 0.8 --nIMF 200 >logs/onerun/$model_name'_'Etth1_$seq_len'_'336'_'0.9'_'$percentage'_'StAug.log
322
+
323
+ python3 -u ./run_main.py \
324
+ --is_training 1 \
325
+ --root_path ./dataset/ \
326
+ --data_path ETTh1.csv \
327
+ --model $model_name \
328
+ --data ETTh1 \
329
+ --features M \
330
+ --seq_len $seq_len \
331
+ --pred_len 720 \
332
+ --enc_in 7 \
333
+ --des '100p-h1-' \
334
+ --percentage $percentage \
335
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 5 --aug_rate 0.7 --nIMF 1000 >logs/onerun/$model_name'_'Etth1_$seq_len'_'720'_'0.9'_'$percentage'_'StAug.log
scripts/etth2.sh ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ if [ ! -d "./logs" ]; then
3
+ mkdir ./logs
4
+ fi
5
+
6
+ if [ ! -d "./logs/onerun" ]; then
7
+ mkdir ./logs/onerun
8
+ fi
9
+
10
+ seq_len=336
11
+ percentage=100
12
+ model_name=DLinear
13
+
14
+ # aug 0: None
15
+ # aug 1: Frequency Masking
16
+ # aug 2: Frequency Mixing
17
+ # aug 3: Wave Masking
18
+ # aug 4: Wave Mixing
19
+ # aug 5: STAug
20
+
21
+ pred_lens=(96 192 336 720)
22
+
23
+
24
+ # For Aug 0: None
25
+
26
+ for pred_len in "${pred_lens[@]}"; do
27
+ python3 -u ./run_main.py \
28
+ --is_training 1 \
29
+ --root_path ./dataset/ \
30
+ --data_path ETTh2.csv \
31
+ --model $model_name \
32
+ --data ETTh2 \
33
+ --features M \
34
+ --seq_len $seq_len \
35
+ --pred_len $pred_len \
36
+ --enc_in 7 \
37
+ --des '100p-h2-' \
38
+ --percentage $percentage \
39
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 0 --aug_rate 0.0 >logs/onerun/$model_name'_'Etth2_$seq_len'_'$pred_len'_'0.0'_'$percentage'_'None.log
40
+ done
41
+
42
+
43
+ # For Aug 1: Freq-Masking
44
+
45
+ python3 -u ./run_main.py \
46
+ --is_training 1 \
47
+ --root_path ./dataset/ \
48
+ --data_path ETTh2.csv \
49
+ --model $model_name \
50
+ --data ETTh2 \
51
+ --features M \
52
+ --seq_len $seq_len \
53
+ --pred_len 96 \
54
+ --enc_in 7 \
55
+ --des '100p-h2-' \
56
+ --percentage $percentage \
57
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.6 >logs/onerun/$model_name'_'Etth2_$seq_len'_'96'_'0.4'_'$percentage'_'FreqMask.log
58
+
59
+ python3 -u ./run_main.py \
60
+ --is_training 1 \
61
+ --root_path ./dataset/ \
62
+ --data_path ETTh2.csv \
63
+ --model $model_name \
64
+ --data ETTh2 \
65
+ --features M \
66
+ --seq_len $seq_len \
67
+ --pred_len 192 \
68
+ --enc_in 7 \
69
+ --des '100p-h2-' \
70
+ --percentage $percentage \
71
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.6 >logs/onerun/$model_name'_'Etth2_$seq_len'_'192'_'0.4'_'$percentage'_'FreqMask.log
72
+
73
+ python3 -u ./run_main.py \
74
+ --is_training 1 \
75
+ --root_path ./dataset/ \
76
+ --data_path ETTh2.csv \
77
+ --model $model_name \
78
+ --data ETTh2 \
79
+ --features M \
80
+ --seq_len $seq_len \
81
+ --pred_len 336 \
82
+ --enc_in 7 \
83
+ --des '100p-h2-' \
84
+ --percentage $percentage \
85
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'Etth2_$seq_len'_'336'_'0.4'_'$percentage'_'FreqMask.log
86
+
87
+ python3 -u ./run_main.py \
88
+ --is_training 1 \
89
+ --root_path ./dataset/ \
90
+ --data_path ETTh2.csv \
91
+ --model $model_name \
92
+ --data ETTh2 \
93
+ --features M \
94
+ --seq_len $seq_len \
95
+ --pred_len 720 \
96
+ --enc_in 7 \
97
+ --des '100p-h2-' \
98
+ --percentage $percentage \
99
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'Etth2_$seq_len'_'720'_'0.4'_'$percentage'_'FreqMask.log
100
+
101
+
102
+
103
+ # For Aug 2: Freq-Mixing
104
+
105
+ python3 -u ./run_main.py \
106
+ --is_training 1 \
107
+ --root_path ./dataset/ \
108
+ --data_path ETTh2.csv \
109
+ --model $model_name \
110
+ --data ETTh2 \
111
+ --features M \
112
+ --seq_len $seq_len \
113
+ --pred_len 96 \
114
+ --enc_in 7 \
115
+ --des '100p-h2-' \
116
+ --percentage $percentage \
117
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.9 >logs/onerun/$model_name'_'Etth2_$seq_len'_'96'_'0.4'_'$percentage'_'FreqMix.log
118
+
119
+
120
+ python3 -u ./run_main.py \
121
+ --is_training 1 \
122
+ --root_path ./dataset/ \
123
+ --data_path ETTh2.csv \
124
+ --model $model_name \
125
+ --data ETTh2 \
126
+ --features M \
127
+ --seq_len $seq_len \
128
+ --pred_len 192 \
129
+ --enc_in 7 \
130
+ --des '100p-h2-' \
131
+ --percentage $percentage \
132
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.8 >logs/onerun/$model_name'_'Etth2_$seq_len'_'192'_'0.4'_'$percentage'_'FreqMix.log
133
+
134
+ python3 -u ./run_main.py \
135
+ --is_training 1 \
136
+ --root_path ./dataset/ \
137
+ --data_path ETTh2.csv \
138
+ --model $model_name \
139
+ --data ETTh2 \
140
+ --features M \
141
+ --seq_len $seq_len \
142
+ --pred_len 336 \
143
+ --enc_in 7 \
144
+ --des '100p-h2-' \
145
+ --percentage $percentage \
146
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.8 >logs/onerun/$model_name'_'Etth2_$seq_len'_'336'_'0.4'_'$percentage'_'FreqMix.log
147
+
148
+ python3 -u ./run_main.py \
149
+ --is_training 1 \
150
+ --root_path ./dataset/ \
151
+ --data_path ETTh2.csv \
152
+ --model $model_name \
153
+ --data ETTh2 \
154
+ --features M \
155
+ --seq_len $seq_len \
156
+ --pred_len 720 \
157
+ --enc_in 7 \
158
+ --des '100p-h2-' \
159
+ --percentage $percentage \
160
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'Etth2_$seq_len'_'720'_'0.4'_'$percentage'_'FreqMix.log
161
+
162
+
163
+
164
+
165
+ # For Aug 3: Wave Masking
166
+
167
+ python3 -u ./run_main.py \
168
+ --is_training 1 \
169
+ --root_path ./dataset/ \
170
+ --data_path ETTh2.csv \
171
+ --model $model_name \
172
+ --data ETTh2 \
173
+ --features M \
174
+ --seq_len $seq_len \
175
+ --pred_len 96 \
176
+ --enc_in 7 \
177
+ --des '100p-h2-' \
178
+ --percentage $percentage \
179
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.4, 0.9, 0.0, 0.8, 0.3, 0.7, 0.2]" --wavelet 'db26' --level 2 --sampling_rate 0.5 >logs/onerun/$model_name'_'Etth2_$seq_len'_'96'_'0.0'_'$percentage'_'WaveMask.log
180
+
181
+ python3 -u ./run_main.py \
182
+ --is_training 1 \
183
+ --root_path ./dataset/ \
184
+ --data_path ETTh2.csv \
185
+ --model $model_name \
186
+ --data ETTh2 \
187
+ --features M \
188
+ --seq_len $seq_len \
189
+ --pred_len 192 \
190
+ --enc_in 7 \
191
+ --des '100p-h2-' \
192
+ --percentage $percentage \
193
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.6, 0.7, 0.5, 0.7, 0.8, 0.3, 0.5]" --wavelet 'db26' --level 2 --sampling_rate 0.8 >logs/onerun/$model_name'_'Etth2_$seq_len'_'192'_'0.0'_'$percentage'_'WaveMask.log
194
+
195
+ python3 -u ./run_main.py \
196
+ --is_training 1 \
197
+ --root_path ./dataset/ \
198
+ --data_path ETTh2.csv \
199
+ --model $model_name \
200
+ --data ETTh2 \
201
+ --features M \
202
+ --seq_len $seq_len \
203
+ --pred_len 336 \
204
+ --enc_in 7 \
205
+ --des '100p-h2-' \
206
+ --percentage $percentage \
207
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.2, 0.7, 0.9, 0.4, 0.4, 0.1, 1.0]" --wavelet 'db1' --level 3 --sampling_rate 0.8 >logs/onerun/$model_name'_'Etth2_$seq_len'_'336'_'0.0'_'$percentage'_'WaveMask.log
208
+
209
+ python3 -u ./run_main.py \
210
+ --is_training 1 \
211
+ --root_path ./dataset/ \
212
+ --data_path ETTh2.csv \
213
+ --model $model_name \
214
+ --data ETTh2 \
215
+ --features M \
216
+ --seq_len $seq_len \
217
+ --pred_len 720 \
218
+ --enc_in 7 \
219
+ --des '100p-h2-' \
220
+ --percentage $percentage \
221
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.8, 0.9, 0.4, 0.9, 0.4, 0.4, 0.5]" --wavelet 'db5' --level 4 --sampling_rate 0.2 >logs/onerun/$model_name'_'Etth2_$seq_len'_'720'_'0.0'_'$percentage'_'WaveMask.log
222
+
223
+ # For Aug 4: Wave Mixing
224
+
225
+
226
+ python3 -u ./run_main.py \
227
+ --is_training 1 \
228
+ --root_path ./dataset/ \
229
+ --data_path ETTh2.csv \
230
+ --model $model_name \
231
+ --data ETTh2 \
232
+ --features M \
233
+ --seq_len $seq_len \
234
+ --pred_len 96 \
235
+ --enc_in 7 \
236
+ --des '100p-h2-' \
237
+ --percentage $percentage \
238
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.9, 0.9, 0.1, 0.5, 0.9, 0.6, 0.2]" --wavelet 'db25' --level 2 --sampling_rate 0.2 >logs/onerun/$model_name'_'Etth2_$seq_len'_'96'_'0.0'_'$percentage'_'WaveMix.log
239
+
240
+ python3 -u ./run_main.py \
241
+ --is_training 1 \
242
+ --root_path ./dataset/ \
243
+ --data_path ETTh2.csv \
244
+ --model $model_name \
245
+ --data ETTh2 \
246
+ --features M \
247
+ --seq_len $seq_len \
248
+ --pred_len 192 \
249
+ --enc_in 7 \
250
+ --des '100p-h2-' \
251
+ --percentage $percentage \
252
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.9, 0.4, 0.1, 0.8, 0.9, 0.5, 0.4]" --wavelet 'db1' --level 3 --sampling_rate 0.5 >logs/onerun/$model_name'_'Etth2_$seq_len'_'192'_'0.0'_'$percentage'_'WaveMix.log
253
+
254
+ python3 -u ./run_main.py \
255
+ --is_training 1 \
256
+ --root_path ./dataset/ \
257
+ --data_path ETTh2.csv \
258
+ --model $model_name \
259
+ --data ETTh2 \
260
+ --features M \
261
+ --seq_len $seq_len \
262
+ --pred_len 336 \
263
+ --enc_in 7 \
264
+ --des '100p-h2-' \
265
+ --percentage $percentage \
266
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.9, 0.1, 0.2, 0.5, 1.0, 0.5, 0.5]" --wavelet 'db25' --level 3 --sampling_rate 0.8 >logs/onerun/$model_name'_'Etth2_$seq_len'_'336'_'0.0'_'$percentage'_'WaveMix.log
267
+
268
+ python3 -u ./run_main.py \
269
+ --is_training 1 \
270
+ --root_path ./dataset/ \
271
+ --data_path ETTh2.csv \
272
+ --model $model_name \
273
+ --data ETTh2 \
274
+ --features M \
275
+ --seq_len $seq_len \
276
+ --pred_len 720 \
277
+ --enc_in 7 \
278
+ --des '100p-h2-' \
279
+ --percentage $percentage \
280
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.5, 0.1, 0.2, 0.7, 0.4, 0.9, 0.5]" --wavelet 'db5' --level 1 --sampling_rate 1.0 >logs/onerun/$model_name'_'Etth2_$seq_len'_'720'_'0.0'_'$percentage'_'WaveMix.log
281
+
282
+
283
+
284
+ # For Aug 5: STAug
285
+
286
+ python3 -u ./run_main.py \
287
+ --is_training 1 \
288
+ --root_path ./dataset/ \
289
+ --data_path ETTh2.csv \
290
+ --model $model_name \
291
+ --data ETTh2 \
292
+ --features M \
293
+ --seq_len $seq_len \
294
+ --pred_len 96 \
295
+ --enc_in 7 \
296
+ --des '100p-h2-' \
297
+ --percentage $percentage \
298
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 5 --aug_rate 0.4 --nIMF 2000 >logs/onerun/$model_name'_'Etth2_$seq_len'_'96'_'0.9'_'$percentage'_'StAug.log
299
+
300
+
301
+ python3 -u ./run_main.py \
302
+ --is_training 1 \
303
+ --root_path ./dataset/ \
304
+ --data_path ETTh2.csv \
305
+ --model $model_name \
306
+ --data ETTh2 \
307
+ --features M \
308
+ --seq_len $seq_len \
309
+ --pred_len 192 \
310
+ --enc_in 7 \
311
+ --des '100p-h2-' \
312
+ --percentage $percentage \
313
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 5 --aug_rate 0.9 --nIMF 200 >logs/onerun/$model_name'_'Etth2_$seq_len'_'192'_'0.9'_'$percentage'_'StAug.log
314
+
315
+ python3 -u ./run_main.py \
316
+ --is_training 1 \
317
+ --root_path ./dataset/ \
318
+ --data_path ETTh2.csv \
319
+ --model $model_name \
320
+ --data ETTh2 \
321
+ --features M \
322
+ --seq_len $seq_len \
323
+ --pred_len 336 \
324
+ --enc_in 7 \
325
+ --des '100p-h2-' \
326
+ --percentage $percentage \
327
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 5 --aug_rate 0.6 --nIMF 100 >logs/onerun/$model_name'_'Etth2_$seq_len'_'336'_'0.9'_'$percentage'_'StAug.log
328
+
329
+ python3 -u ./run_main.py \
330
+ --is_training 1 \
331
+ --root_path ./dataset/ \
332
+ --data_path ETTh2.csv \
333
+ --model $model_name \
334
+ --data ETTh2 \
335
+ --features M \
336
+ --seq_len $seq_len \
337
+ --pred_len 720 \
338
+ --enc_in 7 \
339
+ --des '100p-h2-' \
340
+ --percentage $percentage \
341
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 5 --aug_rate 0.4 --nIMF 700 >logs/onerun/$model_name'_'Etth2_$seq_len'_'720'_'0.9'_'$percentage'_'StAug.log
scripts/ili.sh ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ if [ ! -d "./logs" ]; then
3
+ mkdir ./logs
4
+ fi
5
+
6
+ if [ ! -d "./logs/onerun" ]; then
7
+ mkdir ./logs/onerun
8
+ fi
9
+
10
+ seq_len=36
11
+ percentage=100
12
+ model_name=DLinear
13
+
14
+ # aug 0: None
15
+ # aug 1: Frequency Masking
16
+ # aug 2: Frequency Mixing
17
+ # aug 3: Wave Masking
18
+ # aug 4: Wave Mixing
19
+ # aug 5: STAug
20
+
21
+ pred_lens=(24 36 48 60)
22
+
23
+
24
+
25
+ # For Aug 0: None
26
+
27
+ for pred_len in "${pred_lens[@]}"; do
28
+ python3 -u ./run_main.py \
29
+ --is_training 1 \
30
+ --root_path ./dataset/ \
31
+ --data_path national_illness.csv \
32
+ --model $model_name \
33
+ --data custom \
34
+ --features M \
35
+ --seq_len $seq_len \
36
+ --pred_len $pred_len \
37
+ --enc_in 7 \
38
+ --des '100p-ili-' \
39
+ --percentage $percentage \
40
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 0 --aug_rate 0.0 >logs/onerun/$model_name'_'ill_$seq_len'_'$pred_len'_'0.0'_'$percentage'_'None.log
41
+ done
42
+
43
+
44
+ # For Aug 1: Freq-Masking
45
+
46
+
47
+ python3 -u ./run_main.py \
48
+ --is_training 1 \
49
+ --root_path ./dataset/ \
50
+ --data_path national_illness.csv \
51
+ --model $model_name \
52
+ --data custom \
53
+ --features M \
54
+ --seq_len $seq_len \
55
+ --pred_len 24 \
56
+ --enc_in 7 \
57
+ --des '100p-ili-' \
58
+ --percentage $percentage \
59
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 1 --aug_rate 0.2 >logs/onerun/$model_name'_'ill_$seq_len'_'96'_'0.4'_'$percentage'_'FreqMask.log
60
+
61
+ python3 -u ./run_main.py \
62
+ --is_training 1 \
63
+ --root_path ./dataset/ \
64
+ --data_path national_illness.csv \
65
+ --model $model_name \
66
+ --data custom \
67
+ --features M \
68
+ --seq_len $seq_len \
69
+ --pred_len 36 \
70
+ --enc_in 7 \
71
+ --des '100p-ili-' \
72
+ --percentage $percentage \
73
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'ill_$seq_len'_'36'_'0.4'_'$percentage'_'FreqMask.log
74
+
75
+ python3 -u ./run_main.py \
76
+ --is_training 1 \
77
+ --root_path ./dataset/ \
78
+ --data_path national_illness.csv \
79
+ --model $model_name \
80
+ --data custom \
81
+ --features M \
82
+ --seq_len $seq_len \
83
+ --pred_len 48 \
84
+ --enc_in 7 \
85
+ --des '100p-ili-' \
86
+ --percentage $percentage \
87
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'ill_$seq_len'_'48'_'0.4'_'$percentage'_'FreqMask.log
88
+
89
+ python3 -u ./run_main.py \
90
+ --is_training 1 \
91
+ --root_path ./dataset/ \
92
+ --data_path national_illness.csv \
93
+ --model $model_name \
94
+ --data custom \
95
+ --features M \
96
+ --seq_len $seq_len \
97
+ --pred_len 60 \
98
+ --enc_in 7 \
99
+ --des '100p-ili-' \
100
+ --percentage $percentage \
101
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'ill_$seq_len'_'60'_'0.4'_'$percentage'_'FreqMask.log
102
+
103
+
104
+
105
+ # For Aug 2: Freq-Mixing
106
+
107
+ python3 -u ./run_main.py \
108
+ --is_training 1 \
109
+ --root_path ./dataset/ \
110
+ --data_path national_illness.csv \
111
+ --model $model_name \
112
+ --data custom \
113
+ --features M \
114
+ --seq_len $seq_len \
115
+ --pred_len 24 \
116
+ --enc_in 7 \
117
+ --des '100p-ili-' \
118
+ --percentage $percentage \
119
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'ill_$seq_len'_'24'_'0.4'_'$percentage'_'FreqMix.log
120
+
121
+ python3 -u ./run_main.py \
122
+ --is_training 1 \
123
+ --root_path ./dataset/ \
124
+ --data_path national_illness.csv \
125
+ --model $model_name \
126
+ --data custom \
127
+ --features M \
128
+ --seq_len $seq_len \
129
+ --pred_len 36 \
130
+ --enc_in 7 \
131
+ --des '100p-ili-' \
132
+ --percentage $percentage \
133
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'ill_$seq_len'_'36'_'0.4'_'$percentage'_'FreqMix.log
134
+
135
+
136
+ python3 -u ./run_main.py \
137
+ --is_training 1 \
138
+ --root_path ./dataset/ \
139
+ --data_path national_illness.csv \
140
+ --model $model_name \
141
+ --data custom \
142
+ --features M \
143
+ --seq_len $seq_len \
144
+ --pred_len 48 \
145
+ --enc_in 7 \
146
+ --des '100p-ili-' \
147
+ --percentage $percentage \
148
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'ill_$seq_len'_'48'_'0.4'_'$percentage'_'FreqMix.log
149
+
150
+
151
+ python3 -u ./run_main2.py \
152
+ --is_training 1 \
153
+ --root_path ./dataset/ \
154
+ --data_path national_illness.csv \
155
+ --model $model_name \
156
+ --data custom \
157
+ --features M \
158
+ --seq_len $seq_len \
159
+ --pred_len 60 \
160
+ --enc_in 7 \
161
+ --des '100p-ili-' \
162
+ --percentage $percentage \
163
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'ill_$seq_len'_'60'_'0.4'_'$percentage'_'FreqMix.log
164
+
165
+
166
+
167
+ # For Aug 3: Wave Masking
168
+
169
+
170
+ python3 -u ./run_main.py \
171
+ --is_training 1 \
172
+ --root_path ./dataset/ \
173
+ --data_path national_illness.csv \
174
+ --model $model_name \
175
+ --data custom \
176
+ --features M \
177
+ --seq_len $seq_len \
178
+ --pred_len 24 \
179
+ --enc_in 7 \
180
+ --des '100p-ili-' \
181
+ --percentage $percentage \
182
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 3 --rates "[0.4, 0.8, 0.9, 0.7, 0.9, 0.0, 0.5]" --wavelet 'db25' --level 1 --sampling_rate 0.2 >logs/onerun/$model_name'_'ill_$seq_len'_'24'_'0.0'_'$percentage'_'WaveMask.log
183
+
184
+ python3 -u ./run_main.py \
185
+ --is_training 1 \
186
+ --root_path ./dataset/ \
187
+ --data_path national_illness.csv \
188
+ --model $model_name \
189
+ --data custom \
190
+ --features M \
191
+ --seq_len $seq_len \
192
+ --pred_len 36 \
193
+ --enc_in 7 \
194
+ --des '100p-ili-' \
195
+ --percentage $percentage \
196
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 3 --rates "[0.6, 0.8, 0.3, 0.1, 0.9, 0.0, 0.5]" --wavelet 'db25' --level 1 --sampling_rate 0.2 >logs/onerun/$model_name'_'ill_$seq_len'_'36'_'0.0'_'$percentage'_'WaveMask.log
197
+
198
+ python3 -u ./run_main.py \
199
+ --is_training 1 \
200
+ --root_path ./dataset/ \
201
+ --data_path national_illness.csv \
202
+ --model $model_name \
203
+ --data custom \
204
+ --features M \
205
+ --seq_len $seq_len \
206
+ --pred_len 48 \
207
+ --enc_in 7 \
208
+ --des '100p-ili-' \
209
+ --percentage $percentage \
210
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 3 --rates "[0.2, 0.7, 1.0, 0.4, 0.4, 0.0, 0.5]" --wavelet 'db2' --level 1 --sampling_rate 0.2 >logs/onerun/$model_name'_'ill_$seq_len'_'48'_'0.0'_'$percentage'_'WaveMask.log
211
+
212
+ python3 -u ./run_main.py \
213
+ --is_training 1 \
214
+ --root_path ./dataset/ \
215
+ --data_path national_illness.csv \
216
+ --model $model_name \
217
+ --data custom \
218
+ --features M \
219
+ --seq_len $seq_len \
220
+ --pred_len 60 \
221
+ --enc_in 7 \
222
+ --des '100p-ili-' \
223
+ --percentage $percentage \
224
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 3 --rates "[0.2, 0.8, 0.5, 0.1, 0.9, 0.0, 0.5]" --wavelet 'db25' --level 1 --sampling_rate 0.2 >logs/onerun/$model_name'_'ill_$seq_len'_'60'_'0.0'_'$percentage'_'WaveMask.log
225
+
226
+ # For Aug 4: Wave Mixing
227
+
228
+ python3 -u ./run_main.py \
229
+ --is_training 1 \
230
+ --root_path ./dataset/ \
231
+ --data_path national_illness.csv \
232
+ --model $model_name \
233
+ --data custom \
234
+ --features M \
235
+ --seq_len $seq_len \
236
+ --pred_len 24 \
237
+ --enc_in 7 \
238
+ --des '100p-ili-' \
239
+ --percentage $percentage \
240
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 4 --rates "[0.1, 0.8, 1.0, 0.0, 0.5, 0.7, 0.1]" --wavelet 'db1' --level 1 --sampling_rate 0.2 >logs/onerun/$model_name'_'ill_$seq_len'_'24'_'0.0'_'$percentage'_'WaveMix.log
241
+
242
+ python3 -u ./run_main.py \
243
+ --is_training 1 \
244
+ --root_path ./dataset/ \
245
+ --data_path national_illness.csv \
246
+ --model $model_name \
247
+ --data custom \
248
+ --features M \
249
+ --seq_len $seq_len \
250
+ --pred_len 36 \
251
+ --enc_in 7 \
252
+ --des '100p-ili-' \
253
+ --percentage $percentage \
254
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 4 --rates "[0.1, 1.0, 0.9, 0.2, 0.1, 0.7, 0.1]" --wavelet 'db25' --level 1 --sampling_rate 0.8 >logs/onerun/$model_name'_'ill_$seq_len'_'36'_'0.0'_'$percentage'_'WaveMix.log
255
+
256
+ python3 -u ./run_main.py \
257
+ --is_training 1 \
258
+ --root_path ./dataset/ \
259
+ --data_path national_illness.csv \
260
+ --model $model_name \
261
+ --data custom \
262
+ --features M \
263
+ --seq_len $seq_len \
264
+ --pred_len 48 \
265
+ --enc_in 7 \
266
+ --des '100p-ili-' \
267
+ --percentage $percentage \
268
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 4 --rates "[0.1, 1.0, 0.4, 0.5, 0.1, 0.6, 0.1]" --wavelet 'db3' --level 1 --sampling_rate 1.0 >logs/onerun/$model_name'_'ill_$seq_len'_'48'_'0.0'_'$percentage'_'WaveMix.log
269
+
270
+ python3 -u ./run_main.py \
271
+ --is_training 1 \
272
+ --root_path ./dataset/ \
273
+ --data_path national_illness.csv \
274
+ --model $model_name \
275
+ --data custom \
276
+ --features M \
277
+ --seq_len $seq_len \
278
+ --pred_len 60 \
279
+ --enc_in 7 \
280
+ --des '100p-ili-' \
281
+ --percentage $percentage \
282
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 4 --rates "[0.1, 0.9, 0.3, 0.9, 0.5, 0.7, 0.1]" --wavelet 'db1' --level 1 --sampling_rate 0.5 >logs/onerun/$model_name'_'ill_$seq_len'_'60'_'0.0'_'$percentage'_'WaveMix.log
283
+
284
+
285
+ # For Aug 5: STAug
286
+
287
+ python3 -u ./run_main.py \
288
+ --is_training 1 \
289
+ --root_path ./dataset/ \
290
+ --data_path national_illness.csv \
291
+ --model $model_name \
292
+ --data custom \
293
+ --features M \
294
+ --seq_len $seq_len \
295
+ --pred_len 24 \
296
+ --enc_in 7 \
297
+ --des '100p-ili-' \
298
+ --percentage $percentage \
299
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 5 --aug_rate 0.7 --nIMF 200 >logs/onerun/$model_name'_'ill_$seq_len'_'24'_'0.9'_'$percentage'_'StAug.log
300
+
301
+ python3 -u ./run_main.py \
302
+ --is_training 1 \
303
+ --root_path ./dataset/ \
304
+ --data_path national_illness.csv \
305
+ --model $model_name \
306
+ --data custom \
307
+ --features M \
308
+ --seq_len $seq_len \
309
+ --pred_len 36 \
310
+ --enc_in 7 \
311
+ --des '100p-ili-' \
312
+ --percentage $percentage \
313
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 5 --aug_rate 0.3 --nIMF 300 >logs/onerun/$model_name'_'ill_$seq_len'_'36'_'0.9'_'$percentage'_'StAug.log
314
+
315
+ python3 -u ./run_main.py \
316
+ --is_training 1 \
317
+ --root_path ./dataset/ \
318
+ --data_path national_illness.csv \
319
+ --model $model_name \
320
+ --data custom \
321
+ --features M \
322
+ --seq_len $seq_len \
323
+ --pred_len 48 \
324
+ --enc_in 7 \
325
+ --des '100p-ili-' \
326
+ --percentage $percentage \
327
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 5 --aug_rate 0.9 --nIMF 300 >logs/onerun/$model_name'_'ill_$seq_len'_'48'_'0.9'_'$percentage'_'StAug.log
328
+
329
+ python3 -u ./run_main.py \
330
+ --is_training 1 \
331
+ --root_path ./dataset/ \
332
+ --data_path national_illness.csv \
333
+ --model $model_name \
334
+ --data custom \
335
+ --features M \
336
+ --seq_len $seq_len \
337
+ --pred_len 60 \
338
+ --enc_in 7 \
339
+ --des '100p-ili-' \
340
+ --percentage $percentage \
341
+ --itr 10 --batch_size 32 --learning_rate 0.01 --aug_type 5 --aug_rate 0.7 --nIMF 1000 >logs/onerun/$model_name'_'ill_$seq_len'_'60'_'0.9'_'$percentage'_'StAug.log
342
+
scripts/weather.sh ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ if [ ! -d "./logs" ]; then
3
+ mkdir ./logs
4
+ fi
5
+
6
+ if [ ! -d "./logs/onerun" ]; then
7
+ mkdir ./logs/onerun
8
+ fi
9
+
10
+ seq_len=336
11
+ percentage=100
12
+ model_name=DLinear
13
+
14
+ # aug 0: None
15
+ # aug 1: Frequency Masking
16
+ # aug 2: Frequency Mixing
17
+ # aug 3: Wave Masking
18
+ # aug 4: Wave Mixing
19
+ # aug 5: STAug
20
+
21
+ pred_lens=(96 192 336 720)
22
+
23
+ # For Aug 0: None
24
+
25
+ for pred_len in "${pred_lens[@]}"; do
26
+ python3 -u ./run_main.py \
27
+ --is_training 1 \
28
+ --root_path ./dataset/ \
29
+ --data_path weather.csv \
30
+ --model $model_name \
31
+ --data custom \
32
+ --features M \
33
+ --seq_len $seq_len \
34
+ --pred_len $pred_len \
35
+ --enc_in 21 \
36
+ --des '100p-whr-' \
37
+ --percentage $percentage \
38
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 0 --aug_rate 0.0 >logs/onerun/$model_name'_'weather_$seq_len'_'$pred_len'_'0.0'_'$percentage'_'None.log
39
+ done
40
+
41
+
42
+ # For Aug 1: Freq-Masking
43
+
44
+ python3 -u ./run_main.py \
45
+ --is_training 1 \
46
+ --root_path ./dataset/ \
47
+ --data_path weather.csv \
48
+ --model $model_name \
49
+ --data custom \
50
+ --features M \
51
+ --seq_len $seq_len \
52
+ --pred_len 96 \
53
+ --enc_in 21 \
54
+ --des '100p-whr-' \
55
+ --percentage $percentage \
56
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'weather_$seq_len'_'96'_'0.4'_'$percentage'_'FreqMask.log
57
+
58
+ python3 -u ./run_main.py \
59
+ --is_training 1 \
60
+ --root_path ./dataset/ \
61
+ --data_path weather.csv \
62
+ --model $model_name \
63
+ --data custom \
64
+ --features M \
65
+ --seq_len $seq_len \
66
+ --pred_len 192 \
67
+ --enc_in 21 \
68
+ --des '100p-whr-' \
69
+ --percentage $percentage \
70
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'weather_$seq_len'_'192'_'0.4'_'$percentage'_'FreqMask.log
71
+
72
+ python3 -u ./run_main.py \
73
+ --is_training 1 \
74
+ --root_path ./dataset/ \
75
+ --data_path weather.csv \
76
+ --model $model_name \
77
+ --data custom \
78
+ --features M \
79
+ --seq_len $seq_len \
80
+ --pred_len 336 \
81
+ --enc_in 21 \
82
+ --des '100p-whr-' \
83
+ --percentage $percentage \
84
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.1 >logs/onerun/$model_name'_'weather_$seq_len'_'336'_'0.4'_'$percentage'_'FreqMask.log
85
+
86
+ python3 -u ./run_main.py \
87
+ --is_training 1 \
88
+ --root_path ./dataset/ \
89
+ --data_path weather.csv \
90
+ --model $model_name \
91
+ --data custom \
92
+ --features M \
93
+ --seq_len $seq_len \
94
+ --pred_len 720 \
95
+ --enc_in 21 \
96
+ --des '100p-whr-' \
97
+ --percentage $percentage \
98
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 1 --aug_rate 0.9 >logs/onerun/$model_name'_'weather_$seq_len'_'720'_'0.4'_'$percentage'_'FreqMask.log
99
+
100
+
101
+
102
+ # For Aug 2: Freq-Mixing
103
+
104
+ python3 -u ./run_main.py \
105
+ --is_training 1 \
106
+ --root_path ./dataset/ \
107
+ --data_path weather.csv \
108
+ --model $model_name \
109
+ --data custom \
110
+ --features M \
111
+ --seq_len $seq_len \
112
+ --pred_len 96 \
113
+ --enc_in 21 \
114
+ --des '100p-whr-' \
115
+ --percentage $percentage \
116
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.9 >logs/onerun/$model_name'_'weather_$seq_len'_'96'_'0.4'_'$percentage'_'FreqMix.log
117
+
118
+
119
+ python3 -u ./run_main.py \
120
+ --is_training 1 \
121
+ --root_path ./dataset/ \
122
+ --data_path weather.csv \
123
+ --model $model_name \
124
+ --data custom \
125
+ --features M \
126
+ --seq_len $seq_len \
127
+ --pred_len 192 \
128
+ --enc_in 21 \
129
+ --des '100p-whr-' \
130
+ --percentage $percentage \
131
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'weather_$seq_len'_'192'_'0.4'_'$percentage'_'FreqMix.log
132
+
133
+ python3 -u ./run_main.py \
134
+ --is_training 1 \
135
+ --root_path ./dataset/ \
136
+ --data_path weather.csv \
137
+ --model $model_name \
138
+ --data custom \
139
+ --features M \
140
+ --seq_len $seq_len \
141
+ --pred_len 336 \
142
+ --enc_in 21 \
143
+ --des '100p-whr-' \
144
+ --percentage $percentage \
145
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.1 >logs/onerun/$model_name'_'weather_$seq_len'_'336'_'0.4'_'$percentage'_'FreqMix.log
146
+
147
+ python3 -u ./run_main.py \
148
+ --is_training 1 \
149
+ --root_path ./dataset/ \
150
+ --data_path weather.csv \
151
+ --model $model_name \
152
+ --data custom \
153
+ --features M \
154
+ --seq_len $seq_len \
155
+ --pred_len 720 \
156
+ --enc_in 21 \
157
+ --des '100p-whr-' \
158
+ --percentage $percentage \
159
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 2 --aug_rate 0.9 >logs/onerun/$model_name'_'weather_$seq_len'_'720'_'0.4'_'$percentage'_'FreqMix.log
160
+
161
+ # For Aug 3: Wave Masking
162
+
163
+ python3 -u ./run_main.py \
164
+ --is_training 1 \
165
+ --root_path ./dataset/ \
166
+ --data_path weather.csv \
167
+ --model $model_name \
168
+ --data custom \
169
+ --features M \
170
+ --seq_len $seq_len \
171
+ --pred_len 96 \
172
+ --enc_in 21 \
173
+ --des '100p-whr-' \
174
+ --percentage $percentage \
175
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.2, 1.0, 0.4, 0.4, 0.9, 0.0, 0.5]" --wavelet 'db2' --level 2 --sampling_rate 0.5 >logs/onerun/$model_name'_'weather_$seq_len'_'96'_'0.0'_'$percentage'_'WaveMask.log
176
+
177
+ python3 -u ./run_main.py \
178
+ --is_training 1 \
179
+ --root_path ./dataset/ \
180
+ --data_path weather.csv \
181
+ --model $model_name \
182
+ --data custom \
183
+ --features M \
184
+ --seq_len $seq_len \
185
+ --pred_len 192 \
186
+ --enc_in 21 \
187
+ --des '100p-whr-' \
188
+ --percentage $percentage \
189
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[0.1, 0.7, 0.1, 0.4, 0.5, 0.1, 0.7]" --wavelet 'db2' --level 1 --sampling_rate 0.5 >logs/onerun/$model_name'_'weather_$seq_len'_'192'_'0.0'_'$percentage'_'WaveMask.log
190
+
191
+ python3 -u ./run_main.py \
192
+ --is_training 1 \
193
+ --root_path ./dataset/ \
194
+ --data_path weather.csv \
195
+ --model $model_name \
196
+ --data custom \
197
+ --features M \
198
+ --seq_len $seq_len \
199
+ --pred_len 336 \
200
+ --enc_in 21 \
201
+ --des '100p-whr-' \
202
+ --percentage $percentage \
203
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[1.0, 1.0, 0.0, 0.0, 0.3, 0.4, 0.2]" --wavelet 'db1' --level 1 --sampling_rate 1.0 >logs/onerun/$model_name'_'weather_$seq_len'_'336'_'0.0'_'$percentage'_'WaveMask.log
204
+
205
+ python3 -u ./run_main.py \
206
+ --is_training 1 \
207
+ --root_path ./dataset/ \
208
+ --data_path weather.csv \
209
+ --model $model_name \
210
+ --data custom \
211
+ --features M \
212
+ --seq_len $seq_len \
213
+ --pred_len 720 \
214
+ --enc_in 21 \
215
+ --des '100p-whr-' \
216
+ --percentage $percentage \
217
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 3 --rates "[1.0, 0.8, 0.6, 0.0, 0.2, 0.6, 0.5]" --wavelet 'db2' --level 1 --sampling_rate 0.5 >logs/onerun/$model_name'_'weather_$seq_len'_'720'_'0.0'_'$percentage'_'WaveMask.log
218
+
219
+ # For Aug 4: Wave Mixing
220
+
221
+ python3 -u ./run_main.py \
222
+ --is_training 1 \
223
+ --root_path ./dataset/ \
224
+ --data_path weather.csv \
225
+ --model $model_name \
226
+ --data custom \
227
+ --features M \
228
+ --seq_len $seq_len \
229
+ --pred_len 96 \
230
+ --enc_in 21 \
231
+ --des '100p-whr-' \
232
+ --percentage $percentage \
233
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.1, 0.5, 0.1, 0.2, 0.4, 0.7, 0.1]" --wavelet 'db3' --level 1 --sampling_rate 1.0 >logs/onerun/$model_name'_'weather_$seq_len'_'96'_'0.0'_'$percentage'_'WaveMix.log
234
+
235
+ python3 -u ./run_main.py \
236
+ --is_training 1 \
237
+ --root_path ./dataset/ \
238
+ --data_path weather.csv \
239
+ --model $model_name \
240
+ --data custom \
241
+ --features M \
242
+ --seq_len $seq_len \
243
+ --pred_len 192 \
244
+ --enc_in 21 \
245
+ --des '100p-whr-' \
246
+ --percentage $percentage \
247
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.2, 0.7, 1.0, 0.3, 0.2, 0.3, 0.1]" --wavelet 'db3' --level 1 --sampling_rate 1.0 >logs/onerun/$model_name'_'weather_$seq_len'_'192'_'0.0'_'$percentage'_'WaveMix.log
248
+
249
+ python3 -u ./run_main.py \
250
+ --is_training 1 \
251
+ --root_path ./dataset/ \
252
+ --data_path weather.csv \
253
+ --model $model_name \
254
+ --data custom \
255
+ --features M \
256
+ --seq_len $seq_len \
257
+ --pred_len 336 \
258
+ --enc_in 21 \
259
+ --des '100p-whr-' \
260
+ --percentage $percentage \
261
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.8, 0.6, 0.8, 0.6, 0.1, 1.0, 0.3]" --wavelet 'db2' --level 1 --sampling_rate 1.0 >logs/onerun/$model_name'_'weather_$seq_len'_'336'_'0.0'_'$percentage'_'WaveMix.log
262
+
263
+ python3 -u ./run_main.py \
264
+ --is_training 1 \
265
+ --root_path ./dataset/ \
266
+ --data_path weather.csv \
267
+ --model $model_name \
268
+ --data custom \
269
+ --features M \
270
+ --seq_len $seq_len \
271
+ --pred_len 720 \
272
+ --enc_in 21 \
273
+ --des '100p-whr-' \
274
+ --percentage $percentage \
275
+ --itr 10 --batch_size 64 --learning_rate 0.01 --aug_type 4 --rates "[0.1, 0.1, 0.7, 0.5, 0.6, 0.5, 0.1]" --wavelet 'db1' --level 1 --sampling_rate 1.0 >logs/onerun/$model_name'_'weather_$seq_len'_'720'_'0.0'_'$percentage'_'WaveMix.log
276
+
277
+
278
+
utils/metrics.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # The code is originated from
3
+ # Chen, M., Xu, Z., Zeng, A., & Xu, Q. (2023). "FrAug: Frequency Domain Augmentation for Time Series Forecasting".
4
+ # arXiv preprint arXiv:2302.09292.
5
+ # =============================================================================
6
+
7
+ import numpy as np
8
+
9
+ def RSE(pred, true):
10
+ return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2))
11
+
12
+
13
+ def CORR(pred, true):
14
+ u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0)
15
+ d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0))
16
+ d += 1e-12
17
+ return 0.01*(u / d).mean(-1)
18
+
19
+
20
+ def MAE(pred, true):
21
+ return np.mean(np.abs(pred - true))
22
+
23
+
24
+ def MSE(pred, true):
25
+ return np.mean((pred - true) ** 2)
26
+
27
+
28
+ def RMSE(pred, true):
29
+ return np.sqrt(MSE(pred, true))
30
+
31
+
32
+ def MAPE(pred, true):
33
+ return np.mean(np.abs((pred - true) / true))
34
+
35
+
36
+ def MSPE(pred, true):
37
+ return np.mean(np.square((pred - true) / true))
38
+
39
+
40
+ def metric(pred, true):
41
+ mae = MAE(pred, true)
42
+ mse = MSE(pred, true)
43
+ rmse = RMSE(pred, true)
44
+ mape = MAPE(pred, true)
45
+ mspe = MSPE(pred, true)
46
+ rse = RSE(pred, true)
47
+ corr = CORR(pred, true)
48
+
49
+ return mae, mse, rmse, mape, mspe, rse, corr
utils/tools.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # The code is originated from
3
+ # Chen, M., Xu, Z., Zeng, A., & Xu, Q. (2023). "FrAug: Frequency Domain Augmentation for Time Series Forecasting".
4
+ # arXiv preprint arXiv:2302.09292.
5
+ # =============================================================================
6
+
7
+ import torch
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+
11
+ plt.switch_backend('agg')
12
+
13
+ def adjust_learning_rate(optimizer, epoch, args):
14
+ # lr = args.learning_rate * (0.2 ** (epoch // 2))
15
+ if args.lradj == 'type1':
16
+ lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))}
17
+ elif args.lradj == 'type2':
18
+ lr_adjust = {
19
+ 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,
20
+ 10: 5e-7, 15: 1e-7, 20: 5e-8
21
+ }
22
+ elif args.lradj == '3':
23
+ lr_adjust = {epoch: args.learning_rate if epoch < 10 else args.learning_rate*0.1}
24
+ elif args.lradj == '4':
25
+ lr_adjust = {epoch: args.learning_rate if epoch < 15 else args.learning_rate*0.1}
26
+ elif args.lradj == '5':
27
+ lr_adjust = {epoch: args.learning_rate if epoch < 25 else args.learning_rate*0.1}
28
+ elif args.lradj == '6':
29
+ lr_adjust = {epoch: args.learning_rate if epoch < 5 else args.learning_rate*0.1}
30
+ if epoch in lr_adjust.keys():
31
+ lr = lr_adjust[epoch]
32
+ for param_group in optimizer.param_groups:
33
+ param_group['lr'] = lr
34
+ print('Updating learning rate to {}'.format(lr))
35
+
36
+
37
+ class EarlyStopping:
38
+ def __init__(self, patience=7, verbose=False, delta=0):
39
+ self.patience = patience
40
+ self.verbose = verbose
41
+ self.counter = 0
42
+ self.best_score = None
43
+ self.early_stop = False
44
+ self.val_loss_min = np.Inf
45
+ self.delta = delta
46
+
47
+ def __call__(self, val_loss, model, path):
48
+ score = -val_loss
49
+ if self.best_score is None:
50
+ self.best_score = score
51
+ self.save_checkpoint(val_loss, model, path)
52
+ elif score < self.best_score + self.delta:
53
+ self.counter += 1
54
+ print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
55
+ if self.counter >= self.patience:
56
+ self.early_stop = True
57
+ else:
58
+ self.best_score = score
59
+ self.save_checkpoint(val_loss, model, path)
60
+ self.counter = 0
61
+
62
+ def save_checkpoint(self, val_loss, model, path):
63
+ if self.verbose:
64
+ print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
65
+ torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')
66
+ self.val_loss_min = val_loss
67
+
68
+ def get_val_loss_min(self):
69
+ return self.val_loss_min
70
+
71
+ def visual(true, preds=None, name='./pic/test.pdf'):
72
+ """
73
+ Results visualization
74
+ """
75
+ plt.figure()
76
+ plt.plot(true, label='GroundTruth', linewidth=2)
77
+ if preds is not None:
78
+ plt.plot(preds, label='Prediction', linewidth=2)
79
+ plt.legend()
80
+ plt.savefig(name, bbox_inches='tight')
81
+
82
+ def test_params_flop(model,x_shape):
83
+ """
84
+ If you want to thest former's flop, you need to give default value to inputs in model.forward(), the following code can only pass one argument to forward()
85
+ """
86
+ model_params = 0
87
+ for parameter in model.parameters():
88
+ model_params += parameter.numel()
89
+ print('INFO: Trainable parameter count: {:.2f}M'.format(model_params / 1000000.0))
90
+ from ptflops import get_model_complexity_info
91
+ with torch.cuda.device(0):
92
+ macs, params = get_model_complexity_info(model.cuda(), x_shape, as_strings=True, print_per_layer_stat=True)
93
+ # print('Flops:' + flops)
94
+ # print('Params:' + params)
95
+ print('{:<30} {:<8}'.format('Computational complexity: ', macs))
96
+ print('{:<30} {:<8}'.format('Number of parameters: ', params))
97
+
98
+
99
+
100
+
101
+
102
+