| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from transformers import AutoModelForCausalLM |
|
|
| from peft import LoraConfig, get_peft_model |
| from peft.helpers import check_if_peft_model |
|
|
|
|
| class TestCheckIsPeftModel: |
| def test_valid_hub_model(self): |
| result = check_if_peft_model("peft-internal-testing/gpt2-lora-random") |
| assert result is True |
|
|
| def test_invalid_hub_model(self): |
| result = check_if_peft_model("gpt2") |
| assert result is False |
|
|
| def test_nonexisting_hub_model(self): |
| result = check_if_peft_model("peft-internal-testing/non-existing-model") |
| assert result is False |
|
|
| def test_local_model_valid(self, tmp_path): |
| model = AutoModelForCausalLM.from_pretrained("gpt2") |
| config = LoraConfig() |
| model = get_peft_model(model, config) |
| model.save_pretrained(tmp_path / "peft-gpt2-valid") |
| result = check_if_peft_model(tmp_path / "peft-gpt2-valid") |
| assert result is True |
|
|
| def test_local_model_invalid(self, tmp_path): |
| model = AutoModelForCausalLM.from_pretrained("gpt2") |
| model.save_pretrained(tmp_path / "peft-gpt2-invalid") |
| result = check_if_peft_model(tmp_path / "peft-gpt2-invalid") |
| assert result is False |
|
|
| def test_local_model_broken_config(self, tmp_path): |
| with open(tmp_path / "adapter_config.json", "w") as f: |
| f.write('{"foo": "bar"}') |
|
|
| result = check_if_peft_model(tmp_path) |
| assert result is False |
|
|
| def test_local_model_non_default_name(self, tmp_path): |
| model = AutoModelForCausalLM.from_pretrained("gpt2") |
| config = LoraConfig() |
| model = get_peft_model(model, config, adapter_name="other") |
| model.save_pretrained(tmp_path / "peft-gpt2-other") |
|
|
| |
| result = check_if_peft_model(tmp_path / "peft-gpt2-other") |
| assert result is False |
|
|
| |
| result = check_if_peft_model(tmp_path / "peft-gpt2-other" / "other") |
| assert result is True |
|
|