wangrongsheng's picture
Add files using upload-large-folder tool
a86f2f6 verified
# coding=utf-8
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
import argparse
import logging
import os
import shutil
from threading import Thread
import numpy as np
import torch
import yaml
from torch import nn
from transformers import AutoModelForCausalLM
from model import PanguUltraMoEForCausalLM
root_logger = logging.getLogger()
root_logger.handlers.clear()
logging.basicConfig(
format="%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s",
level=logging.INFO,
)
def _to_parameter(data):
return nn.Parameter(data, requires_grad=False)
def split_w_dense(block, dst_model, i, local_rank):
up_weight_list = []
ffn_dim = dst_model.model.layers[i].mlp.intermediate_size_per_rank
gate_weight = block.mlp.gate_proj.weight[
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
].contiguous()
up_weight = block.mlp.up_proj.weight[
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
].contiguous()
up_weight_list.append(_to_parameter(torch.cat([gate_weight, up_weight], axis=0)))
if len(up_weight_list) == 1:
dst_model.model.layers[i].mlp.merge_up_gate_proj.weight = up_weight_list[0]
else:
dst_model.model.layers[i].mlp.merge_up_gate_proj.weight = _to_parameter(
torch.cat(up_weight_list, axis=0)
)
dst_model.model.layers[i].mlp.down_proj.weight.data = (
block.mlp.down_proj.weight.data[
:, local_rank * ffn_dim : (local_rank + 1) * ffn_dim
].contiguous()
)
def split_w_moe(block, dst_model, i, local_rank):
shared_up_weight_list = []
ffn_dim = dst_model.model.layers[i].mlp.shared_experts.intermediate_size_per_rank
gate_weight = block.mlp.shared_experts.gate_proj.weight[
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
].contiguous()
up_weight = block.mlp.shared_experts.up_proj.weight[
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
].contiguous()
shared_up_weight_list.append(
_to_parameter(torch.cat([gate_weight, up_weight], axis=0))
)
if len(shared_up_weight_list) == 1:
dst_model.model.layers[i].mlp.shared_experts.merge_up_gate_proj.weight = (
shared_up_weight_list[0]
)
else:
dst_model.model.layers[i].mlp.shared_experts.merge_up_gate_proj.weight = (
_to_parameter(torch.cat(shared_up_weight_list, axis=0))
)
dst_model.model.layers[i].mlp.shared_experts.down_proj.weight.data = (
block.mlp.shared_experts.down_proj.weight.data[
:, local_rank * ffn_dim : (local_rank + 1) * ffn_dim
].contiguous()
)
dst_model.model.layers[i].mlp.gate.weight.data = block.mlp.gate.weight.data
expert_num = block.mlp.num_routed_experts
gate_proj_list, down_proj_list, up_proj_list = [], [], []
for _, src_expert in enumerate(block.mlp.experts):
ffn_dim = dst_model.model.layers[i].mlp.experts.intermediate_size_per_rank
gate_proj_list.append(
src_expert.gate_proj.weight.data[
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
].contiguous()
)
up_proj_list.append(
src_expert.up_proj.weight.data[
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
].contiguous()
)
down_proj_list.append(
src_expert.down_proj.weight.data[
:, local_rank * ffn_dim : (local_rank + 1) * ffn_dim
].contiguous()
)
dst_model.model.layers[i].mlp.experts.group_w2.data = (
torch.cat(down_proj_list, dim=0).view(expert_num, -1, ffn_dim).contiguous()
)
group_gate_proj = (
torch.cat(gate_proj_list, dim=0).view(expert_num, ffn_dim, -1).contiguous()
)
group_up_proj = (
torch.cat(up_proj_list, dim=0).view(expert_num, ffn_dim, -1).contiguous()
)
dst_model.model.layers[i].mlp.experts.group_w1_w3.data = torch.cat(
[group_gate_proj, group_up_proj], dim=1
)
def split_w_attn(block, dst_model, i, local_rank):
q_dim = (
dst_model.model.layers[0].self_attn.num_heads_per_rank
* dst_model.model.layers[0].self_attn.q_head_dim
)
o_dim = (
dst_model.model.layers[0].self_attn.num_heads_per_rank
* dst_model.model.layers[0].self_attn.attention_v_dim
)
if dst_model.model.layers[i].self_attn.attention_q_lora_dim is None:
dst_model.model.layers[i].self_attn.q_proj.weight.data = (
block.self_attn.q_proj.weight.data[
local_rank * q_dim : (local_rank + 1) * q_dim, :
].contiguous()
)
else:
dst_model.model.layers[i].self_attn.q_a_proj.weight.data = (
block.self_attn.q_a_proj.weight.data
)
dst_model.model.layers[i].self_attn.q_a_layernorm.weight.data = (
block.self_attn.q_a_layernorm.weight.data
)
dst_model.model.layers[i].self_attn.q_b_proj.weight.data = (
block.self_attn.q_b_proj.weight.data[
local_rank * q_dim : (local_rank + 1) * q_dim, :
].contiguous()
)
dst_model.model.layers[i].self_attn.kv_a_proj_with_mqa.weight.data = (
block.self_attn.kv_a_proj_with_mqa.weight.data
)
dst_model.model.layers[i].self_attn.kv_a_layernorm.weight.data = (
block.self_attn.kv_a_layernorm.weight.data
)
dst_model.model.layers[i].self_attn.o_proj.weight.data = (
block.self_attn.o_proj.weight.data[
:, local_rank * o_dim : (local_rank + 1) * o_dim
].contiguous()
)
dst_model.model.layers[i].input_layernorm.weight.data = (
block.input_layernorm.weight.data
)
dst_model.model.layers[i].post_attention_layernorm.weight.data = (
block.post_attention_layernorm.weight.data
)
dst_model.model.layers[i].pre_mlp_layernorm.weight.data = (
block.pre_mlp_layernorm.weight.data
)
dst_model.model.layers[i].post_mlp_layernorm.weight.data = (
block.post_mlp_layernorm.weight.data
)
def kv_low_rank_split(block, dst_model, i, local_rank):
k_dim = dst_model.model.layers[0].self_attn.num_heads_per_rank * (
dst_model.model.layers[0].self_attn.attention_qk_dim
+ dst_model.model.layers[0].self_attn.attention_v_dim
)
kv_b_proj_weight_data = block.self_attn.kv_b_proj.weight.data[
local_rank * k_dim : (local_rank + 1) * k_dim, :
].contiguous()
attention_qk_dim = dst_model.model.layers[i].self_attn.attention_qk_dim
num_heads_per_rank = dst_model.model.layers[i].self_attn.num_heads_per_rank
attention_kv_lora_dim = dst_model.model.layers[i].self_attn.attention_kv_lora_dim
attention_v_dim = dst_model.model.layers[i].self_attn.attention_v_dim
index_tensor = torch.arange(attention_qk_dim).repeat(
num_heads_per_rank
) + torch.arange(num_heads_per_rank).repeat_interleave(attention_qk_dim) * (
attention_qk_dim + attention_v_dim
)
kv_b_proj_w_k = torch.index_select(kv_b_proj_weight_data, dim=0, index=index_tensor)
dst_model.model.layers[i].self_attn.kv_b_proj_w_k.data = kv_b_proj_w_k.view(
num_heads_per_rank, attention_qk_dim, attention_kv_lora_dim
).contiguous()
index_tensor = torch.arange(
attention_qk_dim, attention_qk_dim + attention_v_dim
).repeat(num_heads_per_rank) + torch.arange(num_heads_per_rank).repeat_interleave(
attention_v_dim
) * (
attention_qk_dim + attention_v_dim
)
kv_b_proj_w_v = torch.index_select(kv_b_proj_weight_data, dim=0, index=index_tensor)
dst_model.model.layers[i].self_attn.kv_b_proj_w_v.data = (
kv_b_proj_w_v.view(num_heads_per_rank, attention_v_dim, attention_kv_lora_dim)
.transpose(1, 2)
.contiguous()
)
def split_layer(block, dst_model, i, local_rank, attn_tp_size, moe_tp_size):
# attn weights
local_rank_tp_attn = local_rank % attn_tp_size
split_w_attn(block, dst_model, i, local_rank_tp_attn)
kv_low_rank_split(block, dst_model, i, local_rank_tp_attn)
# moe experts weights
local_rank_tp_moe = local_rank % moe_tp_size
if i >= dst_model.config.num_dense_layers:
split_w_moe(block, dst_model, i, local_rank_tp_moe)
else:
split_w_dense(block, dst_model, i, local_rank_tp_moe)
def split_w(src_model, dst_model, local_rank, runner_config):
attn_tp_size = runner_config.get("parallel_config").get("attn_tp_size")
moe_tp_size = runner_config.get("parallel_config").get("moe_tp_size")
embed_tp_size = runner_config.get("parallel_config").get("embed_tp_size")
vocab_size = src_model.model.vocab_size // embed_tp_size
embed_tp_rank = local_rank % embed_tp_size
dst_model.lm_head.weight.data = src_model.lm_head.weight.data[
embed_tp_rank * vocab_size : (embed_tp_rank + 1) * vocab_size, :
]
dst_model.model.embed_tokens.weight.data = src_model.model.embed_tokens.weight.data[
embed_tp_rank * vocab_size : (embed_tp_rank + 1) * vocab_size, :
]
dst_model.model.norm.weight.data = src_model.model.norm.weight.data
layer_num = len(src_model.model.layers)
all_threads = []
for i in range(0, layer_num):
block = src_model.model.layers[i]
thread = Thread(
target=split_layer,
args=(block, dst_model, i, local_rank, attn_tp_size, moe_tp_size),
)
all_threads.append(thread)
thread.start()
for thread in all_threads:
thread.join()
def copy_files_with_prefix(src_dir, dst_dir, prefix):
for file in os.listdir(src_dir):
if file.startswith(prefix):
src_file = os.path.join(src_dir, file)
dst_file = os.path.join(dst_dir, file)
shutil.copy2(src_file, dst_file)
def parse_args():
parser = argparse.ArgumentParser(
description="Split weight parameters with tensor parallel"
)
parser.add_argument("--model_path", type=str, help="Path of model weights")
parser.add_argument(
"--output_path",
type=str,
help="The output directory where the results are saved",
)
parser.add_argument(
"--origin_yaml_file_path", type=str, help="inference configurations"
)
parser.add_argument(
"--new_yaml_file_path", type=str, help="inference configurations"
)
parser.add_argument(
"--world_size", type=int, default=8, help="The parallel rank size of model"
)
parser.add_argument("--node_num", type=int, default=1, help="The parallel node num")
parser.add_argument(
"--node_rank", type=int, default=0, help="The parallel node rank"
)
parser_args = parser.parse_args()
return parser_args
def show_model_states(origin_model, model_name="src_model"):
src_param_size = 0
for name, params in origin_model.named_parameters():
size_per_param = np.prod(params.size())
src_param_size += size_per_param
logging.info(
"Param of %s tensor parallel: %s, %s, %s",
model_name,
name,
params.size(),
params.dtype,
)
logging.info(
"Total param size of %s tensor parallel: %s", model_name, src_param_size
)
def read_yaml(yaml_file_path):
try:
with open(yaml_file_path, "r", encoding="utf-8") as file:
data = yaml.safe_load(file)
except FileNotFoundError:
logging.error(f"No such yaml file: {yaml_file_path}")
except yaml.YAMLERROR as e:
logging.error(f"Load yaml file failed: {e}")
return data
def check_vars(world_size, runner_config):
attn_tp_size = runner_config.get("parallel_config").get("attn_tp_size")
moe_tp_size = runner_config.get("parallel_config").get("moe_tp_size")
embed_tp_size = runner_config.get("parallel_config").get("embed_tp_size")
if world_size % attn_tp_size != 0:
logging.error(
"world_size %s mod attn_tp_size %s must be 0", world_size, attn_tp_size
)
exit(1)
if world_size % moe_tp_size != 0:
logging.error(
"world_size %s mod moe_tp_size %s must be 0", world_size, moe_tp_size
)
exit(1)
if world_size % embed_tp_size != 0:
logging.error(
"world_size %s mod embed_tp_size %s must be 0", world_size, embed_tp_size
)
exit(1)
if __name__ == "__main__":
logging.info("Start to split weight...")
args = parse_args()
output_path = args.output_path
old_runner_config = read_yaml(args.origin_yaml_file_path)
new_runner_config = read_yaml(args.new_yaml_file_path)
world_size = args.world_size
if not os.path.exists(output_path):
os.makedirs(output_path)
origin_model = AutoModelForCausalLM.from_pretrained(
args.model_path,
trust_remote_code=True,
local_files_only=True,
ignore_mismatched_sizes=True,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
attn_implementation="eager",
)
show_model_states(origin_model, "origin_model")
node_rank_id = args.node_rank
rank_num_per_node = world_size // args.node_num
start_rank = rank_num_per_node * node_rank_id
end_rank = rank_num_per_node * (node_rank_id + 1)
for rank_id in range(start_rank, end_rank):
logging.info("rank_id={} / rank_size={}".format(rank_id, world_size))
os.environ["LOCAL_RANK"] = str(rank_id)
save_path = os.path.join(output_path, f"rank_{rank_id}")
logging.info(
"Split weight for rank %s start, save path is: %s", rank_id, save_path
)
config = origin_model.config
part_model = PanguUltraMoEForCausalLM(config, new_runner_config)
split_w(origin_model, part_model, rank_id, new_runner_config)
show_model_states(part_model, "dst_model")
part_model.save_pretrained(save_path)
copy_files_with_prefix(args.model_path, save_path, "tokenizer")
copy_files_with_prefix(args.model_path, save_path, "tokenization")
logging.info(
"Split weight for rank %s finished, save path is: %s", rank_id, save_path
)
del part_model