Safetensors
English
bert_hash
custom_code
davidmezzetti commited on
Commit
d19aa05
·
1 Parent(s): 5bcfeee

Add Transformers v5 support

Browse files
Files changed (1) hide show
  1. modeling_bert_hash.py +127 -214
modeling_bert_hash.py CHANGED
@@ -4,15 +4,18 @@ import torch
4
  from torch import nn
5
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6
 
7
- from transformers.cache_utils import Cache
 
8
  from transformers.models.bert.modeling_bert import BertEncoder, BertPooler, BertPreTrainedModel, BertOnlyMLMHead
9
- from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
10
  from transformers.modeling_outputs import (
11
  BaseModelOutputWithPoolingAndCrossAttentions,
12
  MaskedLMOutput,
13
  SequenceClassifierOutput,
14
  )
15
- from transformers.utils import auto_docstring, logging
 
 
 
16
 
17
  from .configuration_bert_hash import BertHashConfig
18
 
@@ -63,12 +66,9 @@ class BertHashEmbeddings(nn.Module):
63
  self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
64
  self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
65
 
66
- # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
67
- # any TensorFlow checkpoint file
68
  self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
69
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
70
  # position_ids (1, len position emb) is contiguous in memory and exported when serialized
71
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
72
  self.register_buffer(
73
  "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
74
  )
@@ -78,10 +78,10 @@ class BertHashEmbeddings(nn.Module):
78
 
79
  def forward(
80
  self,
81
- input_ids: Optional[torch.LongTensor] = None,
82
- token_type_ids: Optional[torch.LongTensor] = None,
83
- position_ids: Optional[torch.LongTensor] = None,
84
- inputs_embeds: Optional[torch.FloatTensor] = None,
85
  past_key_values_length: int = 0,
86
  ) -> torch.Tensor:
87
  if input_ids is not None:
@@ -89,30 +89,36 @@ class BertHashEmbeddings(nn.Module):
89
  else:
90
  input_shape = inputs_embeds.size()[:-1]
91
 
92
- seq_length = input_shape[1]
 
93
 
94
  if position_ids is None:
95
- position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
 
 
 
 
96
 
97
  # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
98
  # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
99
  # issue #5664
100
  if token_type_ids is None:
101
  if hasattr(self, "token_type_ids"):
102
- buffered_token_type_ids = self.token_type_ids[:, :seq_length]
103
- buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
104
- token_type_ids = buffered_token_type_ids_expanded
 
105
  else:
106
  token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
107
 
108
  if inputs_embeds is None:
109
  inputs_embeds = self.word_embeddings(input_ids)
110
  token_type_embeddings = self.token_type_embeddings(token_type_ids)
111
-
112
  embeddings = inputs_embeds + token_type_embeddings
113
- if self.position_embedding_type == "absolute":
114
- position_embeddings = self.position_embeddings(position_ids)
115
- embeddings += position_embeddings
 
116
  embeddings = self.LayerNorm(embeddings)
117
  embeddings = self.dropout(embeddings)
118
  return embeddings
