File size: 2,601 Bytes
dfefe0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py
index cc0295c25b55..1835f55aaec4 100644
--- a/src/transformers/models/modernbert/configuration_modernbert.py
+++ b/src/transformers/models/modernbert/configuration_modernbert.py
@@ -214,5 +214,10 @@ def __init__(
                 f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
             )
 
+    def to_dict(self):
+        output = super().to_dict()
+        output.pop("reference_compile", None)
+        return output
+
 
 __all__ = ["ModernBertConfig"]
diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py
index 0901662f66a7..934931da3bfa 100644
--- a/src/transformers/models/modernbert/modular_modernbert.py
+++ b/src/transformers/models/modernbert/modular_modernbert.py
@@ -248,6 +248,11 @@ def __init__(
                 f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
             )
 
+    def to_dict(self):
+        output = super().to_dict()
+        output.pop("reference_compile", None)
+        return output
+
 
 def _unpad_modernbert_input(
     inputs: torch.Tensor,
diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py
index 14882b0879c2..82a0f8505273 100644
--- a/tests/models/modernbert/test_modeling_modernbert.py
+++ b/tests/models/modernbert/test_modeling_modernbert.py
@@ -12,7 +12,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import json
 import os
+import tempfile
 import unittest
 
 import pytest
@@ -366,6 +368,14 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
     def test_flash_attn_2_conversion(self):
         self.skipTest(reason="ModernBert doesn't use the ModernBertFlashAttention2 class method.")
 
+    def test_saved_config_excludes_reference_compile(self):
+        config = ModernBertConfig(reference_compile=True)
+        with tempfile.TemporaryDirectory() as tmpdirname:
+            config.save_pretrained(tmpdirname)
+            with open(os.path.join(tmpdirname, "config.json"), "r") as f:
+                config_dict = json.load(f)
+            self.assertNotIn("reference_compile", config_dict)
+
 
 @require_torch
 class ModernBertModelIntegrationTest(unittest.TestCase):