|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from layers.Autoformer_EncDec import series_decomp |
|
|
from layers.Embed import DataEmbedding_wo_pos |
|
|
from layers.StandardNorm import Normalize |
|
|
|
|
|
|
|
|
class DFT_series_decomp(nn.Module): |
|
|
""" |
|
|
Series decomposition block |
|
|
""" |
|
|
|
|
|
def __init__(self, top_k=5): |
|
|
super(DFT_series_decomp, self).__init__() |
|
|
self.top_k = top_k |
|
|
|
|
|
def forward(self, x): |
|
|
xf = torch.fft.rfft(x) |
|
|
freq = abs(xf) |
|
|
freq[0] = 0 |
|
|
top_k_freq, top_list = torch.topk(freq, 5) |
|
|
xf[freq <= top_k_freq.min()] = 0 |
|
|
x_season = torch.fft.irfft(xf) |
|
|
x_trend = x - x_season |
|
|
return x_season, x_trend |
|
|
|
|
|
|
|
|
class MultiScaleSeasonMixing(nn.Module): |
|
|
""" |
|
|
Bottom-up mixing season pattern |
|
|
""" |
|
|
|
|
|
def __init__(self, configs): |
|
|
super(MultiScaleSeasonMixing, self).__init__() |
|
|
|
|
|
self.down_sampling_layers = torch.nn.ModuleList( |
|
|
[ |
|
|
nn.Sequential( |
|
|
torch.nn.Linear( |
|
|
configs.seq_len // (configs.down_sampling_window ** i), |
|
|
configs.seq_len // (configs.down_sampling_window ** (i + 1)), |
|
|
), |
|
|
nn.GELU(), |
|
|
torch.nn.Linear( |
|
|
configs.seq_len // (configs.down_sampling_window ** (i + 1)), |
|
|
configs.seq_len // (configs.down_sampling_window ** (i + 1)), |
|
|
), |
|
|
|
|
|
) |
|
|
for i in range(configs.down_sampling_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
def forward(self, season_list): |
|
|
|
|
|
|
|
|
out_high = season_list[0] |
|
|
out_low = season_list[1] |
|
|
out_season_list = [out_high.permute(0, 2, 1)] |
|
|
|
|
|
for i in range(len(season_list) - 1): |
|
|
out_low_res = self.down_sampling_layers[i](out_high) |
|
|
out_low = out_low + out_low_res |
|
|
out_high = out_low |
|
|
if i + 2 <= len(season_list) - 1: |
|
|
out_low = season_list[i + 2] |
|
|
out_season_list.append(out_high.permute(0, 2, 1)) |
|
|
|
|
|
return out_season_list |
|
|
|
|
|
|
|
|
class MultiScaleTrendMixing(nn.Module): |
|
|
""" |
|
|
Top-down mixing trend pattern |
|
|
""" |
|
|
|
|
|
def __init__(self, configs): |
|
|
super(MultiScaleTrendMixing, self).__init__() |
|
|
|
|
|
self.up_sampling_layers = torch.nn.ModuleList( |
|
|
[ |
|
|
nn.Sequential( |
|
|
torch.nn.Linear( |
|
|
configs.seq_len // (configs.down_sampling_window ** (i + 1)), |
|
|
configs.seq_len // (configs.down_sampling_window ** i), |
|
|
), |
|
|
nn.GELU(), |
|
|
torch.nn.Linear( |
|
|
configs.seq_len // (configs.down_sampling_window ** i), |
|
|
configs.seq_len // (configs.down_sampling_window ** i), |
|
|
), |
|
|
) |
|
|
for i in reversed(range(configs.down_sampling_layers)) |
|
|
]) |
|
|
|
|
|
def forward(self, trend_list): |
|
|
|
|
|
|
|
|
trend_list_reverse = trend_list.copy() |
|
|
trend_list_reverse.reverse() |
|
|
out_low = trend_list_reverse[0] |
|
|
out_high = trend_list_reverse[1] |
|
|
out_trend_list = [out_low.permute(0, 2, 1)] |
|
|
|
|
|
for i in range(len(trend_list_reverse) - 1): |
|
|
out_high_res = self.up_sampling_layers[i](out_low) |
|
|
out_high = out_high + out_high_res |
|
|
out_low = out_high |
|
|
if i + 2 <= len(trend_list_reverse) - 1: |
|
|
out_high = trend_list_reverse[i + 2] |
|
|
out_trend_list.append(out_low.permute(0, 2, 1)) |
|
|
|
|
|
out_trend_list.reverse() |
|
|
return out_trend_list |
|
|
|
|
|
|
|
|
class PastDecomposableMixing(nn.Module): |
|
|
def __init__(self, configs): |
|
|
super(PastDecomposableMixing, self).__init__() |
|
|
self.seq_len = configs.seq_len |
|
|
self.pred_len = configs.pred_len |
|
|
self.down_sampling_window = configs.down_sampling_window |
|
|
|
|
|
self.layer_norm = nn.LayerNorm(configs.d_model) |
|
|
self.dropout = nn.Dropout(configs.dropout) |
|
|
self.channel_independence = configs.channel_independence |
|
|
|
|
|
if configs.decomp_method == 'moving_avg': |
|
|
self.decompsition = series_decomp(configs.moving_avg) |
|
|
elif configs.decomp_method == "dft_decomp": |
|
|
self.decompsition = DFT_series_decomp(configs.top_k) |
|
|
else: |
|
|
raise ValueError('decompsition is error') |
|
|
|
|
|
if not configs.channel_independence: |
|
|
self.cross_layer = nn.Sequential( |
|
|
nn.Linear(in_features=configs.d_model, out_features=configs.d_ff), |
|
|
nn.GELU(), |
|
|
nn.Linear(in_features=configs.d_ff, out_features=configs.d_model), |
|
|
) |
|
|
|
|
|
|
|
|
self.mixing_multi_scale_season = MultiScaleSeasonMixing(configs) |
|
|
|
|
|
|
|
|
self.mixing_multi_scale_trend = MultiScaleTrendMixing(configs) |
|
|
|
|
|
self.out_cross_layer = nn.Sequential( |
|
|
nn.Linear(in_features=configs.d_model, out_features=configs.d_ff), |
|
|
nn.GELU(), |
|
|
nn.Linear(in_features=configs.d_ff, out_features=configs.d_model), |
|
|
) |
|
|
|
|
|
def forward(self, x_list): |
|
|
length_list = [] |
|
|
for x in x_list: |
|
|
_, T, _ = x.size() |
|
|
length_list.append(T) |
|
|
|
|
|
|
|
|
season_list = [] |
|
|
trend_list = [] |
|
|
for x in x_list: |
|
|
season, trend = self.decompsition(x) |
|
|
if not self.channel_independence: |
|
|
season = self.cross_layer(season) |
|
|
trend = self.cross_layer(trend) |
|
|
season_list.append(season.permute(0, 2, 1)) |
|
|
trend_list.append(trend.permute(0, 2, 1)) |
|
|
|
|
|
|
|
|
out_season_list = self.mixing_multi_scale_season(season_list) |
|
|
|
|
|
out_trend_list = self.mixing_multi_scale_trend(trend_list) |
|
|
|
|
|
out_list = [] |
|
|
for ori, out_season, out_trend, length in zip(x_list, out_season_list, out_trend_list, |
|
|
length_list): |
|
|
out = out_season + out_trend |
|
|
if self.channel_independence: |
|
|
out = ori + self.out_cross_layer(out) |
|
|
out_list.append(out[:, :length, :]) |
|
|
return out_list |
|
|
|
|
|
|
|
|
class Model(nn.Module): |
|
|
|
|
|
def __init__(self, configs): |
|
|
super(Model, self).__init__() |
|
|
self.configs = configs |
|
|
self.task_name = configs.task_name |
|
|
self.seq_len = configs.seq_len |
|
|
self.label_len = configs.label_len |
|
|
self.pred_len = configs.pred_len |
|
|
self.down_sampling_window = configs.down_sampling_window |
|
|
self.channel_independence = configs.channel_independence |
|
|
self.pdm_blocks = nn.ModuleList([PastDecomposableMixing(configs) |
|
|
for _ in range(configs.e_layers)]) |
|
|
|
|
|
self.preprocess = series_decomp(configs.moving_avg) |
|
|
self.enc_in = configs.enc_in |
|
|
|
|
|
if self.channel_independence: |
|
|
self.enc_embedding = DataEmbedding_wo_pos(1, configs.d_model, configs.embed, configs.freq, |
|
|
configs.dropout) |
|
|
else: |
|
|
self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, |
|
|
configs.dropout) |
|
|
|
|
|
self.layer = configs.e_layers |
|
|
|
|
|
self.normalize_layers = torch.nn.ModuleList( |
|
|
[ |
|
|
Normalize(self.configs.enc_in, affine=True, non_norm=True if configs.use_norm == 0 else False) |
|
|
for i in range(configs.down_sampling_layers + 1) |
|
|
] |
|
|
) |
|
|
|
|
|
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': |
|
|
self.predict_layers = torch.nn.ModuleList( |
|
|
[ |
|
|
torch.nn.Linear( |
|
|
configs.seq_len // (configs.down_sampling_window ** i), |
|
|
configs.pred_len, |
|
|
) |
|
|
for i in range(configs.down_sampling_layers + 1) |
|
|
] |
|
|
) |
|
|
|
|
|
if self.channel_independence: |
|
|
self.projection_layer = nn.Linear( |
|
|
configs.d_model, 1, bias=True) |
|
|
else: |
|
|
self.projection_layer = nn.Linear( |
|
|
configs.d_model, configs.c_out, bias=True) |
|
|
|
|
|
self.out_res_layers = torch.nn.ModuleList([ |
|
|
torch.nn.Linear( |
|
|
configs.seq_len // (configs.down_sampling_window ** i), |
|
|
configs.seq_len // (configs.down_sampling_window ** i), |
|
|
) |
|
|
for i in range(configs.down_sampling_layers + 1) |
|
|
]) |
|
|
|
|
|
self.regression_layers = torch.nn.ModuleList( |
|
|
[ |
|
|
torch.nn.Linear( |
|
|
configs.seq_len // (configs.down_sampling_window ** i), |
|
|
configs.pred_len, |
|
|
) |
|
|
for i in range(configs.down_sampling_layers + 1) |
|
|
] |
|
|
) |
|
|
|
|
|
if self.task_name == 'imputation' or self.task_name == 'anomaly_detection': |
|
|
if self.channel_independence: |
|
|
self.projection_layer = nn.Linear( |
|
|
configs.d_model, 1, bias=True) |
|
|
else: |
|
|
self.projection_layer = nn.Linear( |
|
|
configs.d_model, configs.c_out, bias=True) |
|
|
if self.task_name == 'classification': |
|
|
self.act = F.gelu |
|
|
self.dropout = nn.Dropout(configs.dropout) |
|
|
self.projection = nn.Linear( |
|
|
configs.d_model * configs.seq_len, configs.num_class) |
|
|
|
|
|
def out_projection(self, dec_out, i, out_res): |
|
|
dec_out = self.projection_layer(dec_out) |
|
|
out_res = out_res.permute(0, 2, 1) |
|
|
out_res = self.out_res_layers[i](out_res) |
|
|
out_res = self.regression_layers[i](out_res).permute(0, 2, 1) |
|
|
dec_out = dec_out + out_res |
|
|
return dec_out |
|
|
|
|
|
def pre_enc(self, x_list): |
|
|
if self.channel_independence: |
|
|
return (x_list, None) |
|
|
else: |
|
|
out1_list = [] |
|
|
out2_list = [] |
|
|
for x in x_list: |
|
|
x_1, x_2 = self.preprocess(x) |
|
|
out1_list.append(x_1) |
|
|
out2_list.append(x_2) |
|
|
return (out1_list, out2_list) |
|
|
|
|
|
def __multi_scale_process_inputs(self, x_enc, x_mark_enc): |
|
|
if self.configs.down_sampling_method == 'max': |
|
|
down_pool = torch.nn.MaxPool1d(self.configs.down_sampling_window, return_indices=False) |
|
|
elif self.configs.down_sampling_method == 'avg': |
|
|
down_pool = torch.nn.AvgPool1d(self.configs.down_sampling_window) |
|
|
elif self.configs.down_sampling_method == 'conv': |
|
|
padding = 1 if torch.__version__ >= '1.5.0' else 2 |
|
|
down_pool = nn.Conv1d(in_channels=self.configs.enc_in, out_channels=self.configs.enc_in, |
|
|
kernel_size=3, padding=padding, |
|
|
stride=self.configs.down_sampling_window, |
|
|
padding_mode='circular', |
|
|
bias=False) |
|
|
else: |
|
|
return x_enc, x_mark_enc |
|
|
|
|
|
x_enc = x_enc.permute(0, 2, 1) |
|
|
|
|
|
x_enc_ori = x_enc |
|
|
x_mark_enc_mark_ori = x_mark_enc |
|
|
|
|
|
x_enc_sampling_list = [] |
|
|
x_mark_sampling_list = [] |
|
|
x_enc_sampling_list.append(x_enc.permute(0, 2, 1)) |
|
|
x_mark_sampling_list.append(x_mark_enc) |
|
|
|
|
|
for i in range(self.configs.down_sampling_layers): |
|
|
x_enc_sampling = down_pool(x_enc_ori) |
|
|
|
|
|
x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1)) |
|
|
x_enc_ori = x_enc_sampling |
|
|
|
|
|
if x_mark_enc is not None: |
|
|
x_mark_sampling_list.append(x_mark_enc_mark_ori[:, ::self.configs.down_sampling_window, :]) |
|
|
x_mark_enc_mark_ori = x_mark_enc_mark_ori[:, ::self.configs.down_sampling_window, :] |
|
|
|
|
|
x_enc = x_enc_sampling_list |
|
|
x_mark_enc = x_mark_sampling_list if x_mark_enc is not None else None |
|
|
|
|
|
return x_enc, x_mark_enc |
|
|
|
|
|
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec): |
|
|
|
|
|
x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc) |
|
|
|
|
|
x_list = [] |
|
|
x_mark_list = [] |
|
|
if x_mark_enc is not None: |
|
|
for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc): |
|
|
B, T, N = x.size() |
|
|
x = self.normalize_layers[i](x, 'norm') |
|
|
if self.channel_independence: |
|
|
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) |
|
|
x_list.append(x) |
|
|
x_mark = x_mark.repeat(N, 1, 1) |
|
|
x_mark_list.append(x_mark) |
|
|
else: |
|
|
x_list.append(x) |
|
|
x_mark_list.append(x_mark) |
|
|
else: |
|
|
for i, x in zip(range(len(x_enc)), x_enc, ): |
|
|
B, T, N = x.size() |
|
|
x = self.normalize_layers[i](x, 'norm') |
|
|
if self.channel_independence: |
|
|
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) |
|
|
x_list.append(x) |
|
|
|
|
|
|
|
|
enc_out_list = [] |
|
|
x_list = self.pre_enc(x_list) |
|
|
if x_mark_enc is not None: |
|
|
for i, x, x_mark in zip(range(len(x_list[0])), x_list[0], x_mark_list): |
|
|
enc_out = self.enc_embedding(x, x_mark) |
|
|
enc_out_list.append(enc_out) |
|
|
else: |
|
|
for i, x in zip(range(len(x_list[0])), x_list[0]): |
|
|
enc_out = self.enc_embedding(x, None) |
|
|
enc_out_list.append(enc_out) |
|
|
|
|
|
|
|
|
for i in range(self.layer): |
|
|
enc_out_list = self.pdm_blocks[i](enc_out_list) |
|
|
|
|
|
|
|
|
dec_out_list = self.future_multi_mixing(B, enc_out_list, x_list) |
|
|
|
|
|
dec_out = torch.stack(dec_out_list, dim=-1).sum(-1) |
|
|
dec_out = self.normalize_layers[0](dec_out, 'denorm') |
|
|
return dec_out |
|
|
|
|
|
def future_multi_mixing(self, B, enc_out_list, x_list): |
|
|
dec_out_list = [] |
|
|
if self.channel_independence: |
|
|
x_list = x_list[0] |
|
|
for i, enc_out in zip(range(len(x_list)), enc_out_list): |
|
|
dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute( |
|
|
0, 2, 1) |
|
|
dec_out = self.projection_layer(dec_out) |
|
|
dec_out = dec_out.reshape(B, self.configs.c_out, self.pred_len).permute(0, 2, 1).contiguous() |
|
|
dec_out_list.append(dec_out) |
|
|
|
|
|
else: |
|
|
for i, enc_out, out_res in zip(range(len(x_list[0])), enc_out_list, x_list[1]): |
|
|
dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute( |
|
|
0, 2, 1) |
|
|
dec_out = self.out_projection(dec_out, i, out_res) |
|
|
dec_out_list.append(dec_out) |
|
|
|
|
|
return dec_out_list |
|
|
|
|
|
def classification(self, x_enc, x_mark_enc): |
|
|
x_enc, _ = self.__multi_scale_process_inputs(x_enc, None) |
|
|
x_list = x_enc |
|
|
|
|
|
|
|
|
enc_out_list = [] |
|
|
for x in x_list: |
|
|
enc_out = self.enc_embedding(x, None) |
|
|
enc_out_list.append(enc_out) |
|
|
|
|
|
|
|
|
for i in range(self.layer): |
|
|
enc_out_list = self.pdm_blocks[i](enc_out_list) |
|
|
|
|
|
enc_out = enc_out_list[0] |
|
|
|
|
|
|
|
|
output = self.act(enc_out) |
|
|
output = self.dropout(output) |
|
|
|
|
|
output = output * x_mark_enc.unsqueeze(-1) |
|
|
|
|
|
output = output.reshape(output.shape[0], -1) |
|
|
output = self.projection(output) |
|
|
return output |
|
|
|
|
|
def anomaly_detection(self, x_enc): |
|
|
B, T, N = x_enc.size() |
|
|
x_enc, _ = self.__multi_scale_process_inputs(x_enc, None) |
|
|
|
|
|
x_list = [] |
|
|
|
|
|
for i, x in zip(range(len(x_enc)), x_enc, ): |
|
|
B, T, N = x.size() |
|
|
x = self.normalize_layers[i](x, 'norm') |
|
|
if self.channel_independence: |
|
|
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) |
|
|
x_list.append(x) |
|
|
|
|
|
|
|
|
enc_out_list = [] |
|
|
for x in x_list: |
|
|
enc_out = self.enc_embedding(x, None) |
|
|
enc_out_list.append(enc_out) |
|
|
|
|
|
|
|
|
for i in range(self.layer): |
|
|
enc_out_list = self.pdm_blocks[i](enc_out_list) |
|
|
|
|
|
dec_out = self.projection_layer(enc_out_list[0]) |
|
|
dec_out = dec_out.reshape(B, self.configs.c_out, -1).permute(0, 2, 1).contiguous() |
|
|
|
|
|
dec_out = self.normalize_layers[0](dec_out, 'denorm') |
|
|
return dec_out |
|
|
|
|
|
def imputation(self, x_enc, x_mark_enc, mask): |
|
|
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1) |
|
|
means = means.unsqueeze(1).detach() |
|
|
x_enc = x_enc - means |
|
|
x_enc = x_enc.masked_fill(mask == 0, 0) |
|
|
stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) / |
|
|
torch.sum(mask == 1, dim=1) + 1e-5) |
|
|
stdev = stdev.unsqueeze(1).detach() |
|
|
x_enc /= stdev |
|
|
|
|
|
B, T, N = x_enc.size() |
|
|
x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc) |
|
|
|
|
|
x_list = [] |
|
|
x_mark_list = [] |
|
|
if x_mark_enc is not None: |
|
|
for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc): |
|
|
B, T, N = x.size() |
|
|
if self.channel_independence: |
|
|
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) |
|
|
x_list.append(x) |
|
|
x_mark = x_mark.repeat(N, 1, 1) |
|
|
x_mark_list.append(x_mark) |
|
|
else: |
|
|
for i, x in zip(range(len(x_enc)), x_enc, ): |
|
|
B, T, N = x.size() |
|
|
if self.channel_independence: |
|
|
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1) |
|
|
x_list.append(x) |
|
|
|
|
|
|
|
|
enc_out_list = [] |
|
|
for x in x_list: |
|
|
enc_out = self.enc_embedding(x, None) |
|
|
enc_out_list.append(enc_out) |
|
|
|
|
|
|
|
|
for i in range(self.layer): |
|
|
enc_out_list = self.pdm_blocks[i](enc_out_list) |
|
|
|
|
|
dec_out = self.projection_layer(enc_out_list[0]) |
|
|
dec_out = dec_out.reshape(B, self.configs.c_out, -1).permute(0, 2, 1).contiguous() |
|
|
|
|
|
dec_out = dec_out * \ |
|
|
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1)) |
|
|
dec_out = dec_out + \ |
|
|
(means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1)) |
|
|
return dec_out |
|
|
|
|
|
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): |
|
|
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast': |
|
|
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec) |
|
|
return dec_out |
|
|
if self.task_name == 'imputation': |
|
|
dec_out = self.imputation(x_enc, x_mark_enc, mask) |
|
|
return dec_out |
|
|
if self.task_name == 'anomaly_detection': |
|
|
dec_out = self.anomaly_detection(x_enc) |
|
|
return dec_out |
|
|
if self.task_name == 'classification': |
|
|
dec_out = self.classification(x_enc, x_mark_enc) |
|
|
return dec_out |
|
|
else: |
|
|
raise ValueError('Other tasks implemented yet') |
|
|
|