@@ -142,15 +148,13 @@ class BertHashModel(BertPreTrainedModel):
142
  """
143
  super().__init__(config)
144
  self.config = config
 
145
 
146
  self.embeddings = BertHashEmbeddings(config)
147
  self.encoder = BertEncoder(config)
148
 
149
  self.pooler = BertPooler(config) if add_pooling_layer else None
150
 
151
- self.attn_implementation = config._attn_implementation
152
- self.position_embedding_type = config.position_embedding_type
153
-
154
  # Initialize weights and apply final processing
155
  self.post_init()
156
 
@@ -158,73 +162,40 @@ class BertHashModel(BertPreTrainedModel):
158
  return self.embeddings.word_embeddings.embeddings
159
 
160
  def set_input_embeddings(self, value):
161
- self.embeddings.word_embeddings.embeddings = value
162
-
163
- def _prune_heads(self, heads_to_prune):
164
- """
165
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
166
- class PreTrainedModel
167
- """
168
- for layer, heads in heads_to_prune.items():
169
- self.encoder.layer[layer].attention.prune_heads(heads)
170
 
 
 
171
  @auto_docstring
172
  def forward(
173
  self,
174
- input_ids: Optional[torch.Tensor] = None,
175
- attention_mask: Optional[torch.Tensor] = None,
176
- token_type_ids: Optional[torch.Tensor] = None,
177
- position_ids: Optional[torch.Tensor] = None,
178
- head_mask: Optional[torch.Tensor] = None,
179
- inputs_embeds: Optional[torch.Tensor] = None,
180
- encoder_hidden_states: Optional[torch.Tensor] = None,
181
- encoder_attention_mask: Optional[torch.Tensor] = None,
182
- past_key_values: Optional[list[torch.FloatTensor]] = None,
183
- use_cache: Optional[bool] = None,
184
- output_attentions: Optional[bool] = None,
185
- output_hidden_states: Optional[bool] = None,
186
- return_dict: Optional[bool] = None,
187
- cache_position: Optional[torch.Tensor] = None,
188
- ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
189
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
190
- output_hidden_states = (
191
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
192
- )
193
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
194
 
195
  if self.config.is_decoder:
196
  use_cache = use_cache if use_cache is not None else self.config.use_cache
197
  else:
198
  use_cache = False
199
 
200
- if input_ids is not None and inputs_embeds is not None:
201
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
202
- elif input_ids is not None:
203
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
204
- input_shape = input_ids.size()
205
- elif inputs_embeds is not None:
206
- input_shape = inputs_embeds.size()[:-1]
207
- else:
208
- raise ValueError("You have to specify either input_ids or inputs_embeds")
209
-
210
- batch_size, seq_length = input_shape
211
- device = input_ids.device if input_ids is not None else inputs_embeds.device
212
-
213
- past_key_values_length = 0
214
- if past_key_values is not None:
215
- past_key_values_length = (
216
- past_key_values[0][0].shape[-2]
217
- if not isinstance(past_key_values, Cache)
218
- else past_key_values.get_seq_length()
219
  )
220
 
221
- if token_type_ids is None:
222
- if hasattr(self.embeddings, "token_type_ids"):
223
- buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
224
- buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
225
- token_type_ids = buffered_token_type_ids_expanded
226
- else:
227
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
228
 
229
  embedding_output = self.embeddings(
230
  input_ids=input_ids,
@@ -234,94 +205,72 @@ class BertHashModel(BertPreTrainedModel):
234
  past_key_values_length=past_key_values_length,
235
  )
236
 
237
- if attention_mask is None:
238
- attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
239
-
240
- use_sdpa_attention_masks = (
241
- self.attn_implementation == "sdpa"
242
- and self.position_embedding_type == "absolute"
243
- and head_mask is None
244
- and not output_attentions
245
  )
246
 
247
- # Expand the attention mask
248
- if use_sdpa_attention_masks and attention_mask.dim() == 2:
249
- # Expand the attention mask for SDPA.
250
- # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
251
- if self.config.is_decoder:
252
- extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
253
- attention_mask,
254
- input_shape,
255
- embedding_output,
256
- past_key_values_length,
257
- )
258
- else:
259
- extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
260
- attention_mask, embedding_output.dtype, tgt_len=seq_length
261
- )
262
- else:
263
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
264
- # ourselves in which case we just need to make it broadcastable to all heads.
265
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
266
-
267
- # If a 2D or 3D attention mask is provided for the cross-attention
268
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
269
- if self.config.is_decoder and encoder_hidden_states is not None:
270
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
271
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
272
- if encoder_attention_mask is None:
273
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
274
-
275
- if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
276
- # Expand the attention mask for SDPA.
277
- # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
278
- encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
279
- encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
280
- )
281
- else:
282
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
283
- else:
284
- encoder_extended_attention_mask = None
285
-
286
- # Prepare head mask if needed
287
- # 1.0 in head_mask indicate we keep the head
288
- # attention_probs has shape bsz x n_heads x N x N
289
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
290
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
291
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
292
-
293
  encoder_outputs = self.encoder(
294
  embedding_output,
295
- attention_mask=extended_attention_mask,
296
- head_mask=head_mask,
297
  encoder_hidden_states=encoder_hidden_states,
298
- encoder_attention_mask=encoder_extended_attention_mask,
299
  past_key_values=past_key_values,
300
  use_cache=use_cache,
301
- output_attentions=output_attentions,
302
- output_hidden_states=output_hidden_states,
303
- return_dict=return_dict,
304
- cache_position=cache_position,
305
  )
306
- sequence_output = encoder_outputs[0]
307
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
308
 
309
- if not return_dict:
310
- return (sequence_output, pooled_output) + encoder_outputs[1:]
311
-
312
  return BaseModelOutputWithPoolingAndCrossAttentions(
313
  last_hidden_state=sequence_output,
314
  pooler_output=pooled_output,
315
  past_key_values=encoder_outputs.past_key_values,
316
- hidden_states=encoder_outputs.hidden_states,
317
- attentions=encoder_outputs.attentions,
318
- cross_attentions=encoder_outputs.cross_attentions,
319
  )
320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  @auto_docstring
323
- class BertHashForMaskedLM(BertPreTrainedModel):
324
- _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
 
 
 
325
  config_class = BertHashConfig
326
 
327
  def __init__(self, config):
@@ -339,43 +288,43 @@ class BertHashForMaskedLM(BertPreTrainedModel):
339
  # Initialize weights and apply final processing
340
  self.post_init()
341
 
 
 
 
 
 
 
 
 
342
  @auto_docstring
343
  def forward(
344
  self,
345
- input_ids: Optional[torch.Tensor] = None,
346
- attention_mask: Optional[torch.Tensor] = None,
347
- token_type_ids: Optional[torch.Tensor] = None,
348
- position_ids: Optional[torch.Tensor] = None,
349
- head_mask: Optional[torch.Tensor] = None,
350
- inputs_embeds: Optional[torch.Tensor] = None,
351
- encoder_hidden_states: Optional[torch.Tensor] = None,
352
- encoder_attention_mask: Optional[torch.Tensor] = None,
353
- labels: Optional[torch.Tensor] = None,
354
- output_attentions: Optional[bool] = None,
355
- output_hidden_states: Optional[bool] = None,
356
- return_dict: Optional[bool] = None,
357
- ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
358
  r"""
