File size: 2,601 Bytes
dfefe0b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py
index cc0295c25b55..1835f55aaec4 100644
--- a/src/transformers/models/modernbert/configuration_modernbert.py
+++ b/src/transformers/models/modernbert/configuration_modernbert.py
@@ -214,5 +214,10 @@ def __init__(
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
)
+ def to_dict(self):
+ output = super().to_dict()
+ output.pop("reference_compile", None)
+ return output
+
__all__ = ["ModernBertConfig"]
diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py
index 0901662f66a7..934931da3bfa 100644
--- a/src/transformers/models/modernbert/modular_modernbert.py
+++ b/src/transformers/models/modernbert/modular_modernbert.py
@@ -248,6 +248,11 @@ def __init__(
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
)
+ def to_dict(self):
+ output = super().to_dict()
+ output.pop("reference_compile", None)
+ return output
+
def _unpad_modernbert_input(
inputs: torch.Tensor,
diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py
index 14882b0879c2..82a0f8505273 100644
--- a/tests/models/modernbert/test_modeling_modernbert.py
+++ b/tests/models/modernbert/test_modeling_modernbert.py
@@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
import os
+import tempfile
import unittest
import pytest
@@ -366,6 +368,14 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
def test_flash_attn_2_conversion(self):
self.skipTest(reason="ModernBert doesn't use the ModernBertFlashAttention2 class method.")
+ def test_saved_config_excludes_reference_compile(self):
+ config = ModernBertConfig(reference_compile=True)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ config.save_pretrained(tmpdirname)
+ with open(os.path.join(tmpdirname, "config.json"), "r") as f:
+ config_dict = json.load(f)
+ self.assertNotIn("reference_compile", config_dict)
+
@require_torch
class ModernBertModelIntegrationTest(unittest.TestCase):
|