davda54 commited on
Commit
8e47ceb
·
1 Parent(s): 7f719ea

Update modeling_ltgbert.py

Browse files
Files changed (1) hide show
  1. modeling_ltgbert.py +210 -56
modeling_ltgbert.py CHANGED
@@ -1,4 +1,20 @@
1
- from __future__ import absolute_import, division, print_function, unicode_literals
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import math
4
  from typing import List, Optional, Tuple, Union
@@ -6,10 +22,9 @@ from typing import List, Optional, Tuple, Union
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
- from torch import _softmax_backward_data as _softmax_backward_data
10
  from torch.utils import checkpoint
11
 
12
- from configuration_ltgbert import LTGBertConfig
13
  from transformers.modeling_utils import PreTrainedModel
14
  from transformers.activations import gelu_new
15
  from transformers.modeling_outputs import (
@@ -20,6 +35,36 @@ from transformers.modeling_outputs import (
20
  TokenClassifierOutput,
21
  BaseModelOutput
22
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  class Encoder(nn.Module):
@@ -130,8 +175,8 @@ class MaskedSoftmax(torch.autograd.Function):
130
  @staticmethod
131
  def backward(self, grad_output):
132
  output, = self.saved_tensors
133
- inputGrad = _softmax_backward_data(grad_output, output, self.dim, output.dtype)
134
- return inputGrad, None, None
135
 
136
 
137
  class Attention(nn.Module):
@@ -195,25 +240,21 @@ class Attention(nn.Module):
195
  hidden_states = self.pre_layer_norm(hidden_states)
196
 
197
  query, key = self.in_proj_qk(hidden_states).chunk(2, dim=2) # shape: [T, B, D]
198
- key = key * self.scale
199
  value = self.in_proj_v(hidden_states) # shape: [T, B, D]
200
 
201
  query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
202
  key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
203
  value = value.view(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
204
 
205
- attention_scores = torch.bmm(query, key.transpose(1, 2))
206
 
207
  pos = self.in_proj_qk(self.dropout(relative_embedding)) # shape: [2T-1, 2D]
208
- pos = pos.view(-1, self.num_heads, 2*self.head_size)
209
- query_pos, key_pos = pos.chunk(2, dim=2)
210
- key_pos = key_pos * self.scale
211
-
212
  query = query.view(batch_size, self.num_heads, query_len, self.head_size)
213
- key = key.view(batch_size, self.num_heads, key_len, self.head_size)
214
 
215
- attention_c_p = torch.einsum("bhqd,khd->bhqk", query, key_pos)
216
- attention_p_c = torch.einsum("bhkd,qhd->bhqk", key, query_pos)
217
 
218
  position_indices = self.position_indices[:query_len, :key_len].expand(batch_size, self.num_heads, -1, -1)
219
  attention_c_p = attention_c_p.gather(3, position_indices)
@@ -223,7 +264,7 @@ class Attention(nn.Module):
223
  attention_scores.add_(attention_c_p)
224
  attention_scores.add_(attention_p_c)
225
 
226
- return attention_scores, attention_c_p, attention_p_c, value
227
 
228
  def compute_output(self, attention_probs, value):
229
  attention_probs = self.dropout(attention_probs)
@@ -269,20 +310,65 @@ class Embedding(nn.Module):
269
  # HuggingFace wrappers
270
  #
271
 
272
- class LTGBertPreTrainedModel(PreTrainedModel):
273
- config_class = LTGBertConfig
274
- base_model_prefix = "LTG-BERT"
 
 
 
 
 
275
  supports_gradient_checkpointing = True
276
 
277
  def _set_gradient_checkpointing(self, module, value=False):
278
  if isinstance(module, Encoder):
279
  module.activation_checkpointing = value
280
 
281
- def _init_weights(self, module):
282
  pass # everything is already initialized
283
 
284
 
285
- class LTGBertModel(LTGBertPreTrainedModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  def __init__(self, config, add_mlm_layer=False):
287
  super().__init__(config)
288
  self.config = config
@@ -326,31 +412,40 @@ class LTGBertModel(LTGBertPreTrainedModel):
326
  ]
327
  return last_layer, contextualized_embeddings, attention_probs
328
 
 
329
  def forward(
330
  self,
331
  input_ids: Optional[torch.Tensor] = None,
332
  attention_mask: Optional[torch.Tensor] = None,
333
- token_type_ids: Optional[torch.Tensor] = None,
334
- position_ids: Optional[torch.Tensor] = None,
335
  output_hidden_states: Optional[bool] = None,
336
  output_attentions: Optional[bool] = None,
337
  return_dict: Optional[bool] = None,
338
  ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
 
 
 
 
 
339
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
340
 
341
  sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
342
 
343
  if not return_dict:
344
- return sequence_output, contextualized_embeddings, attention_probs
 
 
 
 
345
 
346
  return BaseModelOutput(
347
  last_hidden_state=sequence_output,
348
- hidden_states=contextualized_embeddings,
349
- attentions=attention_probs
350
  )
351
 
352
 
353
- class LTGBertForMaskedLM(LTGBertModel):
 
354
  _keys_to_ignore_on_load_unexpected = ["head"]
355
 
356
  def __init__(self, config):
@@ -362,36 +457,44 @@ class LTGBertForMaskedLM(LTGBertModel):
362
  def set_output_embeddings(self, new_embeddings):
363
  self.classifier.nonlinearity[-1].weight = new_embeddings
364
 
 
365
  def forward(
366
  self,
367
  input_ids: Optional[torch.Tensor] = None,
368
  attention_mask: Optional[torch.Tensor] = None,
369
- token_type_ids: Optional[torch.Tensor] = None,
370
- position_ids: Optional[torch.Tensor] = None,
371
  output_hidden_states: Optional[bool] = None,
372
  output_attentions: Optional[bool] = None,
373
  return_dict: Optional[bool] = None,
374
  labels: Optional[torch.LongTensor] = None,
375
  ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
 
 
 
 
 
 
376
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
377
 
378
  sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
379
  subword_prediction = self.classifier(sequence_output)
380
- subword_prediction[:, :, :106+1] = float("-inf")
381
 
382
  masked_lm_loss = None
383
  if labels is not None:
384
  masked_lm_loss = F.cross_entropy(subword_prediction.flatten(0, 1), labels.flatten())
385
 
386
  if not return_dict:
387
- output = (subword_prediction, contextualized_embeddings, attention_probs)
 
 
 
 
388
  return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
389
 
390
  return MaskedLMOutput(
391
  loss=masked_lm_loss,
392
  logits=subword_prediction,
393
- hidden_states=contextualized_embeddings,
394
- attentions=attention_probs
395
  )
396
 
397
 
@@ -399,8 +502,7 @@ class Classifier(nn.Module):
399
  def __init__(self, config, num_labels: int):
400
  super().__init__()
401
 
402
- drop_out = getattr(config, "cls_dropout", None)
403
- drop_out = config.hidden_dropout_prob if drop_out is None else drop_out
404
 
405
  self.nonlinearity = nn.Sequential(
406
  nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False),
@@ -424,7 +526,14 @@ class Classifier(nn.Module):
424
  return x
425
 
426
 
427
- class LTGBertForSequenceClassification(LTGBertModel):
 
 
 
 
 
 
 
428
  _keys_to_ignore_on_load_unexpected = ["classifier"]
429
  _keys_to_ignore_on_load_missing = ["head"]
430
 
@@ -434,17 +543,22 @@ class LTGBertForSequenceClassification(LTGBertModel):
434
  self.num_labels = config.num_labels
435
  self.head = Classifier(config, self.num_labels)
436
 
 
437
  def forward(
438
  self,
439
  input_ids: Optional[torch.Tensor] = None,
440
  attention_mask: Optional[torch.Tensor] = None,
441
- token_type_ids: Optional[torch.Tensor] = None,
442
- position_ids: Optional[torch.Tensor] = None,
443
  output_attentions: Optional[bool] = None,
444
  output_hidden_states: Optional[bool] = None,
445
  return_dict: Optional[bool] = None,
446
  labels: Optional[torch.LongTensor] = None,
447
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
 
 
 
 
 
 
448
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
449
 
450
  sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
@@ -474,18 +588,29 @@ class LTGBertForSequenceClassification(LTGBertModel):
474
  loss = loss_fct(logits, labels)
475
 
476
  if not return_dict:
477
- output = (logits, contextualized_embeddings, attention_probs)
 
 
 
 
478
  return ((loss,) + output) if loss is not None else output
479
 
480
  return SequenceClassifierOutput(
481
  loss=loss,
482
  logits=logits,
483
- hidden_states=contextualized_embeddings,
484
- attentions=attention_probs
485
  )
486
 
487
 
488
- class LTGBertForTokenClassification(LTGBertModel):
 
 
 
 
 
 
 
489
  _keys_to_ignore_on_load_unexpected = ["classifier"]
490
  _keys_to_ignore_on_load_missing = ["head"]
491
 
@@ -495,6 +620,7 @@ class LTGBertForTokenClassification(LTGBertModel):
495
  self.num_labels = config.num_labels
496
  self.head = Classifier(config, self.num_labels)
497
 
 
498
  def forward(
499
  self,
500
  input_ids: Optional[torch.Tensor] = None,
@@ -517,18 +643,29 @@ class LTGBertForTokenClassification(LTGBertModel):
517
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
518
 
519
  if not return_dict:
520
- output = (logits, contextualized_embeddings, attention_probs)
 
 
 
 
521
  return ((loss,) + output) if loss is not None else output
522
 
523
  return TokenClassifierOutput(
524
  loss=loss,
525
  logits=logits,
526
- hidden_states=contextualized_embeddings,
527
- attentions=attention_probs
528
  )
529
 
530
 
531
- class LTGBertForQuestionAnswering(LTGBertModel):
 
 
 
 
 
 
 
532
  _keys_to_ignore_on_load_unexpected = ["classifier"]
533
  _keys_to_ignore_on_load_missing = ["head"]
534
 
@@ -538,6 +675,7 @@ class LTGBertForQuestionAnswering(LTGBertModel):
538
  self.num_labels = config.num_labels
539
  self.head = Classifier(config, self.num_labels)
540
 
 
541
  def forward(
542
  self,
543
  input_ids: Optional[torch.Tensor] = None,
@@ -578,19 +716,31 @@ class LTGBertForQuestionAnswering(LTGBertModel):
578
  total_loss = (start_loss + end_loss) / 2
579
 
580
  if not return_dict:
581
- output = start_logits, end_logits, contextualized_embeddings, attention_probs
 
 
 
 
 
582
  return ((total_loss,) + output) if total_loss is not None else output
583
 
584
  return QuestionAnsweringModelOutput(
585
  loss=total_loss,
586
  start_logits=start_logits,
587
  end_logits=end_logits,
588
- hidden_states=contextualized_embeddings,
589
- attentions=attention_probs,
590
  )
591
 
592
 
593
- class LTGBertForMultipleChoice(LTGBertModel):
 
 
 
 
 
 
 
594
  _keys_to_ignore_on_load_unexpected = ["classifier"]
595
  _keys_to_ignore_on_load_missing = ["head"]
596
 
@@ -600,6 +750,7 @@ class LTGBertForMultipleChoice(LTGBertModel):
600
  self.num_labels = getattr(config, "num_labels", 2)
601
  self.head = Classifier(config, self.num_labels)
602
 
 
603
  def forward(
604
  self,
605
  input_ids: Optional[torch.Tensor] = None,
@@ -607,9 +758,9 @@ class LTGBertForMultipleChoice(LTGBertModel):
607
  token_type_ids: Optional[torch.Tensor] = None,
608
  position_ids: Optional[torch.Tensor] = None,
609
  labels: Optional[torch.Tensor] = None,
610
- return_dict: Optional[bool] = None,
611
- start_positions: Optional[torch.Tensor] = None,
612
- end_positions: Optional[torch.Tensor] = None
613
  ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
614
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
615
  num_choices = input_ids.shape[1]
@@ -627,13 +778,16 @@ class LTGBertForMultipleChoice(LTGBertModel):
627
  loss = loss_fct(reshaped_logits, labels)
628
 
629
  if not return_dict:
630
- output = (reshaped_logits, contextualized_embeddings, attention_probs)
 
 
 
 
631
  return ((loss,) + output) if loss is not None else output
632
 
633
  return MultipleChoiceModelOutput(
634
  loss=loss,
635
  logits=reshaped_logits,
636
- hidden_states=contextualized_embeddings,
637
- attentions=attention_probs,
638
  )
639
-
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Language Technology Group from University of Oslo and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ PyTorch LTG-BERT model."""
17
+
18
 
19
  import math
20
  from typing import List, Optional, Tuple, Union
 
22
  import torch
23
  import torch.nn as nn
24
  import torch.nn.functional as F
 
25
  from torch.utils import checkpoint
26
 
27
+ from configuration_ltgbert import LtgBertConfig
28
  from transformers.modeling_utils import PreTrainedModel
29
  from transformers.activations import gelu_new
30
  from transformers.modeling_outputs import (
 
35
  TokenClassifierOutput,
36
  BaseModelOutput
37
  )
38
+ from transformers.pytorch_utils import softmax_backward_data
39
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward
40
+
41
+
42
+ _CHECKPOINT_FOR_DOC = "ltg/bnc-bert-span"
43
+ _CONFIG_FOR_DOC = "LtgBertConfig"
44
+
45
+
46
+ LTG_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
47
+ "bnc-bert-span",
48
+ "bnc-bert-span-2x",
49
+ "bnc-bert-span-0.5x",
50
+ "bnc-bert-span-0.25x",
51
+ "bnc-bert-span-order",
52
+ "bnc-bert-span-document",
53
+ "bnc-bert-span-word",
54
+ "bnc-bert-span-subword",
55
+
56
+ "norbert3-xs",
57
+ "norbert3-small",
58
+ "norbert3-base",
59
+ "norbert3-large",
60
+
61
+ "norbert3-oversampled-base",
62
+ "norbert3-ncc-base",
63
+ "norbert3-nak-base",
64
+ "norbert3-nb-base",
65
+ "norbert3-wiki-base",
66
+ "norbert3-c4-base"
67
+ ]
68
 
69
 
70
  class Encoder(nn.Module):
 
175
  @staticmethod
176
  def backward(self, grad_output):
177
  output, = self.saved_tensors
178
+ input_grad = softmax_backward_data(self, grad_output, output, self.dim, output)
179
+ return input_grad, None, None
180
 
181
 
182
  class Attention(nn.Module):
 
240
  hidden_states = self.pre_layer_norm(hidden_states)
241
 
242
  query, key = self.in_proj_qk(hidden_states).chunk(2, dim=2) # shape: [T, B, D]
 
243
  value = self.in_proj_v(hidden_states) # shape: [T, B, D]
244
 
245
  query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
246
  key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
247
  value = value.view(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
248
 
249
+ attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
250
 
251
  pos = self.in_proj_qk(self.dropout(relative_embedding)) # shape: [2T-1, 2D]
252
+ query_pos, key_pos = pos.view(-1, self.num_heads, 2*self.head_size).chunk(2, dim=2)
 
 
 
253
  query = query.view(batch_size, self.num_heads, query_len, self.head_size)
254
+ key = key.view(batch_size, self.num_heads, query_len, self.head_size)
255
 
256
+ attention_c_p = torch.einsum("bhqd,khd->bhqk", query, key_pos.squeeze(1) * self.scale)
257
+ attention_p_c = torch.einsum("bhkd,qhd->bhqk", key * self.scale, query_pos.squeeze(1))
258
 
259
  position_indices = self.position_indices[:query_len, :key_len].expand(batch_size, self.num_heads, -1, -1)
260
  attention_c_p = attention_c_p.gather(3, position_indices)
 
264
  attention_scores.add_(attention_c_p)
265
  attention_scores.add_(attention_p_c)
266
 
267
+ return attention_scores, value
268
 
269
  def compute_output(self, attention_probs, value):
270
  attention_probs = self.dropout(attention_probs)
 
310
  # HuggingFace wrappers
311
  #
312
 
313
+ class LtgBertPreTrainedModel(PreTrainedModel):
314
+ """
315
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
316
+ models.
317
+ """
318
+
319
+ config_class = LtgBertConfig
320
+ base_model_prefix = "bnc-bert"
321
  supports_gradient_checkpointing = True
322
 
323
  def _set_gradient_checkpointing(self, module, value=False):
324
  if isinstance(module, Encoder):
325
  module.activation_checkpointing = value
326
 
327
+ def _init_weights(self, _):
328
  pass # everything is already initialized
329
 
330
 
331
+ LTG_BERT_START_DOCSTRING = r"""
332
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
333
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
334
+ etc.)
335
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
336
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
337
+ and behavior.
338
+ Parameters:
339
+ config ([`LtgBertConfig`]): Model configuration class with all the parameters of the model.
340
+ Initializing with a config file does not load the weights associated with the model, only the
341
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
342
+ """
343
+
344
+ LTG_BERT_INPUTS_DOCSTRING = r"""
345
+ Args:
346
+ input_ids (`torch.LongTensor` of shape `({0})`):
347
+ Indices of input sequence tokens in the vocabulary.
348
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
349
+ [`PreTrainedTokenizer.__call__`] for details.
350
+ [What are input IDs?](../glossary#input-ids)
351
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
352
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
353
+ - 1 for tokens that are **not masked**,
354
+ - 0 for tokens that are **masked**.
355
+ [What are attention masks?](../glossary#attention-mask)
356
+ output_hidden_states (`bool`, *optional*):
357
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
358
+ more detail.
359
+ output_attentions (`bool`, *optional*):
360
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
361
+ tensors for more detail.
362
+ return_dict (`bool`, *optional*):
363
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
364
+ """
365
+
366
+
367
+ @add_start_docstrings(
368
+ "The bare LTG-BERT transformer outputting raw hidden-states without any specific head on top.",
369
+ LTG_BERT_START_DOCSTRING,
370
+ )
371
+ class LtgBertModel(LtgBertPreTrainedModel):
372
  def __init__(self, config, add_mlm_layer=False):
373
  super().__init__(config)
374
  self.config = config
 
412
  ]
413
  return last_layer, contextualized_embeddings, attention_probs
414
 
415
+ @add_start_docstrings_to_model_forward(LTG_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
416
  def forward(
417
  self,
418
  input_ids: Optional[torch.Tensor] = None,
419
  attention_mask: Optional[torch.Tensor] = None,
 
 
420
  output_hidden_states: Optional[bool] = None,
421
  output_attentions: Optional[bool] = None,
422
  return_dict: Optional[bool] = None,
423
  ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
424
+
425
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
426
+ output_hidden_states = (
427
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
428
+ )
429
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
430
 
431
  sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
432
 
433
  if not return_dict:
434
+ return (
435
+ sequence_output,
436
+ *([contextualized_embeddings] if output_hidden_states else []),
437
+ *([attention_probs] if output_attentions else [])
438
+ )
439
 
440
  return BaseModelOutput(
441
  last_hidden_state=sequence_output,
442
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
443
+ attentions=attention_probs if output_attentions else None
444
  )
445
 
446
 
447
+ @add_start_docstrings("""LTG-BERT model with a `language modeling` head on top.""", LTG_BERT_START_DOCSTRING)
448
+ class LtgBertForMaskedLM(LtgBertModel):
449
  _keys_to_ignore_on_load_unexpected = ["head"]
450
 
451
  def __init__(self, config):
 
457
  def set_output_embeddings(self, new_embeddings):
458
  self.classifier.nonlinearity[-1].weight = new_embeddings
459
 
460
+ @add_start_docstrings_to_model_forward(LTG_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
461
  def forward(
462
  self,
463
  input_ids: Optional[torch.Tensor] = None,
464
  attention_mask: Optional[torch.Tensor] = None,
 
 
465
  output_hidden_states: Optional[bool] = None,
466
  output_attentions: Optional[bool] = None,
467
  return_dict: Optional[bool] = None,
468
  labels: Optional[torch.LongTensor] = None,
469
  ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
470
+ r"""
471
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
472
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
473
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
474
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
475
+ """
476
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
477
 
478
  sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
479
  subword_prediction = self.classifier(sequence_output)
 
480
 
481
  masked_lm_loss = None
482
  if labels is not None:
483
  masked_lm_loss = F.cross_entropy(subword_prediction.flatten(0, 1), labels.flatten())
484
 
485
  if not return_dict:
486
+ output = (
487
+ subword_prediction,
488
+ *([contextualized_embeddings] if output_hidden_states else []),
489
+ *([attention_probs] if output_attentions else [])
490
+ )
491
  return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
492
 
493
  return MaskedLMOutput(
494
  loss=masked_lm_loss,
495
  logits=subword_prediction,
496
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
497
+ attentions=attention_probs if output_attentions else None
498
  )
499
 
500
 
 
502
  def __init__(self, config, num_labels: int):
503
  super().__init__()
504
 
505
+ drop_out = getattr(config, "classifier_dropout", config.hidden_dropout_prob)
 
506
 
507
  self.nonlinearity = nn.Sequential(
508
  nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False),
 
526
  return x
527
 
528
 
529
+ @add_start_docstrings(
530
+ """
531
+ LTG-BERT model with a sequence classification/regression head on top (a linear layer on top of the pooled
532
+ output) e.g. for GLUE tasks.
533
+ """,
534
+ LTG_BERT_START_DOCSTRING,
535
+ )
536
+ class LtgBertForSequenceClassification(LtgBertModel):
537
  _keys_to_ignore_on_load_unexpected = ["classifier"]
538
  _keys_to_ignore_on_load_missing = ["head"]
539
 
 
543
  self.num_labels = config.num_labels
544
  self.head = Classifier(config, self.num_labels)
545
 
546
+ @add_start_docstrings_to_model_forward(LTG_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
547
  def forward(
548
  self,
549
  input_ids: Optional[torch.Tensor] = None,
550
  attention_mask: Optional[torch.Tensor] = None,
 
 
551
  output_attentions: Optional[bool] = None,
552
  output_hidden_states: Optional[bool] = None,
553
  return_dict: Optional[bool] = None,
554
  labels: Optional[torch.LongTensor] = None,
555
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
556
+ r"""
557
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
558
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
559
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
560
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
561
+ """
562
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
563
 
564
  sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
 
588
  loss = loss_fct(logits, labels)
589
 
590
  if not return_dict:
591
+ output = (
592
+ logits,
593
+ *([contextualized_embeddings] if output_hidden_states else []),
594
+ *([attention_probs] if output_attentions else [])
595
+ )
596
  return ((loss,) + output) if loss is not None else output
597
 
598
  return SequenceClassifierOutput(
599
  loss=loss,
600
  logits=logits,
601
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
602
+ attentions=attention_probs if output_attentions else None
603
  )
604
 
605
 
606
+ @add_start_docstrings(
607
+ """
608
+ LTG-BERT model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
609
+ Named-Entity-Recognition (NER) tasks.
610
+ """,
611
+ LTG_BERT_START_DOCSTRING,
612
+ )
613
+ class LtgBertForTokenClassification(LtgBertModel):
614
  _keys_to_ignore_on_load_unexpected = ["classifier"]
615
  _keys_to_ignore_on_load_missing = ["head"]
616
 
 
620
  self.num_labels = config.num_labels
621
  self.head = Classifier(config, self.num_labels)
622
 
623
+ @add_start_docstrings_to_model_forward(LTG_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
624
  def forward(
625
  self,
626
  input_ids: Optional[torch.Tensor] = None,
 
643
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
644
 
645
  if not return_dict:
646
+ output = (
647
+ logits,
648
+ *([contextualized_embeddings] if output_hidden_states else []),
649
+ *([attention_probs] if output_attentions else [])
650
+ )
651
  return ((loss,) + output) if loss is not None else output
652
 
653
  return TokenClassifierOutput(
654
  loss=loss,
655
  logits=logits,
656
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
657
+ attentions=attention_probs if output_attentions else None
658
  )
659
 
660
 
661
+ @add_start_docstrings(
662
+ """
663
+ LTG-BERT model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
664
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
665
+ """,
666
+ LTG_BERT_START_DOCSTRING,
667
+ )
668
+ class LtgBertForQuestionAnswering(LtgBertModel):
669
  _keys_to_ignore_on_load_unexpected = ["classifier"]
670
  _keys_to_ignore_on_load_missing = ["head"]
671
 
 
675
  self.num_labels = config.num_labels
676
  self.head = Classifier(config, self.num_labels)
677
 
678
+ @add_start_docstrings_to_model_forward(LTG_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
679
  def forward(
680
  self,
681
  input_ids: Optional[torch.Tensor] = None,
 
716
  total_loss = (start_loss + end_loss) / 2
717
 
718
  if not return_dict:
719
+ output = (
720
+ start_logits,
721
+ end_logits,
722
+ *([contextualized_embeddings] if output_hidden_states else []),
723
+ *([attention_probs] if output_attentions else [])
724
+ )
725
  return ((total_loss,) + output) if total_loss is not None else output
726
 
727
  return QuestionAnsweringModelOutput(
728
  loss=total_loss,
729
  start_logits=start_logits,
730
  end_logits=end_logits,
731
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
732
+ attentions=attention_probs if output_attentions else None
733
  )
734
 
735
 
736
+ @add_start_docstrings(
737
+ """
738
+ LTG-BERT model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
739
+ softmax) e.g. for RocStories/SWAG tasks.
740
+ """,
741
+ LTG_BERT_START_DOCSTRING,
742
+ )
743
+ class LtgBertForMultipleChoice(LtgBertModel):
744
  _keys_to_ignore_on_load_unexpected = ["classifier"]
745
  _keys_to_ignore_on_load_missing = ["head"]
746
 
 
750
  self.num_labels = getattr(config, "num_labels", 2)
751
  self.head = Classifier(config, self.num_labels)
752
 
753
+ @add_start_docstrings_to_model_forward(LTG_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
754
  def forward(
755
  self,
756
  input_ids: Optional[torch.Tensor] = None,
 
758
  token_type_ids: Optional[torch.Tensor] = None,
759
  position_ids: Optional[torch.Tensor] = None,
760
  labels: Optional[torch.Tensor] = None,
761
+ output_attentions: Optional[bool] = None,
762
+ output_hidden_states: Optional[bool] = None,
763
+ return_dict: Optional[bool] = None
764
  ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
765
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
766
  num_choices = input_ids.shape[1]
 
778
  loss = loss_fct(reshaped_logits, labels)
779
 
780
  if not return_dict:
781
+ output = (
782
+ reshaped_logits,
783
+ *([contextualized_embeddings] if output_hidden_states else []),
784
+ *([attention_probs] if output_attentions else [])
785
+ )
786
  return ((loss,) + output) if loss is not None else output
787
 
788
  return MultipleChoiceModelOutput(
789
  loss=loss,
790
  logits=reshaped_logits,
791
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
792
+ attentions=attention_probs if output_attentions else None
793
  )