update for transformers >= 4.29.1
Browse files- modeling_lsg_bart.py +23 -3
modeling_lsg_bart.py
CHANGED
|
@@ -643,6 +643,11 @@ class LSGBartEncoderLayer(BartEncoderLayer):
|
|
| 643 |
class LSGBartPretrainedModel(BartPretrainedModel):
|
| 644 |
|
| 645 |
config_class = LSGBartConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
|
| 647 |
def _set_gradient_checkpointing(self, module, value=False):
|
| 648 |
|
|
@@ -836,8 +841,13 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
|
| 836 |
if output_hidden_states:
|
| 837 |
encoder_states = encoder_states + (hidden_states,)
|
| 838 |
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
| 839 |
-
|
| 840 |
-
if self.training
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 841 |
layer_outputs = (None, None)
|
| 842 |
else:
|
| 843 |
if self.gradient_checkpointing and self.training:
|
|
@@ -879,6 +889,8 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
|
| 879 |
|
| 880 |
class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
| 881 |
|
|
|
|
|
|
|
| 882 |
def __init__(self, config):
|
| 883 |
|
| 884 |
LSGBartPretrainedModel.__init__(self, config)
|
|
@@ -984,7 +996,8 @@ class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
|
| 984 |
class LSGBartForConditionalGeneration(LSGBartPretrainedModel, BartForConditionalGeneration):
|
| 985 |
|
| 986 |
base_model_prefix = "model"
|
| 987 |
-
|
|
|
|
| 988 |
|
| 989 |
def __init__(self, config):
|
| 990 |
|
|
@@ -999,6 +1012,8 @@ class LSGBartForConditionalGeneration(LSGBartPretrainedModel, BartForConditional
|
|
| 999 |
|
| 1000 |
class LSGBartForSequenceClassification(LSGBartPretrainedModel, BartForSequenceClassification):
|
| 1001 |
|
|
|
|
|
|
|
| 1002 |
def __init__(self, config: LSGBartConfig, **kwargs):
|
| 1003 |
|
| 1004 |
LSGBartPretrainedModel.__init__(self, config, **kwargs)
|
|
@@ -1015,6 +1030,8 @@ class LSGBartForSequenceClassification(LSGBartPretrainedModel, BartForSequenceCl
|
|
| 1015 |
|
| 1016 |
class LSGBartForQuestionAnswering(LSGBartPretrainedModel, BartForQuestionAnswering):
|
| 1017 |
|
|
|
|
|
|
|
| 1018 |
def __init__(self, config: LSGBartConfig):
|
| 1019 |
|
| 1020 |
LSGBartPretrainedModel.__init__(self, config)
|
|
@@ -1030,6 +1047,9 @@ class LSGBartForQuestionAnswering(LSGBartPretrainedModel, BartForQuestionAnsweri
|
|
| 1030 |
|
| 1031 |
class LSGBartForCausalLM(LSGBartPretrainedModel, BartForCausalLM):
|
| 1032 |
|
|
|
|
|
|
|
|
|
|
| 1033 |
def __init__(self, config: LSGBartConfig):
|
| 1034 |
|
| 1035 |
LSGBartPretrainedModel.__init__(self, config)
|
|
|
|
| 643 |
class LSGBartPretrainedModel(BartPretrainedModel):
|
| 644 |
|
| 645 |
config_class = LSGBartConfig
|
| 646 |
+
base_model_prefix = "model"
|
| 647 |
+
supports_gradient_checkpointing = True
|
| 648 |
+
_keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"]
|
| 649 |
+
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
|
| 650 |
+
_skip_keys_device_placement = "past_key_values"
|
| 651 |
|
| 652 |
def _set_gradient_checkpointing(self, module, value=False):
|
| 653 |
|
|
|
|
| 841 |
if output_hidden_states:
|
| 842 |
encoder_states = encoder_states + (hidden_states,)
|
| 843 |
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
| 844 |
+
to_drop = False
|
| 845 |
+
if self.training:
|
| 846 |
+
dropout_probability = torch.rand([])
|
| 847 |
+
if dropout_probability < self.layerdrop: # skip the layer
|
| 848 |
+
to_drop = True
|
| 849 |
+
|
| 850 |
+
if to_drop:
|
| 851 |
layer_outputs = (None, None)
|
| 852 |
else:
|
| 853 |
if self.gradient_checkpointing and self.training:
|
|
|
|
| 889 |
|
| 890 |
class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
| 891 |
|
| 892 |
+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
| 893 |
+
|
| 894 |
def __init__(self, config):
|
| 895 |
|
| 896 |
LSGBartPretrainedModel.__init__(self, config)
|
|
|
|
| 996 |
class LSGBartForConditionalGeneration(LSGBartPretrainedModel, BartForConditionalGeneration):
|
| 997 |
|
| 998 |
base_model_prefix = "model"
|
| 999 |
+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
| 1000 |
+
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
| 1001 |
|
| 1002 |
def __init__(self, config):
|
| 1003 |
|
|
|
|
| 1012 |
|
| 1013 |
class LSGBartForSequenceClassification(LSGBartPretrainedModel, BartForSequenceClassification):
|
| 1014 |
|
| 1015 |
+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
| 1016 |
+
|
| 1017 |
def __init__(self, config: LSGBartConfig, **kwargs):
|
| 1018 |
|
| 1019 |
LSGBartPretrainedModel.__init__(self, config, **kwargs)
|
|
|
|
| 1030 |
|
| 1031 |
class LSGBartForQuestionAnswering(LSGBartPretrainedModel, BartForQuestionAnswering):
|
| 1032 |
|
| 1033 |
+
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
| 1034 |
+
|
| 1035 |
def __init__(self, config: LSGBartConfig):
|
| 1036 |
|
| 1037 |
LSGBartPretrainedModel.__init__(self, config)
|
|
|
|
| 1047 |
|
| 1048 |
class LSGBartForCausalLM(LSGBartPretrainedModel, BartForCausalLM):
|
| 1049 |
|
| 1050 |
+
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
| 1051 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1052 |
+
|
| 1053 |
def __init__(self, config: LSGBartConfig):
|
| 1054 |
|
| 1055 |
LSGBartPretrainedModel.__init__(self, config)
|