| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import pytest |
| import torch |
| from peft import LoraConfig, get_peft_model |
| from transformers import AutoModelForCausalLM, Qwen3Config |
|
|
| from verl.utils.fsdp_utils import normalize_peft_param_name |
|
|
|
|
| def create_base_model(): |
| """Create a simple base model for testing.""" |
| config = Qwen3Config( |
| num_hidden_layers=2, |
| num_attention_heads=2, |
| num_key_value_heads=2, |
| hidden_size=128, |
| intermediate_size=256, |
| ) |
| model = AutoModelForCausalLM.from_config(config) |
| return model |
|
|
|
|
| def create_peft_model(): |
| lora_config = LoraConfig( |
| r=8, lora_alpha=16, target_modules="all-linear", lora_dropout=0.0, bias="none", task_type="CAUSAL_LM" |
| ) |
| model = create_base_model() |
| model = get_peft_model(model, lora_config) |
| return model |
|
|
|
|
| @pytest.fixture |
| def base_model(): |
| """Create a simple base model for testing.""" |
| return create_base_model() |
|
|
|
|
| @pytest.fixture |
| def peft_model(): |
| """Create a PEFT model with LoRA adapters.""" |
| return create_peft_model() |
|
|
|
|
| def test_normalize_peft_param_name_keys_match_base_model(): |
| """Test that normalized PEFT model keys match base model keys.""" |
| |
| base_model = create_base_model() |
| peft_model = create_peft_model() |
| base_state_dict = base_model.state_dict() |
| peft_state_dict = peft_model.state_dict() |
|
|
| |
| normalized_peft_state_dict = normalize_peft_param_name(peft_state_dict) |
|
|
| |
| base_keys = set(base_state_dict.keys()) |
| normalized_peft_keys = set(normalized_peft_state_dict.keys()) |
| print(f"{base_keys=}") |
| print(f"{normalized_peft_keys=}") |
|
|
| |
| missing_keys = base_keys - normalized_peft_keys |
| assert len(missing_keys) == 0, f"Missing keys from base model: {missing_keys}" |
|
|
| |
| extra_keys = normalized_peft_keys - base_keys |
| assert len(extra_keys) == 0, f"Extra keys not in base model: {extra_keys}" |
|
|
| |
| assert base_keys == normalized_peft_keys, "Normalized PEFT keys should exactly match base model keys" |
|
|
|
|
| def test_normalize_peft_param_name_removes_lora_keys(peft_model): |
| """Test that LoRA-specific parameters are removed after normalization.""" |
| peft_state_dict = peft_model.state_dict() |
|
|
| |
| lora_keys_before = [k for k in peft_state_dict.keys() if "lora_" in k] |
| assert len(lora_keys_before) > 0, "PEFT model should have LoRA parameters" |
|
|
| |
| normalized_state_dict = normalize_peft_param_name(peft_state_dict) |
| lora_keys_after = [k for k in normalized_state_dict.keys() if "lora_" in k] |
| assert len(lora_keys_after) == 0, ( |
| f"Normalized state dict should not contain LoRA keys, but found: {lora_keys_after}" |
| ) |
|
|
|
|
| def test_normalize_peft_param_name_removes_base_model_prefix(peft_model): |
| """Test that base_model prefix is removed from parameter names.""" |
| peft_state_dict = peft_model.state_dict() |
|
|
| |
| base_model_keys = [k for k in peft_state_dict.keys() if "base_model" in k] |
| assert len(base_model_keys) > 0, "PEFT model should have base_model prefix" |
|
|
| |
| normalized_state_dict = normalize_peft_param_name(peft_state_dict) |
| base_model_keys_after = [k for k in normalized_state_dict.keys() if "base_model" in k] |
| assert len(base_model_keys_after) == 0, ( |
| f"Normalized keys should not contain base_model prefix, but found: {base_model_keys_after}" |
| ) |
|
|
|
|
| def test_normalize_peft_param_name_removes_base_layer_suffix(peft_model): |
| """Test that .base_layer suffix is removed from parameter names.""" |
| peft_state_dict = peft_model.state_dict() |
|
|
| |
| base_layer_keys = [k for k in peft_state_dict.keys() if ".base_layer" in k] |
| assert len(base_layer_keys) > 0, "PEFT model should have .base_layer suffix" |
|
|
| |
| normalized_state_dict = normalize_peft_param_name(peft_state_dict) |
| base_layer_keys_after = [k for k in normalized_state_dict.keys() if ".base_layer" in k] |
| assert len(base_layer_keys_after) == 0, ( |
| f"Normalized keys should not contain .base_layer suffix, but found: {base_layer_keys_after}" |
| ) |
|
|
|
|
| def test_normalize_peft_param_name_tensor_shapes_match(base_model, peft_model): |
| """Test that tensor shapes match between base model and normalized PEFT model.""" |
| base_state_dict = base_model.state_dict() |
| peft_state_dict = peft_model.state_dict() |
|
|
| |
| normalized_peft_state_dict = normalize_peft_param_name(peft_state_dict) |
|
|
| |
| for key in base_state_dict.keys(): |
| assert key in normalized_peft_state_dict, f"Key {key} not found in normalized PEFT state dict" |
| base_shape = base_state_dict[key].shape |
| peft_shape = normalized_peft_state_dict[key].shape |
| assert base_shape == peft_shape, f"Shape mismatch for {key}: base={base_shape}, peft={peft_shape}" |
|
|
|
|
| def test_normalize_peft_param_name_empty_dict(): |
| """Test that normalize_peft_param_name handles empty dict.""" |
| result = normalize_peft_param_name({}) |
| assert result == {}, "Empty dict should return empty dict" |
|
|
|
|
| @pytest.mark.parametrize( |
| "lora_key_pattern", |
| [ |
| "model.layers.0.self_attn.q_proj.lora_A.default.weight", |
| "model.layers.0.self_attn.q_proj.lora_B.default.weight", |
| "model.layers.0.adapter_layer.weight", |
| "base_model.model.layers.0.lora_embedding_A", |
| ], |
| ) |
| def test_normalize_peft_param_name_filters_lora_patterns(lora_key_pattern): |
| """Test that various LoRA key patterns are filtered out.""" |
| test_dict = { |
| lora_key_pattern: torch.randn(10, 10), |
| "model.layers.0.weight": torch.randn(10, 10), |
| } |
|
|
| normalized = normalize_peft_param_name(test_dict) |
|
|
| |
| assert lora_key_pattern not in normalized, f"LoRA key {lora_key_pattern} should be filtered out" |
|
|
| |
| assert len(normalized) == 1, "Should have exactly one key remaining" |
| assert "model.layers.0.weight" in normalized, "Regular weight should remain" |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__, "-v"]) |
|
|