Abdullah-Nazhat commited on
Commit
d11d0a9
·
verified ·
1 Parent(s): 007d4b5

Update mamba.py

Browse files
Files changed (1) hide show
  1. mamba.py +68 -150
mamba.py CHANGED
@@ -8,47 +8,30 @@ import torch.nn.functional as F
8
 
9
  from pscan import pscan
10
 
11
- """
12
 
13
- This file closely follows the mamba_simple.py from the official Mamba implementation, and the mamba-minimal by @johnma2006.
14
- The major differences are :
15
- -the convolution is done with torch.nn.Conv1d
16
- -the selective scan is done in PyTorch
17
-
18
- A sequential version of the selective scan is also available for comparison.
19
-
20
- - A Mamba model is composed of several layers, which are ResidualBlock.
21
- - A ResidualBlock is composed of a MambaBlock, a normalization, and a residual connection : ResidualBlock(x) = mamba(norm(x)) + x
22
- - This leaves us with the MambaBlock : its input x is (B, L, D) and its outputs y is also (B, L, D) (B=batch size, L=seq len, D=model dim).
23
- First, we expand x into (B, L, 2*ED) (where E is usually 2) and split it into x and z, each (B, L, ED).
24
- Then, we apply the short 1d conv to x, followed by an activation function (silu), then the SSM.
25
- We then multiply it by silu(z).
26
- See Figure 3 of the paper (page 8) for a visual representation of a MambaBlock.
27
-
28
- """
29
 
30
  @dataclass
31
  class MambaConfig:
32
- d_model: int # D
33
  n_layers: int
34
  dt_rank: Union[int, str] = 'auto'
35
- d_state: int = 16 # N in paper/comments
36
- expand_factor: int = 2 # E in paper/comments
37
  d_conv: int = 4
38
 
39
  dt_min: float = 0.001
40
  dt_max: float = 0.1
41
- dt_init: str = "random" # "random" or "constant"
42
  dt_scale: float = 1.0
43
  dt_init_floor = 1e-4
44
 
45
  bias: bool = False
46
  conv_bias: bool = True
47
 
48
- pscan: bool = True # use parallel scan mode or sequential mode when training
49
 
50
  def __post_init__(self):
51
- self.d_inner = self.expand_factor * self.d_model # E*D = ED in comments
52
 
53
  if self.dt_rank == 'auto':
54
  self.dt_rank = math.ceil(self.d_model / 16)
@@ -60,26 +43,20 @@ class Mamba(nn.Module):
60
  self.config = config
61
 
62
  self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)])
63
- #self.norm_f = RMSNorm(config.d_model)
64
 
65
  def forward(self, x):
66
- # x : (B, L, D)
67
-
68
- # y : (B, L, D)
69
 
70
  for layer in self.layers:
71
  x = layer(x)
72
 
73
- #x = self.norm_f(x)
74
 
75
  return x
76
 
77
  def step(self, x, caches):
78
- # x : (B, L, D)
79
- # caches : [cache(layer) for all layers], cache : (h, inputs)
80
-
81
- # y : (B, L, D)
82
- # caches : [cache(layer) for all layers], cache : (h, inputs)
83
 
84
  for i, layer in enumerate(self.layers):
85
  x, caches[i] = layer.step(x, caches[i])
@@ -94,21 +71,13 @@ class ResidualBlock(nn.Module):
94
  self.norm = RMSNorm(config.d_model)
95
 
96
  def forward(self, x):
97
- # x : (B, L, D)
98
-
99
- # output : (B, L, D)
100
 
101
  output = self.mixer(self.norm(x)) + x
102
  return output
103
 
104
  def step(self, x, cache):
105
- # x : (B, D)
106
- # cache : (h, inputs)
107
- # h : (B, ED, N)
108
- # inputs: (B, ED, d_conv-1)
109
-
110
- # output : (B, D)
111
- # cache : (h, inputs)
112
 
113
  output, cache = self.mixer.step(self.norm(x), cache)
114
  output = output + x
