|
|
import math |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from mamba_ssm import Mamba |
|
|
|
|
|
from layers.Embed import DataEmbedding |
|
|
|
|
|
class Model(nn.Module): |
|
|
|
|
|
def __init__(self, configs): |
|
|
super(Model, self).__init__() |
|
|
self.task_name = configs.task_name |
|
|
self.pred_len = configs.pred_len |
|
|
|
|
|
self.d_inner = configs.d_model * configs.expand |
|
|
self.dt_rank = math.ceil(configs.d_model / 16) |
|
|
|
|
|
self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout) |
|
|
|
|
|
self.mamba = Mamba( |
|
|
d_model = configs.d_model, |
|
|
d_state = configs.d_ff, |
|
|
d_conv = configs.d_conv, |
|
|
expand = configs.expand, |
|
|
) |
|
|
|
|
|
self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False) |
|
|
|
|
|
def forecast(self, x_enc, x_mark_enc): |
|
|
mean_enc = x_enc.mean(1, keepdim=True).detach() |
|
|
x_enc = x_enc - mean_enc |
|
|
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() |
|
|
x_enc = x_enc / std_enc |
|
|
|
|
|
x = self.embedding(x_enc, x_mark_enc) |
|
|
x = self.mamba(x) |
|
|
x_out = self.out_layer(x) |
|
|
|
|
|
x_out = x_out * std_enc + mean_enc |
|
|
return x_out |
|
|
|
|
|
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): |
|
|
if self.task_name in ['short_term_forecast', 'long_term_forecast']: |
|
|
x_out = self.forecast(x_enc, x_mark_enc) |
|
|
return x_out[:, -self.pred_len:, :] |
|
|
|
|
|
|