d-Matrix commited on
Commit
71da3e8
·
verified ·
1 Parent(s): 9a56cac

Update modeling_opt.py

Browse files
Files changed (1) hide show
  1. modeling_opt.py +270 -620
modeling_opt.py CHANGED
@@ -13,44 +13,38 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  """ PyTorch OPT model."""
 
16
  from typing import List, Optional, Tuple, Union
17
 
18
  import torch
19
- import torch.nn.functional as F
20
  import torch.utils.checkpoint
21
  from torch import nn
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
 
24
- from transformers.activations import ACT2FN
25
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
26
- from transformers.modeling_outputs import (
27
  BaseModelOutputWithPast,
28
  CausalLMOutputWithPast,
29
  QuestionAnsweringModelOutput,
30
  SequenceClassifierOutputWithPast,
31
  )
32
- from transformers.modeling_utils import PreTrainedModel
33
- from transformers.utils import (
34
  add_code_sample_docstrings,
35
  add_start_docstrings,
36
  add_start_docstrings_to_model_forward,
37
- is_flash_attn_2_available,
38
- is_flash_attn_greater_or_equal_2_10,
39
  logging,
40
  replace_return_docstrings,
41
  )
42
  from .configuration_opt import OPTConfig
43
-
44
-
45
- if is_flash_attn_2_available():
46
- from flash_attn import flash_attn_func, flash_attn_varlen_func
47
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
48
 
49
 
50
  logger = logging.get_logger(__name__)
51
 
52
  _CHECKPOINT_FOR_DOC = "facebook/opt-350m"
53
  _CONFIG_FOR_DOC = "OPTConfig"
 
54
 
55
  # Base model docstring
56
  _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
@@ -71,45 +65,36 @@ OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
71
  # See all OPT models at https://huggingface.co/models?filter=opt
72
  ]
73
 
 
 
 
 
 
 
 
 
 
74
 
75
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
76
- def _get_unpad_data(attention_mask):
77
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
78
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
79
- max_seqlen_in_batch = seqlens_in_batch.max().item()
80
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
81
- return (
82
- indices,
83
- cu_seqlens,
84
- max_seqlen_in_batch,
85
- )
86
-
87
-
88
- # class OPTLearnedPositionalEmbedding(nn.Embedding):
89
- # """
90
- # This module learns positional embeddings up to a fixed maximum size.
91
- # """
92
 
93
- # def __init__(self, num_embeddings: int, embedding_dim: int):
94
- # # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
95
- # # and adjust num_embeddings appropriately. Other models don't have this hack
96
- # self.offset = 2
97
- # super().__init__(num_embeddings + self.offset, embedding_dim)
98
 
99
- # def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
100
- # """`input_ids_shape` is expected to be [bsz x seqlen]."""
101
- # attention_mask = attention_mask.long()
 
 
 
102
 
103
- # # create positions depending on attention_mask
104
- # positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
105
 
106
- # # cut positions if `past_key_values_length` is > 0
107
- # positions = positions[:, past_key_values_length:]
108
 
109
- # return super().forward(positions + self.offset)
110
 
111
 
112
- class OPTLearnedPositionalEmbedding(nn.Module):
113
  """
114
  This module learns positional embeddings up to a fixed maximum size.
115
  """
@@ -117,25 +102,20 @@ class OPTLearnedPositionalEmbedding(nn.Module):
117
  def __init__(self, num_embeddings: int, embedding_dim: int):
118
  # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
119
  # and adjust num_embeddings appropriately. Other models don't have this hack
120
- super().__init__()
121
  self.offset = 2
122
- self.embeddings = nn.Embedding(num_embeddings + self.offset, embedding_dim)
123
 
124
- def forward(
125
- self, attention_mask: torch.LongTensor, past_key_values_length: int = 0
126
- ):
127
  """`input_ids_shape` is expected to be [bsz x seqlen]."""
128
  attention_mask = attention_mask.long()
129
 
130
  # create positions depending on attention_mask
131
- positions = (
132
- torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
133
- ).long() - 1
134
 
135
  # cut positions if `past_key_values_length` is > 0
136
  positions = positions[:, past_key_values_length:]
137
 
138
- return self.embeddings(positions + self.offset)
139
 
140
 
141
  class OPTAttention(nn.Module):
@@ -143,64 +123,36 @@ class OPTAttention(nn.Module):
143
 
144
  def __init__(
145
  self,
146
- config: OPTConfig,
 
 
147
  is_decoder: bool = False,
148
- **kwargs,
149
  ):
150
  super().__init__()
151
- self.config = config
152
-
153
- def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs):
154
- """
155
- If a the deprecated argument `fn_arg_name` is passed, raise a deprecation
156
- warning and return that value, otherwise take the equivalent config.config_arg_name
157
- """
158
- val = None
159
- if fn_arg_name in kwargs:
160
- logging.warning(
161
- "Passing in {fn_arg_name} to {self.__class__.__name__} is deprecated and won't be supported from "
162
- "v4.39. Please set it in the config instead"
163
- )
164
- val = kwargs.pop(fn_arg_name)
165
- else:
166
- val = getattr(config, config_arg_name)
167
- return val
168
 
169
- self.embed_dim = _handle_deprecated_argument(
170
- "hidden_size", config, "embed_dim", kwargs
171
- )
172
- self.num_heads = _handle_deprecated_argument(
173
- "num_attention_heads", config, "num_heads", kwargs
174
- )
175
- self.dropout = _handle_deprecated_argument(
176
- "attention_dropout", config, "dropout", kwargs
177
- )
178
- self.enable_bias = _handle_deprecated_argument(
179
- "enable_bias", config, "bias", kwargs
180
- )
181
-
182
- self.head_dim = self.embed_dim // self.num_heads
183
- self.is_causal = True
184
-
185
- if (self.head_dim * self.num_heads) != self.embed_dim:
186
  raise ValueError(
187
  f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
188
- f" and `num_heads`: {self.num_heads})."
189
  )
190
  self.scaling = self.head_dim**-0.5
191
  self.is_decoder = is_decoder
192
 
193
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
194
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
195
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
196
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
 
 
 
197
 
198
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
199
- return (
200
- tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
201
- .transpose(1, 2)
202
- .contiguous()
203
- )
204
 
205
  def forward(
206
  self,
@@ -270,25 +222,16 @@ class OPTAttention(nn.Module):
270
  raise ValueError(
271
  f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
272
  )
273
- attn_weights = (
274
- attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
275
- + attention_mask
276
- )
277
- attn_weights = torch.max(
278
- attn_weights,
279
- torch.tensor(
280
- torch.finfo(attn_weights.dtype).min, device=attn_weights.device
281
- ),
282
- )
283
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
284
 
285
  # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
 
286
  if attn_weights.dtype == torch.float16:
287
- attn_weights = nn.functional.softmax(
288
- attn_weights, dim=-1, dtype=torch.float32
289
- ).to(torch.float16)
290
  else:
291
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
292
 
293
  if layer_head_mask is not None:
294
  if layer_head_mask.size() != (self.num_heads,):
@@ -296,9 +239,7 @@ class OPTAttention(nn.Module):
296
  f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
297
  f" {layer_head_mask.size()}"
298
  )
299
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
300
- bsz, self.num_heads, tgt_len, src_len
301
- )
302
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
303
 
304
  if output_attentions:
@@ -306,19 +247,12 @@ class OPTAttention(nn.Module):
306
  # make sure that attn_weights keeps its gradient.
307
  # In order to do so, attn_weights have to be reshaped
308
  # twice and have to be reused in the following
