harness / diffs /36305.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py
index cc0295c25b55..1835f55aaec4 100644
--- a/src/transformers/models/modernbert/configuration_modernbert.py
+++ b/src/transformers/models/modernbert/configuration_modernbert.py
@@ -214,5 +214,10 @@ def __init__(
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
)
+ def to_dict(self):
+ output = super().to_dict()
+ output.pop("reference_compile", None)
+ return output
+
__all__ = ["ModernBertConfig"]
diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py
index 0901662f66a7..934931da3bfa 100644
--- a/src/transformers/models/modernbert/modular_modernbert.py
+++ b/src/transformers/models/modernbert/modular_modernbert.py
@@ -248,6 +248,11 @@ def __init__(
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
)
+ def to_dict(self):
+ output = super().to_dict()
+ output.pop("reference_compile", None)
+ return output
+
def _unpad_modernbert_input(
inputs: torch.Tensor,
diff --git a/tests/models/modernbert/test_modeling_modernbert.py b/tests/models/modernbert/test_modeling_modernbert.py
index 14882b0879c2..82a0f8505273 100644
--- a/tests/models/modernbert/test_modeling_modernbert.py
+++ b/tests/models/modernbert/test_modeling_modernbert.py
@@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
import os
+import tempfile
import unittest
import pytest
@@ -366,6 +368,14 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
def test_flash_attn_2_conversion(self):
self.skipTest(reason="ModernBert doesn't use the ModernBertFlashAttention2 class method.")
+ def test_saved_config_excludes_reference_compile(self):
+ config = ModernBertConfig(reference_compile=True)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ config.save_pretrained(tmpdirname)
+ with open(os.path.join(tmpdirname, "config.json"), "r") as f:
+ config_dict = json.load(f)
+ self.assertNotIn("reference_compile", config_dict)
+
@require_torch
class ModernBertModelIntegrationTest(unittest.TestCase):