@@ -120,7 +89,7 @@ class MambaBlock(nn.Module):
120
 
121
  self.config = config
122
 
123
- # projects block input from D to 2*ED (two branches)
124
  self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias)
125
 
126
  self.conv1d = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner,
@@ -128,14 +97,13 @@ class MambaBlock(nn.Module):
128
  groups=config.d_inner,
129
  padding=config.d_conv - 1)
130
 
131
- # projects x to input-dependent Δ, B, C
132
  self.x_proj = nn.Linear(config.d_inner, config.dt_rank + 2 * config.d_state, bias=False)
133
 
134
- # projects Δ from dt_rank to d_inner
135
  self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True)
136
 
137
- # dt initialization
138
- # dt weights
139
  dt_init_std = config.dt_rank**-0.5 * config.dt_scale
140
  if config.dt_init == "constant":
141
  nn.init.constant_(self.dt_proj.weight, dt_init_std)
@@ -144,63 +112,56 @@ class MambaBlock(nn.Module):
144
  else:
145
  raise NotImplementedError
146
 
147
- # dt bias
148
  dt = torch.exp(
149
  torch.rand(config.d_inner) * (math.log(config.dt_max) - math.log(config.dt_min)) + math.log(config.dt_min)
150
  ).clamp(min=config.dt_init_floor)
151
- inv_dt = dt + torch.log(-torch.expm1(-dt)) # inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
152
  with torch.no_grad():
153
  self.dt_proj.bias.copy_(inv_dt)
154
- #self.dt_proj.bias._no_reinit = True # initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
155
- # todo : explain why removed
156
-
157
- # S4D real initialization
158
  A = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat(config.d_inner, 1)
159
- self.A_log = nn.Parameter(torch.log(A)) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ?
160
  self.D = nn.Parameter(torch.ones(config.d_inner))
161
 
162
- # projects block output from ED back to D
163
  self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
164
 
165
  def forward(self, x):
166
- # x : (B, L, D)
167
 
168
- # y : (B, L, D)
169
 
170
  _, L, _ = x.shape
171
 
172
- xz = self.in_proj(x) # (B, L, 2*ED)
173
- x, z = xz.chunk(2, dim=-1) # (B, L, ED), (B, L, ED)
174
 
175
- # x branch
176
- x = x.transpose(1, 2) # (B, ED, L)
177
- x = self.conv1d(x)[:, :, :L] # depthwise convolution over time, with a short filter
178
- x = x.transpose(1, 2) # (B, L, ED)
179
 
180
  x = F.silu(x)
181
  y = self.ssm(x)
182
 
183
- # z branch
184
  z = F.silu(z)
185
 
186
  output = y * z
187
- output = self.out_proj(output) # (B, L, D)
188
 
189
  return output
190
 
191
  def ssm(self, x):
192
- # x : (B, L, ED)
193
 
194
- # y : (B, L, ED)
195
-
196
- A = -torch.exp(self.A_log.float()) # (ED, N)
197
  D = self.D.float()
198
- # TODO remove .float()
199
 
200
- deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N)
201
 
202
  delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, L, dt_rank), (B, L, N), (B, L, N)
203
- delta = F.softplus(self.dt_proj(delta)) # (B, L, ED)
204
 
205
  if self.config.pscan:
206
  y = self.selective_scan(x, delta, A, B, C, D)
@@ -210,44 +171,30 @@ class MambaBlock(nn.Module):
210
  return y
211
 
212
  def selective_scan(self, x, delta, A, B, C, D):
213
- # x : (B, L, ED)
214
- # Δ : (B, L, ED)
215
- # A : (ED, N)
216
- # B : (B, L, N)
217
- # C : (B, L, N)
218
- # D : (ED)
219
-
220
- # y : (B, L, ED)
221
 
222
- deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
223
- deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
224
 
225
- BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
226
 
227
  hs = pscan(deltaA, BX)
228
 
229
- y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
230
 
231
  y = y + D * x
232
 
233
  return y
234
 