309
- attn_weights_reshaped = attn_weights.view(
310
- bsz, self.num_heads, tgt_len, src_len
311
- )
312
- attn_weights = attn_weights_reshaped.view(
313
- bsz * self.num_heads, tgt_len, src_len
314
- )
315
  else:
316
  attn_weights_reshaped = None
317
 
318
- attn_probs = nn.functional.dropout(
319
- attn_weights, p=self.dropout, training=self.training
320
- )
321
-
322
  attn_output = torch.bmm(attn_probs, value_states)
323
 
324
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
@@ -339,297 +273,36 @@ class OPTAttention(nn.Module):
339
  return attn_output, attn_weights_reshaped, past_key_value
340
 
341
 
342
- class OptFlashAttention2(OPTAttention):
343
- """
344
- OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
345
- The only required change would be on the forward pass where it needs to correctly call the public API of flash
346
- attention and deal with padding tokens in case the input contains any of them.
347
- """
348
-
349
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
350
- def __init__(self, *args, **kwargs):
351
- super().__init__(*args, **kwargs)
352
-
353
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
354
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
355
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
356
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
357
-
358
- def forward(
359
- self,
360
- hidden_states: torch.Tensor,
361
- key_value_states: Optional[torch.Tensor] = None,
362
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
363
- attention_mask: Optional[torch.Tensor] = None,
364
- layer_head_mask: Optional[torch.Tensor] = None,
365
- output_attentions: bool = False,
366
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
367
- """Input shape: Batch x Time x Channel"""
368
-
369
- # if key_value_states are provided this layer is used as a cross-attention layer
370
- # for the decoder
371
- is_cross_attention = key_value_states is not None
372
-
373
- bsz, _, _ = hidden_states.size()
374
-
375
- # get query proj
376
- query_states = self.q_proj(hidden_states)
377
- # get key, value proj
378
- if is_cross_attention and past_key_value is not None:
379
- # reuse k,v, cross_attentions
380
- key_states = past_key_value[0]
381
- value_states = past_key_value[1]
382
- elif is_cross_attention:
383
- # cross_attentions
384
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
385
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
386
- elif past_key_value is not None:
387
- # reuse k, v, self_attention
388
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
389
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
390
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
391
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
392
- else:
393
- # self_attention
394
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
395
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
396
-
397
- if self.is_decoder:
398
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
399
- # Further calls to cross_attention layer can then reuse all cross-attention
400
- # key/value_states (first "if" case)
401
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
402
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
403
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
404
- # if encoder bi-directional self-attention `past_key_value` is always `None`
405
- past_key_value = (key_states, value_states)
406
-
407
- query_length = query_states.shape[1]
408
- tgt_len = key_states.shape[-2]
409
-
410
- # Flash attention requires the input to have the shape
411
- # batch_size x seq_length x head_dim x hidden_dim
412
- query_states = query_states.view(
413
- bsz, query_length, self.num_heads, self.head_dim
414
- )
415
- key_states = key_states.transpose(1, 2).view(
416
- bsz, tgt_len, self.num_heads, self.head_dim
417
- )
418
- value_states = value_states.transpose(1, 2).view(
419
- bsz, tgt_len, self.num_heads, self.head_dim
420
- )
421
-
422
- attn_dropout = self.dropout if self.training else 0.0
423
-
424
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
425
- # therefore the input hidden states gets silently casted in float32. Hence, we need
426
- # cast them back in float16 just to be sure everything works as expected.
427
- input_dtype = query_states.dtype
428
- if input_dtype == torch.float32:
429
- if torch.is_autocast_enabled():
430
- target_dtype = torch.get_autocast_gpu_dtype()
431
- # Handle the case where the model is quantized
432
- elif hasattr(self.config, "_pre_quantization_dtype"):
433
- target_dtype = self.config._pre_quantization_dtype
434
- else:
435
- target_dtype = self.q_proj.weight.dtype
436
-
437
- logger.warning_once(
438
- f"The input hidden states seems to be silently casted in float32, this might be related to"
439
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
440
- f" {target_dtype}."
441
- )
442
-
443
- query_states = query_states.to(target_dtype)
444
- key_states = key_states.to(target_dtype)
445
- value_states = value_states.to(target_dtype)
446
-
447
- attn_output = self._flash_attention_forward(
448
- query_states,
449
- key_states,
450
- value_states,
451
- attention_mask,
452
- query_length,
453
- dropout=attn_dropout,
454
- )
455
-
456
- attn_weights_reshaped = attn_output.reshape(
457
- bsz, query_length, self.num_heads * self.head_dim
458
- )
459
- attn_output = self.out_proj(attn_weights_reshaped)
460
-
461
- if not output_attentions:
462
- attn_weights_reshaped = None
463
-
464
- return attn_output, attn_weights_reshaped, past_key_value
465
-
466
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
467
- def _flash_attention_forward(
468
- self,
469
- query_states,
470
- key_states,
471
- value_states,
472
- attention_mask,
473
- query_length,
474
- dropout=0.0,
475
- softmax_scale=None,
476
- ):
477
- """
478
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
479
- first unpad the input, then computes the attention scores and pad the final attention scores.
480
-
481
- Args:
482
- query_states (`torch.Tensor`):
483
- Input query states to be passed to Flash Attention API
484
- key_states (`torch.Tensor`):
485
- Input key states to be passed to Flash Attention API
486
- value_states (`torch.Tensor`):
487
- Input value states to be passed to Flash Attention API
488
- attention_mask (`torch.Tensor`):
489
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
490
- position of padding tokens and 1 for the position of non-padding tokens.
491
- dropout (`int`, *optional*):
492
- Attention dropout
493
- softmax_scale (`float`, *optional*):
494
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
495
- """
496
- if not self._flash_attn_uses_top_left_mask:
497
- causal = self.is_causal
498
- else:
499
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
500
- causal = self.is_causal and query_length != 1
501
-
502
- # Contains at least one padding token in the sequence
503
- if attention_mask is not None:
504
- batch_size = query_states.shape[0]
505
- (
506
- query_states,
507
- key_states,
508
- value_states,
509
- indices_q,
510
- cu_seq_lens,
511
- max_seq_lens,
512
- ) = self._upad_input(
513
- query_states, key_states, value_states, attention_mask, query_length
514
- )
515
-
516
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
517
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
518
-
519
- attn_output_unpad = flash_attn_varlen_func(
520
- query_states,
521
- key_states,
522
- value_states,
523
- cu_seqlens_q=cu_seqlens_q,
524
- cu_seqlens_k=cu_seqlens_k,
525
- max_seqlen_q=max_seqlen_in_batch_q,
526
- max_seqlen_k=max_seqlen_in_batch_k,
527
- dropout_p=dropout,
528
- softmax_scale=softmax_scale,
529
- causal=causal,
530
- )
531
-
532
- attn_output = pad_input(
533
- attn_output_unpad, indices_q, batch_size, query_length
534
- )
535
- else:
536
- attn_output = flash_attn_func(
537
- query_states,
538
- key_states,
539
- value_states,
540
- dropout,
541
- softmax_scale=softmax_scale,
542
- causal=causal,
543
- )
544
-
545
- return attn_output
546
-
547
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
548
- def _upad_input(
549
- self, query_layer, key_layer, value_layer, attention_mask, query_length
550
- ):
551
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
552
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
553
-
554
- key_layer = index_first_axis(
555
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
556
- indices_k,
557
- )
558
- value_layer = index_first_axis(
559
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
560
- indices_k,
561
- )
562
- if query_length == kv_seq_len:
563
- query_layer = index_first_axis(
564
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
565
- indices_k,
566
- )
567
- cu_seqlens_q = cu_seqlens_k
568
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
569
- indices_q = indices_k
570
- elif query_length == 1:
571
- max_seqlen_in_batch_q = 1
572
- cu_seqlens_q = torch.arange(
573
- batch_size + 1, dtype=torch.int32, device=query_layer.device
574
- ) # There is a memcpy here, that is very bad.
575
- indices_q = cu_seqlens_q[:-1]
576
- query_layer = query_layer.squeeze(1)
577
- else:
578
- # The -q_len: slice assumes left padding.
579
- attention_mask = attention_mask[:, -query_length:]
580
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
581
- query_layer, attention_mask
582
- )
583
-
584
- return (
585
- query_layer,
586
- key_layer,
587
- value_layer,
588
- indices_q,
589
- (cu_seqlens_q, cu_seqlens_k),
590
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
591
- )
592
-
593
-
594
- OPT_ATTENTION_CLASSES = {
595
- "eager": OPTAttention,
596
- "flash_attention_2": OptFlashAttention2,
597
- }
598
-
599
-
600
  class OPTDecoderLayer(nn.Module):
