transformers / tests /utils /test_core_model_loading.py
AbdulElahGwaith's picture
Upload folder using huggingface_hub
a9bd396 verified
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from types import SimpleNamespace
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from transformers.conversion_mapping import get_checkpoint_conversion_mapping, register_checkpoint_conversion_mapping
from transformers.core_model_loading import (
Chunk,
Concatenate,
ErnieFuseAndSplitTextVisionExperts,
MergeModulelist,
PermuteForRope,
WeightConverter,
WeightRenaming,
build_glob_alternation,
convert_and_load_state_dict_in_model,
rename_source_key,
revert_weight_conversion,
)
from transformers.utils.import_utils import is_triton_available
from ..test_modeling_common import compare_state_dicts
class TestWeightGlobMatching(unittest.TestCase):
def setUp(self):
self.weight_globs_digits = [
"model.layers.*.mlp.gate_up_proj.weight",
"model.layers.*.self_attn.q_proj.weight",
"embed_tokens.weight",
]
self.alt_digits, self.map_digits, _ = build_glob_alternation(self.weight_globs_digits)
self.weight_globs_any = [
"model.layers.*.mlp.gate_up_proj.weight",
"model.layers.*.self_attn.q_proj.weight",
"embed_tokens.weight",
]
self.alt_any, self.map_any, _ = build_glob_alternation(self.weight_globs_any)
@staticmethod
def _match_glob(key, alt, mapping):
matched = alt.search(key)
return mapping.get(matched.lastgroup) if matched else None
def test_exact_match(self):
self.assertEqual(
self._match_glob("embed_tokens.weight", self.alt_digits, self.map_digits), "embed_tokens.weight"
)
def test_digits_only_star_accepts_digits(self):
self.assertEqual(
self._match_glob("model.layers.0.mlp.gate_up_proj.weight", self.alt_digits, self.map_digits),
"model.layers.*.mlp.gate_up_proj.weight",
)
self.assertEqual(
self._match_glob("model.layers.12.self_attn.q_proj.weight", self.alt_digits, self.map_digits),
"model.layers.*.self_attn.q_proj.weight",
)
def test_anychar_star_accepts_nondigits(self):
self.assertEqual(
self._match_glob("model.layers.a.mlp.gate_up_proj.weight", self.alt_any, self.map_any),
"model.layers.*.mlp.gate_up_proj.weight",
)
self.assertEqual(
self._match_glob("model.layers.00x.mlp.gate_up_proj.weight", self.alt_any, self.map_any),
"model.layers.*.mlp.gate_up_proj.weight",
)
def test_no_match(self):
self.assertIsNone(self._match_glob("model.layers.0.mlp.up_proj.weight", self.alt_digits, self.map_digits))
def test_leftmost_alternative_wins_for_overlapping_patterns(self):
# Overlapping patterns: both could match; ensure leftmost wins
globs = [
"model.layers.*.mlp.*.weight", # broader (first)
"model.layers.0.mlp.gate_up_proj.weight", # more specific (second)
]
alt, mapping, _ = build_glob_alternation(globs)
# Both branches match; Python's regex picks the leftmost alternative → index 0
self.assertEqual(
self._match_glob("model.layers.0.mlp.gate_up_proj.weight", alt, mapping), "model.layers.*.mlp.*.weight"
)
def test_multiple_patterns_same_prefix(self):
globs = [
"model.layers.*.self_attn.q_proj.weight",
"model.layers.*.self_attn.k_proj.weight",
"model.layers.*.self_attn.v_proj.weight",
]
alt, mapping, _ = build_glob_alternation(
globs,
)
self.assertEqual(
self._match_glob("model.layers.3.self_attn.q_proj.weight", alt, mapping),
"model.layers.*.self_attn.q_proj.weight",
)
self.assertEqual(
self._match_glob("model.layers.3.self_attn.k_proj.weight", alt, mapping),
"model.layers.*.self_attn.k_proj.weight",
)
self.assertEqual(
self._match_glob("model.layers.3.self_attn.v_proj.weight", alt, mapping),
"model.layers.*.self_attn.v_proj.weight",
)
def test_anchor_full_match_only(self):
self.assertIsNotNone(
self._match_glob("model.layers.0.mlp.gate_up_proj.weight.bar", self.alt_any, self.map_any)
)
def test_large_batch_performance_smoke(self):
# Not a perf benchmark, but ensures building and matching a larger alternation is OK
globs = [f"model.layers.*.mlp.block{i}.weight" for i in range(200)]
alt, mapping, _ = build_glob_alternation(globs)
key = "model.layers.123.mlp.block57.weight"
self.assertEqual(self._match_glob(key, alt, mapping), "model.layers.*.mlp.block57.weight")
def test_sub_key_rewrites_targets(self):
renamings = [
WeightRenaming("block_sparse_moe.experts.*.w1.weight", "mlp.experts.gate_up_proj"),
WeightRenaming("block_sparse_moe.experts.*.w2.weight", "mlp.experts.down_proj"),
WeightRenaming("model.language_model.*", "language_model"),
]
self.assertEqual(
rename_source_key("foo.block_sparse_moe.experts.3.w1.weight", renamings, [])[0],
"foo.mlp.experts.gate_up_proj",
)
self.assertEqual(
rename_source_key("foo.block_sparse_moe.experts.3.w2.weight", renamings, [])[0],
"foo.mlp.experts.down_proj",
)
self.assertEqual(rename_source_key("model.language_model.lm_head.weight", renamings, [])[0], "language_model")
def test_sub_key_no_match_returns_original(self):
renamings = [
WeightRenaming("block_sparse_moe.experts.*.w1.weight", "*.mlp.experts.gate_up_proj"),
]
key = "unrelated.key"
renamed_key, _ = rename_source_key(key, renamings, [])
self.assertEqual(renamed_key, key)
class DummyParamModule(nn.Module):
def __init__(self, shape):
super().__init__()
self.weight = nn.Parameter(torch.zeros(shape))
class DummySelfAttn(nn.Module):
def __init__(self):
super().__init__()
self.q_proj = DummyParamModule((1, 2))
self.k_proj = DummyParamModule((1, 2))
self.v_proj = DummyParamModule((1, 2))
class DummyExperts(nn.Module):
def __init__(self):
super().__init__()
self.gate_up_proj = DummyParamModule((2, 4, 2))
self.down_proj = DummyParamModule((2, 2, 2))
class DummyLayer(nn.Module):
def __init__(self, add_extra_moe=False):
super().__init__()
self.self_attn = DummySelfAttn()
self.experts = DummyExperts()
if add_extra_moe:
self.extra_experts = DummyExperts()
class DummyTopModel(nn.Module):
def __init__(self, add_extra_moe=False):
super().__init__()
self.layers = nn.ModuleList([DummyLayer(add_extra_moe), DummyLayer(add_extra_moe)])
class DummyMLP(nn.Module):
def __init__(self):
super().__init__()
self.down_proj = DummyParamModule((2, 2))
class DummyRoot(nn.Module):
base_model_prefix = "model"
def __init__(self, add_extra_moe=False):
super().__init__()
self.model = DummyTopModel(add_extra_moe)
self.mlp = DummyMLP()
class TestConvertAndLoadStateDict(unittest.TestCase):
def test_moe_and_qkv_conversion(self):
model = DummyRoot()
model.config = PretrainedConfig()
raw_tensors = {
"model.layers.0.experts.0.w1.weight": torch.tensor([[0.0, 1.0], [2.0, 3.0]]),
"model.layers.0.experts.1.w1.weight": torch.tensor([[10.0, 11.0], [12.0, 13.0]]),
"model.layers.0.experts.0.w3.weight": torch.tensor([[4.0, 5.0], [6.0, 7.0]]),
"model.layers.0.experts.1.w3.weight": torch.tensor([[14.0, 15.0], [16.0, 17.0]]),
"model.layers.0.experts.0.w2.weight": torch.tensor([[20.0, 21.0], [22.0, 23.0]]),
"model.layers.0.experts.1.w2.weight": torch.tensor([[24.0, 25.0], [26.0, 27.0]]),
"model.layers.1.experts.0.w1.weight": torch.tensor([[30.0, 31.0], [32.0, 33.0]]),
"model.layers.1.experts.1.w1.weight": torch.tensor([[34.0, 35.0], [36.0, 37.0]]),
"model.layers.1.experts.0.w3.weight": torch.tensor([[38.0, 39.0], [40.0, 41.0]]),
"model.layers.1.experts.1.w3.weight": torch.tensor([[42.0, 43.0], [44.0, 45.0]]),
"model.layers.1.experts.0.w2.weight": torch.tensor([[46.0, 47.0], [48.0, 49.0]]),
"model.layers.1.experts.1.w2.weight": torch.tensor([[50.0, 51.0], [52.0, 53.0]]),
"model.layers.0.self_attn.qkv_proj.weight": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
"model.layers.1.self_attn.qkv_proj.weight": torch.tensor([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]),
"mlp.w2.weight": torch.tensor([[60.0, 61.0], [62.0, 63.0]]),
}
state_dict = {k: v.clone() for k, v in raw_tensors.items()}
weight_mapping = [
WeightConverter(
["experts.*.w1.weight", "experts.*.w3.weight"],
"experts.gate_up_proj.weight",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
"experts.*.w2.weight",
"experts.down_proj.weight",
operations=[MergeModulelist(dim=0)],
),
WeightConverter(
"model.layers.0.self_attn.qkv_proj.weight",
[
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.self_attn.k_proj.weight",
"model.layers.0.self_attn.v_proj.weight",
],
operations=[Chunk(dim=0)],
),
WeightRenaming("mlp.w2.weight", "mlp.down_proj.weight"),
]
missing, unexpected, mismatch, _, misc = convert_and_load_state_dict_in_model(
model, state_dict, weight_mapping, tp_plan=None, hf_quantizer=None
)
self.assertEqual(
missing,
{
"model.layers.1.self_attn.k_proj.weight",
"model.layers.1.self_attn.v_proj.weight",
"model.layers.1.self_attn.q_proj.weight",
},
)
self.assertEqual(unexpected, {"model.layers.1.self_attn.qkv_proj.weight"})
self.assertEqual(mismatch, set())
self.assertEqual(misc, {})
model_state = model.state_dict()
def cat_gate(layer_prefix: str) -> torch.Tensor:
w1 = [
raw_tensors[f"{layer_prefix}.experts.0.w1.weight"],
raw_tensors[f"{layer_prefix}.experts.1.w1.weight"],
]
w3 = [
raw_tensors[f"{layer_prefix}.experts.0.w3.weight"],
raw_tensors[f"{layer_prefix}.experts.1.w3.weight"],
]
return torch.cat([torch.stack(w1, dim=0), torch.stack(w3, dim=0)], dim=1)
torch.testing.assert_close(
model_state["model.layers.0.experts.gate_up_proj.weight"], cat_gate("model.layers.0")
)
torch.testing.assert_close(
model_state["model.layers.1.experts.gate_up_proj.weight"], cat_gate("model.layers.1")
)
def stack_down(layer_prefix: str) -> torch.Tensor:
return torch.stack(
[
raw_tensors[f"{layer_prefix}.experts.0.w2.weight"],
raw_tensors[f"{layer_prefix}.experts.1.w2.weight"],
],
dim=0,
)
torch.testing.assert_close(
model_state["model.layers.0.experts.down_proj.weight"], stack_down("model.layers.0")
)
torch.testing.assert_close(
model_state["model.layers.1.experts.down_proj.weight"], stack_down("model.layers.1")
)
for layer_idx in range(2):
key = f"model.layers.{layer_idx}.self_attn.qkv_proj.weight"
expected_q, expected_k, expected_v = torch.chunk(raw_tensors[key], chunks=3, dim=0)
prefix = f"model.layers.{layer_idx}.self_attn"
if layer_idx == 1:
# These were missing and thus not loaded
continue
torch.testing.assert_close(model_state[f"{prefix}.q_proj.weight"], expected_q)
torch.testing.assert_close(model_state[f"{prefix}.k_proj.weight"], expected_k)
torch.testing.assert_close(model_state[f"{prefix}.v_proj.weight"], expected_v)
torch.testing.assert_close(model_state["mlp.down_proj.weight"], raw_tensors["mlp.w2.weight"])
def test_moe_and_qkv_conversion_reversed(self):
model = DummyRoot()
model.config = PretrainedConfig()
raw_tensors = {
"model.layers.0.experts.0.w1.weight": torch.tensor([[0.0, 1.0], [2.0, 3.0]]),
"model.layers.0.experts.1.w1.weight": torch.tensor([[10.0, 11.0], [12.0, 13.0]]),
"model.layers.0.experts.0.w3.weight": torch.tensor([[4.0, 5.0], [6.0, 7.0]]),
"model.layers.0.experts.1.w3.weight": torch.tensor([[14.0, 15.0], [16.0, 17.0]]),
"model.layers.0.experts.0.w2.weight": torch.tensor([[20.0, 21.0], [22.0, 23.0]]),
"model.layers.0.experts.1.w2.weight": torch.tensor([[24.0, 25.0], [26.0, 27.0]]),
"model.layers.1.experts.0.w1.weight": torch.tensor([[30.0, 31.0], [32.0, 33.0]]),
"model.layers.1.experts.1.w1.weight": torch.tensor([[34.0, 35.0], [36.0, 37.0]]),
"model.layers.1.experts.0.w3.weight": torch.tensor([[38.0, 39.0], [40.0, 41.0]]),
"model.layers.1.experts.1.w3.weight": torch.tensor([[42.0, 43.0], [44.0, 45.0]]),
"model.layers.1.experts.0.w2.weight": torch.tensor([[46.0, 47.0], [48.0, 49.0]]),
"model.layers.1.experts.1.w2.weight": torch.tensor([[50.0, 51.0], [52.0, 53.0]]),
"model.layers.0.self_attn.qkv_proj.weight": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
"model.layers.1.self_attn.qkv_proj.weight": torch.tensor([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]),
"mlp.w2.weight": torch.tensor([[60.0, 61.0], [62.0, 63.0]]),
}
state_dict = {k: v.clone() for k, v in raw_tensors.items()}
weight_mapping = [
WeightConverter(
["experts.*.w1.weight", "experts.*.w3.weight"],
"experts.gate_up_proj.weight",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
"experts.*.w2.weight",
"experts.down_proj.weight",
operations=[MergeModulelist(dim=0)],
),
WeightConverter(
"self_attn.qkv_proj.weight",
[
"self_attn.q_proj.weight",
"self_attn.k_proj.weight",
"self_attn.v_proj.weight",
],
operations=[Chunk(dim=0)],
),
WeightRenaming("mlp.w2.weight", "mlp.down_proj.weight"),
]
# Use the mapping to load
missing, unexpected, mismatch, _, misc = convert_and_load_state_dict_in_model(
model, state_dict, weight_mapping, tp_plan=None, hf_quantizer=None
)
self.assertTrue(len(missing) == 0)
self.assertTrue(len(unexpected) == 0)
self.assertTrue(len(mismatch) == 0)
self.assertTrue(len(misc) == 0)
# Try to revert the mapping
reversed_state_dict = revert_weight_conversion(model, model.state_dict())
# Make sure both saved state_dict are identical
self.assertTrue(compare_state_dicts(reversed_state_dict, state_dict))
def test_qkv_chunk_rope_permute_with_fp8_quantization(self):
if is_triton_available():
from transformers.integrations.finegrained_fp8 import Fp8Dequantize
else:
self.skipTest("Fine-grained FP8 integration tests require Triton to be installed.")
n_heads = 2
head_dim = 4
in_dim = 4
out_dim = n_heads * head_dim
block_size = (4, 4)
class RopeProjector(nn.Module):
def __init__(self, *, with_scale: bool = False):
super().__init__()
self.weight = nn.Parameter(torch.zeros(out_dim, in_dim))
if with_scale:
scale_shape = (out_dim // block_size[0], in_dim // block_size[1])
self.weight_scale_inv = nn.Parameter(torch.ones(scale_shape))
class RopeSelfAttn(nn.Module):
def __init__(self):
super().__init__()
self.q_proj = RopeProjector(with_scale=True)
self.k_proj = RopeProjector()
self.v_proj = RopeProjector()
class RopeLayer(nn.Module):
def __init__(self):
super().__init__()
self.self_attn = RopeSelfAttn()
class RopeModel(nn.Module):
base_model_prefix = "model"
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([RopeLayer()])
model = RopeModel()
model.config = PretrainedConfig()
model.config.num_attention_heads = n_heads
raw_q = torch.tensor(
[
[1.0, -1.0, 1.0, -1.0],
[0.5, -0.5, 0.5, -0.5],
[-1.0, 1.0, -1.0, 1.0],
[-0.5, 0.5, -0.5, 0.5],
[1.0, 1.0, -1.0, -1.0],
[0.5, 0.5, -0.5, -0.5],
[-1.0, -1.0, 1.0, 1.0],
[-0.5, -0.5, 0.5, 0.5],
],
dtype=torch.float32,
)
raw_k = torch.arange(out_dim * in_dim, dtype=torch.float32).reshape(out_dim, in_dim)
raw_v = torch.arange(out_dim * in_dim, dtype=torch.float32).reshape(out_dim, in_dim) + 100.0
raw_qkv = torch.cat([raw_q, raw_k, raw_v], dim=0)
state_dict = {"model.layers.0.self_attn.qkv_proj.weight": raw_qkv.clone()}
quantizer_cls = type(
"FineGrainedFP8HfQuantizer",
(),
{
"__init__": lambda self, bs=block_size: setattr(
self, "quantization_config", SimpleNamespace(weight_block_size=bs)
),
"param_needs_quantization": lambda self, _model, param_name: param_name.endswith("q_proj.weight"),
"pre_quantized": False,
},
)
quantizer = quantizer_cls()
weight_mapping = [
WeightConverter(
"model.layers.*.self_attn.qkv_proj.weight",
[
"model.layers.*.self_attn.q_proj.weight",
"model.layers.*.self_attn.k_proj.weight",
"model.layers.*.self_attn.v_proj.weight",
],
operations=[Chunk(dim=0), PermuteForRope()],
)
]
missing, unexpected, mismatch, _, misc = convert_and_load_state_dict_in_model(
model, state_dict, weight_mapping, tp_plan=None, hf_quantizer=quantizer
)
self.assertEqual(missing, set())
self.assertEqual(unexpected, set())
self.assertEqual(mismatch, set())
self.assertEqual(misc, {})
permute_op = PermuteForRope()
permute_op.config = model.config
expected_q = permute_op._apply(raw_q)
expected_k = permute_op._apply(raw_k)
expected_v = permute_op._apply(raw_v)
model_state = model.state_dict()
self.assertFalse(torch.allclose(raw_k, expected_k))
torch.testing.assert_close(model_state["model.layers.0.self_attn.k_proj.weight"], expected_k)
torch.testing.assert_close(model_state["model.layers.0.self_attn.v_proj.weight"], expected_v)
q_weight_key = "model.layers.0.self_attn.q_proj.weight"
scale_key = "model.layers.0.self_attn.q_proj.weight_scale_inv"
self.assertIn(scale_key, model_state)
expected_dtype = torch.float8_e4m3fn if hasattr(torch, "float8_e4m3fn") else torch.int8
self.assertEqual(model_state[q_weight_key].dtype, expected_dtype)
self.assertEqual(model_state[q_weight_key].shape, torch.Size((out_dim, in_dim)))
self.assertEqual(model_state[scale_key].dtype, torch.float32)
self.assertEqual(
model_state[scale_key].shape,
torch.Size((out_dim // block_size[0], in_dim // block_size[1])),
)
dequant = Fp8Dequantize(block_size=block_size)
dequantized_q = dequant.convert(
[model_state[q_weight_key], model_state[scale_key]],
context={"quantization_config": quantizer.quantization_config},
)
torch.testing.assert_close(dequantized_q, expected_q, rtol=1e-2, atol=1e-2)
def test_ernie4_5_vl_moe_conversion(self):
model = DummyRoot(add_extra_moe=True)
model.config = PretrainedConfig()
raw_tensors = {
"model.layers.0.experts.0.w1.weight": torch.tensor([[0.0, 1.0], [2.0, 3.0]]),
"model.layers.0.experts.1.w1.weight": torch.tensor([[10.0, 11.0], [12.0, 13.0]]),
"model.layers.0.experts.2.w1.weight": torch.tensor([[11.0, 12.0], [13.0, 14.0]]),
"model.layers.0.experts.3.w1.weight": torch.tensor([[12.0, 13.0], [14.0, 15.0]]),
"model.layers.0.experts.0.w3.weight": torch.tensor([[4.0, 5.0], [6.0, 7.0]]),
"model.layers.0.experts.1.w3.weight": torch.tensor([[14.0, 15.0], [16.0, 17.0]]),
"model.layers.0.experts.2.w3.weight": torch.tensor([[15.0, 16.0], [17.0, 18.0]]),
"model.layers.0.experts.3.w3.weight": torch.tensor([[16.0, 17.0], [18.0, 19.0]]),
"model.layers.0.experts.0.w2.weight": torch.tensor([[20.0, 21.0], [22.0, 23.0]]),
"model.layers.0.experts.1.w2.weight": torch.tensor([[24.0, 25.0], [26.0, 27.0]]),
"model.layers.0.experts.2.w2.weight": torch.tensor([[25.0, 26.0], [27.0, 28.0]]),
"model.layers.0.experts.3.w2.weight": torch.tensor([[26.0, 27.0], [28.0, 29.0]]),
"model.layers.1.experts.0.w1.weight": torch.tensor([[30.0, 31.0], [32.0, 33.0]]),
"model.layers.1.experts.1.w1.weight": torch.tensor([[34.0, 35.0], [36.0, 37.0]]),
"model.layers.1.experts.2.w1.weight": torch.tensor([[35.0, 36.0], [37.0, 38.0]]),
"model.layers.1.experts.3.w1.weight": torch.tensor([[36.0, 37.0], [38.0, 39.0]]),
"model.layers.1.experts.0.w3.weight": torch.tensor([[38.0, 39.0], [40.0, 41.0]]),
"model.layers.1.experts.1.w3.weight": torch.tensor([[42.0, 43.0], [44.0, 45.0]]),
"model.layers.1.experts.2.w3.weight": torch.tensor([[43.0, 44.0], [45.0, 46.0]]),
"model.layers.1.experts.3.w3.weight": torch.tensor([[44.0, 45.0], [46.0, 47.0]]),
"model.layers.1.experts.0.w2.weight": torch.tensor([[46.0, 47.0], [48.0, 49.0]]),
"model.layers.1.experts.1.w2.weight": torch.tensor([[50.0, 51.0], [52.0, 53.0]]),
"model.layers.1.experts.2.w2.weight": torch.tensor([[51.0, 52.0], [53.0, 54.0]]),
"model.layers.1.experts.3.w2.weight": torch.tensor([[52.0, 53.0], [54.0, 55.0]]),
"model.layers.0.self_attn.qkv_proj.weight": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
"model.layers.1.self_attn.qkv_proj.weight": torch.tensor([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]),
"mlp.w2.weight": torch.tensor([[60.0, 61.0], [62.0, 63.0]]),
}
state_dict = {k: v.clone() for k, v in raw_tensors.items()}
weight_mapping = [
WeightConverter(
["experts.*.w1.weight", "experts.*.w3.weight"],
["experts.gate_up_proj.weight", "extra_experts.gate_up_proj.weight"],
operations=[ErnieFuseAndSplitTextVisionExperts(stack_dim=0, concat_dim=1)],
),
WeightConverter(
"experts.*.w2.weight",
["experts.down_proj.weight", "extra_experts.down_proj.weight"],
operations=[ErnieFuseAndSplitTextVisionExperts(stack_dim=0, concat_dim=1)],
),
WeightConverter(
"self_attn.qkv_proj.weight",
[
"self_attn.q_proj.weight",
"self_attn.k_proj.weight",
"self_attn.v_proj.weight",
],
operations=[Chunk(dim=0)],
),
WeightRenaming("mlp.w2.weight", "mlp.down_proj.weight"),
]
missing, unexpected, mismatch, _, misc = convert_and_load_state_dict_in_model(
model, state_dict, weight_mapping, tp_plan=None, hf_quantizer=None
)
self.assertEqual(missing, set())
self.assertEqual(unexpected, set())
self.assertEqual(mismatch, set())
self.assertEqual(misc, {})
model_state = model.state_dict()
def cat_gate(layer_prefix: str) -> torch.Tensor:
moe_1_w1 = [
raw_tensors[f"{layer_prefix}.experts.0.w1.weight"],
raw_tensors[f"{layer_prefix}.experts.1.w1.weight"],
]
moe_2_w1 = [
raw_tensors[f"{layer_prefix}.experts.2.w1.weight"],
raw_tensors[f"{layer_prefix}.experts.3.w1.weight"],
]
moe_1_w3 = [
raw_tensors[f"{layer_prefix}.experts.0.w3.weight"],
raw_tensors[f"{layer_prefix}.experts.1.w3.weight"],
]
moe_2_w3 = [
raw_tensors[f"{layer_prefix}.experts.2.w3.weight"],
raw_tensors[f"{layer_prefix}.experts.3.w3.weight"],
]
moe_1 = torch.cat([torch.stack(moe_1_w1, dim=0), torch.stack(moe_1_w3, dim=0)], dim=1)
moe_2 = torch.cat([torch.stack(moe_2_w1, dim=0), torch.stack(moe_2_w3, dim=0)], dim=1)
return moe_1, moe_2
moe_1, moe_2 = cat_gate("model.layers.0")
torch.testing.assert_close(model_state["model.layers.0.experts.gate_up_proj.weight"], moe_1)
torch.testing.assert_close(model_state["model.layers.0.extra_experts.gate_up_proj.weight"], moe_2)
moe_1, moe_2 = cat_gate("model.layers.1")
torch.testing.assert_close(model_state["model.layers.1.experts.gate_up_proj.weight"], moe_1)
torch.testing.assert_close(model_state["model.layers.1.extra_experts.gate_up_proj.weight"], moe_2)
def stack_down(layer_prefix: str) -> torch.Tensor:
moe_1 = torch.stack(
[
raw_tensors[f"{layer_prefix}.experts.0.w2.weight"],
raw_tensors[f"{layer_prefix}.experts.1.w2.weight"],
],
dim=0,
)
moe_2 = torch.stack(
[
raw_tensors[f"{layer_prefix}.experts.2.w2.weight"],
raw_tensors[f"{layer_prefix}.experts.3.w2.weight"],
],
dim=0,
)
return moe_1, moe_2
moe_1, moe_2 = stack_down("model.layers.0")
torch.testing.assert_close(model_state["model.layers.0.experts.down_proj.weight"], moe_1)
torch.testing.assert_close(model_state["model.layers.0.extra_experts.down_proj.weight"], moe_2)
moe_1, moe_2 = stack_down("model.layers.1")
torch.testing.assert_close(model_state["model.layers.1.experts.down_proj.weight"], moe_1)
torch.testing.assert_close(model_state["model.layers.1.extra_experts.down_proj.weight"], moe_2)
def test_ernie4_5_vl_moe_conversion_reversed(self):
model = DummyRoot(add_extra_moe=True)
model.config = PretrainedConfig()
raw_tensors = {
"model.layers.0.experts.0.w1.weight": torch.tensor([[0.0, 1.0], [2.0, 3.0]]),
"model.layers.0.experts.1.w1.weight": torch.tensor([[10.0, 11.0], [12.0, 13.0]]),
"model.layers.0.experts.2.w1.weight": torch.tensor([[11.0, 12.0], [13.0, 14.0]]),
"model.layers.0.experts.3.w1.weight": torch.tensor([[12.0, 13.0], [14.0, 15.0]]),
"model.layers.0.experts.0.w3.weight": torch.tensor([[4.0, 5.0], [6.0, 7.0]]),
"model.layers.0.experts.1.w3.weight": torch.tensor([[14.0, 15.0], [16.0, 17.0]]),
"model.layers.0.experts.2.w3.weight": torch.tensor([[15.0, 16.0], [17.0, 18.0]]),
"model.layers.0.experts.3.w3.weight": torch.tensor([[16.0, 17.0], [18.0, 19.0]]),
"model.layers.0.experts.0.w2.weight": torch.tensor([[20.0, 21.0], [22.0, 23.0]]),
"model.layers.0.experts.1.w2.weight": torch.tensor([[24.0, 25.0], [26.0, 27.0]]),
"model.layers.0.experts.2.w2.weight": torch.tensor([[25.0, 26.0], [27.0, 28.0]]),
"model.layers.0.experts.3.w2.weight": torch.tensor([[26.0, 27.0], [28.0, 29.0]]),
"model.layers.1.experts.0.w1.weight": torch.tensor([[30.0, 31.0], [32.0, 33.0]]),
"model.layers.1.experts.1.w1.weight": torch.tensor([[34.0, 35.0], [36.0, 37.0]]),
"model.layers.1.experts.2.w1.weight": torch.tensor([[35.0, 36.0], [37.0, 38.0]]),
"model.layers.1.experts.3.w1.weight": torch.tensor([[36.0, 37.0], [38.0, 39.0]]),
"model.layers.1.experts.0.w3.weight": torch.tensor([[38.0, 39.0], [40.0, 41.0]]),
"model.layers.1.experts.1.w3.weight": torch.tensor([[42.0, 43.0], [44.0, 45.0]]),
"model.layers.1.experts.2.w3.weight": torch.tensor([[43.0, 44.0], [45.0, 46.0]]),
"model.layers.1.experts.3.w3.weight": torch.tensor([[44.0, 45.0], [46.0, 47.0]]),
"model.layers.1.experts.0.w2.weight": torch.tensor([[46.0, 47.0], [48.0, 49.0]]),
"model.layers.1.experts.1.w2.weight": torch.tensor([[50.0, 51.0], [52.0, 53.0]]),
"model.layers.1.experts.2.w2.weight": torch.tensor([[51.0, 52.0], [53.0, 54.0]]),
"model.layers.1.experts.3.w2.weight": torch.tensor([[52.0, 53.0], [54.0, 55.0]]),
"model.layers.0.self_attn.qkv_proj.weight": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
"model.layers.1.self_attn.qkv_proj.weight": torch.tensor([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]),
"mlp.w2.weight": torch.tensor([[60.0, 61.0], [62.0, 63.0]]),
}
state_dict = {k: v.clone() for k, v in raw_tensors.items()}
weight_mapping = [
WeightConverter(
["experts.*.w1.weight", "experts.*.w3.weight"],
["experts.gate_up_proj.weight", "extra_experts.gate_up_proj.weight"],
operations=[ErnieFuseAndSplitTextVisionExperts(stack_dim=0, concat_dim=1)],
),
WeightConverter(
"experts.*.w2.weight",
["experts.down_proj.weight", "extra_experts.down_proj.weight"],
operations=[ErnieFuseAndSplitTextVisionExperts(stack_dim=0, concat_dim=1)],
),
WeightConverter(
"self_attn.qkv_proj.weight",
[
"self_attn.q_proj.weight",
"self_attn.k_proj.weight",
"self_attn.v_proj.weight",
],
operations=[Chunk(dim=0)],
),
WeightRenaming("mlp.w2.weight", "mlp.down_proj.weight"),
]
# Use the mapping to load
missing, unexpected, mismatch, _, misc = convert_and_load_state_dict_in_model(
model, state_dict, weight_mapping, tp_plan=None, hf_quantizer=None
)
self.assertTrue(len(missing) == 0)
self.assertTrue(len(unexpected) == 0)
self.assertTrue(len(mismatch) == 0)
self.assertTrue(len(misc) == 0)
# Try to revert the mapping
reversed_state_dict = revert_weight_conversion(model, model.state_dict())
# Make sure both saved state_dict are identical
self.assertTrue(compare_state_dicts(reversed_state_dict, state_dict))
class TestConversionMapping(unittest.TestCase):
def test_register_checkpoint_conversion_mapping(self):
register_checkpoint_conversion_mapping(
"foobar",
[
WeightRenaming(".block_sparse_moe.gate", ".mlp.gate"),
],
)
self.assertEqual(len(get_checkpoint_conversion_mapping("foobar")), 1)
def test_register_checkpoint_conversion_mapping_overwrites(self):
register_checkpoint_conversion_mapping(
"foobarbaz",
[
WeightRenaming(".block_sparse_moe.gate", ".mlp.gate"),
],
)
with self.assertRaises(ValueError):
register_checkpoint_conversion_mapping(
"foobarbaz",
[
WeightRenaming(".block_sparse_moe.foo", ".mlp.foo"),
WeightRenaming(".block_sparse_moe.bar", ".mlp.bar"),
],
)
register_checkpoint_conversion_mapping(
"foobarbaz",
[
WeightRenaming(".block_sparse_moe.foo", ".mlp.foo"),
WeightRenaming(".block_sparse_moe.bar", ".mlp.bar"),
],
overwrite=True,
)
self.assertEqual(len(get_checkpoint_conversion_mapping("foobarbaz")), 2)
if __name__ == "__main__":
unittest.main()