| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import tempfile |
|
|
| import pytest |
| import torch |
| from torch.testing import assert_close |
| from transformers import AutoModelForCausalLM |
|
|
| from peft import get_peft_model |
| from peft.peft_model import PeftModel |
| from peft.tuners.multitask_prompt_tuning import MultitaskPromptTuningConfig, MultitaskPromptTuningInit |
| from peft.utils import infer_device |
| from peft.utils.other import WEIGHTS_NAME, prepare_model_for_kbit_training |
| from peft.utils.save_and_load import get_peft_model_state_dict |
|
|
|
|
| MODELS_TO_TEST = [ |
| "trl-internal-testing/tiny-random-LlamaForCausalLM", |
| ] |
|
|
|
|
| class TestMultiTaskPromptTuning: |
| """ |
| Tests for the MultiTaskPromptTuning model. |
| """ |
|
|
| @pytest.fixture |
| def config(cls) -> MultitaskPromptTuningConfig: |
| return MultitaskPromptTuningConfig( |
| task_type="CAUSAL_LM", |
| num_virtual_tokens=50, |
| num_tasks=3, |
| prompt_tuning_init_text=( |
| "classify the following into either positive or negative, or entailment, neutral or contradiction:" |
| ), |
| ) |
|
|
| transformers_class = AutoModelForCausalLM |
| torch_device = infer_device() |
|
|
| @pytest.mark.parametrize("model_id", MODELS_TO_TEST) |
| def test_prepare_for_training(self, model_id, config): |
| model = AutoModelForCausalLM.from_pretrained(model_id) |
| model = get_peft_model(model, config) |
| model = model.to(self.torch_device) |
|
|
| dummy_input = torch.LongTensor([[1, 1, 1]]).to(self.torch_device) |
| dummy_output = model.get_input_embeddings()(dummy_input) |
|
|
| assert not dummy_output.requires_grad |
|
|
| @pytest.mark.parametrize("model_id", MODELS_TO_TEST) |
| def test_prepare_for_int8_training(self, model_id, config): |
| model = AutoModelForCausalLM.from_pretrained(model_id) |
| model = prepare_model_for_kbit_training(model) |
| model = model.to(self.torch_device) |
|
|
| for param in model.parameters(): |
| assert not param.requires_grad |
|
|
| model = get_peft_model(model, config) |
|
|
| |
| if hasattr(model, "enable_input_require_grads"): |
| model.enable_input_require_grads() |
| else: |
|
|
| def make_inputs_require_grad(module, input, output): |
| output.requires_grad_(True) |
|
|
| model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
|
|
| dummy_input = torch.LongTensor([[1, 1, 1]]).to(self.torch_device) |
| dummy_output = model.get_input_embeddings()(dummy_input) |
|
|
| assert dummy_output.requires_grad |
|
|
| @pytest.mark.parametrize("model_id", MODELS_TO_TEST) |
| def test_save_pretrained(self, model_id, config): |
| seed = 420 |
| torch.manual_seed(seed) |
| model = AutoModelForCausalLM.from_pretrained(model_id) |
| model = get_peft_model(model, config) |
| model = model.to(self.torch_device) |
|
|
| with tempfile.TemporaryDirectory() as tmp_dirname: |
| model.save_pretrained(tmp_dirname) |
|
|
| torch.manual_seed(seed) |
| model_from_pretrained = AutoModelForCausalLM.from_pretrained(model_id) |
| model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname) |
|
|
| |
| state_dict = get_peft_model_state_dict(model) |
|
|
| state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained) |
|
|
| |
| assert state_dict.keys() == state_dict_from_pretrained.keys() |
|
|
| |
| assert len(state_dict) == 3 |
|
|
| |
| for key in state_dict.keys(): |
| assert torch.allclose( |
| state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) |
| ) |
|
|
| |
| assert os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors")) |
|
|
| |
| assert os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")) |
|
|
| |
| assert not os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin")) |
|
|
| |
| assert not os.path.exists(os.path.join(tmp_dirname, "config.json")) |
|
|
| @pytest.mark.parametrize("model_id", MODELS_TO_TEST) |
| def test_save_pretrained_regression(self, model_id, config): |
| seed = 420 |
| torch.manual_seed(seed) |
| model = AutoModelForCausalLM.from_pretrained(model_id) |
| model = get_peft_model(model, config) |
| model = model.to(self.torch_device) |
|
|
| with tempfile.TemporaryDirectory() as tmp_dirname: |
| model.save_pretrained(tmp_dirname, safe_serialization=False) |
|
|
| torch.manual_seed(seed) |
| model_from_pretrained = AutoModelForCausalLM.from_pretrained(model_id) |
| model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname) |
|
|
| |
| state_dict = get_peft_model_state_dict(model) |
|
|
| state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained) |
|
|
| |
| assert state_dict.keys() == state_dict_from_pretrained.keys() |
|
|
| |
| assert len(state_dict) == 3 |
|
|
| |
| for key in state_dict.keys(): |
| assert torch.allclose( |
| state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) |
| ) |
|
|
| |
| assert os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin")) |
|
|
| |
| assert os.path.exists(os.path.join(tmp_dirname, "adapter_config.json")) |
|
|
| |
| assert not os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin")) |
|
|
| |
| assert not os.path.exists(os.path.join(tmp_dirname, "config.json")) |
|
|
| @pytest.mark.parametrize("model_id", MODELS_TO_TEST) |
| def test_generate(self, model_id, config): |
| model = AutoModelForCausalLM.from_pretrained(model_id) |
| model = get_peft_model(model, config) |
| model = model.to(self.torch_device) |
|
|
| input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) |
| attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) |
| task_ids = torch.LongTensor([1, 2]).to(self.torch_device) |
|
|
| |
| _ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids) |
|
|
| |
| _ = model.generate(input_ids, attention_mask=attention_mask, task_ids=task_ids) |
|
|
| @pytest.mark.parametrize("model_id", MODELS_TO_TEST) |
| def test_use_cache(self, model_id, config): |
| """Test that MultiTaskPromptTuning works when Llama config use_cache=True.""" |
| torch.manual_seed(0) |
| input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) |
| task_ids = torch.LongTensor([1, 2]).to(self.torch_device) |
|
|
| original = AutoModelForCausalLM.from_pretrained(model_id) |
| mpt = get_peft_model(original, config) |
| mpt = mpt.to(self.torch_device) |
|
|
| expected = mpt.generate(input_ids=input_ids, max_length=8, task_ids=task_ids) |
|
|
| |
| mpt.base_model.config.use_cache = True |
| actual = mpt.generate(input_ids=input_ids, max_length=8, task_ids=task_ids) |
| assert_close(expected, actual, rtol=0, atol=0) |
|
|
| @pytest.mark.parametrize("model_id", MODELS_TO_TEST) |
| def test_bf16_inference(self, model_id, config): |
| """Test that MultiTaskPromptTuning works when Llama using a half-precision model.""" |
| input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) |
| task_ids = torch.tensor([1, 2]).to(self.torch_device) |
|
|
| original = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) |
| mpt = get_peft_model(original, config) |
| mpt = mpt.to(self.torch_device) |
| _ = mpt.generate(input_ids=input_ids, task_ids=task_ids) |
|
|
| @pytest.mark.parametrize("model_id", MODELS_TO_TEST) |
| def test_generate_text_with_random_init(self, model_id, config) -> None: |
| torch.manual_seed(0) |
| model = AutoModelForCausalLM.from_pretrained(model_id) |
| config.prompt_tuning_init = MultitaskPromptTuningInit.RANDOM |
| model = get_peft_model(model, config) |
| model = model.to(self.torch_device) |
|
|
| input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) |
| attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) |
| task_ids = torch.LongTensor([0]).to(self.torch_device) |
|
|
| |
| _ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids) |
|
|
| with pytest.raises(ValueError): |
| |
| _ = model.generate(input_ids, attention_mask=attention_mask) |
|
|
| @pytest.mark.parametrize( |
| "prompt_tuning_init", |
| [ |
| MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS, |
| MultitaskPromptTuningInit.EXACT_SOURCE_TASK, |
| MultitaskPromptTuningInit.ONLY_SOURCE_SHARED, |
| ], |
| ) |
| @pytest.mark.parametrize("model_id", MODELS_TO_TEST) |
| def test_generate_text_with_other_init(self, prompt_tuning_init, model_id, config) -> None: |
| |
| |
| |
| |
| torch.manual_seed(42) |
|
|
| with tempfile.TemporaryDirectory() as tmp_dirname: |
| model = AutoModelForCausalLM.from_pretrained(model_id) |
| model = get_peft_model(model, config) |
| model.save_pretrained(tmp_dirname, safe_serialization=False) |
|
|
| config = MultitaskPromptTuningConfig( |
| task_type="CAUSAL_LM", |
| num_virtual_tokens=50, |
| num_tasks=1, |
| prompt_tuning_init_text=( |
| "classify the following into either positive or negative, or entailment, neutral or contradiction:" |
| ), |
| prompt_tuning_init=prompt_tuning_init, |
| prompt_tuning_init_state_dict_path=os.path.join(tmp_dirname, WEIGHTS_NAME), |
| ) |
| model = AutoModelForCausalLM.from_pretrained(model_id) |
| model = get_peft_model(model, config) |
| model = model.to(self.torch_device) |
|
|
| input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device) |
| attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) |
| task_ids = torch.LongTensor([0]).to(self.torch_device) |
|
|
| |
| _ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids) |
|
|
| with pytest.raises(ValueError, match="task_ids cannot be None"): |
| |
| _ = model.generate(input_ids, attention_mask=attention_mask) |
|
|