601
  def __init__(self, config: OPTConfig):
602
  super().__init__()
603
  self.embed_dim = config.hidden_size
604
-
605
- self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](
606
- config=config, is_decoder=True
 
 
607
  )
608
-
609
  self.do_layer_norm_before = config.do_layer_norm_before
610
- self.dropout = config.dropout
611
  self.activation_fn = ACT2FN[config.activation_function]
612
 
613
- self.self_attn_layer_norm = nn.LayerNorm(
614
- self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
615
- )
616
- self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias)
617
- self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias)
618
- self.final_layer_norm = nn.LayerNorm(
619
- self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
620
- )
621
 
622
  def forward(
623
  self,
624
  hidden_states: torch.Tensor,
625
  attention_mask: Optional[torch.Tensor] = None,
626
  layer_head_mask: Optional[torch.Tensor] = None,
627
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
628
  output_attentions: Optional[bool] = False,
629
  use_cache: Optional[bool] = False,
630
- ) -> Tuple[
631
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
632
- ]:
633
  """
634
  Args:
635
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@@ -660,9 +333,7 @@ class OPTDecoderLayer(nn.Module):
660
  layer_head_mask=layer_head_mask,
661
  output_attentions=output_attentions,
662
  )
663
- hidden_states = nn.functional.dropout(
664
- hidden_states, p=self.dropout, training=self.training
665
- )
666
  hidden_states = residual + hidden_states
667
 
668
  # 350m applies layer norm AFTER attention
@@ -682,9 +353,7 @@ class OPTDecoderLayer(nn.Module):
682
  hidden_states = self.activation_fn(hidden_states)
683
 
684
  hidden_states = self.fc2(hidden_states)
685
- hidden_states = nn.functional.dropout(
686
- hidden_states, p=self.dropout, training=self.training
687
- )
688
 
689
  hidden_states = (residual + hidden_states).view(hidden_states_shape)
690
 
@@ -729,7 +398,7 @@ class OPTPreTrainedModel(PreTrainedModel):
729
  base_model_prefix = "model"
730
  supports_gradient_checkpointing = True
731
  _no_split_modules = ["OPTDecoderLayer"]
732
- _supports_flash_attn_2 = True
733
 
734
  def _init_weights(self, module):
735
  std = self.config.init_std
@@ -742,6 +411,10 @@ class OPTPreTrainedModel(PreTrainedModel):
742
  if module.padding_idx is not None:
743
  module.weight.data[module.padding_idx].zero_()
744
 
 
 
 
 
745
 
746
  OPT_INPUTS_DOCSTRING = r"""
747
  Args:
@@ -749,7 +422,7 @@ OPT_INPUTS_DOCSTRING = r"""
749
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
750
  it.
751
 
752
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
753
  [`PreTrainedTokenizer.__call__`] for details.
754
 
755
  [What are input IDs?](../glossary#input-ids)
@@ -761,7 +434,7 @@ OPT_INPUTS_DOCSTRING = r"""
761
 