359
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
360
  Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
361
  config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
362
  loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
363
  """
364
-
365
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
366
-
367
  outputs = self.bert(
368
  input_ids,
369
  attention_mask=attention_mask,
370
  token_type_ids=token_type_ids,
371
  position_ids=position_ids,
372
- head_mask=head_mask,
373
  inputs_embeds=inputs_embeds,
374
  encoder_hidden_states=encoder_hidden_states,
375
  encoder_attention_mask=encoder_attention_mask,
376
- output_attentions=output_attentions,
377
- output_hidden_states=output_hidden_states,
378
- return_dict=return_dict,
379
  )
380
 
381
  sequence_output = outputs[0]
@@ -386,10 +335,6 @@ class BertHashForMaskedLM(BertPreTrainedModel):
386
  loss_fct = CrossEntropyLoss() # -100 index = padding token
387
  masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
388
 
389
- if not return_dict:
390
- output = (prediction_scores,) + outputs[2:]
391
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
392
-
393
  return MaskedLMOutput(
394
  loss=masked_lm_loss,
395
  logits=prediction_scores,
@@ -397,29 +342,6 @@ class BertHashForMaskedLM(BertPreTrainedModel):
397
  attentions=outputs.attentions,
398
  )
399
 
400
- def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
401
- input_shape = input_ids.shape
402
- effective_batch_size = input_shape[0]
403
-
404
- # add a dummy token
405
- if self.config.pad_token_id is None:
406
- raise ValueError("The PAD token should be defined for generation")
407
-
408
- attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
409
- dummy_token = torch.full(
410
- (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
411
- )
412
- input_ids = torch.cat([input_ids, dummy_token], dim=1)
413
-
414
- return {"input_ids": input_ids, "attention_mask": attention_mask}
415
-
416
- @classmethod
417
- def can_generate(cls) -> bool:
418
- """
419
- Legacy correction: BertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
420
- `prepare_inputs_for_generation` method.
421
- """
422
- return False
423
 
424
 
425
  @auto_docstring(
@@ -446,38 +368,32 @@ class BertHashForSequenceClassification(BertPreTrainedModel):
446
  # Initialize weights and apply final processing
447
  self.post_init()
448
 
 
449
  @auto_docstring
450
  def forward(
451
  self,
452
- input_ids: Optional[torch.Tensor] = None,
453
- attention_mask: Optional[torch.Tensor] = None,
454
- token_type_ids: Optional[torch.Tensor] = None,
455
- position_ids: Optional[torch.Tensor] = None,
456
- head_mask: Optional[torch.Tensor] = None,
457
- inputs_embeds: Optional[torch.Tensor] = None,
458
- labels: Optional[torch.Tensor] = None,
459
- output_attentions: Optional[bool] = None,
460
- output_hidden_states: Optional[bool] = None,
461
- return_dict: Optional[bool] = None,
462
- ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
463
  r"""
