BobShan commited on
Commit
167d698
·
verified ·
1 Parent(s): 9931ed0

Add files using upload-large-folder tool

Browse files
mamba_ssm/utils/generation.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+ import gc
3
+ import time
4
+ from collections import namedtuple
5
+ from dataclasses import dataclass, field
6
+ from functools import partial
7
+ from typing import Callable, Optional, Sequence, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from einops import rearrange, repeat
12
+ from torch import Tensor
13
+ from torch.profiler import ProfilerActivity, profile, record_function
14
+ from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
15
+
16
+
17
+ @dataclass
18
+ class InferenceParams:
19
+ """Inference parameters that are passed to the main model in order
20
+ to efficienly calculate and store the context during inference."""
21
+
22
+ max_seqlen: int
23
+ max_batch_size: int
24
+ seqlen_offset: int = 0
25
+ batch_size_offset: int = 0
26
+ key_value_memory_dict: dict = field(default_factory=dict)
27
+ lengths_per_sample: Optional[Tensor] = None
28
+
29
+ def reset(self, max_seqlen, max_batch_size):
30
+ self.max_seqlen = max_seqlen
31
+ self.max_batch_size = max_batch_size
32
+ self.seqlen_offset = 0
33
+ if self.lengths_per_sample is not None:
34
+ self.lengths_per_sample.zero_()
35
+
36
+
37
+ def modify_logits_for_min_p_filtering(logits, min_p):
38
+ """Set the logits for none min_p values to -inf. Done in-place."""
39
+ if min_p <= 0.0 or min_p >= 1.0:
40
+ return
41
+ indices_to_remove = logits < min_p
42
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
43
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
44
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
45
+ def modify_logits_for_top_k_filtering(logits, top_k):
46
+ """Set the logits for none top-k values to -inf. Done in-place."""
47
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
48
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
49
+
50
+
51
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
52
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
53
+ def modify_logits_for_top_p_filtering(logits, top_p):
54
+ """Set the logits for none top-p values to -inf. Done in-place."""
55
+ if top_p <= 0.0 or top_p >= 1.0:
56
+ return
57
+ # First sort and calculate cumulative sum of probabilities.
58
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
59
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
60
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
61
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
62
+ # scatter sorted tensors to original indexing
63
+ indices_to_remove = sorted_indices_to_remove.scatter(
64
+ 1, sorted_indices, sorted_indices_to_remove
65
+ )
66
+ logits.masked_fill_(indices_to_remove, float("-inf"))
67
+
68
+
69
+ def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
70
+ """Apply repetition penalty. See https://arxiv.org/abs/1909.05858
71
+ logits: (batch_size, vocab_size)
72
+ prev_output_tokens: (batch_size, seq_len)
73
+ """
74
+ if repetition_penalty == 1.0:
75
+ return logits
76
+ score = torch.gather(logits, 1, prev_output_tokens)
77
+ # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
78
+ score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
79
+ logits.scatter_(1, prev_output_tokens, score)
80
+ return logits
81
+
82
+
83
+ def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
84
+ """Sample from top-k logits.
85
+ Arguments:
86
+ logits: Tensor of shape (batch_size, vocab_size)
87
+ """
88
+ if top_k == 1: # Short-circuit for greedy decoding
89
+ return logits.argmax(dim=-1)
90
+ else:
91
+ if top_p > 0.0:
92
+ assert top_p <= 1.0, "top-p should be in (0, 1]."
93
+ if top_k > 0:
94
+ top_k = min(top_k, logits.size(-1)) # Safety check
95
+ logits_top, indices = torch.topk(logits, top_k, dim=-1)
96
+ if temperature != 1.0:
97
+ logits_top /= temperature
98
+ modify_logits_for_top_p_filtering(logits_top, top_p)
99
+ return indices[
100
+ torch.arange(indices.shape[0], device=indices.device),
101
+ torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
102
+ ]
103
+ else:
104
+ if min_p > 0.0:
105
+ logits_top = logits.clone()
106
+ max_prob = logits_top[..., 0].item()
107
+ min_prob = max_prob * min_p
108
+ modify_logits_for_min_p_filtering(logits_top, min_prob)
109
+ if temperature != 1.0:
110
+ logits_top /= temperature
111
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
112
+ # Clone so that when we modify for top_p we don't change the original logits
113
+ logits_top = logits / temperature if temperature != 1.0 else logits.clone()
114
+ modify_logits_for_top_p_filtering(logits_top, top_p)
115
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
116
+ dim=-1
117
+ )
118
+
119
+
120
+ @torch.inference_mode()
121
+ def decode(
122
+ input_ids,
123
+ model,
124
+ max_length,
125
+ top_k=1,
126
+ top_p=0.0,
127
+ min_p=0.0,
128
+ temperature=1.0,
129
+ repetition_penalty=1.0,
130
+ eos_token_id=None,
131
+ teacher_outputs=None,
132
+ vocab_size=None,
133
+ cg=False,
134
+ enable_timing=False,
135
+ output_scores=False,
136
+ streamer: Optional[TextStreamer] = None
137
+ ):
138
+ """Decoding, either greedy or with top-k or top-p sampling.
139
+ If top-k = 0, don't limit the number of candidates (pure sampling).
140
+ Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
141
+ then top-p.
142
+ We assume that all sequences in the same batch have the same length.
143
+
144
+ Arguments:
145
+ input_ids: (batch, seq_len)
146
+ max_length: int
147
+ teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
148
+ logits, the next token is taken from the teacher_outputs. Useful for testing.
149
+ Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
150
+ sequences: (batch, max_length)
151
+ scores: tuples of (batch, vocab_size)
152
+ """
153
+ if streamer is not None:
154
+ streamer.put(input_ids.cpu())
155
+
156
+ batch_size, seqlen_og = input_ids.shape
157
+ teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
158
+ if cg:
159
+ if not hasattr(model, "_decoding_cache"):
160
+ model._decoding_cache = None
161
+ model._decoding_cache = update_graph_cache(
162
+ model,
163
+ model._decoding_cache,
164
+ batch_size,
165
+ seqlen_og,
166
+ max_length,
167
+ )
168
+ inference_params = model._decoding_cache.inference_params
169
+ inference_params.reset(max_length, batch_size)
170
+ else:
171
+ inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
172
+
173
+ def get_logits(input_ids, inference_params):
174
+ decoding = inference_params.seqlen_offset > 0
175
+ if decoding:
176
+ position_ids = torch.full(
177
+ (batch_size, 1),
178
+ inference_params.seqlen_offset,
179
+ dtype=torch.long,
180
+ device=input_ids.device,
181
+ )
182
+ else:
183
+ position_ids = None
184
+ if not cg or not decoding:
185
+ logits = model(
186
+ input_ids,
187
+ position_ids=position_ids,
188
+ inference_params=inference_params,
189
+ num_last_tokens=1,
190
+ ).logits.squeeze(dim=1)
191
+ else:
192
+ logits = model._decoding_cache.run(
193
+ input_ids, position_ids, inference_params.seqlen_offset
194
+ ).squeeze(dim=1)
195
+ return logits[..., :vocab_size] if vocab_size is not None else logits
196
+
197
+ def sample_tokens(logits, inference_params):
198
+ if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
199
+ token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
200
+ else:
201
+ token = teacher_outputs[:, inference_params.seqlen_offset]
202
+ # return rearrange(token, "b -> b 1")
203
+ return token.unsqueeze(1)
204
+
205
+ def should_stop(current_token, inference_params):
206
+ if inference_params.seqlen_offset == 0:
207
+ return False
208
+ if eos_token_id is not None and (current_token == eos_token_id).all():
209
+ return True
210
+ if inference_params.seqlen_offset >= max_length - 1:
211
+ return True
212
+ return False
213
+
214
+ start = torch.cuda.Event(enable_timing=enable_timing)
215
+ end = torch.cuda.Event(enable_timing=enable_timing)
216
+
217
+ if enable_timing:
218
+ start.record()
219
+ scores, sequences = [], [input_ids]
220
+ sequences_cat = input_ids
221
+ while not should_stop(sequences[-1], inference_params):
222
+ logits = get_logits(sequences[-1], inference_params)
223
+ if output_scores:
224
+ scores.append(logits.clone())
225
+ inference_params.seqlen_offset += sequences[-1].shape[1]
226
+ if repetition_penalty == 1.0:
227
+ sampled_tokens = sample_tokens(logits, inference_params)
228
+ else:
229
+ logits = modify_logit_for_repetition_penalty(
230
+ logits, sequences_cat, repetition_penalty
231
+ )
232
+ sampled_tokens = sample_tokens(logits, inference_params)
233
+ sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
234
+ sequences.append(sampled_tokens)
235
+ if streamer is not None:
236
+ streamer.put(sampled_tokens.cpu())
237
+ if streamer is not None:
238
+ streamer.end()
239
+ if enable_timing:
240
+ end.record()
241
+ torch.cuda.synchronize()
242
+ print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
243
+ output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
244
+ return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
245
+
246
+
247
+ class GenerationMixin:
248
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
249
+ raise NotImplementedError
250
+
251
+ def generate(
252
+ self,
253
+ input_ids,
254
+ max_length,
255
+ top_k=1,
256
+ top_p=0.0,
257
+ min_p=0.0,
258
+ temperature=1.0,
259
+ return_dict_in_generate=False,
260
+ output_scores=False,
261
+ **kwargs,
262
+ ):
263
+ output = decode(
264
+ input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, output_scores=output_scores, **kwargs
265
+ )
266
+ if not output_scores:
267
+ output.scores = None
268
+ return output if return_dict_in_generate else output.sequences
269
+
270
+
271
+ @dataclass
272
+ class DecodingCGCache:
273
+ max_batch_size: int = 0
274
+ max_seqlen: int = 0
275
+ device = None
276
+ dtype = None
277
+ callables: dict = field(default_factory=dict)
278
+ mempool = None
279
+ inference_params: Optional[InferenceParams] = None
280
+ run: Optional[Callable] = None
281
+
282
+
283
+ @torch.inference_mode()
284
+ def update_graph_cache(
285
+ model,
286
+ cache,
287
+ batch_size,
288
+ seqlen_og,
289
+ max_seqlen,
290
+ decoding_seqlens=(1,),
291
+ dtype=None,
292
+ n_warmups=2,
293
+ ):
294
+ if cache is None:
295
+ cache = DecodingCGCache()
296
+ param_example = next(iter(model.parameters()))
297
+ device = param_example.device
298
+ if dtype is None:
299
+ dtype = param_example.dtype
300
+ if (
301
+ (device, dtype) != (cache.device, cache.dtype)
302
+ or batch_size > cache.max_batch_size
303
+ or max_seqlen > cache.max_seqlen
304
+ ): # Invalidate the cache
305
+ cache.callables = {}
306
+ cache.mempool = None
307
+ cache.inference_params = None
308
+ gc.collect()
309
+ cache.device, cache.dtype = device, dtype
310
+ cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
311
+ assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
312
+ inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
313
+ lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
314
+ cache.inference_params = InferenceParams(
315
+ max_seqlen=max_seqlen,
316
+ max_batch_size=batch_size,
317
+ seqlen_offset=seqlen_og,
318
+ key_value_memory_dict=inf_cache,
319
+ lengths_per_sample=lengths_per_sample,
320
+ )
321
+ cache.mempool = torch.cuda.graphs.graph_pool_handle()
322
+ for decoding_seqlen in decoding_seqlens:
323
+ if (batch_size, decoding_seqlen) not in cache.callables:
324
+ cache.callables[batch_size, decoding_seqlen] = capture_graph(
325
+ model,
326
+ cache.inference_params,
327
+ batch_size,
328
+ max_seqlen,
329
+ decoding_seqlen=decoding_seqlen,
330
+ mempool=cache.mempool,
331
+ n_warmups=n_warmups,
332
+ )
333
+
334
+ def dispatch(input_ids, position_ids, seqlen):
335
+ batch_size, decoding_seqlen = input_ids.shape[:2]
336
+ return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
337
+
338
+ cache.run = dispatch
339
+ cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
340
+ return cache
341
+
342
+
343
+ def capture_graph(
344
+ model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
345
+ ):
346
+ device = next(iter(model.parameters())).device
347
+ input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
348
+ position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
349
+ seqlen_offset_og = inference_params.seqlen_offset
350
+ inference_params.seqlen_offset = max_seqlen - decoding_seqlen
351
+ inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
352
+
353
+ # Warmup before capture
354
+ s = torch.cuda.Stream()
355
+ s.wait_stream(torch.cuda.current_stream())
356
+ with torch.cuda.stream(s):
357
+ for _ in range(n_warmups):
358
+ logits = model(
359
+ input_ids,
360
+ position_ids=position_ids,
361
+ inference_params=inference_params,
362
+ num_last_tokens=decoding_seqlen,
363
+ ).logits
364
+ s.synchronize()
365
+ # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
366
+ # which requires that graph launch and non-captured launch to not overlap (I think,
367
+ # that's how I interpret the documentation). I'm not sure if this is required.
368
+ if torch.distributed.is_initialized():
369
+ torch.distributed.barrier()
370
+ torch.cuda.current_stream().wait_stream(s)
371
+ # Captures the graph
372
+ # To allow capture, automatically sets a side stream as the current stream in the context
373
+ graph = torch.cuda.CUDAGraph()
374
+ with torch.cuda.graph(graph, pool=mempool):
375
+ logits = model(
376
+ input_ids,
377
+ position_ids=position_ids,
378
+ inference_params=inference_params,
379
+ num_last_tokens=decoding_seqlen,
380
+ ).logits
381
+
382
+ def run(new_input_ids, new_position_ids, seqlen):
383
+ inference_params.lengths_per_sample[:] = seqlen
384
+ input_ids.copy_(new_input_ids)
385
+ position_ids.copy_(new_position_ids)
386
+ graph.replay()
387
+ return logits.clone()
388
+
389
+ inference_params.seqlen_offset = seqlen_offset_og
390
+ return run
mamba_ssm/utils/hf.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import torch
4
+
5
+ from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
6
+ from transformers.utils.hub import cached_file
7
+
8
+
9
+ def load_config_hf(model_name):
10
+ resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
11
+ return json.load(open(resolved_archive_file))
12
+
13
+
14
+ def load_state_dict_hf(model_name, device=None, dtype=None):
15
+ # If not fp32, then we don't want to load directly to the GPU
16
+ mapped_device = "cpu" if dtype not in [torch.float32, None] else device
17
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
18
+ return torch.load(resolved_archive_file, map_location=mapped_device)
19
+ # Convert dtype before moving to GPU to save memory
20
+ if dtype is not None:
21
+ state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
22
+ state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
23
+ return state_dict
mamba_ssm/utils/torch.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from functools import partial
3
+ from typing import Callable
4
+
5
+ def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
6
+ def decorator(*args, **kwargs):
7
+ if cuda_amp_deprecated:
8
+ kwargs["device_type"] = "cuda"
9
+ return dec(*args, **kwargs)
10
+ return decorator
11
+
12
+
13
+ if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
14
+ deprecated = True
15
+ from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
16
+ else:
17
+ deprecated = False
18
+ from torch.cuda.amp import custom_fwd, custom_bwd
19
+
20
+ custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
21
+ custom_bwd = custom_amp_decorator(custom_bwd, deprecated)