phoebeklett commited on
Commit
c09aabb
·
1 Parent(s): 75ee012

Delete modeling_mpt.py

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +0 -837
modeling_mpt.py DELETED
@@ -1,837 +0,0 @@
1
- # Adapted from https://github.com/mosaicml/llm-foundry
2
- # Classes changed: MPTModel, MPTForCausalLM
3
- # SPDX-License-Identifier: Apache-2.0
4
-
5
- """A simple, flexible implementation of a GPT model.
6
-
7
- Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
8
- """
9
-
10
- import math
11
- import warnings
12
- from typing import List, Optional, Tuple, Union
13
- import torch
14
- import torch.nn as nn
15
- import torch.nn.functional as F
16
- from torch.linalg import vector_norm
17
- import faiss
18
- from einops import rearrange
19
- from composer.utils import dist
20
- from omegaconf import DictConfig
21
-
22
- from transformers import (PreTrainedModel, PreTrainedTokenizer,
23
- PreTrainedTokenizerFast)
24
- from transformers.modeling_outputs import (BaseModelOutputWithPast,
25
- CausalLMOutputWithPast)
26
- from llmfoundry.models.layers.custom_embedding import SharedEmbedding
27
- from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
28
- from llmfoundry.models.utils.param_init_fns import MODEL_INIT_REGISTRY
29
-
30
- from configuration import ExtendedMPTConfig
31
- from attention import attn_bias_shape, build_attn_bias
32
- from blocks import MPTBlock
33
- from utils import instantiate_from_config
34
-
35
- Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
36
-
37
- class MPTPreTrainedModel(PreTrainedModel):
38
- config_class = ExtendedMPTConfig
39
- base_model_prefix = 'model'
40
- _no_split_modules = ['MPTBlock']
41
-
42
- class ExtendedMPTModel(MPTPreTrainedModel):
43
-
44
- def __init__(self, config: ExtendedMPTConfig):
45
- config._validate_config()
46
- super().__init__(config)
47
-
48
- self.attn_impl = config.attn_config['attn_impl']
49
- self.prefix_lm = config.attn_config['prefix_lm']
50
- self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
51
- self.alibi = config.attn_config['alibi']
52
- self.alibi_bias_max = config.attn_config['alibi_bias_max']
53
-
54
- self.mask_by_sim = config.attn_config['mask_by_sim']
55
- self.sim_threshold = config.attn_config['sim_threshold']
56
- self.topk = config.attn_config['topk']
57
- self.use_active_externalism = config.attn_config['use_active_externalism']
58
-
59
- self.use_active_externalism_by_layer = config.use_active_externalism_by_layer
60
-
61
- if config.init_device == 'mixed':
62
- if dist.get_local_rank() == 0:
63
- config.init_device = 'cpu'
64
- else:
65
- config.init_device = 'meta'
66
-
67
- if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
68
- norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
69
- raise NotImplementedError(
70
- f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).'
71
- )
72
- norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
73
-
74
- # CogView (https://arxiv.org/abs/2105.13290) and GLM-130B (https://arxiv.org/abs/2210.02414)
75
- # both report this helping with stabilizing training
76
- self.embedding_fraction = config.embedding_fraction
77
-
78
- self.wte = SharedEmbedding(config.vocab_size,
79
- config.d_model,
80
- device=config.init_device)
81
- if not self.alibi:
82
- self.wpe = torch.nn.Embedding(config.max_seq_len,
83
- config.d_model,
84
- device=config.init_device)
85
- self.emb_drop = nn.Dropout(config.emb_pdrop)
86
- self.blocks = nn.ModuleList([
87
- MPTBlock(
88
- device=config.init_device,
89
- **config.to_dict(),
90
- ) for _ in range(config.n_layers)
91
- ])
92
- self.norm_f = norm_class(config.d_model, device=config.init_device)
93
-
94
- if config.init_device != 'meta':
95
- print(
96
- f'You are using {config.init_device=}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
97
- )
98
- self.apply(self.param_init_fn)
99
-
100
- self.is_causal = not self.prefix_lm
101
-
102
- # define attn mask
103
- self._attn_bias_initialized = False
104
- self.attn_bias = None
105
- self.attn_bias_shape = attn_bias_shape(
106
- self.attn_impl,
107
- config.n_heads,
108
- config.max_seq_len,
109
- self.alibi,
110
- prefix_lm=self.prefix_lm,
111
- causal=self.is_causal,
112
- use_sequence_id=self.attn_uses_sequence_id,
113
- )
114
- self._attn_bias_ae_initialized = False
115
- self.attn_bias_ae = None
116
-
117
- if self.config.no_bias:
118
- for module in self.modules():
119
- if hasattr(module, 'bias') and isinstance(
120
- module.bias, nn.Parameter):
121
- if self.config.verbose:
122
- warnings.warn(
123
- f'Removing bias ({module.bias}) from {module}.')
124
- module.register_parameter('bias', None)
125
-
126
- # Print verbose info
127
- if config.verbose and config.verbose > 2:
128
- print(self)
129
- if 'verbose' not in self.config.init_config:
130
- self.config.init_config['verbose'] = self.config.verbose
131
- if self.config.init_config['verbose'] > 1:
132
- init_fn_name = self.config.init_config['name']
133
- warnings.warn(f'Using {init_fn_name} initialization.')
134
-
135
- def get_input_embeddings(self):
136
- return self.wte
137
-
138
- def set_input_embeddings(self, value: nn.Embedding):
139
- self.wte = value
140
-
141
- @torch.no_grad()
142
- def _attn_bias(
143
- self,
144
- device,
145
- dtype,
146
- attention_mask: Optional[torch.ByteTensor] = None,
147
- prefix_mask: Optional[torch.ByteTensor] = None,
148
- sequence_id: Optional[torch.LongTensor] = None,
149
- seq_len: Optional[int] = None,
150
- use_active_externalism:bool=None,
151
- topk=None,
152
- ):
153
- if not self._attn_bias_initialized:
154
- if self.attn_bias_shape:
155
- self.attn_bias = torch.zeros(self.attn_bias_shape,
156
- device=device,
157
- dtype=dtype)
158
- self.attn_bias = build_attn_bias(
159
- self.attn_impl,
160
- self.config.n_heads,
161
- self.config.max_seq_len,
162
- device=device,
163
- dtype=dtype,
164
- attn_bias = self.attn_bias,
165
- causal=self.is_causal,
166
- alibi=self.alibi,
167
- alibi_bias_max=self.alibi_bias_max
168
- )
169
- self._attn_bias_initialized = True
170
-
171
- if use_active_externalism:
172
- self.attn_bias_ae = build_attn_bias(
173
- self.attn_impl,
174
- self.config.n_heads,
175
- seq_len,
176
- device=device,
177
- dtype=dtype,
178
- causal=self.is_causal,
179
- alibi=self.alibi,
180
- alibi_bias_max=self.alibi_bias_max,
181
- for_ae=use_active_externalism,
182
- topk=topk
183
- )
184
-
185
- self._attn_bias_ae_initialized = True
186
-
187
- # flash does not support prefix_lm and will incorporate any
188
- # attention_mask inside the attention module
189
- if self.attn_impl == 'flash':
190
- return self.attn_bias, attention_mask
191
-
192
- if self.attn_bias is not None:
193
- # .to(*args, **kwargs) is a no-op if tensor is already on
194
- # specified device or of specificed dtype
195
- self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
196
-
197
- attn_bias = self.attn_bias
198
-
199
- if self.attn_bias_ae is not None:
200
- self.attn_bias_ae = self.attn_bias_ae.to(dtype=dtype, device=device)
201
- attn_bias_ae = self.attn_bias_ae
202
-
203
- # If using torch or triton, we incorporate the prefix_mask (if appropriate)
204
- if self.prefix_lm:
205
- assert isinstance(attn_bias, torch.Tensor) # pyright
206
- assert isinstance(prefix_mask, torch.Tensor) # pyright
207
- attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
208
-
209
- # If using torch or triton, we incorporate sequence_id (if appropriate)
210
- if self.attn_uses_sequence_id and sequence_id is not None:
211
- assert isinstance(attn_bias, torch.Tensor) # pyright
212
- attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
213
-
214
- # If using torch or triton, we incorporate attention_mask. This will output
215
- # None in place of attention_mask since it will not be further needed in the
216
- # attention modules.
217
- if attention_mask is not None:
218
- s_k = attention_mask.shape[-1]
219
- if attn_bias is None:
220
- attn_bias = torch.zeros((1, 1, 1, s_k),
221
- device=device,
222
- dtype=dtype)
223
- else:
224
- # clamp to 0 necessary for torch 2.0 compile()
225
- _s_k = max(0, attn_bias.size(-1) - s_k)
226
- attn_bias = attn_bias[:, :, :, _s_k:]
227
- if prefix_mask is not None and (attention_mask.shape !=
228
- prefix_mask.shape):
229
- raise ValueError(
230
- f'attention_mask shape={attention_mask.shape} ' +
231
- f'and prefix_mask shape={prefix_mask.shape} are not equal.')
232
- min_val = torch.finfo(attn_bias.dtype).min
233
- attn_bias = attn_bias.masked_fill(
234
- ~attention_mask.view(-1, 1, 1, s_k), min_val)
235
-
236
- return attn_bias, attn_bias_ae, None
237
-
238
- def _apply_prefix_mask(self, attn_bias: torch.Tensor,
239
- prefix_mask: torch.Tensor):
240
- s_k, s_q = attn_bias.shape[-2:]
241
- if (s_k != self.config.max_seq_len) or (s_q != self.config.max_seq_len):
242
- raise ValueError(
243
- 'attn_bias does not match the expected shape. ' +
244
- f'The last two dimensions should both be {self.config.max_length} '
245
- + f'but are {s_k} and {s_q}.')
246
- seq_len = prefix_mask.shape[-1]
247
- if seq_len > self.config.max_seq_len:
248
- raise ValueError(
249
- f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
250
- )
251
-
252
- # select seq_len subset of attn mask
253
- attn_bias = attn_bias[..., :seq_len, :seq_len]
254
-
255
- # Mix the causal max and the bidirectional mask to get the full
256
- # allowable attention (i.e. full = not accounting for padding yet)
257
- causal = torch.tril(
258
- torch.ones((seq_len, seq_len),
259
- dtype=torch.bool,
260
- device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
261
- prefix = prefix_mask.view(-1, 1, 1, seq_len)
262
- cannot_attend = ~torch.logical_or(causal, prefix.bool())
263
-
264
- min_val = torch.finfo(attn_bias.dtype).min
265
- attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
266
-
267
- return attn_bias
268
-
269
- def _apply_sequence_id(self, attn_bias: torch.Tensor,
270
- sequence_id: torch.LongTensor):
271
- seq_len = sequence_id.shape[-1]
272
- if seq_len > self.config.max_seq_len:
273
- raise ValueError(
274
- f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}'
275
- )
276
-
277
- # select seq_len subset of attn mask
278
- attn_bias = attn_bias[..., :seq_len, :seq_len]
279
-
280
- # Restrict attention to tokens that share the same value
281
- # in sequence_id
282
- cannot_attend = torch.logical_not(
283
- torch.eq(
284
- sequence_id.view(-1, seq_len, 1),
285
- sequence_id.view(-1, 1, seq_len),
286
- )).unsqueeze(1)
287
- min_val = torch.finfo(attn_bias.dtype).min
288
- attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
289
-
290
- return attn_bias
291
-
292
- def forward(
293
- self,
294
- input_ids: torch.LongTensor,
295
- past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
296
- attention_mask: Optional[torch.ByteTensor] = None,
297
- prefix_mask: Optional[torch.ByteTensor] = None,
298
- sequence_id: Optional[torch.LongTensor] = None,
299
- return_dict: Optional[bool] = None,
300
- output_attentions: Optional[bool] = None,
301
- output_hidden_states: Optional[bool] = None,
302
- use_cache: Optional[bool] = None,
303
- inputs_embeds: Optional[torch.Tensor] = None,
304
- use_active_externalism:Optional[bool]=None,
305
- long_range_past_key_values:Optional[List[Tuple[torch.FloatTensor]]] = None,
306
- faiss_indexes:Tuple=None,
307
- topk:int=None,
308
- ):
309
- return_dict = (return_dict
310
- if return_dict is not None else self.config.return_dict)
311
- use_cache = (use_cache
312
- if use_cache is not None else self.config.use_cache)
313
- use_active_externalism = (use_active_externalism
314
- if use_active_externalism is not None else self.use_active_externalism)
315
- topk = (topk if topk is not None else self.topk)
316
-
317
- if attention_mask is not None:
318
- attention_mask = attention_mask.bool()
319
-
320
- if prefix_mask is not None:
321
- prefix_mask = prefix_mask.bool()
322
-
323
- # These args are passed in by keyword in huggingface's generate function
324
- # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
325
- # but have not yet been fully implemented in MPTModel
326
- if not return_dict:
327
- raise NotImplementedError(
328
- 'return_dict False is not implemented yet for MPT')
329
- if output_attentions:
330
- if self.attn_impl != 'torch':
331
- raise NotImplementedError(
332
- 'output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.'
333
- )
334
-
335
- if (attention_mask is not None and
336
- attention_mask[:, 0].sum() != attention_mask.shape[0] and
337
- self.training):
338
- raise NotImplementedError(
339
- 'MPT does not support training with left padding.')
340
-
341
- if self.prefix_lm and prefix_mask is None:
342
- raise ValueError(
343
- 'prefix_mask is a required argument when MPT is configured with prefix_lm=True.'
344
- )
345
-
346
- # Raise a not implemented error if input_embeds is not None (this is an arg in huggingface transformers and we need to support it for PEFT)
347
- if inputs_embeds is not None:
348
- raise NotImplementedError(
349
- 'inputs_embeds is not implemented for MPT.')
350
-
351
- if self.training:
352
- if self.attn_uses_sequence_id and sequence_id is None:
353
- raise ValueError(
354
- 'sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True '
355
- + 'and the model is in train mode.')
356
- elif (self.attn_uses_sequence_id is False) and (sequence_id
357
- is not None):
358
- warnings.warn(
359
- 'MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. '
360
- +
361
- 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.'
362
- )
363
-
364
- S = input_ids.size(1)
365
-
366
- assert (
367
- S <= self.config.max_seq_len
368
- ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
369
-
370
- tok_emb = self.wte(input_ids) # type: ignore
371
- if self.alibi:
372
- x = tok_emb
373
- else:
374
- past_position = 0
375
- if past_key_values is not None:
376
- if len(past_key_values) != self.config.n_layers:
377
- raise ValueError(
378
- f'past_key_values must provide a past_key_value for each attention '
379
- +
380
- f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).'
381
- )
382
- # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
383
- # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
384
- # Here we shift position embedding using the `seq` dim of the past key
385
- past_position = past_key_values[0][0].size(1)
386
- if self.attn_impl == 'torch':
387
- past_position = past_key_values[0][0].size(3)
388
-
389
- if S + past_position > self.config.max_seq_len:
390
- raise ValueError(
391
- f'Cannot forward input with past sequence length {past_position} and current sequence length '
392
- f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.'
393
- )
394
- pos = torch.arange(
395
- past_position,
396
- S + past_position,
397
- dtype=torch.long,
398
- device=input_ids.device,
399
- ).unsqueeze(0)
400
- if attention_mask is not None:
401
- # adjust the position indices to account for padding tokens
402
- pos = torch.clamp(
403
- pos - torch.cumsum((~attention_mask).to(torch.int32),
404
- dim=1)[:, past_position:],
405
- min=0,
406
- )
407
-
408
- pos_emb = self.wpe(pos) # type: ignore
409
- x = tok_emb + pos_emb
410
-
411
- if self.embedding_fraction == 1:
412
- x = self.emb_drop(x) # type: ignore
413
- else:
414
- # this implementation is proposed on page 7 of the GLM-130B paper https://arxiv.org/abs/2210.02414
415
- x_shrunk = (x * self.embedding_fraction) + (
416
- x.detach() * (1 - self.embedding_fraction))
417
- assert isinstance(self.emb_drop, nn.Module) # pyright
418
- x = self.emb_drop(x_shrunk)
419
-
420
- # self._attn_bias_initialized = False #right now this needs to run each step
421
-
422
- seq_len = S
423
- if past_key_values is not None:
424
- past_position = past_key_values[0][0].size(-1)
425
- seq_len += past_position
426
-
427
- attn_bias, attn_bias_ae, attention_mask = self._attn_bias(
428
- device=x.device,
429
- dtype=torch.float32,
430
- attention_mask=attention_mask,
431
- prefix_mask=prefix_mask,
432
- sequence_id=sequence_id,
433
- seq_len = seq_len,
434
- use_active_externalism=use_active_externalism,
435
- topk=topk
436
- )
437
-
438
- # initialize the past key values cache if it should be used
439
- if use_cache and past_key_values is None:
440
- past_key_values = [() for _ in range(self.config.n_layers)
441
- ] # type: ignore
442
-
443
- all_hidden_states = () if output_hidden_states else None
444
- all_self_attns = () if output_attentions else None
445
- all_idx = () if output_attentions else None
446
- for b_idx, block in enumerate(self.blocks): # type: ignore
447
- if output_hidden_states:
448
- assert all_hidden_states is not None # pyright
449
- all_hidden_states = all_hidden_states + (x,)
450
- past_key_value = (past_key_values[b_idx]
451
- if past_key_values is not None else None)
452
- long_range_past_key_value = (long_range_past_key_values[b_idx]
453
- if (long_range_past_key_values is not None and self.use_active_externalism_by_layer[b_idx] and use_active_externalism is True) else None)
454
-
455
- if long_range_past_key_value is not None and faiss_indexes is not None:
456
- raise NotImplementedError(
457
- 'Using faiss and passing key value pairs manually are mutually exclusive right now.')
458
-
459
- x, attn_weights, past_key_value, reshaped_idx = block(
460
- x,
461
- past_key_value=past_key_value,
462
- long_range_past_key_value=long_range_past_key_value,
463
- attn_bias=attn_bias,
464
- attention_mask=attention_mask,
465
- attn_bias_ae=attn_bias_ae,
466
- is_causal=self.is_causal,
467
- topk=topk,
468
- needs_weights=output_attentions,
469
- faiss_indexes=faiss_indexes,
470
- n_layers=self.config.n_layers,
471
- current_layer=b_idx,
472
- mask_by_sim=self.mask_by_sim,
473
- sim_threshold=self.sim_threshold,
474
- )
475
- if past_key_values is not None:
476
- past_key_values[b_idx] = past_key_value
477
-
478
- if output_attentions:
479
- assert all_self_attns is not None # pyright
480
- all_self_attns = all_self_attns + (attn_weights,)
481
-
482
- assert all_idx is not None
483
- all_idx = all_idx + (reshaped_idx,)
484
-
485
- x = self.norm_f(x) # type: ignore
486
-
487
- # add hidden states from the last decoder layer
488
- if output_hidden_states:
489
- assert all_hidden_states is not None # pyright
490
- all_hidden_states = all_hidden_states + (x,)
491
-
492
- return BaseModelOutputWithPast(
493
- last_hidden_state=x,
494
- past_key_values=past_key_values,
495
- hidden_states=all_hidden_states,
496
- attentions=(all_self_attns, all_idx),
497
- )
498
-
499
- # Param Initialization, needed for device='meta' fast initialization
500
- def param_init_fn(self, module):
501
- init_fn_name = self.config.init_config['name']
502
- MODEL_INIT_REGISTRY[init_fn_name](
503
- module=module,
504
- n_layers=self.config.n_layers,
505
- d_model=self.config.d_model,
506
- **self.config.init_config,
507
- )
508
-
509
- # FSDP Wrap function
510
- def fsdp_wrap_fn(self, module):
511
- return isinstance(module, MPTBlock)
512
-
513
- # Activation Checkpointing
514
- def activation_checkpointing_fn(self, module):
515
- return isinstance(module, MPTBlock)
516
-
517
- class ExtendedMPTForCausalLM(MPTPreTrainedModel):
518
-
519
- def __init__(self, config:ExtendedMPTConfig, external_memories=None):
520
- if isinstance(config, DictConfig):
521
- config = instantiate_from_config(config)
522
-
523
- super().__init__(config)
524
- if not config.tie_word_embeddings:
525
- raise ValueError(
526
- 'MPTForCausalLM only supports tied word embeddings')
527
-
528
- print(f'Instantiating an MPTForCausalLM model from {__file__}')
529
-
530
- self.transformer: ExtendedMPTModel = ExtendedMPTModel(config)
531
-
532
- self.use_active_externalism = config.attn_config['use_active_externalism']
533
- self.memory_type = config.attn_config['memory_type']
534
- self._memories = None
535
- self.memory_device = config.memory_device
536
-
537
- for child in self.transformer.children():
538
- if isinstance(child, torch.nn.ModuleList):
539
- continue
540
- if isinstance(child, torch.nn.Module):
541
- child._fsdp_wrap = True
542
-
543
- # enables scaling output logits; similar to a softmax "temperature"
544
- # PaLM paper uses scale 1/sqrt(config.d_model)
545
- self.logit_scale = None
546
- if config.logit_scale is not None:
547
- logit_scale = config.logit_scale
548
- if isinstance(logit_scale, str):
549
- if logit_scale == 'inv_sqrt_d_model':
550
- logit_scale = 1 / math.sqrt(config.d_model)
551
- else:
552
- raise ValueError(
553
- f"{logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
554
- )
555
- self.logit_scale = logit_scale
556
-
557
- if external_memories is not None:
558
- self._memories = external_memories
559
- self.memories = None
560
-
561
- def set_memories(self, memories):
562
- self.memories = memories
563
-
564
- def empty_memories(self):
565
- self.memories = None
566
-
567
- def get_input_embeddings(self):
568
- return self.transformer.wte
569
-
570
- def set_input_embeddings(self, value):
571
- self.transformer.wte = value
572
-
573
- def get_output_embeddings(self):
574
- return self.transformer.wte
575
-
576
- def set_output_embeddings(self, new_embeddings):
577
- self.transformer.wte = new_embeddings
578
-
579
- def set_decoder(self, decoder):
580
- self.transformer = decoder
581
-
582
- def get_decoder(self):
583
- return self.transformer
584
-
585
- def forward(
586
- self,
587
- input_ids: torch.LongTensor,
588
- past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
589
- attention_mask: Optional[torch.ByteTensor] = None,
590
- prefix_mask: Optional[torch.ByteTensor] = None,
591
- sequence_id: Optional[torch.LongTensor] = None,
592
- labels: Optional[torch.LongTensor] = None,
593
- return_dict: Optional[bool] = None,
594
- output_attentions: Optional[bool] = None,
595
- output_hidden_states: Optional[bool] = None,
596
- use_cache: Optional[bool] = None,
597
- inputs_embeds: Optional[torch.FloatTensor] = None,
598
- use_active_externalism: Optional[bool]=None,
599
- topk:int=None
600
- ):
601
- if self._memories is not None and self.memories is None:
602
- self.memories = self.generate_cache(self._memories, cache_type=self.memory_type)
603
-
604
- return_dict = (return_dict
605
- if return_dict is not None else self.config.return_dict)
606
- use_cache = (use_cache
607
- if use_cache is not None else self.config.use_cache)
608
- use_active_externalism = (use_active_externalism
609
- if use_active_externalism is not None else self.use_active_externalism)
610
-
611
- topk = topk if topk is not None else None
612
-
613
- # if input_embeds is not none, raise a not implemented error
614
- if inputs_embeds is not None:
615
- raise NotImplementedError(
616
- 'inputs_embeds has to be None (for hf/peft support).')
617
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
618
-
619
- if hasattr(self, "memories") and type(self.memories)==list:
620
- long_range_past_key_values = self.memories
621
- faiss_indexes = None
622
- elif hasattr(self, "memories"):
623
- long_range_past_key_values = None
624
- faiss_indexes = self.memories
625
- else:
626
- long_range_past_key_values = None
627
- faiss_indexes = None
628
-
629
- outputs = self.transformer(
630
- input_ids=input_ids,
631
- past_key_values=past_key_values,
632
- long_range_past_key_values=long_range_past_key_values,
633
- faiss_indexes=faiss_indexes,
634
- attention_mask=attention_mask,
635
- prefix_mask=prefix_mask,
636
- sequence_id=sequence_id,
637
- return_dict=return_dict,
638
- output_attentions=output_attentions,
639
- output_hidden_states=output_hidden_states,
640
- use_cache=use_cache,
641
- use_active_externalism=use_active_externalism,
642
- topk=topk
643
- )
644
-
645
- # move outputs to same device as weights for token embedding
646
- # needed to support HF `device_map`
647
- logits = self.transformer.wte(
648
- outputs.last_hidden_state.to(self.transformer.wte.weight.device),
649
- True,
650
- )
651
-
652
- if self.logit_scale is not None:
653
- if self.logit_scale == 0:
654
- warnings.warn(
655
- f'Multiplying logits by {self.logit_scale=}. This will produce uniform (uninformative) outputs.'
656
- )
657
- logits *= self.logit_scale
658
-
659
- loss = None
660
- if labels is not None:
661
- _labels = torch.roll(labels, shifts=-1)
662
- _labels[:, -1] = -100
663
- loss = F.cross_entropy(
664
- logits.view(-1, logits.size(-1)),
665
- _labels.to(logits.device).view(-1),
666
- )
667
-
668
- return CausalLMOutputWithPast(
669
- loss=loss,
670
- logits=logits,
671
- past_key_values=outputs.past_key_values,
672
- hidden_states=outputs.hidden_states,
673
- attentions=outputs.attentions,
674
- )
675
-
676
- # Param Initialization, needed for device='meta' fast initialization
677
- def param_init_fn(self, module):
678
- init_fn_name = self.config.init_config['name']
679
- MODEL_INIT_REGISTRY[init_fn_name](
680
- module=module,
681
- n_layers=self.config.n_layers,
682
- d_model=self.config.d_model,
683
- **self.config.init_config,
684
- )
685
-
686
- # FSDP Wrap function
687
- def fsdp_wrap_fn(self, module):
688
- return isinstance(module, MPTBlock)
689
-
690
- # Activation Checkpointing
691
- def activation_checkpointing_fn(self, module):
692
- return isinstance(module, MPTBlock)
693
-
694
- def generate_cache(self,
695
- input_ids:torch.LongTensor,
696
- stride:int=512,
697
- max_len:int=2048,
698
- cache_type:str='manual'):
699
- if cache_type not in ['manual', 'faiss']:
700
- raise NotImplementedError(f"Cache type {cache_type} not implemented.")
701
-
702
- prev_end_loc=0
703
- long_range_past_key_values = None
704
- faiss_indexes= None
705
- for b_idx in range(0, input_ids.size(-1), stride):
706
- end_loc = min(b_idx + max_len, input_ids.size(-1))
707
-
708
- trg_len = end_loc - prev_end_loc
709
- subseq = input_ids[:, b_idx:end_loc].to(self.device)
710
- with torch.no_grad():
711
- outputs = self.transformer(subseq, use_cache=True, use_active_externalism=False)
712
- to_cache = [(
713
- kv[0][:,:,:,-trg_len:],
714
- kv[1][:,:,-trg_len:])
715
- for kv in outputs.past_key_values
716
- ]
717
- long_range_past_key_values, faiss_indexes = self.cache(to_cache, cache_type, long_range_past_key_values=long_range_past_key_values, faiss_indexes=faiss_indexes)
718
-
719
- prev_end_loc = end_loc
720
- if end_loc == input_ids.size(-1):
721
- break
722
- if long_range_past_key_values is not None:
723
- return long_range_past_key_values
724
- else:
725
- return faiss_indexes
726
-
727
- def cache(self,
728
- to_cache:List,
729
- cache_type:str='manual',
730
- long_range_past_key_values:List=None,
731
- faiss_indexes:faiss.IndexFlatIP=None,
732
- max_length_cache=100000,
733
- verbose=False):
734
- if long_range_past_key_values is not None and faiss_indexes is not None:
735
- raise NotImplementedError("Using faiss and passing key value pairs manually are mutually exclusive right now.")
736
-
737
- if cache_type=='faiss':
738
- one_hot_encodings = F.one_hot(torch.arange(0, self.config.n_heads*self.config.n_layers))*10
739
- if faiss_indexes is None:
740
- faiss_indexes = (faiss.IndexFlatIP(to_cache[0][0].size(-2)+one_hot_encodings.size(-1)), faiss.IndexFlatIP(to_cache[0][1].size(-1)*2))
741
- kn_index, kv_index = faiss_indexes
742
- for b_idx, (k, v) in enumerate(to_cache):
743
- k_n = (k/vector_norm(k, ord=2, dim=-2, keepdim=True)).to('cpu')
744
- k_n = torch.concat([rearrange(k_n, 'b h d s -> b (h s) d', h=self.config.n_heads), one_hot_encodings[self.config.n_heads*b_idx:self.config.n_heads*(b_idx+1)].unsqueeze(0).repeat_interleave(repeats=k.size(-1), dim=-2)], dim=-1)
745
- kn_index.add(k_n.squeeze().numpy())
746
-
747
- k= rearrange(k, 'b h d s -> b (h s) d', h=self.config.n_heads)
748
- v= rearrange(v, 'b h s d -> b (h s) d', h=self.config.n_heads)
749
- kv_index.add(torch.concat([v.squeeze(), k.squeeze()], dim=1).to('cpu').numpy())
750
-
751
- else:
752
- if long_range_past_key_values is None:
753
- long_range_past_key_values = [(k.to(self.memory_device),v.to(self.memory_device)) for k,v in to_cache]
754
- else:
755
- long_range_past_key_values = [
756
- (
757
- torch.concat([kv[0], to_cache[ind][0].to(self.memory_device)], dim=3),
758
- torch.concat([kv[1], to_cache[ind][1].to(self.memory_device)], dim=2)
759
- )
760
- for ind, kv in enumerate(long_range_past_key_values)
761
- ]
762
- if long_range_past_key_values is not None:
763
- if long_range_past_key_values[0][0].size(-1) > max_length_cache: #set a limit on manual memory length
764
- long_range_past_key_values = [
765
- (
766
- kv[0][:, :, :, -max_length_cache:],
767
- kv[1][:, :, -max_length_cache:]
768
- )
769
- for kv in long_range_past_key_values]
770
- if verbose:
771
- if cache_type == 'faiss':
772
- print(f"{kn_index.ntotal} keys in faiss index")
773
- else:
774
- print(f"{long_range_past_key_values[0][0].size(-1)} cached kvs")
775
-
776
- return long_range_past_key_values, (kn_index, kv_index) if cache_type == 'faiss' else None
777
-
778
- def prepare_inputs_for_generation(
779
- self,
780
- input_ids,
781
- past_key_values=None,
782
- inputs_embeds=None,
783
- **kwargs,
784
- ):
785
- if inputs_embeds is not None:
786
- raise NotImplementedError(
787
- 'inputs_embeds is not implemented for MPT yet')
788
-
789
- attention_mask = kwargs['attention_mask'].bool()
790
- if attention_mask[:, -1].sum() != attention_mask.shape[0]:
791
- raise NotImplementedError(
792
- 'MPT does not support generation with right padding.')
793
-
794
- if self.transformer.attn_uses_sequence_id and self.training:
795
- sequence_id = torch.zeros_like(input_ids[:1])
796
- else:
797
- sequence_id = None
798
-
799
- if past_key_values is not None:
800
- input_ids = input_ids[:, -1].unsqueeze(-1)
801
-
802
- if self.transformer.prefix_lm:
803
- # Leverage a convenience of sequential generation!
804
- prefix_mask = torch.ones_like(attention_mask)
805
- # This requires that we're using the cache
806
- if kwargs.get('use_cache') == False:
807
- raise NotImplementedError(
808
- 'MPT with prefix_lm=True does not support use_cache=False.')
809
- else:
810
- prefix_mask = None
811
-
812
- return {
813
- 'input_ids': input_ids,
814
- 'attention_mask': attention_mask,
815
- 'prefix_mask': prefix_mask,
816
- 'sequence_id': sequence_id,
817
- 'past_key_values': past_key_values,
818
- 'use_cache': kwargs.get('use_cache', True),
819
- 'use_active_externalism': kwargs.get('use_active_externalism'),
820
- 'topk': kwargs.get('topk', None),
821
- }
822
-
823
- @staticmethod
824
- def _reorder_cache(past_key_values, beam_idx):
825
- """Used by HuggingFace generate when using beam search with kv-caching.
826
-
827
- See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
828
- for an example in transformers.
829
- """
830
- reordered_past = []
831
- for layer_past in past_key_values:
832
- reordered_past += [
833
- tuple(
834
- past_state.index_select(0, beam_idx)
835
- for past_state in layer_past)
836
- ]
837
- return reordered_past