464
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
465
  Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
466
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
467
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
468
  """
469
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
470
-
471
  outputs = self.bert(
472
  input_ids,
473
  attention_mask=attention_mask,
474
  token_type_ids=token_type_ids,
475
  position_ids=position_ids,
476
- head_mask=head_mask,
477
  inputs_embeds=inputs_embeds,
478
- output_attentions=output_attentions,
479
- output_hidden_states=output_hidden_states,
480
- return_dict=return_dict,
481
  )
482
 
483
  pooled_output = outputs[1]
@@ -507,9 +423,6 @@ class BertHashForSequenceClassification(BertPreTrainedModel):
507
  elif self.config.problem_type == "multi_label_classification":
508
  loss_fct = BCEWithLogitsLoss()
509
  loss = loss_fct(logits, labels)
510
- if not return_dict:
511
- output = (logits,) + outputs[2:]
512
- return ((loss,) + output) if loss is not None else output
513
 
514
  return SequenceClassifierOutput(
515
  loss=loss,
 
4
  from torch import nn
5
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6
 
7
+ from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
8
+ from transformers.masking_utils import create_bidirectional_mask, create_causal_mask
9
  from transformers.models.bert.modeling_bert import BertEncoder, BertPooler, BertPreTrainedModel, BertOnlyMLMHead
 
10
  from transformers.modeling_outputs import (
11
  BaseModelOutputWithPoolingAndCrossAttentions,
12
  MaskedLMOutput,
13
  SequenceClassifierOutput,
14
  )
15
+ from transformers.processing_utils import Unpack
16
+ from transformers.utils import TransformersKwargs, auto_docstring, logging
17
+ from transformers.utils.generic import can_return_tuple, merge_with_config_defaults
18
+ from transformers.utils.output_capturing import capture_outputs
19
 
20
  from .configuration_bert_hash import BertHashConfig
21
 
 
66
  self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
67
  self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
68
 
 
 
69
  self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
70
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
71
  # position_ids (1, len position emb) is contiguous in memory and exported when serialized
 
72
  self.register_buffer(
73
  "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
74
  )
 
78
 
79
  def forward(
80
  self,
81
+ input_ids: torch.LongTensor | None = None,
82
+ token_type_ids: torch.LongTensor | None = None,
83
+ position_ids: torch.LongTensor | None = None,
84
+ inputs_embeds: torch.FloatTensor | None = None,
85
  past_key_values_length: int = 0,
86
  ) -> torch.Tensor:
87
  if input_ids is not None:
 
89
  else:
90
  input_shape = inputs_embeds.size()[:-1]
91
 
92
+ batch_size, seq_length = input_shape
93
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
94
 
95
  if position_ids is None:
96
+ position_ids = (
97
+ torch.arange(seq_length, dtype=torch.long, device=device)
98
+ .unsqueeze(0)
99
+ .expand(batch_size, seq_length)
100
+ )
101
 
102
  # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
103
  # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
104
  # issue #5664
105
  if token_type_ids is None:
106
  if hasattr(self, "token_type_ids"):
107
+ # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
108
+ buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
109
+ buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
110
+ token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
111
  else:
112
  token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
113
 
114
  if inputs_embeds is None:
115
  inputs_embeds = self.word_embeddings(input_ids)
116
  token_type_embeddings = self.token_type_embeddings(token_type_ids)
 
117
  embeddings = inputs_embeds + token_type_embeddings
118
+
119
+ position_embeddings = self.position_embeddings(position_ids)
120
+ embeddings = embeddings + position_embeddings
121
+
122
  embeddings = self.LayerNorm(embeddings)
123
  embeddings = self.dropout(embeddings)
124
  return embeddings
 
148
  """
