| """Module for testing models utils file.""" |
|
|
| from unittest.mock import MagicMock, patch |
|
|
| import pytest |
| from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase |
| from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
| from transformers.utils.import_utils import is_torch_mps_available |
|
|
| from axolotl.utils.dict import DictDefault |
| from axolotl.utils.models import ModelLoader, load_model |
|
|
|
|
| class TestModelsUtils: |
| """Testing module for models utils.""" |
|
|
| def setup_method(self) -> None: |
| |
| self.cfg = DictDefault( |
| { |
| "base_model": "JackFram/llama-68m", |
| "model_type": "LlamaForCausalLM", |
| "tokenizer_type": "LlamaTokenizer", |
| "load_in_8bit": True, |
| "load_in_4bit": False, |
| "adapter": "lora", |
| "flash_attention": False, |
| "sample_packing": True, |
| "device_map": "auto", |
| } |
| ) |
| self.tokenizer = MagicMock( |
| spec=PreTrainedTokenizerBase |
| ) |
| self.inference = False |
| self.reference_model = True |
|
|
| |
| self.model_loader = ( |
| ModelLoader( |
| cfg=self.cfg, |
| tokenizer=self.tokenizer, |
| inference=self.inference, |
| reference_model=self.reference_model, |
| ) |
| ) |
|
|
| def test_set_device_map_config(self): |
| |
| device_map = self.cfg.device_map |
| if is_torch_mps_available(): |
| device_map = "mps" |
| self.model_loader.set_device_map_config() |
| if is_deepspeed_zero3_enabled(): |
| assert "device_map" not in self.model_loader.model_kwargs |
| else: |
| assert device_map in self.model_loader.model_kwargs["device_map"] |
|
|
| |
| assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"] |
|
|
| def test_cfg_throws_error_with_s2_attention_and_sample_packing(self): |
| cfg = DictDefault( |
| { |
| "s2_attention": True, |
| "sample_packing": True, |
| "base_model": "", |
| "model_type": "LlamaForCausalLM", |
| } |
| ) |
|
|
| |
| with patch( |
| "axolotl.utils.models.load_model_config" |
| ) as mocked_load_model_config: |
| mocked_load_model_config.return_value = {} |
| with pytest.raises(ValueError) as exc: |
| |
| load_model(cfg, tokenizer="") |
| assert ( |
| "shifted-sparse attention does not currently support sample packing" |
| in str(exc.value) |
| ) |
|
|
| @pytest.mark.parametrize("adapter", ["lora", "qlora", None]) |
| @pytest.mark.parametrize("load_in_8bit", [True, False]) |
| @pytest.mark.parametrize("load_in_4bit", [True, False]) |
| @pytest.mark.parametrize("gptq", [True, False]) |
| def test_set_quantization_config( |
| self, |
| adapter, |
| load_in_8bit, |
| load_in_4bit, |
| gptq, |
| ): |
| |
| self.cfg.load_in_8bit = load_in_8bit |
| self.cfg.load_in_4bit = load_in_4bit |
| self.cfg.gptq = gptq |
| self.cfg.adapter = adapter |
|
|
| self.model_loader.set_quantization_config() |
| if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq: |
| assert not ( |
| hasattr(self.model_loader.model_kwargs, "load_in_8bit") |
| and hasattr(self.model_loader.model_kwargs, "load_in_4bit") |
| ) |
| elif load_in_8bit and self.cfg.adapter is not None: |
| assert self.model_loader.model_kwargs["load_in_8bit"] |
| elif load_in_4bit and self.cfg.adapter is not None: |
| assert self.model_loader.model_kwargs["load_in_4bit"] |
|
|
| if (self.cfg.adapter == "qlora" and load_in_4bit) or ( |
| self.cfg.adapter == "lora" and load_in_8bit |
| ): |
| assert self.model_loader.model_kwargs.get( |
| "quantization_config", BitsAndBytesConfig |
| ) |
|
|
| def test_message_property_mapping(self): |
| """Test message property mapping configuration validation""" |
| from axolotl.utils.schemas.datasets import SFTDataset |
|
|
| |
| dataset = SFTDataset( |
| path="test_path", |
| message_field_role="role_field", |
| message_field_content="content_field", |
| ) |
| assert dataset.message_property_mappings == { |
| "role": "role_field", |
| "content": "content_field", |
| } |
|
|
| |
| dataset = SFTDataset( |
| path="test_path", |
| message_property_mappings={ |
| "role": "custom_role", |
| "content": "custom_content", |
| }, |
| ) |
| assert dataset.message_property_mappings == { |
| "role": "custom_role", |
| "content": "custom_content", |
| } |
|
|
| |
| dataset = SFTDataset( |
| path="test_path", |
| message_field_role="same_role", |
| message_property_mappings={"role": "same_role"}, |
| ) |
| assert dataset.message_property_mappings == { |
| "role": "same_role", |
| "content": "content", |
| } |
|
|
| |
| dataset = SFTDataset( |
| path="test_path", |
| message_field_role="role_field", |
| message_property_mappings={"content": "content_field"}, |
| ) |
| assert dataset.message_property_mappings == { |
| "role": "role_field", |
| "content": "content_field", |
| } |
|
|
| |
| dataset = SFTDataset( |
| path="test_path", |
| ) |
| assert dataset.message_property_mappings == { |
| "role": "role", |
| "content": "content", |
| } |
|
|
| |
| with pytest.raises(ValueError) as exc_info: |
| SFTDataset( |
| path="test_path", |
| message_field_role="legacy_role", |
| message_property_mappings={"role": "different_role"}, |
| ) |
| assert "Conflicting message role fields" in str(exc_info.value) |
|
|
| with pytest.raises(ValueError) as exc_info: |
| SFTDataset( |
| path="test_path", |
| message_field_content="legacy_content", |
| message_property_mappings={"content": "different_content"}, |
| ) |
| assert "Conflicting message content fields" in str(exc_info.value) |
|
|