| |
| |
| |
| |
| @@ -214,5 +214,10 @@ def __init__( |
| f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' |
| ) |
| |
| + def to_dict(self): |
| + output = super().to_dict() |
| + output.pop("reference_compile", None) |
| + return output |
| + |
| |
| __all__ = ["ModernBertConfig"] |
| |
| |
| |
| |
| @@ -248,6 +248,11 @@ def __init__( |
| f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' |
| ) |
| |
| + def to_dict(self): |
| + output = super().to_dict() |
| + output.pop("reference_compile", None) |
| + return output |
| + |
| |
| def _unpad_modernbert_input( |
| inputs: torch.Tensor, |
| |
| |
| |
| |
| @@ -12,7 +12,9 @@ |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| +import json |
| import os |
| +import tempfile |
| import unittest |
| |
| import pytest |
| @@ -366,6 +368,14 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): |
| def test_flash_attn_2_conversion(self): |
| self.skipTest(reason="ModernBert doesn't use the ModernBertFlashAttention2 class method.") |
| |
| + def test_saved_config_excludes_reference_compile(self): |
| + config = ModernBertConfig(reference_compile=True) |
| + with tempfile.TemporaryDirectory() as tmpdirname: |
| + config.save_pretrained(tmpdirname) |
| + with open(os.path.join(tmpdirname, "config.json"), "r") as f: |
| + config_dict = json.load(f) |
| + self.assertNotIn("reference_compile", config_dict) |
| + |
| |
| @require_torch |
| class ModernBertModelIntegrationTest(unittest.TestCase): |
|
|