149
  super().__init__(config)
150
  self.config = config
151
+ self.gradient_checkpointing = False
152
 
153
  self.embeddings = BertHashEmbeddings(config)
154
  self.encoder = BertEncoder(config)
155
 
156
  self.pooler = BertPooler(config) if add_pooling_layer else None
157
 
 
 
 
158
  # Initialize weights and apply final processing
159
  self.post_init()
160
 
 
162
  return self.embeddings.word_embeddings.embeddings
163
 
164
  def set_input_embeddings(self, value):
165
+ self.embeddings.word_embeddings = value
 
 
 
 
 
 
 
 
166
 
167
+ @merge_with_config_defaults
168
+ @capture_outputs
169
  @auto_docstring
170
  def forward(
171
  self,
172
+ input_ids: torch.Tensor | None = None,
173
+ attention_mask: torch.Tensor | None = None,
174
+ token_type_ids: torch.Tensor | None = None,
175
+ position_ids: torch.Tensor | None = None,
176
+ inputs_embeds: torch.Tensor | None = None,
177
+ encoder_hidden_states: torch.Tensor | None = None,
178
+ encoder_attention_mask: torch.Tensor | None = None,
179
+ past_key_values: Cache | None = None,
180
+ use_cache: bool | None = None,
181
+ **kwargs: Unpack[TransformersKwargs],
182
+ ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
183
+ if (input_ids is None) ^ (inputs_embeds is not None):
184
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
 
 
 
 
 
 
 
185
 
186
  if self.config.is_decoder:
187
  use_cache = use_cache if use_cache is not None else self.config.use_cache
188
  else:
189
  use_cache = False
190
 
191
+ if use_cache and past_key_values is None:
192
+ past_key_values = (
193
+ EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
194
+ if encoder_hidden_states is not None or self.config.is_encoder_decoder
195
+ else DynamicCache(config=self.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  )
197
 
198
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
 
 
 
 
 
 
199
 
200
  embedding_output = self.embeddings(
201
  input_ids=input_ids,
 
205
  past_key_values_length=past_key_values_length,
206
  )
207
 
208
+ attention_mask, encoder_attention_mask = self._create_attention_masks(
209
+ attention_mask=attention_mask,
210
+ encoder_attention_mask=encoder_attention_mask,
211
+ embedding_output=embedding_output,
212
+ encoder_hidden_states=encoder_hidden_states,
213
+ past_key_values=past_key_values,
 
 
214
  )
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  encoder_outputs = self.encoder(
217
  embedding_output,
218
+ attention_mask=attention_mask,
 
219
  encoder_hidden_states=encoder_hidden_states,
220
+ encoder_attention_mask=encoder_attention_mask,
221
  past_key_values=past_key_values,
222
  use_cache=use_cache,
223
+ position_ids=position_ids,
224
+ **kwargs,
 
 
225
  )
226
+ sequence_output = encoder_outputs.last_hidden_state
227
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
228
 
 
 
 
229
  return BaseModelOutputWithPoolingAndCrossAttentions(
230
  last_hidden_state=sequence_output,
231
  pooler_output=pooled_output,
232
  past_key_values=encoder_outputs.past_key_values,
 
 
 
233
  )
234
 
235
+ def _create_attention_masks(
236
+ self,
237
+ attention_mask,
238
+ encoder_attention_mask,
239
+ embedding_output,
240
+ encoder_hidden_states,
241
+ past_key_values,
242
+ ):
243
+ if self.config.is_decoder:
244
+ attention_mask = create_causal_mask(
245
+ config=self.config,
246
+ inputs_embeds=embedding_output,
247
+ attention_mask=attention_mask,
248
+ past_key_values=past_key_values,
249
+ )
250
+ else:
251
+ attention_mask = create_bidirectional_mask(
252
+ config=self.config,
253
+ inputs_embeds=embedding_output,
254
+ attention_mask=attention_mask,
255
+ )
256
+
257
+ if encoder_attention_mask is not None:
258
+ encoder_attention_mask = create_bidirectional_mask(
259
+ config=self.config,
260
+ inputs_embeds=embedding_output,
261
+ attention_mask=encoder_attention_mask,
262
+ encoder_hidden_states=encoder_hidden_states,
263
+ )
264
+
265
+ return attention_mask, encoder_attention_mask
266
+
267
 
268
  @auto_docstring
269
+ class BertForMaskedLM(BertPreTrainedModel):
270
+ _tied_weights_keys = {
271
+ "cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
272
+ "cls.predictions.decoder.bias": "cls.predictions.bias",
273
+ }
274
  config_class = BertHashConfig
275
 
276
  def __init__(self, config):
 
288
  # Initialize weights and apply final processing
289
  self.post_init()
290
 
291
+ def get_output_embeddings(self):
292
+ return self.cls.predictions.decoder
293
+
294
+ def set_output_embeddings(self, new_embeddings):
295
+ self.cls.predictions.decoder = new_embeddings
296
+ self.cls.predictions.bias = new_embeddings.bias
297
+
298
+ @can_return_tuple
299
  @auto_docstring
300
  def forward(
301
  self,
302
+ input_ids: torch.Tensor | None = None,
303
+ attention_mask: torch.Tensor | None = None,
304
+ token_type_ids: torch.Tensor | None = None,
305
+ position_ids: torch.Tensor | None = None,
306
+ inputs_embeds: torch.Tensor | None = None,
307
+ encoder_hidden_states: torch.Tensor | None = None,
308
+ encoder_attention_mask: torch.Tensor | None = None,
309
+ labels: torch.Tensor | None = None,
310
+ **kwargs: Unpack[TransformersKwargs],
311
+ ) -> tuple[torch.Tensor] | MaskedLMOutput:
 
 
 
312
  r"""
313
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
314
  Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
315
  config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
316
  loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
317
  """
 
 
 
318
  outputs = self.bert(
319
  input_ids,
320
  attention_mask=attention_mask,
321
  token_type_ids=token_type_ids,
322
  position_ids=position_ids,
 
323
  inputs_embeds=inputs_embeds,
324
  encoder_hidden_states=encoder_hidden_states,
325
  encoder_attention_mask=encoder_attention_mask,
326
+ return_dict=True,
327
+ **kwargs,
 
328
  )
329
 
330
  sequence_output = outputs[0]
 
335
  loss_fct = CrossEntropyLoss() # -100 index = padding token
336
  masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
337
 
 
 
 
 
338
  return MaskedLMOutput(
339
  loss=masked_lm_loss,
340
  logits=prediction_scores,
 
342
  attentions=outputs.attentions,
343
  )
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
 
347
  @auto_docstring(
 
368
  # Initialize weights and apply final processing
369
  self.post_init()
370
 
371
+ @can_return_tuple
372
  @auto_docstring
373
  def forward(
374
  self,
375
+ input_ids: torch.Tensor | None = None,
376
+ attention_mask: torch.Tensor | None = None,
377
+ token_type_ids: torch.Tensor | None = None,
378
+ position_ids: torch.Tensor | None = None,
379
+ inputs_embeds: torch.Tensor | None = None,
380
+ labels: torch.Tensor | None = None,
381
+ **kwargs: Unpack[TransformersKwargs],
382
+ ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
 
 
 
383
  r"""
384
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
385
  Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
386
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
387
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
388
  """
 
 
389
  outputs = self.bert(
390
  input_ids,
391
  attention_mask=attention_mask,
392
  token_type_ids=token_type_ids,
393
  position_ids=position_ids,
 
394
  inputs_embeds=inputs_embeds,
395
+ return_dict=True,
396
+ **kwargs,
 
397
  )
398
 
399
  pooled_output = outputs[1]
 
423
  elif self.config.problem_type == "multi_label_classification":
424
  loss_fct = BCEWithLogitsLoss()
425
  loss = loss_fct(logits, labels)
 
 
 
426
 
427
  return SequenceClassifierOutput(
428
  loss=loss,