AlgoX commited on
Commit
21e07e4
·
1 Parent(s): fd1ab91

feat : add mamba2 model

Browse files
Files changed (1) hide show
  1. model/mamba2.py +134 -0
model/mamba2.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def get_model_device(model):
7
+ return next(iter(model.parameters())).device
8
+
9
+
10
+ class CausalConv1d(nn.Module):
11
+
12
+ def __init__(self, hidden_size, kernel_size):
13
+ super().__init__()
14
+ self.hidden_size = hidden_size
15
+ self.kernel_size = kernel_size
16
+ self.conv = nn.Conv1d(
17
+ hidden_size, hidden_size, kernel_size, groups=hidden_size, bias=True
18
+ )
19
+
20
+ def init_state(self, batch_size: int, device: torch.device | None = None):
21
+ if device is None:
22
+ device = get_model_device(self)
23
+ return torch.zeros(
24
+ batch_size, self.hidden_size, self.kernel_size - 1, device=device
25
+ )
26
+
27
+ def forward(self, x: torch.Tensor, state: torch.Tensor):
28
+ x_with_state = torch.concat([state, x[:, :, None]], dim=-1)
29
+ out = self.conv(x_with_state)
30
+ new_state = x_with_state[:, :, 1:]
31
+ return out.squeeze(-1), new_state
32
+
33
+
34
+ class Mamba2(nn.Module):
35
+ def __init__(
36
+ self,
37
+ hidden_size: int,
38
+ inner_size: int | None = None,
39
+ head_size: int = 64,
40
+ bc_head_size: int = 128,
41
+ conv_kernel_size: int = 4,
42
+ ):
43
+ super().__init__()
44
+
45
+ self.head_size = head_size
46
+ self.bc_head_size = bc_head_size
47
+ if inner_size is None:
48
+ inner_size = 2 * hidden_size
49
+ assert inner_size % head_size == 0
50
+ self.inner_size = inner_size
51
+ self.num_heads = inner_size // head_size
52
+
53
+ # Projections
54
+ self.input_proj = nn.Linear(hidden_size, inner_size, bias=False)
55
+ self.z_proj = nn.Linear(hidden_size, inner_size, bias=False)
56
+ self.b_proj = nn.Linear(hidden_size, bc_head_size, bias=False)
57
+ self.c_proj = nn.Linear(hidden_size, bc_head_size, bias=False)
58
+ self.dt_proj = nn.Linear(hidden_size, self.num_heads, bias=True)
59
+
60
+ # Convs
61
+ self.input_conv = CausalConv1d(inner_size, conv_kernel_size)
62
+ self.b_conv = CausalConv1d(bc_head_size, conv_kernel_size)
63
+ self.c_conv = CausalConv1d(bc_head_size, conv_kernel_size)
64
+
65
+ # Other parameters
66
+ self.a = nn.Parameter(-torch.empty(self.num_heads).uniform_(1, 16))
67
+ self.d = nn.Parameter(torch.ones(self.num_heads))
68
+
69
+ # Output
70
+ self.norm = nn.RMSNorm(inner_size, eps=1e-5)
71
+ self.out_proj = nn.Linear(inner_size, hidden_size, bias=False)
72
+
73
+ def init_state(self, batch_size: int, device: torch.device | None = None):
74
+ if device is None:
75
+ device = get_model_device(self)
76
+ conv_states = [
77
+ conv.init_state(batch_size, device)
78
+ for conv in [self.input_conv, self.b_conv, self.c_conv]
79
+ ]
80
+ ssm_state = torch.zeros(
81
+ batch_size, self.num_heads, self.head_size, self.bc_head_size, device=device
82
+ )
83
+ return conv_states + [ssm_state]
84
+
85
+ def forward(self, t, state):
86
+ batch_size = t.shape[0]
87
+
88
+ x = self.input_proj(t)
89
+ z = self.z_proj(t)
90
+ b = self.b_proj(t)
91
+ c = self.c_proj(t)
92
+ dt = self.dt_proj(t)
93
+
94
+ x_conv_state, b_conv_state, c_conv_state, ssm_state = state
95
+ x, x_conv_state = self.input_conv(x, x_conv_state)
96
+ b, b_conv_state = self.b_conv(b, b_conv_state)
97
+ c, c_conv_state = self.c_conv(c, c_conv_state)
98
+ x = F.silu(x)
99
+ b = F.silu(b)
100
+ c = F.silu(c)
101
+
102
+ x = x.view(batch_size, self.num_heads, self.head_size)
103
+ dt = F.softplus(dt)
104
+
105
+ # new_state computation: h[t] = exp(A*dt) * h[t-1] + dt * B * x[t]
106
+ # [batch_size, num_heads]
107
+
108
+ decay = torch.exp(self.a[None] * dt)
109
+
110
+ # dt is [batch_size, num_heads]
111
+ # b is [batch_size, bc_head_size]
112
+ # x is [batch_size, head_size]
113
+
114
+ new_state_contrib = dt[:, :, None, None] * b[:, None, None] * x[:, :, :, None]
115
+ ssm_state = decay[:, :, None, None] * ssm_state + new_state_contrib
116
+
117
+ # output computation: y[t] = C @ h[t] + D * x[t]
118
+
119
+ # The accumulation in the product of C and h[t] is on the bc_head_size dimension
120
+
121
+ state_contrib = torch.einsum("bc,bnhc->bnh", c, ssm_state)
122
+ # d has shape [num_heads], broadcasting it to the shape of x.
123
+
124
+ y = state_contrib + self.d[None, :, None] * x
125
+
126
+ # Combine heads
127
+ y = y.view(batch_size, self.inner_size)
128
+ # Gate, normalization and out
129
+ y = y * F.silu(z)
130
+ y = self.norm(y)
131
+ output = self.out_proj(y)
132
+
133
+ new_state = [x_conv_state, b_conv_state, c_conv_state, ssm_state]
134
+ return output, new_state