235
  def selective_scan_seq(self, x, delta, A, B, C, D):
236
- # x : (B, L, ED)
237
- # Δ : (B, L, ED)
238
- # A : (ED, N)
239
- # B : (B, L, N)
240
- # C : (B, L, N)
241
- # D : (ED)
242
-
243
- # y : (B, L, ED)
244
 
245
  _, L, _ = x.shape
246
 
247
- deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
248
- deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
249
 
250
- BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
251
 
252
  h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
253
  hs = []
@@ -256,103 +203,74 @@ class MambaBlock(nn.Module):
256
  h = deltaA[:, t] * h + BX[:, t]
257
  hs.append(h)
258
 
259
- hs = torch.stack(hs, dim=1) # (B, L, ED, N)
260
 
261
- y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
262
 
263
  y = y + D * x
264
 
265
  return y
266
 
267
- # -------------------------- inference -------------------------- #
268
- """
269
- Concerning auto-regressive inference
270
-
271
- The cool part of using Mamba : inference is constant wrt to sequence length
272
- We just have to keep in cache, for each layer, two things :
273
- - the hidden state h (which is (B, ED, N)), as you typically would when doing inference with a RNN
274
- - the last d_conv-1 inputs of the layer, to be able to compute the 1D conv which is a convolution over the time dimension
275
- (d_conv is fixed so this doesn't incur a growing cache as we progress on generating the sequence)
276
- (and d_conv is usually very small, like 4, so we just have to "remember" the last 3 inputs)
277
-
278
- Concretely, these two quantities are put inside a cache tuple, and are named h and inputs respectively.
279
- h is (B, ED, N), and inputs is (B, ED, d_conv-1)
280
- The MambaBlock.step() receives this cache, and, along with outputing the output, alos outputs the updated cache for the next call.
281
-
282
- The cache object is initialized as follows : (None, torch.zeros()).
283
- When h is None, the selective scan function detects it and start with h=0.
284
- The torch.zeros() isn't a problem (it's same as just feeding the input, because the conv1d is padded)
285
-
286
- As we need one such cache variable per layer, we store a caches object, which is simply a list of cache object. (See mamba_lm.py)
287
- """
288
 
289
  def step(self, x, cache):
290
- # x : (B, D)
291
- # cache : (h, inputs)
292
- # h : (B, ED, N)
293
- # inputs : (B, ED, d_conv-1)
294
-
295
- # y : (B, D)
296
- # cache : (h, inputs)
297
 
298
  h, inputs = cache
299
 
300
- xz = self.in_proj(x) # (B, 2*ED)
301
- x, z = xz.chunk(2, dim=1) # (B, ED), (B, ED)
302
 
303
- # x branch
304
  x_cache = x.unsqueeze(2)
305
- x = self.conv1d(torch.cat([inputs, x_cache], dim=2))[:, :, self.config.d_conv-1] # (B, ED)
306
 
307
  x = F.silu(x)
308
  y, h = self.ssm_step(x, h)
309
 
310
- # z branch
311
  z = F.silu(z)
312
 
313
  output = y * z
314
  output = self.out_proj(output) # (B, D)
315
 
316
- # prepare cache for next call
317
  inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2) # (B, ED, d_conv-1)
318
  cache = (h, inputs)
319
 
320
  return output, cache
321
 
322
  def ssm_step(self, x, h):
323
- # x : (B, ED)
324
- # h : (B, ED, N)
325
 
326
- # y : (B, ED)
327
- # h : (B, ED, N)
328
-
329
- A = -torch.exp(self.A_log.float()) # (ED, N) # todo : ne pas le faire tout le temps, puisque c'est indépendant de la timestep
330
  D = self.D.float()
331
- # TODO remove .float()
332
 
333
- deltaBC = self.x_proj(x) # (B, dt_rank+2*N)
334
 
335
- delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, dt_rank), (B, N), (B, N)
336
- delta = F.softplus(self.dt_proj(delta)) # (B, ED)
337
 
