|
|
import math |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange, repeat, einsum |
|
|
|
|
|
from layers.Embed import DataEmbedding |
|
|
|
|
|
|
|
|
class Model(nn.Module): |
|
|
""" |
|
|
Mamba, linear-time sequence modeling with selective state spaces O(L) |
|
|
Paper link: https://arxiv.org/abs/2312.00752 |
|
|
Implementation refernce: https://github.com/johnma2006/mamba-minimal/ |
|
|
""" |
|
|
|
|
|
def __init__(self, configs): |
|
|
super(Model, self).__init__() |
|
|
self.task_name = configs.task_name |
|
|
self.pred_len = configs.pred_len |
|
|
|
|
|
self.d_inner = configs.d_model * configs.expand |
|
|
self.dt_rank = math.ceil(configs.d_model / 16) |
|
|
|
|
|
self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout) |
|
|
|
|
|
self.layers = nn.ModuleList([ResidualBlock(configs, self.d_inner, self.dt_rank) for _ in range(configs.e_layers)]) |
|
|
self.norm = RMSNorm(configs.d_model) |
|
|
|
|
|
self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False) |
|
|
|
|
|
def forecast(self, x_enc, x_mark_enc): |
|
|
mean_enc = x_enc.mean(1, keepdim=True).detach() |
|
|
x_enc = x_enc - mean_enc |
|
|
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() |
|
|
x_enc = x_enc / std_enc |
|
|
|
|
|
x = self.embedding(x_enc, x_mark_enc) |
|
|
for layer in self.layers: |
|
|
x = layer(x) |
|
|
|
|
|
x = self.norm(x) |
|
|
x_out = self.out_layer(x) |
|
|
|
|
|
x_out = x_out * std_enc + mean_enc |
|
|
return x_out |
|
|
|
|
|
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): |
|
|
if self.task_name in ['short_term_forecast', 'long_term_forecast']: |
|
|
x_out = self.forecast(x_enc, x_mark_enc) |
|
|
return x_out[:, -self.pred_len:, :] |
|
|
|
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
|
def __init__(self, configs, d_inner, dt_rank): |
|
|
super(ResidualBlock, self).__init__() |
|
|
|
|
|
self.mixer = MambaBlock(configs, d_inner, dt_rank) |
|
|
self.norm = RMSNorm(configs.d_model) |
|
|
|
|
|
def forward(self, x): |
|
|
output = self.mixer(self.norm(x)) + x |
|
|
return output |
|
|
|
|
|
class MambaBlock(nn.Module): |
|
|
def __init__(self, configs, d_inner, dt_rank): |
|
|
super(MambaBlock, self).__init__() |
|
|
self.d_inner = d_inner |
|
|
self.dt_rank = dt_rank |
|
|
|
|
|
self.in_proj = nn.Linear(configs.d_model, self.d_inner * 2, bias=False) |
|
|
|
|
|
self.conv1d = nn.Conv1d( |
|
|
in_channels = self.d_inner, |
|
|
out_channels = self.d_inner, |
|
|
bias = True, |
|
|
kernel_size = configs.d_conv, |
|
|
padding = configs.d_conv - 1, |
|
|
groups = self.d_inner |
|
|
) |
|
|
|
|
|
|
|
|
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + configs.d_ff * 2, bias=False) |
|
|
|
|
|
|
|
|
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) |
|
|
|
|
|
A = repeat(torch.arange(1, configs.d_ff + 1), "n -> d n", d=self.d_inner) |
|
|
self.A_log = nn.Parameter(torch.log(A)) |
|
|
self.D = nn.Parameter(torch.ones(self.d_inner)) |
|
|
|
|
|
self.out_proj = nn.Linear(self.d_inner, configs.d_model, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Figure 3 in Section 3.4 in the paper |
|
|
""" |
|
|
(b, l, d) = x.shape |
|
|
|
|
|
x_and_res = self.in_proj(x) |
|
|
(x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1) |
|
|
|
|
|
x = rearrange(x, "b l d -> b d l") |
|
|
x = self.conv1d(x)[:, :, :l] |
|
|
x = rearrange(x, "b d l -> b l d") |
|
|
|
|
|
x = F.silu(x) |
|
|
|
|
|
y = self.ssm(x) |
|
|
y = y * F.silu(res) |
|
|
|
|
|
output = self.out_proj(y) |
|
|
return output |
|
|
|
|
|
|
|
|
def ssm(self, x): |
|
|
""" |
|
|
Algorithm 2 in Section 3.2 in the paper |
|
|
""" |
|
|
|
|
|
(d_in, n) = self.A_log.shape |
|
|
|
|
|
A = -torch.exp(self.A_log.float()) |
|
|
D = self.D.float() |
|
|
|
|
|
x_dbl = self.x_proj(x) |
|
|
(delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) |
|
|
delta = F.softplus(self.dt_proj(delta)) |
|
|
y = self.selective_scan(x, delta, A, B, C, D) |
|
|
|
|
|
return y |
|
|
|
|
|
def selective_scan(self, u, delta, A, B, C, D): |
|
|
(b, l, d_in) = u.shape |
|
|
n = A.shape[1] |
|
|
|
|
|
deltaA = torch.exp(einsum(delta, A, "b l d, d n -> b l d n")) |
|
|
deltaB_u = einsum(delta, B, u, "b l d, b l n, b l d -> b l d n") |
|
|
|
|
|
|
|
|
x = torch.zeros((b, d_in, n), device=deltaA.device) |
|
|
ys = [] |
|
|
for i in range(l): |
|
|
x = deltaA[:, i] * x + deltaB_u[:, i] |
|
|
y = einsum(x, C[:, i, :], "b d n, b n -> b d") |
|
|
ys.append(y) |
|
|
|
|
|
y = torch.stack(ys, dim=1) |
|
|
y = y + u * D |
|
|
|
|
|
return y |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, d_model, eps=1e-5): |
|
|
super(RMSNorm, self).__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(d_model)) |
|
|
|
|
|
def forward(self, x): |
|
|
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight |
|
|
return output |
|
|
|