| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import os |
| import tempfile |
| import unittest |
|
|
| from peft import LoraConfig, PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig |
|
|
|
|
| class PeftConfigTestMixin: |
| all_config_classes = ( |
| LoraConfig, |
| PromptEncoderConfig, |
| PrefixTuningConfig, |
| PromptTuningConfig, |
| ) |
|
|
|
|
| class PeftConfigTester(unittest.TestCase, PeftConfigTestMixin): |
| def test_methods(self): |
| r""" |
| Test if all configs have the expected methods. Here we test |
| - to_dict |
| - save_pretrained |
| - from_pretrained |
| - from_json_file |
| """ |
| |
| for config_class in self.all_config_classes: |
| config = config_class() |
| self.assertTrue(hasattr(config, "to_dict")) |
| self.assertTrue(hasattr(config, "save_pretrained")) |
| self.assertTrue(hasattr(config, "from_pretrained")) |
| self.assertTrue(hasattr(config, "from_json_file")) |
|
|
| def test_task_type(self): |
| for config_class in self.all_config_classes: |
| |
| _ = config_class(task_type="test") |
|
|
| def test_save_pretrained(self): |
| r""" |
| Test if the config is correctly saved and loaded using |
| - save_pretrained |
| """ |
| for config_class in self.all_config_classes: |
| config = config_class() |
| with tempfile.TemporaryDirectory() as tmp_dirname: |
| config.save_pretrained(tmp_dirname) |
|
|
| config_from_pretrained = config_class.from_pretrained(tmp_dirname) |
| self.assertEqual(config.to_dict(), config_from_pretrained.to_dict()) |
|
|
| def test_from_json_file(self): |
| for config_class in self.all_config_classes: |
| config = config_class() |
| with tempfile.TemporaryDirectory() as tmp_dirname: |
| config.save_pretrained(tmp_dirname) |
|
|
| config_from_json = config_class.from_json_file(os.path.join(tmp_dirname, "adapter_config.json")) |
| self.assertEqual(config.to_dict(), config_from_json) |
|
|
| def test_to_dict(self): |
| r""" |
| Test if the config can be correctly converted to a dict using: |
| - to_dict |
| - __dict__ |
| """ |
| for config_class in self.all_config_classes: |
| config = config_class() |
| self.assertEqual(config.to_dict(), config.__dict__) |
| self.assertTrue(isinstance(config.to_dict(), dict)) |
|
|
| def test_set_attributes(self): |
| |
| for config_class in self.all_config_classes: |
| config = config_class(peft_type="test") |
|
|
| |
| with tempfile.TemporaryDirectory() as tmp_dirname: |
| config.save_pretrained(tmp_dirname) |
|
|
| config_from_pretrained = config_class.from_pretrained(tmp_dirname) |
| self.assertEqual(config.to_dict(), config_from_pretrained.to_dict()) |
|
|