338
- deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, ED, N)
339
- deltaB = delta.unsqueeze(-1) * B.unsqueeze(1) # (B, ED, N)
340
 
341
- BX = deltaB * (x.unsqueeze(-1)) # (B, ED, N)
342
 
343
  if h is None:
344
- h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
345
 
346
- h = deltaA * h + BX # (B, ED, N)
347
 
348
- y = (h @ C.unsqueeze(-1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1)
349
 
350
  y = y + D * x
351
 
352
- # todo : pq h.squeeze(1) ??
353
  return y, h.squeeze(1)
354
 
355
- # taken straight from https://github.com/johnma2006/mamba-minimal/blob/master/model.py
356
  class RMSNorm(nn.Module):
357
  def __init__(self, d_model: int, eps: float = 1e-5):
358
  super().__init__()
 
8
 
9
  from pscan import pscan
10
 
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  @dataclass
14
  class MambaConfig:
15
+ d_model: int
16
  n_layers: int
17
  dt_rank: Union[int, str] = 'auto'
18
+ d_state: int = 16
19
+ expand_factor: int = 2
20
  d_conv: int = 4
21
 
22
  dt_min: float = 0.001
23
  dt_max: float = 0.1
24
+ dt_init: str = "random"
25
  dt_scale: float = 1.0
26
  dt_init_floor = 1e-4
27
 
28
  bias: bool = False
29
  conv_bias: bool = True
30
 
31
+ pscan: bool = True
32
 
33
  def __post_init__(self):
34
+ self.d_inner = self.expand_factor * self.d_model
35
 
36
  if self.dt_rank == 'auto':
37
  self.dt_rank = math.ceil(self.d_model / 16)
 
43
  self.config = config
44
 
45
  self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)])
46
+
47
 
48
  def forward(self, x):
49
+
 
 
50
 
51
  for layer in self.layers:
52
  x = layer(x)
53
 
54
+
55
 
56
  return x
57
 
58
  def step(self, x, caches):
59
+
 
 
 
 
60
 
61
  for i, layer in enumerate(self.layers):
62
  x, caches[i] = layer.step(x, caches[i])
 
71
  self.norm = RMSNorm(config.d_model)
72
 
73
  def forward(self, x):
74
+
 
 
75
 
76
  output = self.mixer(self.norm(x)) + x
77
  return output
78
 
79
  def step(self, x, cache):
80
+
 
 
 
 
 
 
81
 
82
  output, cache = self.mixer.step(self.norm(x), cache)
83
  output = output + x
 
89
 
90
  self.config = config
91
 
92
+
93
  self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias)
94
 
95
  self.conv1d = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner,
 
97
  groups=config.d_inner,
98
  padding=config.d_conv - 1)
99
 
100
+
101
  self.x_proj = nn.Linear(config.d_inner, config.dt_rank + 2 * config.d_state, bias=False)
102
 
103
+
104
  self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True)
105
 
106
+
 
107
  dt_init_std = config.dt_rank**-0.5 * config.dt_scale
108
  if config.dt_init == "constant":
109
  nn.init.constant_(self.dt_proj.weight, dt_init_std)
 
112
  else:
113
  raise NotImplementedError
114
 
115
+
116
  dt = torch.exp(
117
  torch.rand(config.d_inner) * (math.log(config.dt_max) - math.log(config.dt_min)) + math.log(config.dt_min)
118
  ).clamp(min=config.dt_init_floor)
119
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
120
  with torch.no_grad():
121
  self.dt_proj.bias.copy_(inv_dt)
122
+
 
 
 
123
  A = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat(config.d_inner, 1)
124
+ self.A_log = nn.Parameter(torch.log(A))
125
  self.D = nn.Parameter(torch.ones(config.d_inner))
126
 
127
+
128
  self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
129
 
130
  def forward(self, x):
 
131
 
 
132
 
133
  _, L, _ = x.shape
134
 
135
+ xz = self.in_proj(x)
136
+ x, z = xz.chunk(2, dim=-1)
137
 
