Allow loading via AutoModelForSequenceClassification
#1
by
tomaarsen
HF Staff
- opened
- bert_layers.py +9 -0
bert_layers.py
CHANGED
|
@@ -29,6 +29,7 @@ from transformers.modeling_outputs import (MaskedLMOutput,
|
|
| 29 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
| 30 |
|
| 31 |
from .blockdiag_linear import BlockdiagLinear
|
|
|
|
| 32 |
from .monarch_mixer_sequence_mixer import MonarchMixerSequenceMixing
|
| 33 |
|
| 34 |
logger = logging.getLogger(__name__)
|
|
@@ -475,6 +476,8 @@ class BertModel(BertPreTrainedModel):
|
|
| 475 |
```
|
| 476 |
"""
|
| 477 |
|
|
|
|
|
|
|
| 478 |
def __init__(self, config, add_pooling_layer=True):
|
| 479 |
super(BertModel, self).__init__(config)
|
| 480 |
self.embeddings = BertEmbeddings(config)
|
|
@@ -602,6 +605,8 @@ class BertOnlyNSPHead(nn.Module):
|
|
| 602 |
#######################
|
| 603 |
class BertForMaskedLM(BertPreTrainedModel):
|
| 604 |
|
|
|
|
|
|
|
| 605 |
def __init__(self, config):
|
| 606 |
super().__init__(config)
|
| 607 |
|
|
@@ -748,6 +753,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
| 748 |
e.g., GLUE tasks.
|
| 749 |
"""
|
| 750 |
|
|
|
|
|
|
|
| 751 |
def __init__(self, config):
|
| 752 |
super().__init__(config)
|
| 753 |
self.num_labels = config.num_labels
|
|
@@ -873,6 +880,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
| 873 |
|
| 874 |
class BertForTextEncoding(BertPreTrainedModel):
|
| 875 |
|
|
|
|
|
|
|
| 876 |
def __init__(self, config):
|
| 877 |
super().__init__(config)
|
| 878 |
|
|
|
|
| 29 |
from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
| 30 |
|
| 31 |
from .blockdiag_linear import BlockdiagLinear
|
| 32 |
+
from .configuration_bert import BertConfig
|
| 33 |
from .monarch_mixer_sequence_mixer import MonarchMixerSequenceMixing
|
| 34 |
|
| 35 |
logger = logging.getLogger(__name__)
|
|
|
|
| 476 |
```
|
| 477 |
"""
|
| 478 |
|
| 479 |
+
config_class = BertConfig
|
| 480 |
+
|
| 481 |
def __init__(self, config, add_pooling_layer=True):
|
| 482 |
super(BertModel, self).__init__(config)
|
| 483 |
self.embeddings = BertEmbeddings(config)
|
|
|
|
| 605 |
#######################
|
| 606 |
class BertForMaskedLM(BertPreTrainedModel):
|
| 607 |
|
| 608 |
+
config_class = BertConfig
|
| 609 |
+
|
| 610 |
def __init__(self, config):
|
| 611 |
super().__init__(config)
|
| 612 |
|
|
|
|
| 753 |
e.g., GLUE tasks.
|
| 754 |
"""
|
| 755 |
|
| 756 |
+
config_class = BertConfig
|
| 757 |
+
|
| 758 |
def __init__(self, config):
|
| 759 |
super().__init__(config)
|
| 760 |
self.num_labels = config.num_labels
|
|
|
|
| 880 |
|
| 881 |
class BertForTextEncoding(BertPreTrainedModel):
|
| 882 |
|
| 883 |
+
config_class = BertConfig
|
| 884 |
+
|
| 885 |
def __init__(self, config):
|
| 886 |
super().__init__(config)
|
| 887 |
|