|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from data_provider.data_factory import data_provider |
|
|
|
|
|
|
|
|
|
|
|
class FourierFilter(nn.Module): |
|
|
""" |
|
|
Fourier Filter: to time-variant and time-invariant term |
|
|
""" |
|
|
def __init__(self, mask_spectrum): |
|
|
super(FourierFilter, self).__init__() |
|
|
self.mask_spectrum = mask_spectrum |
|
|
|
|
|
def forward(self, x): |
|
|
xf = torch.fft.rfft(x, dim=1) |
|
|
mask = torch.ones_like(xf) |
|
|
mask[:, self.mask_spectrum, :] = 0 |
|
|
x_var = torch.fft.irfft(xf*mask, dim=1) |
|
|
x_inv = x - x_var |
|
|
|
|
|
return x_var, x_inv |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
''' |
|
|
Multilayer perceptron to encode/decode high dimension representation of sequential data |
|
|
''' |
|
|
def __init__(self, |
|
|
f_in, |
|
|
f_out, |
|
|
hidden_dim=128, |
|
|
hidden_layers=2, |
|
|
dropout=0.05, |
|
|
activation='tanh'): |
|
|
super(MLP, self).__init__() |
|
|
self.f_in = f_in |
|
|
self.f_out = f_out |
|
|
self.hidden_dim = hidden_dim |
|
|
self.hidden_layers = hidden_layers |
|
|
self.dropout = dropout |
|
|
if activation == 'relu': |
|
|
self.activation = nn.ReLU() |
|
|
elif activation == 'tanh': |
|
|
self.activation = nn.Tanh() |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
layers = [nn.Linear(self.f_in, self.hidden_dim), |
|
|
self.activation, nn.Dropout(self.dropout)] |
|
|
for i in range(self.hidden_layers-2): |
|
|
layers += [nn.Linear(self.hidden_dim, self.hidden_dim), |
|
|
self.activation, nn.Dropout(dropout)] |
|
|
|
|
|
layers += [nn.Linear(hidden_dim, f_out)] |
|
|
self.layers = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
y = self.layers(x) |
|
|
return y |
|
|
|
|
|
|
|
|
class KPLayer(nn.Module): |
|
|
""" |
|
|
A demonstration of finding one step transition of linear system by DMD iteratively |
|
|
""" |
|
|
def __init__(self): |
|
|
super(KPLayer, self).__init__() |
|
|
|
|
|
self.K = None |
|
|
|
|
|
def one_step_forward(self, z, return_rec=False, return_K=False): |
|
|
B, input_len, E = z.shape |
|
|
assert input_len > 1, 'snapshots number should be larger than 1' |
|
|
x, y = z[:, :-1], z[:, 1:] |
|
|
|
|
|
|
|
|
self.K = torch.linalg.lstsq(x, y).solution |
|
|
if torch.isnan(self.K).any(): |
|
|
print('Encounter K with nan, replace K by identity matrix') |
|
|
self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1) |
|
|
|
|
|
z_pred = torch.bmm(z[:, -1:], self.K) |
|
|
if return_rec: |
|
|
z_rec = torch.cat((z[:, :1], torch.bmm(x, self.K)), dim=1) |
|
|
return z_rec, z_pred |
|
|
|
|
|
return z_pred |
|
|
|
|
|
def forward(self, z, pred_len=1): |
|
|
assert pred_len >= 1, 'prediction length should not be less than 1' |
|
|
z_rec, z_pred= self.one_step_forward(z, return_rec=True) |
|
|
z_preds = [z_pred] |
|
|
for i in range(1, pred_len): |
|
|
z_pred = torch.bmm(z_pred, self.K) |
|
|
z_preds.append(z_pred) |
|
|
z_preds = torch.cat(z_preds, dim=1) |
|
|
return z_rec, z_preds |
|
|
|
|
|
|
|
|
class KPLayerApprox(nn.Module): |
|
|
""" |
|
|
Find koopman transition of linear system by DMD with multistep K approximation |
|
|
""" |
|
|
def __init__(self): |
|
|
super(KPLayerApprox, self).__init__() |
|
|
|
|
|
self.K = None |
|
|
self.K_step = None |
|
|
|
|
|
def forward(self, z, pred_len=1): |
|
|
|
|
|
|
|
|
|
|
|
B, input_len, E = z.shape |
|
|
assert input_len > 1, 'snapshots number should be larger than 1' |
|
|
x, y = z[:, :-1], z[:, 1:] |
|
|
|
|
|
|
|
|
self.K = torch.linalg.lstsq(x, y).solution |
|
|
|
|
|
if torch.isnan(self.K).any(): |
|
|
print('Encounter K with nan, replace K by identity matrix') |
|
|
self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1) |
|
|
|
|
|
z_rec = torch.cat((z[:, :1], torch.bmm(x, self.K)), dim=1) |
|
|
|
|
|
if pred_len <= input_len: |
|
|
self.K_step = torch.linalg.matrix_power(self.K, pred_len) |
|
|
if torch.isnan(self.K_step).any(): |
|
|
print('Encounter multistep K with nan, replace it by identity matrix') |
|
|
self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1) |
|
|
z_pred = torch.bmm(z[:, -pred_len:, :], self.K_step) |
|
|
else: |
|
|
self.K_step = torch.linalg.matrix_power(self.K, input_len) |
|
|
if torch.isnan(self.K_step).any(): |
|
|
print('Encounter multistep K with nan, replace it by identity matrix') |
|
|
self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1) |
|
|
temp_z_pred, all_pred = z, [] |
|
|
for _ in range(math.ceil(pred_len / input_len)): |
|
|
temp_z_pred = torch.bmm(temp_z_pred, self.K_step) |
|
|
all_pred.append(temp_z_pred) |
|
|
z_pred = torch.cat(all_pred, dim=1)[:, :pred_len, :] |
|
|
|
|
|
return z_rec, z_pred |
|
|
|
|
|
|
|
|
class TimeVarKP(nn.Module): |
|
|
""" |
|
|
Koopman Predictor with DMD (analysitical solution of Koopman operator) |
|
|
Utilize local variations within individual sliding window to predict the future of time-variant term |
|
|
""" |
|
|
def __init__(self, |
|
|
enc_in=8, |
|
|
input_len=96, |
|
|
pred_len=96, |
|
|
seg_len=24, |
|
|
dynamic_dim=128, |
|
|
encoder=None, |
|
|
decoder=None, |
|
|
multistep=False, |
|
|
): |
|
|
super(TimeVarKP, self).__init__() |
|
|
self.input_len = input_len |
|
|
self.pred_len = pred_len |
|
|
self.enc_in = enc_in |
|
|
self.seg_len = seg_len |
|
|
self.dynamic_dim = dynamic_dim |
|
|
self.multistep = multistep |
|
|
self.encoder, self.decoder = encoder, decoder |
|
|
self.freq = math.ceil(self.input_len / self.seg_len) |
|
|
self.step = math.ceil(self.pred_len / self.seg_len) |
|
|
self.padding_len = self.seg_len * self.freq - self.input_len |
|
|
|
|
|
self.dynamics = KPLayerApprox() if self.multistep else KPLayer() |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
B, L, C = x.shape |
|
|
|
|
|
res = torch.cat((x[:, L-self.padding_len:, :], x) ,dim=1) |
|
|
|
|
|
res = res.chunk(self.freq, dim=1) |
|
|
res = torch.stack(res, dim=1).reshape(B, self.freq, -1) |
|
|
|
|
|
res = self.encoder(res) |
|
|
x_rec, x_pred = self.dynamics(res, self.step) |
|
|
|
|
|
x_rec = self.decoder(x_rec) |
|
|
x_rec = x_rec.reshape(B, self.freq, self.seg_len, self.enc_in) |
|
|
x_rec = x_rec.reshape(B, -1, self.enc_in)[:, :self.input_len, :] |
|
|
|
|
|
x_pred = self.decoder(x_pred) |
|
|
x_pred = x_pred.reshape(B, self.step, self.seg_len, self.enc_in) |
|
|
x_pred = x_pred.reshape(B, -1, self.enc_in)[:, :self.pred_len, :] |
|
|
|
|
|
return x_rec, x_pred |
|
|
|
|
|
|
|
|
class TimeInvKP(nn.Module): |
|
|
""" |
|
|
Koopman Predictor with learnable Koopman operator |
|
|
Utilize lookback and forecast window snapshots to predict the future of time-invariant term |
|
|
""" |
|
|
def __init__(self, |
|
|
input_len=96, |
|
|
pred_len=96, |
|
|
dynamic_dim=128, |
|
|
encoder=None, |
|
|
decoder=None): |
|
|
super(TimeInvKP, self).__init__() |
|
|
self.dynamic_dim = dynamic_dim |
|
|
self.input_len = input_len |
|
|
self.pred_len = pred_len |
|
|
self.encoder = encoder |
|
|
self.decoder = decoder |
|
|
|
|
|
K_init = torch.randn(self.dynamic_dim, self.dynamic_dim) |
|
|
U, _, V = torch.svd(K_init) |
|
|
self.K = nn.Linear(self.dynamic_dim, self.dynamic_dim, bias=False) |
|
|
self.K.weight.data = torch.mm(U, V.t()) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
res = x.transpose(1, 2) |
|
|
res = self.encoder(res) |
|
|
res = self.K(res) |
|
|
res = self.decoder(res) |
|
|
res = res.transpose(1, 2) |
|
|
|
|
|
return res |
|
|
|
|
|
|
|
|
class Model(nn.Module): |
|
|
''' |
|
|
Paper link: https://arxiv.org/pdf/2305.18803.pdf |
|
|
''' |
|
|
def __init__(self, configs, dynamic_dim=128, hidden_dim=64, hidden_layers=2, num_blocks=3, multistep=False): |
|
|
""" |
|
|
mask_spectrum: list, shared frequency spectrums |
|
|
seg_len: int, segment length of time series |
|
|
dynamic_dim: int, latent dimension of koopman embedding |
|
|
hidden_dim: int, hidden dimension of en/decoder |
|
|
hidden_layers: int, number of hidden layers of en/decoder |
|
|
num_blocks: int, number of Koopa blocks |
|
|
multistep: bool, whether to use approximation for multistep K |
|
|
alpha: float, spectrum filter ratio |
|
|
""" |
|
|
super(Model, self).__init__() |
|
|
self.task_name = configs.task_name |
|
|
self.enc_in = configs.enc_in |
|
|
self.input_len = configs.seq_len |
|
|
self.pred_len = configs.pred_len |
|
|
|
|
|
self.seg_len = self.pred_len |
|
|
self.num_blocks = num_blocks |
|
|
self.dynamic_dim = dynamic_dim |
|
|
self.hidden_dim = hidden_dim |
|
|
self.hidden_layers = hidden_layers |
|
|
self.multistep = multistep |
|
|
self.alpha = 0.2 |
|
|
self.mask_spectrum = self._get_mask_spectrum(configs) |
|
|
|
|
|
self.disentanglement = FourierFilter(self.mask_spectrum) |
|
|
|
|
|
|
|
|
self.time_inv_encoder = MLP(f_in=self.input_len, f_out=self.dynamic_dim, activation='relu', |
|
|
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers) |
|
|
self.time_inv_decoder = MLP(f_in=self.dynamic_dim, f_out=self.pred_len, activation='relu', |
|
|
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers) |
|
|
self.time_inv_kps = self.time_var_kps = nn.ModuleList([ |
|
|
TimeInvKP(input_len=self.input_len, |
|
|
pred_len=self.pred_len, |
|
|
dynamic_dim=self.dynamic_dim, |
|
|
encoder=self.time_inv_encoder, |
|
|
decoder=self.time_inv_decoder) |
|
|
for _ in range(self.num_blocks)]) |
|
|
|
|
|
|
|
|
self.time_var_encoder = MLP(f_in=self.seg_len*self.enc_in, f_out=self.dynamic_dim, activation='tanh', |
|
|
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers) |
|
|
self.time_var_decoder = MLP(f_in=self.dynamic_dim, f_out=self.seg_len*self.enc_in, activation='tanh', |
|
|
hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers) |
|
|
self.time_var_kps = nn.ModuleList([ |
|
|
TimeVarKP(enc_in=configs.enc_in, |
|
|
input_len=self.input_len, |
|
|
pred_len=self.pred_len, |
|
|
seg_len=self.seg_len, |
|
|
dynamic_dim=self.dynamic_dim, |
|
|
encoder=self.time_var_encoder, |
|
|
decoder=self.time_var_decoder, |
|
|
multistep=self.multistep) |
|
|
for _ in range(self.num_blocks)]) |
|
|
|
|
|
def _get_mask_spectrum(self, configs): |
|
|
""" |
|
|
get shared frequency spectrums |
|
|
""" |
|
|
train_data, train_loader = data_provider(configs, 'train') |
|
|
amps = 0.0 |
|
|
for data in train_loader: |
|
|
lookback_window = data[0] |
|
|
amps += abs(torch.fft.rfft(lookback_window, dim=1)).mean(dim=0).mean(dim=1) |
|
|
mask_spectrum = amps.topk(int(amps.shape[0]*self.alpha)).indices |
|
|
return mask_spectrum |
|
|
|
|
|
def forecast(self, x_enc): |
|
|
|
|
|
mean_enc = x_enc.mean(1, keepdim=True).detach() |
|
|
x_enc = x_enc - mean_enc |
|
|
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() |
|
|
x_enc = x_enc / std_enc |
|
|
|
|
|
|
|
|
residual, forecast = x_enc, None |
|
|
for i in range(self.num_blocks): |
|
|
time_var_input, time_inv_input = self.disentanglement(residual) |
|
|
time_inv_output = self.time_inv_kps[i](time_inv_input) |
|
|
time_var_backcast, time_var_output = self.time_var_kps[i](time_var_input) |
|
|
residual = residual - time_var_backcast |
|
|
if forecast is None: |
|
|
forecast = (time_inv_output + time_var_output) |
|
|
else: |
|
|
forecast += (time_inv_output + time_var_output) |
|
|
|
|
|
|
|
|
res = forecast * std_enc + mean_enc |
|
|
|
|
|
return res |
|
|
|
|
|
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec): |
|
|
if self.task_name == 'long_term_forecast': |
|
|
dec_out = self.forecast(x_enc) |
|
|
return dec_out[:, -self.pred_len:, :] |
|
|
|