138
+
139
+ x = x.transpose(1, 2)
140
+ x = self.conv1d(x)[:, :, :L]
141
+ x = x.transpose(1, 2)
142
 
143
  x = F.silu(x)
144
  y = self.ssm(x)
145
 
146
+
147
  z = F.silu(z)
148
 
149
  output = y * z
150
+ output = self.out_proj(output)
151
 
152
  return output
153
 
154
  def ssm(self, x):
155
+
156
 
157
+ A = -torch.exp(self.A_log.float())
 
 
158
  D = self.D.float()
159
+
160
 
161
+ deltaBC = self.x_proj(x)
162
 
163
  delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, L, dt_rank), (B, L, N), (B, L, N)
164
+ delta = F.softplus(self.dt_proj(delta))
165
 
166
  if self.config.pscan:
167
  y = self.selective_scan(x, delta, A, B, C, D)
 
171
  return y
172
 
173
  def selective_scan(self, x, delta, A, B, C, D):
174
+
 
 
 
 
 
 
 
175
 
176
+ deltaA = torch.exp(delta.unsqueeze(-1) * A)
177
+ deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)
178
 
179
+ BX = deltaB * (x.unsqueeze(-1))
180
 
181
  hs = pscan(deltaA, BX)
182
 
183
+ y = (hs @ C.unsqueeze(-1)).squeeze(3)
184
 
185
  y = y + D * x
186
 
187
  return y
188
 
189
  def selective_scan_seq(self, x, delta, A, B, C, D):
190
+
 
 
 
 
 
 
 
191
 
192
  _, L, _ = x.shape
193
 
194
+ deltaA = torch.exp(delta.unsqueeze(-1) * A)
195
+ deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)
196
 
197
+ BX = deltaB * (x.unsqueeze(-1))
198
 
199
  h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
200
  hs = []
 
203
  h = deltaA[:, t] * h + BX[:, t]
204
  hs.append(h)
205
 
206
+ hs = torch.stack(hs, dim=1)
207
 
208
+ y = (hs @ C.unsqueeze(-1)).squeeze(3)
209
 
210
  y = y + D * x
211
 
212
  return y
213
 
214
+
215
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  def step(self, x, cache):
218
+
 
 
 
 
 
 
219
 
220
  h, inputs = cache
221
 
222
+ xz = self.in_proj(x)
223
+ x, z = xz.chunk(2, dim=1)
224
 
225
+
226
  x_cache = x.unsqueeze(2)
227
+ x = self.conv1d(torch.cat([inputs, x_cache], dim=2))[:, :, self.config.d_conv-1]
228
 
229
  x = F.silu(x)
230
  y, h = self.ssm_step(x, h)
231
 
232
+
233
  z = F.silu(z)
234
 
235
  output = y * z
236
  output = self.out_proj(output) # (B, D)
237
 
238
+
239
  inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2) # (B, ED, d_conv-1)
240
  cache = (h, inputs)
241
 
242
  return output, cache
243
 
244
  def ssm_step(self, x, h):
245
+
 
246
 
247
+ A = -torch.exp(self.A_log.float())
 
 
 
248
  D = self.D.float()
249
+
250
 
251
+ deltaBC = self.x_proj(x)
252
 
253
+ delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1)
254
+ delta = F.softplus(self.dt_proj(delta))
255
 
256
+ deltaA = torch.exp(delta.unsqueeze(-1) * A)
257
+ deltaB = delta.unsqueeze(-1) * B.unsqueeze(1)
258
 
259
+ BX = deltaB * (x.unsqueeze(-1))
260
 
261
  if h is None:
262
+ h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device)
263
 
264
+ h = deltaA * h + BX
265
 
266
+ y = (h @ C.unsqueeze(-1)).squeeze(2)
267
 
268
  y = y + D * x
269
 
270
+
271
  return y, h.squeeze(1)
272
 
273
+
274
  class RMSNorm(nn.Module):
275
  def __init__(self, d_model: int, eps: float = 1e-5):
276
  super().__init__()