flpelerin commited on
Commit
d5344c9
·
1 Parent(s): 8b5543e

Delete mamba.py

Browse files
Files changed (1) hide show
  1. mamba.py +0 -343
mamba.py DELETED
@@ -1,343 +0,0 @@
1
- """Simple, minimal implementation of Mamba in one file of PyTorch.
2
-
3
- Source: https://github.com/johnma2006/mamba-minimal
4
-
5
- Suggest reading the following before/while reading the code:
6
- [1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)
7
- https://arxiv.org/abs/2312.00752
8
- [2] The Annotated S4 (Sasha Rush and Sidd Karamcheti)
9
- https://srush.github.io/annotated-s4
10
-
11
- Glossary:
12
- b: batch size (`B` in Mamba paper [1] Algorithm 2)
13
- l: sequence length (`L` in [1] Algorithm 2)
14
- d or d_model: hidden dim
15
- n or d_state: latent state dim (`N` in [1] Algorithm 2)
16
- expand: expansion factor (`E` in [1] Section 3.4)
17
- d_in or d_inner: d * expand (`D` in [1] Algorithm 2)
18
- A, B, C, D: state space parameters (See any state space representation formula)
19
- (B, C are input-dependent (aka selective, a key innovation in Mamba); A, D are not)
20
- Δ or delta: input-dependent step size
21
- dt_rank: rank of Δ (See [1] Section 3.6 "Parameterization of ∆")
22
-
23
- """
24
- from __future__ import annotations
25
- import math
26
- import json
27
- import torch
28
- import torch.nn as nn
29
- import torch.nn.functional as F
30
- from dataclasses import dataclass
31
- from einops import rearrange, repeat, einsum
32
-
33
-
34
- @dataclass
35
- class ModelArgs:
36
- d_model: int
37
- n_layer: int
38
- vocab_size: int
39
- d_state: int = 16
40
- expand: int = 2
41
- dt_rank: Union[int, str] = 'auto'
42
- d_conv: int = 4
43
- pad_vocab_size_multiple: int = 8
44
- conv_bias: bool = True
45
- bias: bool = False
46
-
47
- def __post_init__(self):
48
- self.d_inner = int(self.expand * self.d_model)
49
-
50
- if self.dt_rank == 'auto':
51
- self.dt_rank = math.ceil(self.d_model / 16)
52
-
53
- if self.vocab_size % self.pad_vocab_size_multiple != 0:
54
- self.vocab_size += (self.pad_vocab_size_multiple
55
- - self.vocab_size % self.pad_vocab_size_multiple)
56
-
57
-
58
- class Mamba(nn.Module):
59
- def __init__(self, args: ModelArgs):
60
- """Full Mamba model."""
61
- super().__init__()
62
- self.args = args
63
-
64
- self.embedding = nn.Embedding(args.vocab_size, args.d_model)
65
- self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
66
- self.norm_f = RMSNorm(args.d_model)
67
-
68
- self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
69
- self.lm_head.weight = self.embedding.weight # Tie output projection to embedding weights.
70
- # See "Weight Tying" paper
71
-
72
-
73
- def forward(self, input_ids):
74
- """
75
- Args:
76
- input_ids (long tensor): shape (b, l) (See Glossary at top for definitions of b, l, d_in, n...)
77
-
78
- Returns:
79
- logits: shape (b, l, vocab_size)
80
-
81
- Official Implementation:
82
- class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173
83
-
84
- """
85
- x = self.embedding(input_ids)
86
-
87
- for layer in self.layers:
88
- x = layer(x)
89
-
90
- x = self.norm_f(x)
91
- logits = self.lm_head(x)
92
-
93
- return logits
94
-
95
-
96
- @staticmethod
97
- def from_pretrained(pretrained_model_name: str):
98
- """Load pretrained weights from HuggingFace into model.
99
-
100
- Args:
101
- pretrained_model_name: One of
102
- * 'state-spaces/mamba-2.8b-slimpj'
103
- * 'state-spaces/mamba-2.8b'
104
- * 'state-spaces/mamba-1.4b'
105
- * 'state-spaces/mamba-790m'
106
- * 'state-spaces/mamba-370m'
107
- * 'state-spaces/mamba-130m'
108
-
109
- Returns:
110
- model: Mamba model with weights loaded
111
-
112
- """
113
- from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
114
- from transformers.utils.hub import cached_file
115
-
116
- def load_config_hf(model_name):
117
- resolved_archive_file = cached_file(model_name, CONFIG_NAME,
118
- _raise_exceptions_for_missing_entries=False)
119
- return json.load(open(resolved_archive_file))
120
-
121
-
122
- def load_state_dict_hf(model_name, device=None, dtype=None):
123
- resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
124
- _raise_exceptions_for_missing_entries=False)
125
- return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
126
-
127
- config_data = load_config_hf(pretrained_model_name)
128
- args = ModelArgs(
129
- d_model=config_data['d_model'],
130
- n_layer=config_data['n_layer'],
131
- vocab_size=config_data['vocab_size']
132
- )
133
- model = Mamba(args)
134
-
135
- state_dict = load_state_dict_hf(pretrained_model_name)
136
- new_state_dict = {}
137
- for key in state_dict:
138
- new_key = key.replace('backbone.', '')
139
- new_state_dict[new_key] = state_dict[key]
140
- model.load_state_dict(new_state_dict)
141
-
142
- return model
143
-
144
-
145
- class ResidualBlock(nn.Module):
146
- def __init__(self, args: ModelArgs):
147
- """Simple block wrapping Mamba block with normalization and residual connection."""
148
- super().__init__()
149
- self.args = args
150
- self.mixer = MambaBlock(args)
151
- self.norm = RMSNorm(args.d_model)
152
-
153
-
154
- def forward(self, x):
155
- """
156
- Args:
157
- x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
158
-
159
- Returns:
160
- output: shape (b, l, d)
161
-
162
- Official Implementation:
163
- Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
164
-
165
- Note: the official repo chains residual blocks that look like
166
- [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
167
- where the first Add is a no-op. This is purely for performance reasons as this
168
- allows them to fuse the Add->Norm.
169
-
170
- We instead implement our blocks as the more familiar, simpler, and numerically equivalent
171
- [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
172
-
173
- """
174
- output = self.mixer(self.norm(x)) + x
175
-
176
- return output
177
-
178
-
179
- class MambaBlock(nn.Module):
180
- def __init__(self, args: ModelArgs):
181
- """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
182
- super().__init__()
183
- self.args = args
184
-
185
- self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)
186
-
187
- self.conv1d = nn.Conv1d(
188
- in_channels=args.d_inner,
189
- out_channels=args.d_inner,
190
- bias=args.conv_bias,
191
- kernel_size=args.d_conv,
192
- groups=args.d_inner,
193
- padding=args.d_conv - 1,
194
- )
195
-
196
- # x_proj takes in `x` and outputs the input-specific Δ, B, C
197
- self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
198
-
199
- # dt_proj projects Δ from dt_rank to d_in
200
- self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)
201
-
202
- A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
203
- self.A_log = nn.Parameter(torch.log(A))
204
- self.D = nn.Parameter(torch.ones(args.d_inner))
205
- self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
206
-
207
-
208
- def forward(self, x):
209
- """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
210
-
211
- Args:
212
- x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
213
-
214
- Returns:
215
- output: shape (b, l, d)
216
-
217
- Official Implementation:
218
- class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
219
- mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
220
-
221
- """
222
- (b, l, d) = x.shape
223
-
224
- x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
225
- (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)
226
-
227
- x = rearrange(x, 'b l d_in -> b d_in l')
228
- x = self.conv1d(x)[:, :, :l]
229
- x = rearrange(x, 'b d_in l -> b l d_in')
230
-
231
- x = F.silu(x)
232
-
233
- y = self.ssm(x)
234
-
235
- y = y * F.silu(res)
236
-
237
- output = self.out_proj(y)
238
-
239
- return output
240
-
241
-
242
- def ssm(self, x):
243
- """Runs the SSM. See:
244
- - Algorithm 2 in Section 3.2 in the Mamba paper [1]
245
- - run_SSM(A, B, C, u) in The Annotated S4 [2]
246
-
247
- Args:
248
- x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
249
-
250
- Returns:
251
- output: shape (b, l, d_in)
252
-
253
- Official Implementation:
254
- mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
255
-
256
- """
257
- (d_in, n) = self.A_log.shape
258
-
259
- # Compute ∆ A B C D, the state space parameters.
260
- # A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
261
- # ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
262
- # and is why Mamba is called **selective** state spaces)
263
-
264
- A = -torch.exp(self.A_log.float()) # shape (d_in, n)
265
- D = self.D.float()
266
-
267
- x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
268
-
269
- (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1) # delta: (b, l, dt_rank). B, C: (b, l, n)
270
- delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
271
-
272
- y = self.selective_scan(x, delta, A, B, C, D) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
273
-
274
- return y
275
-
276
-
277
- def selective_scan(self, u, delta, A, B, C, D):
278
- """Does selective scan algorithm. See:
279
- - Section 2 State Space Models in the Mamba paper [1]
280
- - Algorithm 2 in Section 3.2 in the Mamba paper [1]
281
- - run_SSM(A, B, C, u) in The Annotated S4 [2]
282
-
283
- This is the classic discrete state space formula:
284
- x(t + 1) = Ax(t) + Bu(t)
285
- y(t) = Cx(t) + Du(t)
286
- except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
287
-
288
- Args:
289
- u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
290
- delta: shape (b, l, d_in)
291
- A: shape (d_in, n)
292
- B: shape (b, l, n)
293
- C: shape (b, l, n)
294
- D: shape (d_in,)
295
-
296
- Returns:
297
- output: shape (b, l, d_in)
298
-
299
- Official Implementation:
300
- selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
301
- Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
302
-
303
- """
304
- (b, l, d_in) = u.shape
305
- n = A.shape[1]
306
-
307
- # Discretize continuous parameters (A, B)
308
- # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
309
- # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
310
- # "A is the more important term and the performance doesn't change much with the simplification on B"
311
- deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
312
- deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
313
-
314
- # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
315
- # Note that the below is sequential, while the official implementation does a much faster parallel scan that
316
- # is additionally hardware-aware (like FlashAttention).
317
- x = torch.zeros((b, d_in, n), device=deltaA.device)
318
- ys = []
319
- for i in range(l):
320
- x = deltaA[:, i] * x + deltaB_u[:, i]
321
- y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
322
- ys.append(y)
323
- y = torch.stack(ys, dim=1) # shape (b, l, d_in)
324
-
325
- y = y + u * D
326
-
327
- return y
328
-
329
-
330
- class RMSNorm(nn.Module):
331
- def __init__(self,
332
- d_model: int,
333
- eps: float = 1e-5):
334
- super().__init__()
335
- self.eps = eps
336
- self.weight = nn.Parameter(torch.ones(d_model))
337
-
338
-
339
- def forward(self, x):
340
- output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
341
-
342
- return output
343
-