762
  [What are attention masks?](../glossary#attention-mask)
763
 
764
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
765
  [`PreTrainedTokenizer.__call__`] for details.
766
 
767
  If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
@@ -821,25 +494,16 @@ class OPTDecoder(OPTPreTrainedModel):
821
  self.max_target_positions = config.max_position_embeddings
822
  self.vocab_size = config.vocab_size
823
 
824
- self.embed_tokens = nn.Embedding(
825
- config.vocab_size, config.word_embed_proj_dim, self.padding_idx
826
- )
827
- self._embed_positions = OPTLearnedPositionalEmbedding(
828
- config.max_position_embeddings, config.hidden_size
829
- )
830
- self.embed_positions = self._embed_positions.embeddings
831
 
832
  if config.word_embed_proj_dim != config.hidden_size:
833
- self.project_out = nn.Linear(
834
- config.hidden_size, config.word_embed_proj_dim, bias=False
835
- )
836
  else:
837
  self.project_out = None
838
 
839
  if config.word_embed_proj_dim != config.hidden_size:
840
- self.project_in = nn.Linear(
841
- config.word_embed_proj_dim, config.hidden_size, bias=False
842
- )
843
  else:
844
  self.project_in = None
845
 
@@ -847,17 +511,11 @@ class OPTDecoder(OPTPreTrainedModel):
847
  # with checkpoints that have been fine-tuned before transformers v4.20.1
848
  # see https://github.com/facebookresearch/metaseq/pull/164
849
  if config.do_layer_norm_before and not config._remove_final_layer_norm:
850
- self.final_layer_norm = nn.LayerNorm(
851
- config.hidden_size,
852
- elementwise_affine=config.layer_norm_elementwise_affine,
853
- )
854
  else:
855
  self.final_layer_norm = None
856
 
857
- self.layers = nn.ModuleList(
858
- [OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]
859
- )
860
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
861
 
862
  self.gradient_checkpointing = False
863
  # Initialize weights and apply final processing
@@ -869,6 +527,29 @@ class OPTDecoder(OPTPreTrainedModel):
869
  def set_input_embeddings(self, value):
870
  self.embed_tokens = value
871
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
872
  def forward(
873
  self,
874
  input_ids: torch.LongTensor = None,
@@ -887,7 +568,7 @@ class OPTDecoder(OPTPreTrainedModel):
887
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
888
  provide it.
889
 
890
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
891
  [`PreTrainedTokenizer.__call__`] for details.
892
 
893
  [What are input IDs?](../glossary#input-ids)
@@ -928,89 +609,44 @@ class OPTDecoder(OPTPreTrainedModel):
928
  return_dict (`bool`, *optional*):
929
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
930
  """
931
- output_attentions = (
932
- output_attentions
933
- if output_attentions is not None
934
- else self.config.output_attentions
935
- )
936
  output_hidden_states = (
937
- output_hidden_states
938
- if output_hidden_states is not None
939
- else self.config.output_hidden_states
940
  )
941
  use_cache = use_cache if use_cache is not None else self.config.use_cache
942
 
943
- return_dict = (
944
- return_dict if return_dict is not None else self.config.use_return_dict
945
- )
946
 
947
  # retrieve input_ids and inputs_embeds
948
  if input_ids is not None and inputs_embeds is not None:
949
- raise ValueError(
950
- "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
951
- )
952
  elif input_ids is not None:
953
  input_shape = input_ids.size()
954
  input_ids = input_ids.view(-1, input_shape[-1])
955
  elif inputs_embeds is not None:
956
  input_shape = inputs_embeds.size()[:-1]
957
  else:
958
- raise ValueError(
959
- "You have to specify either decoder_input_ids or decoder_inputs_embeds"
960
- )
961
 
962
  if inputs_embeds is None:
963
  inputs_embeds = self.embed_tokens(input_ids)
964
 
965
- batch_size, seq_length = input_shape
966
- past_key_values_length = (
967
- past_key_values[0][0].shape[2] if past_key_values is not None else 0
968
- )
969
- # required mask seq length can be calculated via length of past
970
- mask_seq_length = past_key_values_length + seq_length
971
-
972
  # embed positions
973
- if self._use_flash_attention_2:
974
- # 2d mask is passed through the layers
975
- causal_attention_mask = (
976
- attention_mask
977
- if (attention_mask is not None and 0 in attention_mask)
978
- else None
979
- )
980
- attention_mask = (
981
- torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
982
- if attention_mask is None
983
- else attention_mask
984
- )
985
- else:
986
- # 4d mask is passed through the layers
987
- if attention_mask is None:
988
- attention_mask = torch.ones(
989
- batch_size, mask_seq_length, device=inputs_embeds.device
990
- )
991
- elif attention_mask.shape[1] != mask_seq_length:
992
- raise ValueError(
993
- f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
994
- f"{mask_seq_length} (sum of the lengths of current and past inputs)"
995
- )
996
- causal_attention_mask = _prepare_4d_causal_attention_mask(
997
- attention_mask, input_shape, inputs_embeds, past_key_values_length
998
- )
999
 
1000
- pos_embeds = self._embed_positions(attention_mask, past_key_values_length)
 
 
1001
 
1002
  if self.project_in is not None:
1003
  inputs_embeds = self.project_in(inputs_embeds)
1004
 
1005
  hidden_states = inputs_embeds + pos_embeds
1006
 
1007
- if self.gradient_checkpointing and self.training:
1008
- if use_cache:
1009
- logger.warning_once(
1010
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1011
- )
1012
- use_cache = False
1013
-
1014
  # decoder layers
1015
  all_hidden_states = () if output_hidden_states else None
1016
  all_self_attns = () if output_attentions else None
@@ -1030,29 +666,39 @@ class OPTDecoder(OPTPreTrainedModel):
1030
  if output_hidden_states:
1031
  all_hidden_states += (hidden_states,)
1032
 
1033
- if self.training:
1034
- dropout_probability = torch.rand([])
1035
- if dropout_probability < self.layerdrop:
1036
- continue
1037
 
1038
- past_key_value = (
1039
- past_key_values[idx] if past_key_values is not None else None
1040
- )
1041
 
1042
  if self.gradient_checkpointing and self.training:
1043
- layer_outputs = self._gradient_checkpointing_func(
1044
- decoder_layer.__call__,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1045
  hidden_states,
1046
- causal_attention_mask,
1047
  head_mask[idx] if head_mask is not None else None,
1048
  None,
1049
- output_attentions,
1050
- use_cache,
1051
  )
1052
  else:
 
1053
  layer_outputs = decoder_layer(
1054
  hidden_states,
1055
- attention_mask=causal_attention_mask,
1056
  layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1057
  past_key_value=past_key_value,
1058
  output_attentions=output_attentions,
@@ -1067,6 +713,12 @@ class OPTDecoder(OPTPreTrainedModel):
1067
  if output_attentions:
1068
  all_self_attns += (layer_outputs[1],)
1069
 
 
 
 
 
 
 
1070
  if self.final_layer_norm is not None:
1071
  hidden_states = self.final_layer_norm(hidden_states)
1072
 
@@ -1079,11 +731,7 @@ class OPTDecoder(OPTPreTrainedModel):
1079
 
1080
  next_cache = next_decoder_cache if use_cache else None
1081
  if not return_dict:
1082
- return tuple(
1083
- v
1084
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1085
- if v is not None
1086
- )
1087
  return BaseModelOutputWithPast(
1088
  last_hidden_state=hidden_states,
1089
  past_key_values=next_cache,
@@ -1100,9 +748,46 @@ class OPTModel(OPTPreTrainedModel):
1100
  def __init__(self, config: OPTConfig):
1101
  super().__init__(config)
1102
  self.decoder = OPTDecoder(config)
 
 
 
 
 
 
1103
  # Initialize weights and apply final processing
1104
  self.post_init()
1105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1106
  def get_input_embeddings(self):
1107
  return self.decoder.embed_tokens
1108
 
@@ -1114,6 +799,7 @@ class OPTModel(OPTPreTrainedModel):
1114
 
1115
  @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
1116
  @add_code_sample_docstrings(
 
1117
  checkpoint=_CHECKPOINT_FOR_DOC,
1118
  output_type=BaseModelOutputWithPast,
1119
  config_class=_CONFIG_FOR_DOC,
@@ -1131,20 +817,13 @@ class OPTModel(OPTPreTrainedModel):
1131
  output_hidden_states: Optional[bool] = None,
1132
  return_dict: Optional[bool] = None,
1133
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1134
- output_attentions = (
1135
- output_attentions
1136
- if output_attentions is not None
1137
- else self.config.output_attentions
1138
- )
1139
  output_hidden_states = (
1140
- output_hidden_states
1141
- if output_hidden_states is not None
1142
- else self.config.output_hidden_states
1143
  )
1144
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1145
- return_dict = (
1146
- return_dict if return_dict is not None else self.config.use_return_dict
1147
- )
1148
 
1149
  # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1150
  decoder_outputs = self.decoder(
@@ -1171,20 +850,40 @@ class OPTModel(OPTPreTrainedModel):
1171
 
1172
 
1173
  class OPTForCausalLM(OPTPreTrainedModel):
1174
- _tied_weights_keys = ["lm_head.weight"]
1175
 
1176
  def __init__(self, config):
1177
  super().__init__(config)
1178
  self.model = OPTModel(config)
1179
 
1180
  # the lm_head weight is automatically tied to the embed tokens weight
1181
- self.lm_head = nn.Linear(
1182
- config.word_embed_proj_dim, config.vocab_size, bias=False
1183
- )
 
 
1184
 
1185
  # Initialize weights and apply final processing
1186
  self.post_init()
1187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1188
  def get_input_embeddings(self):
1189
  return self.model.decoder.embed_tokens
1190
 
@@ -1203,9 +902,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
1203
  def get_decoder(self):
1204
  return self.model.decoder
1205
 
1206
- @replace_return_docstrings(
1207
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1208
- )
1209
  def forward(
1210
  self,
1211
  input_ids: torch.LongTensor = None,
@@ -1225,7 +922,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
1225
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1226
  provide it.
1227
 
1228
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1229
  [`PreTrainedTokenizer.__call__`] for details.
1230
 
1231
  [What are input IDs?](../glossary#input-ids)
@@ -1279,33 +976,25 @@ class OPTForCausalLM(OPTPreTrainedModel):
1279
  Example:
1280
 
1281
  ```python
1282
- >>> from transformers import AutoTokenizer, OPTForCausalLM
1283
 
1284
  >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
1285
- >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
1286
 
1287
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1288
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1289
 
1290
  >>> # Generate
1291
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1292
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1293
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
1294
  ```"""
1295
 
1296
- output_attentions = (
1297
- output_attentions
1298
- if output_attentions is not None
1299
- else self.config.output_attentions
1300
- )
1301
  output_hidden_states = (
1302
- output_hidden_states
1303
- if output_hidden_states is not None
1304
- else self.config.output_hidden_states
1305
- )
1306
- return_dict = (
1307
- return_dict if return_dict is not None else self.config.use_return_dict
1308
  )
 
1309
 
1310
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1311
  outputs = self.model.decoder(
@@ -1320,7 +1009,11 @@ class OPTForCausalLM(OPTPreTrainedModel):
1320
  return_dict=return_dict,
1321
  )
1322
 
1323
- logits = self.lm_head(outputs[0]).contiguous()
 
 
 
 
1324
 
1325
  loss = None
1326
  if labels is not None:
@@ -1331,9 +1024,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
1331
  shift_labels = labels[..., 1:].contiguous()
1332
  # Flatten the tokens
1333
  loss_fct = CrossEntropyLoss()
1334
- loss = loss_fct(
1335
- shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)
1336
- )
1337
 
1338
  if not return_dict:
1339
  output = (logits,) + outputs[1:]
@@ -1347,51 +1038,26 @@ class OPTForCausalLM(OPTPreTrainedModel):
1347
  attentions=outputs.attentions,
1348
  )
1349
 
1350
- def prepare_inputs_for_generation(
1351
- self,
1352
- input_ids,
1353
- past_key_values=None,
1354
- attention_mask=None,
1355
- inputs_embeds=None,
1356
- **kwargs,
1357
- ):
1358
- if past_key_values is not None:
1359
- past_length = past_key_values[0][0].shape[2]
1360
-
1361
- # Some generation methods already pass only the last input ID
1362
- if input_ids.shape[1] > past_length:
1363
- remove_prefix_length = past_length
1364
- else:
1365
- # Default to old behavior: keep only final ID
1366
- remove_prefix_length = input_ids.shape[1] - 1
1367
-
1368
- input_ids = input_ids[:, remove_prefix_length:]
1369
-
1370
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1371
- if inputs_embeds is not None and past_key_values is None:
1372
- model_inputs = {"inputs_embeds": inputs_embeds}
1373
- else:
1374
- model_inputs = {"input_ids": input_ids}
1375
-
1376
- model_inputs.update(
1377
- {
1378
- "past_key_values": past_key_values,
1379
- "use_cache": kwargs.get("use_cache"),
1380
- "attention_mask": attention_mask,
1381
- }
1382
- )
1383
- return model_inputs
1384
 
1385
  @staticmethod
1386
- def _reorder_cache(past_key_values, beam_idx):
1387
  reordered_past = ()
1388
- for layer_past in past_key_values:
1389
- reordered_past += (
1390
- tuple(
1391
- past_state.index_select(0, beam_idx.to(past_state.device))
1392
- for past_state in layer_past
1393
- ),
1394
- )
1395
  return reordered_past
1396
 
1397
 
@@ -1411,6 +1077,8 @@ class OPTForCausalLM(OPTPreTrainedModel):
1411
  OPT_START_DOCSTRING,
1412
  )
