mlStocks-pred / model /mamba2.py
AlgoX's picture
feat : add mamba2 model
21e07e4
import torch
import torch.nn as nn
import torch.nn.functional as F
def get_model_device(model):
return next(iter(model.parameters())).device
class CausalConv1d(nn.Module):
def __init__(self, hidden_size, kernel_size):
super().__init__()
self.hidden_size = hidden_size
self.kernel_size = kernel_size
self.conv = nn.Conv1d(
hidden_size, hidden_size, kernel_size, groups=hidden_size, bias=True
)
def init_state(self, batch_size: int, device: torch.device | None = None):
if device is None:
device = get_model_device(self)
return torch.zeros(
batch_size, self.hidden_size, self.kernel_size - 1, device=device
)
def forward(self, x: torch.Tensor, state: torch.Tensor):
x_with_state = torch.concat([state, x[:, :, None]], dim=-1)
out = self.conv(x_with_state)
new_state = x_with_state[:, :, 1:]
return out.squeeze(-1), new_state
class Mamba2(nn.Module):
def __init__(
self,
hidden_size: int,
inner_size: int | None = None,
head_size: int = 64,
bc_head_size: int = 128,
conv_kernel_size: int = 4,
):
super().__init__()
self.head_size = head_size
self.bc_head_size = bc_head_size
if inner_size is None:
inner_size = 2 * hidden_size
assert inner_size % head_size == 0
self.inner_size = inner_size
self.num_heads = inner_size // head_size
# Projections
self.input_proj = nn.Linear(hidden_size, inner_size, bias=False)
self.z_proj = nn.Linear(hidden_size, inner_size, bias=False)
self.b_proj = nn.Linear(hidden_size, bc_head_size, bias=False)
self.c_proj = nn.Linear(hidden_size, bc_head_size, bias=False)
self.dt_proj = nn.Linear(hidden_size, self.num_heads, bias=True)
# Convs
self.input_conv = CausalConv1d(inner_size, conv_kernel_size)
self.b_conv = CausalConv1d(bc_head_size, conv_kernel_size)
self.c_conv = CausalConv1d(bc_head_size, conv_kernel_size)
# Other parameters
self.a = nn.Parameter(-torch.empty(self.num_heads).uniform_(1, 16))
self.d = nn.Parameter(torch.ones(self.num_heads))
# Output
self.norm = nn.RMSNorm(inner_size, eps=1e-5)
self.out_proj = nn.Linear(inner_size, hidden_size, bias=False)
def init_state(self, batch_size: int, device: torch.device | None = None):
if device is None:
device = get_model_device(self)
conv_states = [
conv.init_state(batch_size, device)
for conv in [self.input_conv, self.b_conv, self.c_conv]
]
ssm_state = torch.zeros(
batch_size, self.num_heads, self.head_size, self.bc_head_size, device=device
)
return conv_states + [ssm_state]
def forward(self, t, state):
batch_size = t.shape[0]
x = self.input_proj(t)
z = self.z_proj(t)
b = self.b_proj(t)
c = self.c_proj(t)
dt = self.dt_proj(t)
x_conv_state, b_conv_state, c_conv_state, ssm_state = state
x, x_conv_state = self.input_conv(x, x_conv_state)
b, b_conv_state = self.b_conv(b, b_conv_state)
c, c_conv_state = self.c_conv(c, c_conv_state)
x = F.silu(x)
b = F.silu(b)
c = F.silu(c)
x = x.view(batch_size, self.num_heads, self.head_size)
dt = F.softplus(dt)
# new_state computation: h[t] = exp(A*dt) * h[t-1] + dt * B * x[t]
# [batch_size, num_heads]
decay = torch.exp(self.a[None] * dt)
# dt is [batch_size, num_heads]
# b is [batch_size, bc_head_size]
# x is [batch_size, head_size]
new_state_contrib = dt[:, :, None, None] * b[:, None, None] * x[:, :, :, None]
ssm_state = decay[:, :, None, None] * ssm_state + new_state_contrib
# output computation: y[t] = C @ h[t] + D * x[t]
# The accumulation in the product of C and h[t] is on the bc_head_size dimension
state_contrib = torch.einsum("bc,bnhc->bnh", c, ssm_state)
# d has shape [num_heads], broadcasting it to the shape of x.
y = state_contrib + self.d[None, :, None] * x
# Combine heads
y = y.view(batch_size, self.inner_size)
# Gate, normalization and out
y = y * F.silu(z)
y = self.norm(y)
output = self.out_proj(y)
new_state = [x_conv_state, b_conv_state, c_conv_state, ssm_state]
return output, new_state