Fix error loading model with AutoModel
Browse filesFixes the following error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "snip/.venv/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 558, in from_pretrained
cls.register(config.__class__, model_class, exist_ok=True)
File "snip/.venv/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 584, in register
raise ValueError(
ValueError: The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has <class 'transformers.models.bert.configuration_bert.BertConfig'> and you passed <class 'transformers_modules.zhihan1996.DNABERT-S.1cdf84d992ace6f3e75c7356774b4da088c8dc7c.configuration_bert.BertConfig'>. Fix one of those so they match!
- bert_layers.py +2 -0
|
@@ -23,6 +23,7 @@ from transformers.models.bert.modeling_bert import BertPreTrainedModel
|
|
| 23 |
from .bert_padding import (index_first_axis,
|
| 24 |
index_put_first_axis, pad_input,
|
| 25 |
unpad_input, unpad_input_only)
|
|
|
|
| 26 |
|
| 27 |
try:
|
| 28 |
from .flash_attn_triton import flash_attn_qkvpacked_func
|
|
@@ -563,6 +564,7 @@ class BertModel(BertPreTrainedModel):
|
|
| 563 |
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
| 564 |
```
|
| 565 |
"""
|
|
|
|
| 566 |
|
| 567 |
def __init__(self, config, add_pooling_layer=True):
|
| 568 |
super(BertModel, self).__init__(config)
|
|
|
|
| 23 |
from .bert_padding import (index_first_axis,
|
| 24 |
index_put_first_axis, pad_input,
|
| 25 |
unpad_input, unpad_input_only)
|
| 26 |
+
from .configuration_bert import BertConfig
|
| 27 |
|
| 28 |
try:
|
| 29 |
from .flash_attn_triton import flash_attn_qkvpacked_func
|
|
|
|
| 564 |
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
| 565 |
```
|
| 566 |
"""
|
| 567 |
+
config_class = BertConfig
|
| 568 |
|
| 569 |
def __init__(self, config, add_pooling_layer=True):
|
| 570 |
super(BertModel, self).__init__(config)
|