1413
  class OPTForSequenceClassification(OPTPreTrainedModel):
 
 
1414
  def __init__(self, config: OPTConfig):
1415
  super().__init__(config)
1416
  self.num_labels = config.num_labels
@@ -1422,6 +1090,7 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
1422
 
1423
  @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
1424
  @add_code_sample_docstrings(
 
1425
  checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
1426
  output_type=SequenceClassifierOutputWithPast,
1427
  config_class=_CONFIG_FOR_DOC,
@@ -1447,9 +1116,7 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
1447
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1448
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1449
  """
1450
- return_dict = (
1451
- return_dict if return_dict is not None else self.config.use_return_dict
1452
- )
1453
 
1454
  transformer_outputs = self.model(
1455
  input_ids,
@@ -1474,12 +1141,7 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
1474
  sequence_lengths = -1
1475
  else:
1476
  if input_ids is not None:
1477
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1478
- sequence_lengths = (
1479
- torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1480
- )
1481
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1482
- sequence_lengths = sequence_lengths.to(logits.device)
1483
  else:
1484
  sequence_lengths = -1
1485
  logger.warning(
@@ -1487,18 +1149,14 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
1487
  "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1488
  )
1489
 
1490
- pooled_logits = logits[
1491
- torch.arange(batch_size, device=logits.device), sequence_lengths
1492
- ]
1493
 
1494
  loss = None
1495
  if labels is not None:
1496
  if self.config.problem_type is None:
1497
  if self.num_labels == 1:
1498
  self.config.problem_type = "regression"
1499
- elif self.num_labels > 1 and (
1500
- labels.dtype == torch.long or labels.dtype == torch.int
1501
- ):
1502
  self.config.problem_type = "single_label_classification"
1503
  else:
1504
  self.config.problem_type = "multi_label_classification"
@@ -1511,9 +1169,7 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
1511
  loss = loss_fct(pooled_logits, labels)
1512
  elif self.config.problem_type == "single_label_classification":
1513
  loss_fct = CrossEntropyLoss()
1514
- loss = loss_fct(
1515
- pooled_logits.view(-1, self.num_labels), labels.view(-1)
1516
- )
1517
  elif self.config.problem_type == "multi_label_classification":
1518
  loss_fct = BCEWithLogitsLoss()
1519
  loss = loss_fct(pooled_logits, labels)
@@ -1544,6 +1200,8 @@ class OPTForSequenceClassification(OPTPreTrainedModel):
1544
  OPT_START_DOCSTRING,
1545
  )
1546
  class OPTForQuestionAnswering(OPTPreTrainedModel):
 
 
1547
  def __init__(self, config: OPTConfig):
1548
  super().__init__(config)
1549
  self.model = OPTModel(config)
@@ -1553,9 +1211,7 @@ class OPTForQuestionAnswering(OPTPreTrainedModel):
1553
  self.post_init()
1554
 
1555
  @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
1556
- @replace_return_docstrings(
1557
- output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC
1558
- )
1559
  def forward(
1560
  self,
1561
  input_ids: Optional[torch.LongTensor] = None,
@@ -1585,11 +1241,11 @@ class OPTForQuestionAnswering(OPTPreTrainedModel):
1585
  Example:
1586
 
1587
  ```python
1588
- >>> from transformers import AutoTokenizer, OPTForQuestionAnswering
1589
  >>> import torch
1590
 
1591
  >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT
1592
- >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
1593
 
1594
  >>> # note: we are loading a OPTForQuestionAnswering from the hub here,
1595
  >>> # so the head will be randomly initialized, hence the predictions will be random
@@ -1604,18 +1260,12 @@ class OPTForQuestionAnswering(OPTPreTrainedModel):
1604
  >>> answer_start_index = outputs.start_logits.argmax()
1605
  >>> answer_end_index = outputs.end_logits.argmax()
1606
 
1607
- >>> answer_offset = len(tokenizer(question)[0])
1608
-
1609
- >>> predict_answer_tokens = inputs.input_ids[
1610
- ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1
1611
- ... ]
1612
  >>> predicted = tokenizer.decode(predict_answer_tokens)
1613
  >>> predicted
1614
- ' a nice puppet'
1615
  ```"""
1616
- return_dict = (
1617
- return_dict if return_dict is not None else self.config.use_return_dict
1618
- )
1619
 
1620
  transformer_outputs = self.model(
1621
  input_ids,
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  """ PyTorch OPT model."""
16
+ import random
17
  from typing import List, Optional, Tuple, Union
18
 
19
  import torch
 
20
  import torch.utils.checkpoint
21
  from torch import nn
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
 
24
+ from ...activations import ACT2FN
25
+ from ...modeling_outputs import (
 
26
  BaseModelOutputWithPast,
27
  CausalLMOutputWithPast,
28
  QuestionAnsweringModelOutput,
29
  SequenceClassifierOutputWithPast,
30
  )
31
+ from ...modeling_utils import PreTrainedModel
32
+ from ...utils import (
33
  add_code_sample_docstrings,
34
  add_start_docstrings,
35
  add_start_docstrings_to_model_forward,
 
 
36
  logging,
37
  replace_return_docstrings,
38
  )
39
  from .configuration_opt import OPTConfig
40
+ from ...utils.model_parallel_utils import assert_device_map, get_device_map
 
 
 
 
41
 
42
 
43
  logger = logging.get_logger(__name__)
44
 
45
  _CHECKPOINT_FOR_DOC = "facebook/opt-350m"
46
  _CONFIG_FOR_DOC = "OPTConfig"
47
+ _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
48
 
49
  # Base model docstring
50
  _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
 
65
  # See all OPT models at https://huggingface.co/models?filter=opt
66
  ]
67
 
68
+ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
69
+ """
70
+ Make causal mask used for bi-directional self-attention.
71
+ """
72
+ bsz, tgt_len = input_ids_shape
73
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
74
+ mask_cond = torch.arange(mask.size(-1))
75
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
76
+ mask = mask.to(dtype)
77
 
78
+ if past_key_values_length > 0:
79
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
80
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
 
 
 
 
 
82
 
83
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
84
+ """
85
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
86
+ """
87
+ bsz, src_len = mask.size()
88
+ tgt_len = tgt_len if tgt_len is not None else src_len
89
 
90
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
 
91
 
92
+ inverted_mask = 1.0 - expanded_mask
 
93
 
94
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
95
 
96
 
97
+ class OPTLearnedPositionalEmbedding(nn.Embedding):
98
  """
99
  This module learns positional embeddings up to a fixed maximum size.
100
  """
 
102
  def __init__(self, num_embeddings: int, embedding_dim: int):
103
  # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
104
  # and adjust num_embeddings appropriately. Other models don't have this hack
 
105
  self.offset = 2
106
+ super().__init__(num_embeddings + self.offset, embedding_dim)
107
 
108
+ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
 
 
109
  """`input_ids_shape` is expected to be [bsz x seqlen]."""
110
  attention_mask = attention_mask.long()
111
 
112
  # create positions depending on attention_mask
113
+ positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
 
 
114
 
115
  # cut positions if `past_key_values_length` is > 0
116
  positions = positions[:, past_key_values_length:]
117
 
118
+ return super().forward(positions + self.offset)
119
 
120
 
121
  class OPTAttention(nn.Module):
 
123
 
124
  def __init__(
125
  self,
126
+ embed_dim: int,
127
+ num_heads: int,
128
+ dropout: float = 0.0,
129
  is_decoder: bool = False,
130
+ bias: bool = True,
131
  ):
132
  super().__init__()
133
+ self.embed_dim = embed_dim
134
+ self.num_heads = num_heads
135
+ # self.dropout = dropout
136
+ self.head_dim = embed_dim // num_heads
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ if (self.head_dim * num_heads) != self.embed_dim:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  raise ValueError(
140
  f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
141
+ f" and `num_heads`: {num_heads})."
142
  )
143
  self.scaling = self.head_dim**-0.5
144
  self.is_decoder = is_decoder
145
 
146
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
147
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
148
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
149
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
150
+
151
+ self.softmax = nn.Softmax(dim=-1)
152
+ self.dropout = nn.Dropout(p=dropout)
153
 
154
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
155
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
 
 
 
 
156
 
157
  def forward(
158
  self,
 
222
  raise ValueError(
223
  f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
224
  )
225
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask.to(attn_weights.device)
226
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
 
 
 
 
 
 
 
 
227
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
228
 
229
  # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
230
+ breakpoint()
231
  if attn_weights.dtype == torch.float16:
232
+ attn_weights = self.softmax(attn_weights.float()).to(torch.float16)
 
 
233
  else:
234
+ attn_weights = self.softmax(attn_weights)
235
 
236
  if layer_head_mask is not None:
237
  if layer_head_mask.size() != (self.num_heads,):
 
239
  f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
240
  f" {layer_head_mask.size()}"
241
  )
242
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
 
 
243
  attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
244
 
245
  if output_attentions:
 
247
  # make sure that attn_weights keeps its gradient.
248
  # In order to do so, attn_weights have to be reshaped
249
  # twice and have to be reused in the following
250
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
251
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
 
 
 
 
252
  else:
253
  attn_weights_reshaped = None
254
 
255
+ attn_probs = self.dropout(attn_weights)
 
 
 
256
  attn_output = torch.bmm(attn_probs, value_states)
257
 
258
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
 
273
  return attn_output, attn_weights_reshaped, past_key_value
274
 
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  class OPTDecoderLayer(nn.Module):
277
  def __init__(self, config: OPTConfig):
278
  super().__init__()
279
  self.embed_dim = config.hidden_size
280
+ self.self_attn = OPTAttention(
281
+ embed_dim=self.embed_dim,
282
+ num_heads=config.num_attention_heads,
283
+ dropout=config.attention_dropout,
284
+ is_decoder=True,
285
  )
 
286
  self.do_layer_norm_before = config.do_layer_norm_before
287
+ # self.dropout = config.dropout
288
  self.activation_fn = ACT2FN[config.activation_function]
289
 
290
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
291
+ self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)
292
+ self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
293
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
294
+
295
+ self.dropout = nn.Dropout(p=config.dropout)
 
 
296
 
