arithmetic-grpo / tests /utils /test_normalize_peft_param_name_on_cpu.py
LeTue09's picture
initial clean commit
1faccd4
# Copyright 2026 Amazon.com Inc and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, Qwen3Config
from verl.utils.fsdp_utils import normalize_peft_param_name
def create_base_model():
"""Create a simple base model for testing."""
config = Qwen3Config(
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
hidden_size=128,
intermediate_size=256,
)
model = AutoModelForCausalLM.from_config(config)
return model
def create_peft_model():
lora_config = LoraConfig(
r=8, lora_alpha=16, target_modules="all-linear", lora_dropout=0.0, bias="none", task_type="CAUSAL_LM"
)
model = create_base_model()
model = get_peft_model(model, lora_config)
return model
@pytest.fixture
def base_model():
"""Create a simple base model for testing."""
return create_base_model()
@pytest.fixture
def peft_model():
"""Create a PEFT model with LoRA adapters."""
return create_peft_model()
def test_normalize_peft_param_name_keys_match_base_model():
"""Test that normalized PEFT model keys match base model keys."""
# Get state dicts
base_model = create_base_model()
peft_model = create_peft_model()
base_state_dict = base_model.state_dict()
peft_state_dict = peft_model.state_dict()
# Normalize PEFT model keys
normalized_peft_state_dict = normalize_peft_param_name(peft_state_dict)
# Get key sets
base_keys = set(base_state_dict.keys())
normalized_peft_keys = set(normalized_peft_state_dict.keys())
print(f"{base_keys=}")
print(f"{normalized_peft_keys=}")
# Verify that all base model keys are in the normalized PEFT keys
missing_keys = base_keys - normalized_peft_keys
assert len(missing_keys) == 0, f"Missing keys from base model: {missing_keys}"
# Verify that all normalized PEFT keys are in the base model
extra_keys = normalized_peft_keys - base_keys
assert len(extra_keys) == 0, f"Extra keys not in base model: {extra_keys}"
# Verify exact match
assert base_keys == normalized_peft_keys, "Normalized PEFT keys should exactly match base model keys"
def test_normalize_peft_param_name_removes_lora_keys(peft_model):
"""Test that LoRA-specific parameters are removed after normalization."""
peft_state_dict = peft_model.state_dict()
# Before normalization, should have lora_A and lora_B keys
lora_keys_before = [k for k in peft_state_dict.keys() if "lora_" in k]
assert len(lora_keys_before) > 0, "PEFT model should have LoRA parameters"
# After normalization, should not have any lora keys
normalized_state_dict = normalize_peft_param_name(peft_state_dict)
lora_keys_after = [k for k in normalized_state_dict.keys() if "lora_" in k]
assert len(lora_keys_after) == 0, (
f"Normalized state dict should not contain LoRA keys, but found: {lora_keys_after}"
)
def test_normalize_peft_param_name_removes_base_model_prefix(peft_model):
"""Test that base_model prefix is removed from parameter names."""
peft_state_dict = peft_model.state_dict()
# Before normalization, should have base_model prefix
base_model_keys = [k for k in peft_state_dict.keys() if "base_model" in k]
assert len(base_model_keys) > 0, "PEFT model should have base_model prefix"
# After normalization, should not have base_model prefix
normalized_state_dict = normalize_peft_param_name(peft_state_dict)
base_model_keys_after = [k for k in normalized_state_dict.keys() if "base_model" in k]
assert len(base_model_keys_after) == 0, (
f"Normalized keys should not contain base_model prefix, but found: {base_model_keys_after}"
)
def test_normalize_peft_param_name_removes_base_layer_suffix(peft_model):
"""Test that .base_layer suffix is removed from parameter names."""
peft_state_dict = peft_model.state_dict()
# Before normalization, should have .base_layer suffix
base_layer_keys = [k for k in peft_state_dict.keys() if ".base_layer" in k]
assert len(base_layer_keys) > 0, "PEFT model should have .base_layer suffix"
# After normalization, should not have .base_layer suffix
normalized_state_dict = normalize_peft_param_name(peft_state_dict)
base_layer_keys_after = [k for k in normalized_state_dict.keys() if ".base_layer" in k]
assert len(base_layer_keys_after) == 0, (
f"Normalized keys should not contain .base_layer suffix, but found: {base_layer_keys_after}"
)
def test_normalize_peft_param_name_tensor_shapes_match(base_model, peft_model):
"""Test that tensor shapes match between base model and normalized PEFT model."""
base_state_dict = base_model.state_dict()
peft_state_dict = peft_model.state_dict()
# Normalize PEFT model keys
normalized_peft_state_dict = normalize_peft_param_name(peft_state_dict)
# Check that shapes match for all common keys
for key in base_state_dict.keys():
assert key in normalized_peft_state_dict, f"Key {key} not found in normalized PEFT state dict"
base_shape = base_state_dict[key].shape
peft_shape = normalized_peft_state_dict[key].shape
assert base_shape == peft_shape, f"Shape mismatch for {key}: base={base_shape}, peft={peft_shape}"
def test_normalize_peft_param_name_empty_dict():
"""Test that normalize_peft_param_name handles empty dict."""
result = normalize_peft_param_name({})
assert result == {}, "Empty dict should return empty dict"
@pytest.mark.parametrize(
"lora_key_pattern",
[
"model.layers.0.self_attn.q_proj.lora_A.default.weight",
"model.layers.0.self_attn.q_proj.lora_B.default.weight",
"model.layers.0.adapter_layer.weight",
"base_model.model.layers.0.lora_embedding_A",
],
)
def test_normalize_peft_param_name_filters_lora_patterns(lora_key_pattern):
"""Test that various LoRA key patterns are filtered out."""
test_dict = {
lora_key_pattern: torch.randn(10, 10),
"model.layers.0.weight": torch.randn(10, 10),
}
normalized = normalize_peft_param_name(test_dict)
# LoRA key should be filtered out
assert lora_key_pattern not in normalized, f"LoRA key {lora_key_pattern} should be filtered out"
# Regular key should remain
assert len(normalized) == 1, "Should have exactly one key remaining"
assert "model.layers.0.weight" in normalized, "Regular weight should remain"
if __name__ == "__main__":
pytest.main([__file__, "-v"])