YixuanWeng commited on
Commit
4c80210
·
1 Parent(s): 996025f

Upload modeling_cmpt.py

Browse files
Files changed (1) hide show
  1. modeling_cmpt.py +1836 -0
modeling_cmpt.py ADDED
@@ -0,0 +1,1836 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch BART model."""
16
+ import copy
17
+ import math
18
+ import random
19
+ import warnings
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
+
27
+ from transformers.activations import ACT2FN
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutput,
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ Seq2SeqLMOutput,
33
+ Seq2SeqModelOutput,
34
+ Seq2SeqQuestionAnsweringModelOutput,
35
+ Seq2SeqSequenceClassifierOutput,
36
+ )
37
+ from transformers.modeling_utils import PreTrainedModel
38
+ from transformers.utils import (
39
+ add_code_sample_docstrings,
40
+ add_end_docstrings,
41
+ add_start_docstrings,
42
+ add_start_docstrings_to_model_forward,
43
+ replace_return_docstrings,
44
+ logging
45
+ )
46
+ from transformers import BartConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ _CHECKPOINT_FOR_DOC = "facebook/bart-base"
52
+ _CONFIG_FOR_DOC = "BartConfig"
53
+ _TOKENIZER_FOR_DOC = "BartTokenizer"
54
+
55
+ # Base model docstring
56
+ _EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
57
+
58
+ # SequenceClassification docstring
59
+ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "valhalla/bart-large-sst2"
60
+ _SEQ_CLASS_EXPECTED_LOSS = 0.0
61
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'POSITIVE'"
62
+
63
+ # QuestionAsnwering docstring
64
+ _CHECKPOINT_FOR_QA = "valhalla/bart-large-finetuned-squadv1"
65
+ _QA_EXPECTED_LOSS = 0.59
66
+ _QA_EXPECTED_OUTPUT = "' nice puppet'"
67
+
68
+
69
+ BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
70
+ "facebook/bart-large",
71
+ # see all BART models at https://huggingface.co/models?filter=bart
72
+ ]
73
+
74
+ deepnet_gain = {
75
+ "encoder": {
76
+ "alpha": lambda config: 0.81
77
+ * (config.encoder_layers**4 * config.decoder_layers) ** 0.0625,
78
+ "beta": lambda config: 0.87
79
+ * (config.encoder_layers**4 * config.decoder_layers) ** -0.0625,
80
+ },
81
+ "decoder": {
82
+ "alpha": lambda config: (3 * config.decoder_layers) ** 0.25,
83
+ "beta": lambda config: (12 * config.decoder_layers) ** -0.25,
84
+ },
85
+ }
86
+
87
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
88
+ """
89
+ Shift input ids one token to the right.
90
+ """
91
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
92
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
93
+ shifted_input_ids[:, 0] = decoder_start_token_id
94
+
95
+ if pad_token_id is None:
96
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
97
+ # replace possible -100 values in labels by `pad_token_id`
98
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
99
+
100
+ return shifted_input_ids
101
+
102
+
103
+ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
104
+ """
105
+ Make causal mask used for bi-directional self-attention.
106
+ """
107
+ bsz, tgt_len = input_ids_shape
108
+ mask = torch.full((tgt_len, tgt_len), float("-inf"))
109
+ mask_cond = torch.arange(mask.size(-1))
110
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
111
+ mask = mask.to(dtype)
112
+
113
+ if past_key_values_length > 0:
114
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
115
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
116
+
117
+
118
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
119
+ """
120
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
121
+ """
122
+ bsz, src_len = mask.size()
123
+ tgt_len = tgt_len if tgt_len is not None else src_len
124
+
125
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
126
+
127
+ inverted_mask = 1.0 - expanded_mask
128
+
129
+ return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
130
+
131
+
132
+ class BartLearnedPositionalEmbedding(nn.Embedding):
133
+ """
134
+ This module learns positional embeddings up to a fixed maximum size.
135
+ """
136
+
137
+ def __init__(self, num_embeddings: int, embedding_dim: int):
138
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
139
+ # and adjust num_embeddings appropriately. Other models don't have this hack
140
+ self.offset = 2
141
+ super().__init__(num_embeddings + self.offset, embedding_dim)
142
+
143
+ def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
144
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
145
+ bsz, seq_len = input_ids_shape[:2]
146
+ positions = torch.arange(
147
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
148
+ )
149
+ return super().forward(positions + self.offset)
150
+
151
+
152
+ class BartAttention(nn.Module):
153
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
154
+
155
+ def __init__(
156
+ self,
157
+ embed_dim: int,
158
+ num_heads: int,
159
+ dropout: float = 0.0,
160
+ is_decoder: bool = False,
161
+ bias: bool = True,
162
+ ):
163
+ super().__init__()
164
+ self.embed_dim = embed_dim
165
+ self.num_heads = num_heads
166
+ self.dropout = dropout
167
+ self.head_dim = embed_dim // num_heads
168
+
169
+ if (self.head_dim * num_heads) != self.embed_dim:
170
+ raise ValueError(
171
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
172
+ f" and `num_heads`: {num_heads})."
173
+ )
174
+ self.scaling = self.head_dim**-0.5
175
+ self.is_decoder = is_decoder
176
+
177
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
178
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
179
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
180
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
181
+
182
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
183
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
184
+
185
+ def forward(
186
+ self,
187
+ hidden_states: torch.Tensor,
188
+ key_value_states: Optional[torch.Tensor] = None,
189
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
190
+ attention_mask: Optional[torch.Tensor] = None,
191
+ layer_head_mask: Optional[torch.Tensor] = None,
192
+ output_attentions: bool = False,
193
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
194
+ """Input shape: Batch x Time x Channel"""
195
+
196
+ # if key_value_states are provided this layer is used as a cross-attention layer
197
+ # for the decoder
198
+ is_cross_attention = key_value_states is not None
199
+
200
+ bsz, tgt_len, _ = hidden_states.size()
201
+
202
+ # get query proj
203
+ query_states = self.q_proj(hidden_states) * self.scaling
204
+ # get key, value proj
205
+ if is_cross_attention and past_key_value is not None:
206
+ # reuse k,v, cross_attentions
207
+ key_states = past_key_value[0]
208
+ value_states = past_key_value[1]
209
+ elif is_cross_attention:
210
+ # cross_attentions
211
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
212
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
213
+ elif past_key_value is not None:
214
+ # reuse k, v, self_attention
215
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
216
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
217
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
218
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
219
+ else:
220
+ # self_attention
221
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
222
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
223
+
224
+ if self.is_decoder:
225
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
226
+ # Further calls to cross_attention layer can then reuse all cross-attention
227
+ # key/value_states (first "if" case)
228
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
229
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
230
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
231
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
232
+ past_key_value = (key_states, value_states)
233
+
234
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
235
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
236
+ key_states = key_states.view(*proj_shape)
237
+ value_states = value_states.view(*proj_shape)
238
+
239
+ src_len = key_states.size(1)
240
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
241
+
242
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
243
+ raise ValueError(
244
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
245
+ )
246
+
247
+ if attention_mask is not None:
248
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
249
+ raise ValueError(
250
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
251
+ )
252
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
253
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
254
+
255
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
256
+
257
+ if layer_head_mask is not None:
258
+ if layer_head_mask.size() != (self.num_heads,):
259
+ raise ValueError(
260
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
261
+ )
262
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
263
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
264
+
265
+ if output_attentions:
266
+ # this operation is a bit awkward, but it's required to
267
+ # make sure that attn_weights keeps its gradient.
268
+ # In order to do so, attn_weights have to be reshaped
269
+ # twice and have to be reused in the following
270
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
271
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
272
+ else:
273
+ attn_weights_reshaped = None
274
+
275
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
276
+
277
+ attn_output = torch.bmm(attn_probs, value_states)
278
+
279
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
280
+ raise ValueError(
281
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
282
+ )
283
+
284
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
285
+ attn_output = attn_output.transpose(1, 2)
286
+
287
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
288
+ # partitioned aross GPUs when using tensor-parallelism.
289
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
290
+
291
+ attn_output = self.out_proj(attn_output)
292
+
293
+ return attn_output, attn_weights_reshaped, past_key_value
294
+
295
+
296
+ class BartEncoderLayer(nn.Module):
297
+ def __init__(self, config: BartConfig):
298
+ super().__init__()
299
+ self.embed_dim = config.d_model
300
+ self.self_attn = BartAttention(
301
+ embed_dim=self.embed_dim,
302
+ num_heads=config.encoder_attention_heads,
303
+ dropout=config.attention_dropout,
304
+ )
305
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
306
+ self.dropout = config.dropout
307
+ self.activation_fn = ACT2FN[config.activation_function]
308
+ self.activation_dropout = config.activation_dropout
309
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
310
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
311
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
312
+ self.res_gain = deepnet_gain["encoder"]["alpha"](config)
313
+
314
+ def forward(
315
+ self,
316
+ hidden_states: torch.FloatTensor,
317
+ attention_mask: torch.FloatTensor,
318
+ layer_head_mask: torch.FloatTensor,
319
+ output_attentions: Optional[bool] = False,
320
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
321
+ """
322
+ Args:
323
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
324
+ attention_mask (`torch.FloatTensor`): attention mask of size
325
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
326
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
327
+ `(encoder_attention_heads,)`.
328
+ output_attentions (`bool`, *optional*):
329
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
330
+ returned tensors for more detail.
331
+ """
332
+ residual = hidden_states
333
+ hidden_states, attn_weights, _ = self.self_attn(
334
+ hidden_states=hidden_states,
335
+ attention_mask=attention_mask,
336
+ layer_head_mask=layer_head_mask,
337
+ output_attentions=output_attentions,
338
+ )
339
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
340
+ hidden_states = residual * self.res_gain + hidden_states
341
+ hidden_states = self.self_attn_layer_norm(hidden_states)
342
+
343
+ residual = hidden_states
344
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
345
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
346
+ hidden_states = self.fc2(hidden_states)
347
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
348
+ hidden_states = residual * self.res_gain + hidden_states
349
+ hidden_states = self.final_layer_norm(hidden_states)
350
+
351
+ if hidden_states.dtype == torch.float16 and (
352
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
353
+ ):
354
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
355
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
356
+
357
+ outputs = (hidden_states,)
358
+
359
+ if output_attentions:
360
+ outputs += (attn_weights,)
361
+
362
+ return outputs
363
+
364
+
365
+ class BartDecoderLayer(nn.Module):
366
+ def __init__(self, config: BartConfig):
367
+ super().__init__()
368
+ self.embed_dim = config.d_model
369
+
370
+ self.self_attn = BartAttention(
371
+ embed_dim=self.embed_dim,
372
+ num_heads=config.decoder_attention_heads,
373
+ dropout=config.attention_dropout,
374
+ is_decoder=True,
375
+ )
376
+ self.dropout = config.dropout
377
+ self.activation_fn = ACT2FN[config.activation_function]
378
+ self.activation_dropout = config.activation_dropout
379
+
380
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
381
+ self.encoder_attn = BartAttention(
382
+ self.embed_dim,
383
+ config.decoder_attention_heads,
384
+ dropout=config.attention_dropout,
385
+ is_decoder=True,
386
+ )
387
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
388
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
389
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
390
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
391
+ self.res_gain = deepnet_gain["decoder"]["alpha"](config)
392
+
393
+ def forward(
394
+ self,
395
+ hidden_states: torch.Tensor,
396
+ attention_mask: Optional[torch.Tensor] = None,
397
+ encoder_hidden_states: Optional[torch.Tensor] = None,
398
+ encoder_attention_mask: Optional[torch.Tensor] = None,
399
+ layer_head_mask: Optional[torch.Tensor] = None,
400
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
401
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
402
+ output_attentions: Optional[bool] = False,
403
+ use_cache: Optional[bool] = True,
404
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
405
+ """
406
+ Args:
407
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
408
+ attention_mask (`torch.FloatTensor`): attention mask of size
409
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
410
+ encoder_hidden_states (`torch.FloatTensor`):
411
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
412
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
413
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
414
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
415
+ `(encoder_attention_heads,)`.
416
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
417
+ size `(decoder_attention_heads,)`.
418
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
419
+ output_attentions (`bool`, *optional*):
420
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
421
+ returned tensors for more detail.
422
+ """
423
+ residual = hidden_states
424
+
425
+ # Self Attention
426
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
427
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
428
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
429
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
430
+ hidden_states=hidden_states,
431
+ past_key_value=self_attn_past_key_value,
432
+ attention_mask=attention_mask,
433
+ layer_head_mask=layer_head_mask,
434
+ output_attentions=output_attentions,
435
+ )
436
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
437
+ hidden_states = residual * self.res_gain + hidden_states
438
+ hidden_states = self.self_attn_layer_norm(hidden_states)
439
+
440
+ # Cross-Attention Block
441
+ cross_attn_present_key_value = None
442
+ cross_attn_weights = None
443
+ if encoder_hidden_states is not None:
444
+ residual = hidden_states
445
+
446
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
447
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
448
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
449
+ hidden_states=hidden_states,
450
+ key_value_states=encoder_hidden_states,
451
+ attention_mask=encoder_attention_mask,
452
+ layer_head_mask=cross_attn_layer_head_mask,
453
+ past_key_value=cross_attn_past_key_value,
454
+ output_attentions=output_attentions,
455
+ )
456
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
457
+ hidden_states = residual * self.res_gain + hidden_states
458
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
459
+
460
+ # add cross-attn to positions 3,4 of present_key_value tuple
461
+ present_key_value = present_key_value + cross_attn_present_key_value
462
+
463
+ # Fully Connected
464
+ residual = hidden_states
465
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
466
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
467
+ hidden_states = self.fc2(hidden_states)
468
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
469
+ hidden_states = residual * self.res_gain + hidden_states
470
+ hidden_states = self.final_layer_norm(hidden_states)
471
+
472
+ outputs = (hidden_states,)
473
+
474
+ if output_attentions:
475
+ outputs += (self_attn_weights, cross_attn_weights)
476
+
477
+ if use_cache:
478
+ outputs += (present_key_value,)
479
+
480
+ return outputs
481
+
482
+
483
+ class BartClassificationHead(nn.Module):
484
+ """Head for sentence-level classification tasks."""
485
+
486
+ def __init__(
487
+ self,
488
+ input_dim: int,
489
+ inner_dim: int,
490
+ num_classes: int,
491
+ pooler_dropout: float,
492
+ ):
493
+ super().__init__()
494
+ self.dense = nn.Linear(input_dim, inner_dim)
495
+ self.dropout = nn.Dropout(p=pooler_dropout)
496
+ self.out_proj = nn.Linear(inner_dim, num_classes)
497
+
498
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
499
+ hidden_states = self.dropout(hidden_states)
500
+ hidden_states = self.dense(hidden_states)
501
+ hidden_states = torch.tanh(hidden_states)
502
+ hidden_states = self.dropout(hidden_states)
503
+ hidden_states = self.out_proj(hidden_states)
504
+ return hidden_states
505
+
506
+
507
+ class BartPretrainedModel(PreTrainedModel):
508
+ config_class = BartConfig
509
+ base_model_prefix = "model"
510
+ supports_gradient_checkpointing = True
511
+ _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
512
+
513
+ def _init_weights(self, module):
514
+ std = self.config.init_std
515
+ if isinstance(module, nn.Linear):
516
+ module.weight.data.normal_(mean=0.0, std=std)
517
+ if module.bias is not None:
518
+ module.bias.data.zero_()
519
+ elif isinstance(module, nn.Embedding):
520
+ module.weight.data.normal_(mean=0.0, std=std)
521
+ if module.padding_idx is not None:
522
+ module.weight.data[module.padding_idx].zero_()
523
+
524
+ def _set_gradient_checkpointing(self, module, value=False):
525
+ if isinstance(module, (BartDecoder, BartEncoder)):
526
+ module.gradient_checkpointing = value
527
+
528
+ @property
529
+ def dummy_inputs(self):
530
+ pad_token = self.config.pad_token_id
531
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
532
+ dummy_inputs = {
533
+ "attention_mask": input_ids.ne(pad_token),
534
+ "input_ids": input_ids,
535
+ }
536
+ return dummy_inputs
537
+
538
+
539
+ class PretrainedBartModel(BartPretrainedModel):
540
+ def __init_subclass__(self):
541
+ warnings.warn(
542
+ "The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.",
543
+ FutureWarning,
544
+ )
545
+
546
+
547
+ BART_START_DOCSTRING = r"""
548
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
549
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
550
+ etc.)
551
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
552
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
553
+ and behavior.
554
+ Parameters:
555
+ config ([`BartConfig`]):
556
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
557
+ load the weights associated with the model, only the configuration. Check out the
558
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
559
+ """
560
+
561
+ BART_GENERATION_EXAMPLE = r"""
562
+ Summarization example:
563
+ ```python
564
+ >>> from transformers import BartTokenizer, BartForConditionalGeneration
565
+ >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
566
+ >>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
567
+ >>> ARTICLE_TO_SUMMARIZE = (
568
+ ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
569
+ ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
570
+ ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
571
+ ... )
572
+ >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")
573
+ >>> # Generate Summary
574
+ >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, max_length=20)
575
+ >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
576
+ 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions'
577
+ ```
578
+ Mask filling example:
579
+ ```python
580
+ >>> from transformers import BartTokenizer, BartForConditionalGeneration
581
+ >>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
582
+ >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
583
+ >>> TXT = "My friends are <mask> but they eat too many carbs."
584
+ >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
585
+ >>> logits = model(input_ids).logits
586
+ >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
587
+ >>> probs = logits[0, masked_index].softmax(dim=0)
588
+ >>> values, predictions = probs.topk(5)
589
+ >>> tokenizer.decode(predictions).split()
590
+ ['not', 'good', 'healthy', 'great', 'very']
591
+ ```
592
+ """
593
+
594
+ BART_INPUTS_DOCSTRING = r"""
595
+ Args:
596
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
597
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
598
+ it.
599
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
600
+ [`PreTrainedTokenizer.__call__`] for details.
601
+ [What are input IDs?](../glossary#input-ids)
602
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
603
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
604
+ - 1 for tokens that are **not masked**,
605
+ - 0 for tokens that are **masked**.
606
+ [What are attention masks?](../glossary#attention-mask)
607
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
608
+ Indices of decoder input sequence tokens in the vocabulary.
609
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
610
+ [`PreTrainedTokenizer.__call__`] for details.
611
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
612
+ Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
613
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
614
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
615
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
616
+ for denoising pre-training following the paper.
617
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
618
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
619
+ be used by default.
620
+ If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_inputs`] and
621
+ modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information
622
+ on the default strategy.
623
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
624
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
625
+ - 1 indicates the head is **not masked**,
626
+ - 0 indicates the head is **masked**.
627
+ decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
628
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
629
+ - 1 indicates the head is **not masked**,
630
+ - 0 indicates the head is **masked**.
631
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
632
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
633
+ 1]`:
634
+ - 1 indicates the head is **not masked**,
635
+ - 0 indicates the head is **masked**.
636
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
637
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
638
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
639
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
640
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
641
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
642
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
643
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
644
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
645
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
646
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
647
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
648
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
649
+ `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
650
+ can choose to directly pass an embedded representation. This is useful if you want more control over how to
651
+ convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
652
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
653
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
654
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
655
+ input (see `past_key_values`). This is useful if you want more control over how to convert
656
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
657
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
658
+ of `inputs_embeds`.
659
+ use_cache (`bool`, *optional*):
660
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
661
+ `past_key_values`).
662
+ output_attentions (`bool`, *optional*):
663
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
664
+ tensors for more detail.
665
+ output_hidden_states (`bool`, *optional*):
666
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
667
+ more detail.
668
+ return_dict (`bool`, *optional*):
669
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
670
+ """
671
+
672
+
673
+ class BartEncoder(BartPretrainedModel):
674
+ """
675
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
676
+ [`BartEncoderLayer`].
677
+ Args:
678
+ config: BartConfig
679
+ embed_tokens (nn.Embedding): output embedding
680
+ """
681
+
682
+ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
683
+ super().__init__(config)
684
+
685
+ self.dropout = config.dropout
686
+ self.layerdrop = config.encoder_layerdrop
687
+
688
+ embed_dim = config.d_model
689
+ self.padding_idx = config.pad_token_id
690
+ self.max_source_positions = config.max_position_embeddings
691
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
692
+
693
+ if embed_tokens is not None:
694
+ self.embed_tokens = embed_tokens
695
+ else:
696
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
697
+
698
+ self.embed_positions = BartLearnedPositionalEmbedding(
699
+ config.max_position_embeddings,
700
+ embed_dim,
701
+ )
702
+ self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
703
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
704
+
705
+ self.gradient_checkpointing = False
706
+ # Initialize weights and apply final processing
707
+ self.post_init()
708
+
709
+ def get_input_embeddings(self):
710
+ return self.embed_tokens
711
+
712
+ def set_input_embeddings(self, value):
713
+ self.embed_tokens = value
714
+
715
+ def forward(
716
+ self,
717
+ input_ids: torch.LongTensor = None,
718
+ attention_mask: Optional[torch.Tensor] = None,
719
+ head_mask: Optional[torch.Tensor] = None,
720
+ inputs_embeds: Optional[torch.FloatTensor] = None,
721
+ output_attentions: Optional[bool] = None,
722
+ output_hidden_states: Optional[bool] = None,
723
+ return_dict: Optional[bool] = None,
724
+ ) -> Union[Tuple, BaseModelOutput]:
725
+ r"""
726
+ Args:
727
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
728
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
729
+ provide it.
730
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
731
+ [`PreTrainedTokenizer.__call__`] for details.
732
+ [What are input IDs?](../glossary#input-ids)
733
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
734
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
735
+ - 1 for tokens that are **not masked**,
736
+ - 0 for tokens that are **masked**.
737
+ [What are attention masks?](../glossary#attention-mask)
738
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
739
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
740
+ - 1 indicates the head is **not masked**,
741
+ - 0 indicates the head is **masked**.
742
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
743
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
744
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
745
+ than the model's internal embedding lookup matrix.
746
+ output_attentions (`bool`, *optional*):
747
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
748
+ returned tensors for more detail.
749
+ output_hidden_states (`bool`, *optional*):
750
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
751
+ for more detail.
752
+ return_dict (`bool`, *optional*):
753
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
754
+ """
755
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
756
+ output_hidden_states = (
757
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
758
+ )
759
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
760
+
761
+ # retrieve input_ids and inputs_embeds
762
+ if input_ids is not None and inputs_embeds is not None:
763
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
764
+ elif input_ids is not None:
765
+ input_shape = input_ids.size()
766
+ input_ids = input_ids.view(-1, input_shape[-1])
767
+ elif inputs_embeds is not None:
768
+ input_shape = inputs_embeds.size()[:-1]
769
+ else:
770
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
771
+
772
+ if inputs_embeds is None:
773
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
774
+
775
+ embed_pos = self.embed_positions(input_shape)
776
+
777
+ hidden_states = inputs_embeds + embed_pos
778
+ hidden_states = self.layernorm_embedding(hidden_states)
779
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
780
+
781
+ # expand attention_mask
782
+ if attention_mask is not None:
783
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
784
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
785
+
786
+ encoder_states = () if output_hidden_states else None
787
+ all_attentions = () if output_attentions else None
788
+
789
+ # check if head_mask has a correct number of layers specified if desired
790
+ if head_mask is not None:
791
+ if head_mask.size()[0] != (len(self.layers)):
792
+ raise ValueError(
793
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
794
+ )
795
+
796
+ for idx, encoder_layer in enumerate(self.layers):
797
+ if output_hidden_states:
798
+ encoder_states = encoder_states + (hidden_states,)
799
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
800
+ dropout_probability = random.uniform(0, 1)
801
+ if self.training and (dropout_probability < self.layerdrop): # skip the layer
802
+ layer_outputs = (None, None)
803
+ else:
804
+ if self.gradient_checkpointing and self.training:
805
+
806
+ def create_custom_forward(module):
807
+ def custom_forward(*inputs):
808
+ return module(*inputs, output_attentions)
809
+
810
+ return custom_forward
811
+
812
+ layer_outputs = torch.utils.checkpoint.checkpoint(
813
+ create_custom_forward(encoder_layer),
814
+ hidden_states,
815
+ attention_mask,
816
+ (head_mask[idx] if head_mask is not None else None),
817
+ )
818
+ else:
819
+ layer_outputs = encoder_layer(
820
+ hidden_states,
821
+ attention_mask,
822
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
823
+ output_attentions=output_attentions,
824
+ )
825
+
826
+ hidden_states = layer_outputs[0]
827
+
828
+ if output_attentions:
829
+ all_attentions = all_attentions + (layer_outputs[1],)
830
+
831
+ if output_hidden_states:
832
+ encoder_states = encoder_states + (hidden_states,)
833
+
834
+ if not return_dict:
835
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
836
+ return BaseModelOutput(
837
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
838
+ )
839
+
840
+
841
+ class BartDecoder(BartPretrainedModel):
842
+ """
843
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`]
844
+ Args:
845
+ config: BartConfig
846
+ embed_tokens (nn.Embedding): output embedding
847
+ """
848
+
849
+ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
850
+ super().__init__(config)
851
+ self.dropout = config.dropout
852
+ self.layerdrop = config.decoder_layerdrop
853
+ self.padding_idx = config.pad_token_id
854
+ self.max_target_positions = config.max_position_embeddings
855
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
856
+
857
+ if embed_tokens is not None:
858
+ self.embed_tokens = embed_tokens
859
+ else:
860
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
861
+
862
+ self.embed_positions = BartLearnedPositionalEmbedding(
863
+ config.max_position_embeddings,
864
+ config.d_model,
865
+ )
866
+ self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
867
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
868
+
869
+ self.gradient_checkpointing = False
870
+ # Initialize weights and apply final processing
871
+ self.post_init()
872
+
873
+ def get_input_embeddings(self):
874
+ return self.embed_tokens
875
+
876
+ def set_input_embeddings(self, value):
877
+ self.embed_tokens = value
878
+
879
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
880
+ # create causal mask
881
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
882
+ combined_attention_mask = None
883
+ if input_shape[-1] > 1:
884
+ combined_attention_mask = _make_causal_mask(
885
+ input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
886
+ ).to(self.device)
887
+
888
+ if attention_mask is not None:
889
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
890
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
891
+ combined_attention_mask = (
892
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
893
+ )
894
+
895
+ return combined_attention_mask
896
+
897
+ def forward(
898
+ self,
899
+ input_ids: torch.LongTensor = None,
900
+ attention_mask: Optional[torch.Tensor] = None,
901
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
902
+ encoder_attention_mask: Optional[torch.LongTensor] = None,
903
+ head_mask: Optional[torch.Tensor] = None,
904
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
905
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
906
+ inputs_embeds: Optional[torch.FloatTensor] = None,
907
+ use_cache: Optional[bool] = None,
908
+ output_attentions: Optional[bool] = None,
909
+ output_hidden_states: Optional[bool] = None,
910
+ return_dict: Optional[bool] = None,
911
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
912
+ r"""
913
+ Args:
914
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
915
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
916
+ provide it.
917
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
918
+ [`PreTrainedTokenizer.__call__`] for details.
919
+ [What are input IDs?](../glossary#input-ids)
920
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
921
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
922
+ - 1 for tokens that are **not masked**,
923
+ - 0 for tokens that are **masked**.
924
+ [What are attention masks?](../glossary#attention-mask)
925
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
926
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
927
+ of the decoder.
928
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
929
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
930
+ selected in `[0, 1]`:
931
+ - 1 for tokens that are **not masked**,
932
+ - 0 for tokens that are **masked**.
933
+ [What are attention masks?](../glossary#attention-mask)
934
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
935
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
936
+ - 1 indicates the head is **not masked**,
937
+ - 0 indicates the head is **masked**.
938
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
939
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
940
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
941
+ - 1 indicates the head is **not masked**,
942
+ - 0 indicates the head is **masked**.
943
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
944
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
945
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
946
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
947
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
948
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
949
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
950
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
951
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
952
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
953
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
954
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
955
+ embedding lookup matrix.
956
+ output_attentions (`bool`, *optional*):
957
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
958
+ returned tensors for more detail.
959
+ output_hidden_states (`bool`, *optional*):
960
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
961
+ for more detail.
962
+ return_dict (`bool`, *optional*):
963
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
964
+ """
965
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
966
+ output_hidden_states = (
967
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
968
+ )
969
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
970
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
971
+
972
+ # retrieve input_ids and inputs_embeds
973
+ if input_ids is not None and inputs_embeds is not None:
974
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
975
+ elif input_ids is not None:
976
+ input_shape = input_ids.size()
977
+ input_ids = input_ids.view(-1, input_shape[-1])
978
+ elif inputs_embeds is not None:
979
+ input_shape = inputs_embeds.size()[:-1]
980
+ else:
981
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
982
+
983
+ # past_key_values_length
984
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
985
+
986
+ if inputs_embeds is None:
987
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
988
+
989
+ attention_mask = self._prepare_decoder_attention_mask(
990
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
991
+ )
992
+
993
+ # expand encoder attention mask
994
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
995
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
996
+ encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
997
+
998
+ # embed positions
999
+ positions = self.embed_positions(input_shape, past_key_values_length)
1000
+
1001
+ hidden_states = inputs_embeds + positions
1002
+ hidden_states = self.layernorm_embedding(hidden_states)
1003
+
1004
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1005
+
1006
+ # decoder layers
1007
+ all_hidden_states = () if output_hidden_states else None
1008
+ all_self_attns = () if output_attentions else None
1009
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1010
+ next_decoder_cache = () if use_cache else None
1011
+
1012
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1013
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
1014
+ if attn_mask is not None:
1015
+ if attn_mask.size()[0] != (len(self.layers)):
1016
+ raise ValueError(
1017
+ "The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
1018
+ )
1019
+
1020
+ for idx, decoder_layer in enumerate(self.layers):
1021
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1022
+ if output_hidden_states:
1023
+ all_hidden_states += (hidden_states,)
1024
+ dropout_probability = random.uniform(0, 1)
1025
+ if self.training and (dropout_probability < self.layerdrop):
1026
+ continue
1027
+
1028
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
1029
+
1030
+ if self.gradient_checkpointing and self.training:
1031
+
1032
+ if use_cache:
1033
+ logger.warning(
1034
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1035
+ )
1036
+ use_cache = False
1037
+
1038
+ def create_custom_forward(module):
1039
+ def custom_forward(*inputs):
1040
+ # None for past_key_value
1041
+ return module(*inputs, output_attentions, use_cache)
1042
+
1043
+ return custom_forward
1044
+
1045
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1046
+ create_custom_forward(decoder_layer),
1047
+ hidden_states,
1048
+ attention_mask,
1049
+ encoder_hidden_states,
1050
+ encoder_attention_mask,
1051
+ head_mask[idx] if head_mask is not None else None,
1052
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
1053
+ None,
1054
+ )
1055
+ else:
1056
+
1057
+ layer_outputs = decoder_layer(
1058
+ hidden_states,
1059
+ attention_mask=attention_mask,
1060
+ encoder_hidden_states=encoder_hidden_states,
1061
+ encoder_attention_mask=encoder_attention_mask,
1062
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1063
+ cross_attn_layer_head_mask=(
1064
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
1065
+ ),
1066
+ past_key_value=past_key_value,
1067
+ output_attentions=output_attentions,
1068
+ use_cache=use_cache,
1069
+ )
1070
+ hidden_states = layer_outputs[0]
1071
+
1072
+ if use_cache:
1073
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1074
+
1075
+ if output_attentions:
1076
+ all_self_attns += (layer_outputs[1],)
1077
+
1078
+ if encoder_hidden_states is not None:
1079
+ all_cross_attentions += (layer_outputs[2],)
1080
+
1081
+ # add hidden states from the last decoder layer
1082
+ if output_hidden_states:
1083
+ all_hidden_states += (hidden_states,)
1084
+
1085
+ next_cache = next_decoder_cache if use_cache else None
1086
+ if not return_dict:
1087
+ return tuple(
1088
+ v
1089
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
1090
+ if v is not None
1091
+ )
1092
+ return BaseModelOutputWithPastAndCrossAttentions(
1093
+ last_hidden_state=hidden_states,
1094
+ past_key_values=next_cache,
1095
+ hidden_states=all_hidden_states,
1096
+ attentions=all_self_attns,
1097
+ cross_attentions=all_cross_attentions,
1098
+ )
1099
+
1100
+
1101
+ @add_start_docstrings(
1102
+ "The bare BART Model outputting raw hidden-states without any specific head on top.",
1103
+ BART_START_DOCSTRING,
1104
+ )
1105
+ class BartModel(BartPretrainedModel):
1106
+ def __init__(self, config: BartConfig):
1107
+ super().__init__(config)
1108
+
1109
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
1110
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
1111
+
1112
+ self.encoder = BartEncoder(config, self.shared)
1113
+ self.decoder = BartDecoder(config, self.shared)
1114
+
1115
+ # Initialize weights and apply final processing
1116
+ self.post_init()
1117
+
1118
+ def get_input_embeddings(self):
1119
+ return self.shared
1120
+
1121
+ def set_input_embeddings(self, value):
1122
+ self.shared = value
1123
+ self.encoder.embed_tokens = self.shared
1124
+ self.decoder.embed_tokens = self.shared
1125
+
1126
+ def get_encoder(self):
1127
+ return self.encoder
1128
+
1129
+ def get_decoder(self):
1130
+ return self.decoder
1131
+
1132
+ @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
1133
+ @add_code_sample_docstrings(
1134
+ processor_class=_TOKENIZER_FOR_DOC,
1135
+ checkpoint=_CHECKPOINT_FOR_DOC,
1136
+ output_type=Seq2SeqModelOutput,
1137
+ config_class=_CONFIG_FOR_DOC,
1138
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
1139
+ )
1140
+ def forward(
1141
+ self,
1142
+ input_ids: torch.LongTensor = None,
1143
+ attention_mask: Optional[torch.Tensor] = None,
1144
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1145
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1146
+ head_mask: Optional[torch.Tensor] = None,
1147
+ decoder_head_mask: Optional[torch.Tensor] = None,
1148
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1149
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1150
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1151
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1152
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1153
+ use_cache: Optional[bool] = None,
1154
+ output_attentions: Optional[bool] = None,
1155
+ output_hidden_states: Optional[bool] = None,
1156
+ return_dict: Optional[bool] = None,
1157
+ ) -> Union[Tuple, Seq2SeqModelOutput]:
1158
+
1159
+ # different to other models, Bart automatically creates decoder_input_ids from
1160
+ # input_ids if no decoder_input_ids are provided
1161
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
1162
+ if input_ids is None:
1163
+ raise ValueError(
1164
+ "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
1165
+ "passed, `input_ids` cannot be `None`. Please pass either "
1166
+ "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
1167
+ )
1168
+
1169
+ decoder_input_ids = shift_tokens_right(
1170
+ input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
1171
+ )
1172
+
1173
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1174
+ output_hidden_states = (
1175
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1176
+ )
1177
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1178
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1179
+
1180
+ if encoder_outputs is None:
1181
+ encoder_outputs = self.encoder(
1182
+ input_ids=input_ids,
1183
+ attention_mask=attention_mask,
1184
+ head_mask=head_mask,
1185
+ inputs_embeds=inputs_embeds,
1186
+ output_attentions=output_attentions,
1187
+ output_hidden_states=output_hidden_states,
1188
+ return_dict=return_dict,
1189
+ )
1190
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1191
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1192
+ encoder_outputs = BaseModelOutput(
1193
+ last_hidden_state=encoder_outputs[0],
1194
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1195
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1196
+ )
1197
+
1198
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1199
+ decoder_outputs = self.decoder(
1200
+ input_ids=decoder_input_ids,
1201
+ attention_mask=decoder_attention_mask,
1202
+ encoder_hidden_states=encoder_outputs[0],
1203
+ encoder_attention_mask=attention_mask,
1204
+ head_mask=decoder_head_mask,
1205
+ cross_attn_head_mask=cross_attn_head_mask,
1206
+ past_key_values=past_key_values,
1207
+ inputs_embeds=decoder_inputs_embeds,
1208
+ use_cache=use_cache,
1209
+ output_attentions=output_attentions,
1210
+ output_hidden_states=output_hidden_states,
1211
+ return_dict=return_dict,
1212
+ )
1213
+
1214
+ if not return_dict:
1215
+ return decoder_outputs + encoder_outputs
1216
+
1217
+ return Seq2SeqModelOutput(
1218
+ last_hidden_state=decoder_outputs.last_hidden_state,
1219
+ past_key_values=decoder_outputs.past_key_values,
1220
+ decoder_hidden_states=decoder_outputs.hidden_states,
1221
+ decoder_attentions=decoder_outputs.attentions,
1222
+ cross_attentions=decoder_outputs.cross_attentions,
1223
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1224
+ encoder_hidden_states=encoder_outputs.hidden_states,
1225
+ encoder_attentions=encoder_outputs.attentions,
1226
+ )
1227
+
1228
+
1229
+ @add_start_docstrings(
1230
+ "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
1231
+ )
1232
+ class BartForConditionalGeneration(BartPretrainedModel):
1233
+ base_model_prefix = "model"
1234
+ _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
1235
+
1236
+ def __init__(self, config: BartConfig):
1237
+ super().__init__(config)
1238
+ self.model = BartModel(config)
1239
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
1240
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
1241
+
1242
+ # Initialize weights and apply final processing
1243
+ self.post_init()
1244
+
1245
+ def get_encoder(self):
1246
+ return self.model.get_encoder()
1247
+
1248
+ def get_decoder(self):
1249
+ return self.model.get_decoder()
1250
+
1251
+ def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
1252
+ new_embeddings = super().resize_token_embeddings(new_num_tokens)
1253
+ self._resize_final_logits_bias(new_num_tokens)
1254
+ return new_embeddings
1255
+
1256
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
1257
+ old_num_tokens = self.final_logits_bias.shape[-1]
1258
+ if new_num_tokens <= old_num_tokens:
1259
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
1260
+ else:
1261
+ extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
1262
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
1263
+ self.register_buffer("final_logits_bias", new_bias)
1264
+
1265
+ def get_output_embeddings(self):
1266
+ return self.lm_head
1267
+
1268
+ def set_output_embeddings(self, new_embeddings):
1269
+ self.lm_head = new_embeddings
1270
+
1271
+ @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
1272
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1273
+ @add_end_docstrings(BART_GENERATION_EXAMPLE)
1274
+ def forward(
1275
+ self,
1276
+ input_ids: torch.LongTensor = None,
1277
+ attention_mask: Optional[torch.Tensor] = None,
1278
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1279
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1280
+ head_mask: Optional[torch.Tensor] = None,
1281
+ decoder_head_mask: Optional[torch.Tensor] = None,
1282
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1283
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1284
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1285
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1286
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1287
+ labels: Optional[torch.LongTensor] = None,
1288
+ use_cache: Optional[bool] = None,
1289
+ output_attentions: Optional[bool] = None,
1290
+ output_hidden_states: Optional[bool] = None,
1291
+ return_dict: Optional[bool] = None,
1292
+ ) -> Union[Tuple, Seq2SeqLMOutput]:
1293
+ r"""
1294
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1295
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1296
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1297
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1298
+ Returns:
1299
+ """
1300
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1301
+
1302
+ if labels is not None:
1303
+ if use_cache:
1304
+ logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
1305
+ use_cache = False
1306
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
1307
+ decoder_input_ids = shift_tokens_right(
1308
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
1309
+ )
1310
+
1311
+ outputs = self.model(
1312
+ input_ids,
1313
+ attention_mask=attention_mask,
1314
+ decoder_input_ids=decoder_input_ids,
1315
+ encoder_outputs=encoder_outputs,
1316
+ decoder_attention_mask=decoder_attention_mask,
1317
+ head_mask=head_mask,
1318
+ decoder_head_mask=decoder_head_mask,
1319
+ cross_attn_head_mask=cross_attn_head_mask,
1320
+ past_key_values=past_key_values,
1321
+ inputs_embeds=inputs_embeds,
1322
+ decoder_inputs_embeds=decoder_inputs_embeds,
1323
+ use_cache=use_cache,
1324
+ output_attentions=output_attentions,
1325
+ output_hidden_states=output_hidden_states,
1326
+ return_dict=return_dict,
1327
+ )
1328
+ lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
1329
+
1330
+ masked_lm_loss = None
1331
+ if labels is not None:
1332
+ loss_fct = CrossEntropyLoss()
1333
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1334
+
1335
+ if not return_dict:
1336
+ output = (lm_logits,) + outputs[1:]
1337
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1338
+
1339
+ return Seq2SeqLMOutput(
1340
+ loss=masked_lm_loss,
1341
+ logits=lm_logits,
1342
+ past_key_values=outputs.past_key_values,
1343
+ decoder_hidden_states=outputs.decoder_hidden_states,
1344
+ decoder_attentions=outputs.decoder_attentions,
1345
+ cross_attentions=outputs.cross_attentions,
1346
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1347
+ encoder_hidden_states=outputs.encoder_hidden_states,
1348
+ encoder_attentions=outputs.encoder_attentions,
1349
+ )
1350
+
1351
+ def prepare_inputs_for_generation(
1352
+ self,
1353
+ decoder_input_ids,
1354
+ past=None,
1355
+ attention_mask=None,
1356
+ head_mask=None,
1357
+ decoder_head_mask=None,
1358
+ cross_attn_head_mask=None,
1359
+ use_cache=None,
1360
+ encoder_outputs=None,
1361
+ **kwargs
1362
+ ):
1363
+ # cut decoder_input_ids if past is used
1364
+ if past is not None:
1365
+ decoder_input_ids = decoder_input_ids[:, -1:]
1366
+
1367
+ return {
1368
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
1369
+ "encoder_outputs": encoder_outputs,
1370
+ "past_key_values": past,
1371
+ "decoder_input_ids": decoder_input_ids,
1372
+ "attention_mask": attention_mask,
1373
+ "head_mask": head_mask,
1374
+ "decoder_head_mask": decoder_head_mask,
1375
+ "cross_attn_head_mask": cross_attn_head_mask,
1376
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1377
+ }
1378
+
1379
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1380
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
1381
+
1382
+ @staticmethod
1383
+ def _reorder_cache(past, beam_idx):
1384
+ reordered_past = ()
1385
+ for layer_past in past:
1386
+ # cached cross_attention states don't have to be reordered -> they are always the same
1387
+ reordered_past += (
1388
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
1389
+ )
1390
+ return reordered_past
1391
+
1392
+
1393
+ @add_start_docstrings(
1394
+ """
1395
+ Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
1396
+ tasks.
1397
+ """,
1398
+ BART_START_DOCSTRING,
1399
+ )
1400
+ class BartForSequenceClassification(BartPretrainedModel):
1401
+ def __init__(self, config: BartConfig, **kwargs):
1402
+ super().__init__(config, **kwargs)
1403
+ self.model = BartModel(config)
1404
+ self.classification_head = BartClassificationHead(
1405
+ config.d_model,
1406
+ config.d_model,
1407
+ config.num_labels,
1408
+ config.classifier_dropout,
1409
+ )
1410
+ self.model._init_weights(self.classification_head.dense)
1411
+ self.model._init_weights(self.classification_head.out_proj)
1412
+
1413
+ @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
1414
+ @add_code_sample_docstrings(
1415
+ processor_class=_TOKENIZER_FOR_DOC,
1416
+ checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
1417
+ output_type=Seq2SeqSequenceClassifierOutput,
1418
+ config_class=_CONFIG_FOR_DOC,
1419
+ expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1420
+ expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
1421
+ )
1422
+ def forward(
1423
+ self,
1424
+ input_ids: torch.LongTensor = None,
1425
+ attention_mask: Optional[torch.Tensor] = None,
1426
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1427
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1428
+ head_mask: Optional[torch.Tensor] = None,
1429
+ decoder_head_mask: Optional[torch.Tensor] = None,
1430
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1431
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1432
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1433
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1434
+ labels: Optional[torch.LongTensor] = None,
1435
+ use_cache: Optional[bool] = None,
1436
+ output_attentions: Optional[bool] = None,
1437
+ output_hidden_states: Optional[bool] = None,
1438
+ return_dict: Optional[bool] = None,
1439
+ ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
1440
+ r"""
1441
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1442
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1443
+ config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1444
+ """
1445
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1446
+ if labels is not None:
1447
+ use_cache = False
1448
+
1449
+ if input_ids is None and inputs_embeds is not None:
1450
+ raise NotImplementedError(
1451
+ f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
1452
+ )
1453
+
1454
+ outputs = self.model(
1455
+ input_ids,
1456
+ attention_mask=attention_mask,
1457
+ decoder_input_ids=decoder_input_ids,
1458
+ decoder_attention_mask=decoder_attention_mask,
1459
+ head_mask=head_mask,
1460
+ decoder_head_mask=decoder_head_mask,
1461
+ cross_attn_head_mask=cross_attn_head_mask,
1462
+ encoder_outputs=encoder_outputs,
1463
+ inputs_embeds=inputs_embeds,
1464
+ decoder_inputs_embeds=decoder_inputs_embeds,
1465
+ use_cache=use_cache,
1466
+ output_attentions=output_attentions,
1467
+ output_hidden_states=output_hidden_states,
1468
+ return_dict=return_dict,
1469
+ )
1470
+ hidden_states = outputs[0] # last hidden state
1471
+
1472
+ eos_mask = input_ids.eq(self.config.eos_token_id)
1473
+
1474
+ if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
1475
+ raise ValueError("All examples must have the same number of <eos> tokens.")
1476
+ sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
1477
+ :, -1, :
1478
+ ]
1479
+ logits = self.classification_head(sentence_representation)
1480
+
1481
+ loss = None
1482
+ if labels is not None:
1483
+ if self.config.problem_type is None:
1484
+ if self.config.num_labels == 1:
1485
+ self.config.problem_type = "regression"
1486
+ elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1487
+ self.config.problem_type = "single_label_classification"
1488
+ else:
1489
+ self.config.problem_type = "multi_label_classification"
1490
+
1491
+ if self.config.problem_type == "regression":
1492
+ loss_fct = MSELoss()
1493
+ if self.config.num_labels == 1:
1494
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1495
+ else:
1496
+ loss = loss_fct(logits, labels)
1497
+ elif self.config.problem_type == "single_label_classification":
1498
+ loss_fct = CrossEntropyLoss()
1499
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1500
+ elif self.config.problem_type == "multi_label_classification":
1501
+ loss_fct = BCEWithLogitsLoss()
1502
+ loss = loss_fct(logits, labels)
1503
+ if not return_dict:
1504
+ output = (logits,) + outputs[1:]
1505
+ return ((loss,) + output) if loss is not None else output
1506
+
1507
+ return Seq2SeqSequenceClassifierOutput(
1508
+ loss=loss,
1509
+ logits=logits,
1510
+ past_key_values=outputs.past_key_values,
1511
+ decoder_hidden_states=outputs.decoder_hidden_states,
1512
+ decoder_attentions=outputs.decoder_attentions,
1513
+ cross_attentions=outputs.cross_attentions,
1514
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1515
+ encoder_hidden_states=outputs.encoder_hidden_states,
1516
+ encoder_attentions=outputs.encoder_attentions,
1517
+ )
1518
+
1519
+
1520
+ @add_start_docstrings(
1521
+ """
1522
+ BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1523
+ layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1524
+ """,
1525
+ BART_START_DOCSTRING,
1526
+ )
1527
+ class BartForQuestionAnswering(BartPretrainedModel):
1528
+ def __init__(self, config):
1529
+ super().__init__(config)
1530
+
1531
+ config.num_labels = 2
1532
+ self.num_labels = config.num_labels
1533
+
1534
+ self.model = BartModel(config)
1535
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1536
+
1537
+ self.model._init_weights(self.qa_outputs)
1538
+
1539
+ @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
1540
+ @add_code_sample_docstrings(
1541
+ processor_class=_TOKENIZER_FOR_DOC,
1542
+ checkpoint=_CHECKPOINT_FOR_QA,
1543
+ output_type=Seq2SeqQuestionAnsweringModelOutput,
1544
+ config_class=_CONFIG_FOR_DOC,
1545
+ expected_loss=_QA_EXPECTED_LOSS,
1546
+ expected_output=_QA_EXPECTED_OUTPUT,
1547
+ )
1548
+ def forward(
1549
+ self,
1550
+ input_ids: torch.Tensor = None,
1551
+ attention_mask: Optional[torch.Tensor] = None,
1552
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1553
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1554
+ head_mask: Optional[torch.Tensor] = None,
1555
+ decoder_head_mask: Optional[torch.Tensor] = None,
1556
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1557
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1558
+ start_positions: Optional[torch.LongTensor] = None,
1559
+ end_positions: Optional[torch.LongTensor] = None,
1560
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1561
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1562
+ use_cache: Optional[bool] = None,
1563
+ output_attentions: Optional[bool] = None,
1564
+ output_hidden_states: Optional[bool] = None,
1565
+ return_dict: Optional[bool] = None,
1566
+ ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
1567
+ r"""
1568
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1569
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1570
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
1571
+ are not taken into account for computing the loss.
1572
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1573
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1574
+ Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
1575
+ are not taken into account for computing the loss.
1576
+ """
1577
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1578
+ if start_positions is not None and end_positions is not None:
1579
+ use_cache = False
1580
+
1581
+ outputs = self.model(
1582
+ input_ids,
1583
+ attention_mask=attention_mask,
1584
+ decoder_input_ids=decoder_input_ids,
1585
+ decoder_attention_mask=decoder_attention_mask,
1586
+ head_mask=head_mask,
1587
+ decoder_head_mask=decoder_head_mask,
1588
+ cross_attn_head_mask=cross_attn_head_mask,
1589
+ encoder_outputs=encoder_outputs,
1590
+ inputs_embeds=inputs_embeds,
1591
+ decoder_inputs_embeds=decoder_inputs_embeds,
1592
+ use_cache=use_cache,
1593
+ output_attentions=output_attentions,
1594
+ output_hidden_states=output_hidden_states,
1595
+ return_dict=return_dict,
1596
+ )
1597
+
1598
+ sequence_output = outputs[0]
1599
+
1600
+ logits = self.qa_outputs(sequence_output)
1601
+ start_logits, end_logits = logits.split(1, dim=-1)
1602
+ start_logits = start_logits.squeeze(-1).contiguous()
1603
+ end_logits = end_logits.squeeze(-1).contiguous()
1604
+
1605
+ total_loss = None
1606
+ if start_positions is not None and end_positions is not None:
1607
+ # If we are on multi-GPU, split add a dimension
1608
+ if len(start_positions.size()) > 1:
1609
+ start_positions = start_positions.squeeze(-1)
1610
+ if len(end_positions.size()) > 1:
1611
+ end_positions = end_positions.squeeze(-1)
1612
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1613
+ ignored_index = start_logits.size(1)
1614
+ start_positions = start_positions.clamp(0, ignored_index)
1615
+ end_positions = end_positions.clamp(0, ignored_index)
1616
+
1617
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1618
+ start_loss = loss_fct(start_logits, start_positions)
1619
+ end_loss = loss_fct(end_logits, end_positions)
1620
+ total_loss = (start_loss + end_loss) / 2
1621
+
1622
+ if not return_dict:
1623
+ output = (
1624
+ start_logits,
1625
+ end_logits,
1626
+ ) + outputs[1:]
1627
+ return ((total_loss,) + output) if total_loss is not None else output
1628
+
1629
+ return Seq2SeqQuestionAnsweringModelOutput(
1630
+ loss=total_loss,
1631
+ start_logits=start_logits,
1632
+ end_logits=end_logits,
1633
+ past_key_values=outputs.past_key_values,
1634
+ decoder_hidden_states=outputs.decoder_hidden_states,
1635
+ decoder_attentions=outputs.decoder_attentions,
1636
+ cross_attentions=outputs.cross_attentions,
1637
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1638
+ encoder_hidden_states=outputs.encoder_hidden_states,
1639
+ encoder_attentions=outputs.encoder_attentions,
1640
+ )
1641
+
1642
+
1643
+ class BartDecoderWrapper(BartPretrainedModel):
1644
+ """
1645
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
1646
+ used in combination with the [`EncoderDecoderModel`] framework.
1647
+ """
1648
+
1649
+ def __init__(self, config):
1650
+ super().__init__(config)
1651
+ self.decoder = BartDecoder(config)
1652
+
1653
+ def forward(self, *args, **kwargs):
1654
+ return self.decoder(*args, **kwargs)
1655
+
1656
+
1657
+ class BartForCausalLM(BartPretrainedModel):
1658
+ def __init__(self, config):
1659
+ config = copy.deepcopy(config)
1660
+ config.is_decoder = True
1661
+ config.is_encoder_decoder = False
1662
+ super().__init__(config)
1663
+ self.model = BartDecoderWrapper(config)
1664
+
1665
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1666
+
1667
+ # Initialize weights and apply final processing
1668
+ self.post_init()
1669
+
1670
+ def get_input_embeddings(self):
1671
+ return self.model.decoder.embed_tokens
1672
+
1673
+ def set_input_embeddings(self, value):
1674
+ self.model.decoder.embed_tokens = value
1675
+
1676
+ def get_output_embeddings(self):
1677
+ return self.lm_head
1678
+
1679
+ def set_output_embeddings(self, new_embeddings):
1680
+ self.lm_head = new_embeddings
1681
+
1682
+ def set_decoder(self, decoder):
1683
+ self.model.decoder = decoder
1684
+
1685
+ def get_decoder(self):
1686
+ return self.model.decoder
1687
+
1688
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1689
+ def forward(
1690
+ self,
1691
+ input_ids: torch.LongTensor = None,
1692
+ attention_mask: Optional[torch.Tensor] = None,
1693
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1694
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1695
+ head_mask: Optional[torch.Tensor] = None,
1696
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1697
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1698
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1699
+ labels: Optional[torch.LongTensor] = None,
1700
+ use_cache: Optional[bool] = None,
1701
+ output_attentions: Optional[bool] = None,
1702
+ output_hidden_states: Optional[bool] = None,
1703
+ return_dict: Optional[bool] = None,
1704
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1705
+ r"""
1706
+ Args:
1707
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1708
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1709
+ provide it.
1710
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1711
+ [`PreTrainedTokenizer.__call__`] for details.
1712
+ [What are input IDs?](../glossary#input-ids)
1713
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1714
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1715
+ - 1 for tokens that are **not masked**,
1716
+ - 0 for tokens that are **masked**.
1717
+ [What are attention masks?](../glossary#attention-mask)
1718
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1719
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1720
+ if the model is configured as a decoder.
1721
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1722
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
1723
+ in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1724
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1725
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1726
+ - 1 indicates the head is **not masked**,
1727
+ - 0 indicates the head is **masked**.
1728
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1729
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
1730
+ - 1 indicates the head is **not masked**,
1731
+ - 0 indicates the head is **masked**.
1732
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1733
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1734
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1735
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
1736
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
1737
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1738
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1739
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1740
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1741
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1742
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1743
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1744
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1745
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1746
+ use_cache (`bool`, *optional*):
1747
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1748
+ (see `past_key_values`).
1749
+ - 1 for tokens that are **not masked**,
1750
+ - 0 for tokens that are **masked**.
1751
+ output_attentions (`bool`, *optional*):
1752
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1753
+ returned tensors for more detail.
1754
+ output_hidden_states (`bool`, *optional*):
1755
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1756
+ for more detail.
1757
+ return_dict (`bool`, *optional*):
1758
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1759
+ Returns:
1760
+ Example:
1761
+ ```python
1762
+ >>> from transformers import BartTokenizer, BartForCausalLM
1763
+ >>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
1764
+ >>> model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False)
1765
+ >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
1766
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1767
+ >>> outputs = model(**inputs)
1768
+ >>> logits = outputs.logits
1769
+ >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
1770
+ >>> list(logits.shape) == expected_shape
1771
+ True
1772
+ ```"""
1773
+
1774
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1775
+ output_hidden_states = (
1776
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1777
+ )
1778
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1779
+
1780
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1781
+ outputs = self.model.decoder(
1782
+ input_ids=input_ids,
1783
+ attention_mask=attention_mask,
1784
+ encoder_hidden_states=encoder_hidden_states,
1785
+ encoder_attention_mask=encoder_attention_mask,
1786
+ head_mask=head_mask,
1787
+ cross_attn_head_mask=cross_attn_head_mask,
1788
+ past_key_values=past_key_values,
1789
+ inputs_embeds=inputs_embeds,
1790
+ use_cache=use_cache,
1791
+ output_attentions=output_attentions,
1792
+ output_hidden_states=output_hidden_states,
1793
+ return_dict=return_dict,
1794
+ )
1795
+
1796
+ logits = self.lm_head(outputs[0])
1797
+
1798
+ loss = None
1799
+ if labels is not None:
1800
+ loss_fct = CrossEntropyLoss()
1801
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
1802
+
1803
+ if not return_dict:
1804
+ output = (logits,) + outputs[1:]
1805
+ return (loss,) + output if loss is not None else output
1806
+
1807
+ return CausalLMOutputWithCrossAttentions(
1808
+ loss=loss,
1809
+ logits=logits,
1810
+ past_key_values=outputs.past_key_values,
1811
+ hidden_states=outputs.hidden_states,
1812
+ attentions=outputs.attentions,
1813
+ cross_attentions=outputs.cross_attentions,
1814
+ )
1815
+
1816
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
1817
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1818
+ if attention_mask is None:
1819
+ attention_mask = input_ids.new_ones(input_ids.shape)
1820
+
1821
+ if past:
1822
+ input_ids = input_ids[:, -1:]
1823
+ # first step, decoder_cached_states are empty
1824
+ return {
1825
+ "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
1826
+ "attention_mask": attention_mask,
1827
+ "past_key_values": past,
1828
+ "use_cache": use_cache,
1829
+ }
1830
+
1831
+ @staticmethod
1832
+ def _reorder_cache(past, beam_idx):
1833
+ reordered_past = ()
1834
+ for layer_past in past:
1835
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1836
+ return reordered_past