297
  def forward(
298
  self,
299
  hidden_states: torch.Tensor,
300
  attention_mask: Optional[torch.Tensor] = None,
301
  layer_head_mask: Optional[torch.Tensor] = None,
 
302
  output_attentions: Optional[bool] = False,
303
  use_cache: Optional[bool] = False,
304
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
305
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
306
  """
307
  Args:
308
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
 
333
  layer_head_mask=layer_head_mask,
334
  output_attentions=output_attentions,
335
  )
336
+ hidden_states = self.dropout(hidden_states)
 
 
337
  hidden_states = residual + hidden_states
338
 
339
  # 350m applies layer norm AFTER attention
 
353
  hidden_states = self.activation_fn(hidden_states)
354
 
355
  hidden_states = self.fc2(hidden_states)
356
+ hidden_states = self.dropout(hidden_states)
 
 
357
 
358
  hidden_states = (residual + hidden_states).view(hidden_states_shape)
359
 
 
398
  base_model_prefix = "model"
399
  supports_gradient_checkpointing = True
400
  _no_split_modules = ["OPTDecoderLayer"]
401
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
402
 
403
  def _init_weights(self, module):
404
  std = self.config.init_std
 
411
  if module.padding_idx is not None:
412
  module.weight.data[module.padding_idx].zero_()
413
 
414
+ def _set_gradient_checkpointing(self, module, value=False):
415
+ if isinstance(module, (OPTDecoder)):
416
+ module.gradient_checkpointing = value
417
+
418
 
419
  OPT_INPUTS_DOCSTRING = r"""
420
  Args:
 
422
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
423
  it.
424
 
425
+ Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
426
  [`PreTrainedTokenizer.__call__`] for details.
427
 
428
  [What are input IDs?](../glossary#input-ids)
 
434
 
435
  [What are attention masks?](../glossary#attention-mask)
436
 
437
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
438
  [`PreTrainedTokenizer.__call__`] for details.
439
 
440
  If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
 
494
  self.max_target_positions = config.max_position_embeddings
495
  self.vocab_size = config.vocab_size
496
 
497
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
498
+ self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
 
 
 
 
 
499
 
500
  if config.word_embed_proj_dim != config.hidden_size:
501
+ self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
 
 
502
  else:
503
  self.project_out = None
504
 
505
  if config.word_embed_proj_dim != config.hidden_size:
506
+ self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
 
 
507
  else:
508
  self.project_in = None
509
 
 
511
  # with checkpoints that have been fine-tuned before transformers v4.20.1
512
  # see https://github.com/facebookresearch/metaseq/pull/164
513
  if config.do_layer_norm_before and not config._remove_final_layer_norm:
514
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size)
 
 
 
515
  else:
516
  self.final_layer_norm = None
517
 
518
+ self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
 
 
 
519
 
520
  self.gradient_checkpointing = False
521
  # Initialize weights and apply final processing
 
527
  def set_input_embeddings(self, value):
528
  self.embed_tokens = value
529
 
530
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
531
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
532
+ # create causal mask
533
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
534
+ combined_attention_mask = None
535
+ if input_shape[-1] > 1:
536
+ combined_attention_mask = _make_causal_mask(
537
+ input_shape,
538
+ inputs_embeds.dtype,
539
+ past_key_values_length=past_key_values_length,
540
+ )
541
+
542
+ if attention_mask is not None:
543
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
544
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
545
+ inputs_embeds.device
546
+ )
547
+ combined_attention_mask = (
548
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask.to(expanded_attn_mask.device)
549
+ )
550
+
551
+ return combined_attention_mask
552
+
553
  def forward(
554
  self,
555
  input_ids: torch.LongTensor = None,
 
568
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
569
  provide it.
570
 
571
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
572
  [`PreTrainedTokenizer.__call__`] for details.
573
 
574
  [What are input IDs?](../glossary#input-ids)
 
609
  return_dict (`bool`, *optional*):
610
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
611
  """
612
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
613
  output_hidden_states = (
614
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
615
  )
616
  use_cache = use_cache if use_cache is not None else self.config.use_cache
617
 
618
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
619
 
620
  # retrieve input_ids and inputs_embeds
621
  if input_ids is not None and inputs_embeds is not None:
622
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
 
 
623
  elif input_ids is not None:
624
  input_shape = input_ids.size()
625
  input_ids = input_ids.view(-1, input_shape[-1])
626
  elif inputs_embeds is not None:
627
  input_shape = inputs_embeds.size()[:-1]
628
  else:
629
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
630
+
631
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
632
 
633
  if inputs_embeds is None:
634
  inputs_embeds = self.embed_tokens(input_ids)
635
 
 
 
 
 
 
 
 
636
  # embed positions
637
+ if attention_mask is None:
638
+ attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
639
+ pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
640
 
641
+ attention_mask = self._prepare_decoder_attention_mask(
642
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
643
+ )
644
 
645
  if self.project_in is not None:
646
  inputs_embeds = self.project_in(inputs_embeds)
647
 
648
  hidden_states = inputs_embeds + pos_embeds
649
 
 
 
 
 
 
 
 
650
  # decoder layers
651
  all_hidden_states = () if output_hidden_states else None
652
  all_self_attns = () if output_attentions else None
 
666
  if output_hidden_states:
667
  all_hidden_states += (hidden_states,)
668
 
669
+ dropout_probability = random.uniform(0, 1)
670
+ if self.training and (dropout_probability < self.layerdrop):
671
+ continue
 
672
 
673
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
 
 
674
 
675
  if self.gradient_checkpointing and self.training:
676
+
677
+ if use_cache:
678
+ logger.warning(
679
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
680
+ )
681
+ use_cache = False
682
+
683
+ def create_custom_forward(module):
684
+ def custom_forward(*inputs):
685
+ # None for past_key_value
686
+ return module(*inputs, output_attentions, None)
687
+
688
+ return custom_forward
689
+
690
+ layer_outputs = torch.utils.checkpoint.checkpoint(
691
+ create_custom_forward(decoder_layer),
692
  hidden_states,
693
+ attention_mask,
694
  head_mask[idx] if head_mask is not None else None,
695
  None,
 
 
696
  )
697
  else:
698
+
699
  layer_outputs = decoder_layer(
700
  hidden_states,
701
+ attention_mask=attention_mask,
702
  layer_head_mask=(head_mask[idx] if head_mask is not None else None),
703
  past_key_value=past_key_value,
704
  output_attentions=output_attentions,
 
713
  if output_attentions:
714
  all_self_attns += (layer_outputs[1],)
715
 
716
+ # Model Parallel: If it's the last layer for that device, put things on the next device
717
+ if self.model_parallel:
718
+ for k, v in self.device_map.items():
719
+ if idx == v[-1] and "cuda:" + str(k) != self.last_device:
720
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
721
+
722
  if self.final_layer_norm is not None:
723
  hidden_states = self.final_layer_norm(hidden_states)
724
 
 
731
 
732
  next_cache = next_decoder_cache if use_cache else None
733
  if not return_dict:
734
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
 
 
 
735
  return BaseModelOutputWithPast(
736
  last_hidden_state=hidden_states,
737
  past_key_values=next_cache,
 
748
  def __init__(self, config: OPTConfig):
749
  super().__init__(config)
750
  self.decoder = OPTDecoder(config)
751
+
752
+ # Model parallel
753
+ self.decoder.model_parallel = False
754
+ self.decoder.device_map = None
755
+ self.decoder.gradient_checkpointing = False
756
+
757
  # Initialize weights and apply final processing
758
  self.post_init()
759
 
760
+ def parallelize(self, device_map=None):
761
+ # Check validity of device_map
762
+ self.decoder.device_map = (
763
+ get_device_map(len(self.decoder.layers), range(torch.cuda.device_count())) if device_map is None else device_map
764
+ )
765
+ assert_device_map(self.decoder.device_map, len(self.decoder.layers))
766
+ self.decoder.model_parallel = True
767
+ self.decoder.first_device = "cpu" if "cpu" in self.decoder.device_map.keys() else "cuda:" + str(min(self.decoder.device_map.keys()))
768
+ self.decoder.last_device = "cuda:" + str(max(self.decoder.device_map.keys()))
769
+ self.decoder.embed_tokens = self.decoder.embed_tokens.to(self.decoder.first_device)
770
+ self.decoder.embed_positions = self.decoder.embed_positions.to(self.decoder.first_device)
771
+ # Load onto devices
772
+ for k, v in self.decoder.device_map.items():
773
+ for block in v:
774
+ cuda_device = "cuda:" + str(k)
775
+ self.decoder.layers[block] = self.decoder.layers[block].to(cuda_device)
776
+ # final_layer_norm to last
777
+ self.decoder.final_layer_norm = self.decoder.final_layer_norm.to(self.decoder.last_device)
778
+
779
+ def deparallelize(self):
780
+ self.decoder.model_parallel = False
781
+ self.decoder.device_map = None
782
+ self.decoder.first_device = "cpu"
783
+ self.decoder.last_device = "cpu"
784
+ self.decoder.embed_tokens = self.decoder.embed_tokens.to("cpu")
785
+ self.decoder.embed_positions = self.decoder.embed_positions.to("cpu")
786
+ for index in range(len(self.decoder)):
787
+ self.decoder.layers[index] = self.decoder.layers[index].to("cpu")
788
+ self.decoder.final_layer_norm = self.decoder.final_layer_norm.to("cpu")
789
+ torch.cuda.empty_cache()
790
+
791
  def get_input_embeddings(self):
792
  return self.decoder.embed_tokens
793
 
 
799
 
800
  @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
801
  @add_code_sample_docstrings(
802
+ processor_class=_TOKENIZER_FOR_DOC,
803
  checkpoint=_CHECKPOINT_FOR_DOC,
804
  output_type=BaseModelOutputWithPast,
805
  config_class=_CONFIG_FOR_DOC,
 
817
  output_hidden_states: Optional[bool] = None,
818
  return_dict: Optional[bool] = None,
819
  ) -> Union[Tuple, BaseModelOutputWithPast]:
820
+
821
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
822
  output_hidden_states = (
823
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
824
  )
825
  use_cache = use_cache if use_cache is not None else self.config.use_cache
826
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
827
 
828
  # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
829
  decoder_outputs = self.decoder(
 
850
 
851
 
852
  class OPTForCausalLM(OPTPreTrainedModel):
853
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
854
 
855
  def __init__(self, config):
856
  super().__init__(config)
857
  self.model = OPTModel(config)
858
 
859
  # the lm_head weight is automatically tied to the embed tokens weight
860
+ self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
861
+
862
+ # Model parallel
863
+ self.model_parallel = False
864
+ self.device_map = None
865
 
866
  # Initialize weights and apply final processing
867
  self.post_init()
868
 
869
+ def parallelize(self, device_map=None):
870
+ self.model.decoder.device_map = (
871
+ get_device_map(len(self.model.decoder.layers), range(torch.cuda.device_count()))
872
+ if device_map is None
873
+ else device_map
874
+ )
875
+ assert_device_map(self.model.decoder.device_map, len(self.model.decoder.layers))
876
+ self.model.parallelize(self.model.decoder.device_map)
877
+ self.lm_head = self.lm_head.to(self.model.decoder.first_device)
878
+ self.model_parallel = True
879
+
880
+ def deparallelize(self):
881
+ self.model.deparallelize()
882
+ self.model = self.model.to("cpu")
883
+ self.lm_head = self.lm_head.to("cpu")
884
+ self.model_parallel = False
885
+ torch.cuda.empty_cache()
886
+
887
  def get_input_embeddings(self):
888
  return self.model.decoder.embed_tokens
889
 
 
902
  def get_decoder(self):
903
  return self.model.decoder
904
 
905
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
 
906
  def forward(
907
  self,
908
  input_ids: torch.LongTensor = None,
 
922
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
923
  provide it.
924
 
925
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
926
  [`PreTrainedTokenizer.__call__`] for details.
927
 
928
  [What are input IDs?](../glossary#input-ids)
 
976
  Example:
977
 
978
  ```python
979
+ >>> from transformers import GPT2Tokenizer, OPTForCausalLM
980
 
981
  >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
982
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
983
 
984
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
985
  >>> inputs = tokenizer(prompt, return_tensors="pt")
986
 
987
  >>> # Generate
988
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
989
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
990
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
991
  ```"""
992
 
993
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
994
  output_hidden_states = (
995
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
996
  )
997
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
998
 
999
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1000
  outputs = self.model.decoder(
 
1009
  return_dict=return_dict,
1010
  )
1011
 
1012
+ # Set device for model parallelism
1013
+ if self.model.decoder.model_parallel:
1014
+ torch.cuda.set_device(self.model.decoder.first_device)
1015
+
1016
+ logits = self.lm_head(outputs[0].to(self.lm_head.weight.device)).contiguous()
1017
 
1018
  loss = None
1019
  if labels is not None:
 
1024
  shift_labels = labels[..., 1:].contiguous()
1025
  # Flatten the tokens
1026
  loss_fct = CrossEntropyLoss()
1027
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
 
 
1028
 
1029
  if not return_dict:
1030
  output = (logits,) + outputs[1:]
 
1038
  attentions=outputs.attentions,
1039
  )
1040
 
1041
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
1042
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1043
+ if attention_mask is None:
1044
+ attention_mask = input_ids.new_ones(input_ids.shape)
1045
+
1046
+ if past:
1047
+ input_ids = input_ids[:, -1:]
1048
+ # first step, decoder_cached_states are empty
1049
+ return {
1050
+ "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
1051
+ "attention_mask": attention_mask,
1052
+ "past_key_values": past,
1053
+ "use_cache": use_cache,
1054
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1055
 
1056
  @staticmethod
1057
+ def _reorder_cache(past, beam_idx):
1058
  reordered_past = ()
1059
+ for layer_past in past:
1060
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
 
 
 
 
 
1061
  return reordered_past
1062
 
1063
 
 
1077
  OPT_START_DOCSTRING,
1078
  )
1079
  class OPTForSequenceClassification(OPTPreTrainedModel):
1080
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
1081
+
1082
  def __init__(self, config: OPTConfig):
1083
  super().__init__(config)
1084
  self.num_labels = config.num_labels
 
1090
 
1091
  @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
1092
  @add_code_sample_docstrings(
1093
+ processor_class=_TOKENIZER_FOR_DOC,
1094
  checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
1095
  output_type=SequenceClassifierOutputWithPast,
1096
  config_class=_CONFIG_FOR_DOC,
 
1116
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1117
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1118
  """
1119
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1120
 
1121
  transformer_outputs = self.model(
1122
  input_ids,
 
1141
  sequence_lengths = -1
1142
  else:
1143
  if input_ids is not None:
1144
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
 
 
 
 
 
1145
  else:
1146
  sequence_lengths = -1
1147
  logger.warning(
 
1149
  "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1150
  )
1151
 
1152
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
 
 
1153
 
1154
  loss = None
1155
  if labels is not None:
1156
  if self.config.problem_type is None:
1157
  if self.num_labels == 1:
1158
  self.config.problem_type = "regression"
1159
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
 
 
1160
  self.config.problem_type = "single_label_classification"
1161
  else:
1162
  self.config.problem_type = "multi_label_classification"
 
1169
  loss = loss_fct(pooled_logits, labels)
1170
  elif self.config.problem_type == "single_label_classification":
1171
  loss_fct = CrossEntropyLoss()
1172
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
 
 
1173
  elif self.config.problem_type == "multi_label_classification":
1174
  loss_fct = BCEWithLogitsLoss()
1175
  loss = loss_fct(pooled_logits, labels)
 
1200
  OPT_START_DOCSTRING,
1201
  )
1202
  class OPTForQuestionAnswering(OPTPreTrainedModel):
1203
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
1204
+
1205
  def __init__(self, config: OPTConfig):
1206
  super().__init__(config)
1207
  self.model = OPTModel(config)
 
1211
  self.post_init()
1212
 
1213
  @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
1214
+ @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
 
 
1215
  def forward(
1216
  self,
1217
  input_ids: Optional[torch.LongTensor] = None,
 
1241
  Example:
1242
 
1243
  ```python
1244
+ >>> from transformers import GPT2Tokenizer, OPTForQuestionAnswering
1245
  >>> import torch
1246
 
1247
  >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT
1248
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
1249
 
1250
  >>> # note: we are loading a OPTForQuestionAnswering from the hub here,
1251
  >>> # so the head will be randomly initialized, hence the predictions will be random
 
1260
  >>> answer_start_index = outputs.start_logits.argmax()
1261
  >>> answer_end_index = outputs.end_logits.argmax()
1262
 
1263
+ >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
 
 
 
 
1264
  >>> predicted = tokenizer.decode(predict_answer_tokens)
1265
  >>> predicted
1266
+ ' Henson?'
1267
  ```"""
1268
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1269
 
1270
  transformer_outputs = self.model(
1271
  input_ids,