# coding=utf-8 # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. import argparse import logging import os import shutil from threading import Thread import numpy as np import torch import yaml from torch import nn from transformers import AutoModelForCausalLM from model import PanguUltraMoEForCausalLM root_logger = logging.getLogger() root_logger.handlers.clear() logging.basicConfig( format="%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s", level=logging.INFO, ) def _to_parameter(data): return nn.Parameter(data, requires_grad=False) def split_w_dense(block, dst_model, i, local_rank): up_weight_list = [] ffn_dim = dst_model.model.layers[i].mlp.intermediate_size_per_rank gate_weight = block.mlp.gate_proj.weight[ local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : ].contiguous() up_weight = block.mlp.up_proj.weight[ local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : ].contiguous() up_weight_list.append(_to_parameter(torch.cat([gate_weight, up_weight], axis=0))) if len(up_weight_list) == 1: dst_model.model.layers[i].mlp.merge_up_gate_proj.weight = up_weight_list[0] else: dst_model.model.layers[i].mlp.merge_up_gate_proj.weight = _to_parameter( torch.cat(up_weight_list, axis=0) ) dst_model.model.layers[i].mlp.down_proj.weight.data = ( block.mlp.down_proj.weight.data[ :, local_rank * ffn_dim : (local_rank + 1) * ffn_dim ].contiguous() ) def split_w_moe(block, dst_model, i, local_rank): shared_up_weight_list = [] ffn_dim = dst_model.model.layers[i].mlp.shared_experts.intermediate_size_per_rank gate_weight = block.mlp.shared_experts.gate_proj.weight[ local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : ].contiguous() up_weight = block.mlp.shared_experts.up_proj.weight[ local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : ].contiguous() shared_up_weight_list.append( _to_parameter(torch.cat([gate_weight, up_weight], axis=0)) ) if len(shared_up_weight_list) == 1: dst_model.model.layers[i].mlp.shared_experts.merge_up_gate_proj.weight = ( shared_up_weight_list[0] ) else: dst_model.model.layers[i].mlp.shared_experts.merge_up_gate_proj.weight = ( _to_parameter(torch.cat(shared_up_weight_list, axis=0)) ) dst_model.model.layers[i].mlp.shared_experts.down_proj.weight.data = ( block.mlp.shared_experts.down_proj.weight.data[ :, local_rank * ffn_dim : (local_rank + 1) * ffn_dim ].contiguous() ) dst_model.model.layers[i].mlp.gate.weight.data = block.mlp.gate.weight.data expert_num = block.mlp.num_routed_experts gate_proj_list, down_proj_list, up_proj_list = [], [], [] for _, src_expert in enumerate(block.mlp.experts): ffn_dim = dst_model.model.layers[i].mlp.experts.intermediate_size_per_rank gate_proj_list.append( src_expert.gate_proj.weight.data[ local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : ].contiguous() ) up_proj_list.append( src_expert.up_proj.weight.data[ local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : ].contiguous() ) down_proj_list.append( src_expert.down_proj.weight.data[ :, local_rank * ffn_dim : (local_rank + 1) * ffn_dim ].contiguous() ) dst_model.model.layers[i].mlp.experts.group_w2.data = ( torch.cat(down_proj_list, dim=0).view(expert_num, -1, ffn_dim).contiguous() ) group_gate_proj = ( torch.cat(gate_proj_list, dim=0).view(expert_num, ffn_dim, -1).contiguous() ) group_up_proj = ( torch.cat(up_proj_list, dim=0).view(expert_num, ffn_dim, -1).contiguous() ) dst_model.model.layers[i].mlp.experts.group_w1_w3.data = torch.cat( [group_gate_proj, group_up_proj], dim=1 ) def split_w_attn(block, dst_model, i, local_rank): q_dim = ( dst_model.model.layers[0].self_attn.num_heads_per_rank * dst_model.model.layers[0].self_attn.q_head_dim ) o_dim = ( dst_model.model.layers[0].self_attn.num_heads_per_rank * dst_model.model.layers[0].self_attn.attention_v_dim ) if dst_model.model.layers[i].self_attn.attention_q_lora_dim is None: dst_model.model.layers[i].self_attn.q_proj.weight.data = ( block.self_attn.q_proj.weight.data[ local_rank * q_dim : (local_rank + 1) * q_dim, : ].contiguous() ) else: dst_model.model.layers[i].self_attn.q_a_proj.weight.data = ( block.self_attn.q_a_proj.weight.data ) dst_model.model.layers[i].self_attn.q_a_layernorm.weight.data = ( block.self_attn.q_a_layernorm.weight.data ) dst_model.model.layers[i].self_attn.q_b_proj.weight.data = ( block.self_attn.q_b_proj.weight.data[ local_rank * q_dim : (local_rank + 1) * q_dim, : ].contiguous() ) dst_model.model.layers[i].self_attn.kv_a_proj_with_mqa.weight.data = ( block.self_attn.kv_a_proj_with_mqa.weight.data ) dst_model.model.layers[i].self_attn.kv_a_layernorm.weight.data = ( block.self_attn.kv_a_layernorm.weight.data ) dst_model.model.layers[i].self_attn.o_proj.weight.data = ( block.self_attn.o_proj.weight.data[ :, local_rank * o_dim : (local_rank + 1) * o_dim ].contiguous() ) dst_model.model.layers[i].input_layernorm.weight.data = ( block.input_layernorm.weight.data ) dst_model.model.layers[i].post_attention_layernorm.weight.data = ( block.post_attention_layernorm.weight.data ) dst_model.model.layers[i].pre_mlp_layernorm.weight.data = ( block.pre_mlp_layernorm.weight.data ) dst_model.model.layers[i].post_mlp_layernorm.weight.data = ( block.post_mlp_layernorm.weight.data ) def kv_low_rank_split(block, dst_model, i, local_rank): k_dim = dst_model.model.layers[0].self_attn.num_heads_per_rank * ( dst_model.model.layers[0].self_attn.attention_qk_dim + dst_model.model.layers[0].self_attn.attention_v_dim ) kv_b_proj_weight_data = block.self_attn.kv_b_proj.weight.data[ local_rank * k_dim : (local_rank + 1) * k_dim, : ].contiguous() attention_qk_dim = dst_model.model.layers[i].self_attn.attention_qk_dim num_heads_per_rank = dst_model.model.layers[i].self_attn.num_heads_per_rank attention_kv_lora_dim = dst_model.model.layers[i].self_attn.attention_kv_lora_dim attention_v_dim = dst_model.model.layers[i].self_attn.attention_v_dim index_tensor = torch.arange(attention_qk_dim).repeat( num_heads_per_rank ) + torch.arange(num_heads_per_rank).repeat_interleave(attention_qk_dim) * ( attention_qk_dim + attention_v_dim ) kv_b_proj_w_k = torch.index_select(kv_b_proj_weight_data, dim=0, index=index_tensor) dst_model.model.layers[i].self_attn.kv_b_proj_w_k.data = kv_b_proj_w_k.view( num_heads_per_rank, attention_qk_dim, attention_kv_lora_dim ).contiguous() index_tensor = torch.arange( attention_qk_dim, attention_qk_dim + attention_v_dim ).repeat(num_heads_per_rank) + torch.arange(num_heads_per_rank).repeat_interleave( attention_v_dim ) * ( attention_qk_dim + attention_v_dim ) kv_b_proj_w_v = torch.index_select(kv_b_proj_weight_data, dim=0, index=index_tensor) dst_model.model.layers[i].self_attn.kv_b_proj_w_v.data = ( kv_b_proj_w_v.view(num_heads_per_rank, attention_v_dim, attention_kv_lora_dim) .transpose(1, 2) .contiguous() ) def split_layer(block, dst_model, i, local_rank, attn_tp_size, moe_tp_size): # attn weights local_rank_tp_attn = local_rank % attn_tp_size split_w_attn(block, dst_model, i, local_rank_tp_attn) kv_low_rank_split(block, dst_model, i, local_rank_tp_attn) # moe experts weights local_rank_tp_moe = local_rank % moe_tp_size if i >= dst_model.config.num_dense_layers: split_w_moe(block, dst_model, i, local_rank_tp_moe) else: split_w_dense(block, dst_model, i, local_rank_tp_moe) def split_w(src_model, dst_model, local_rank, runner_config): attn_tp_size = runner_config.get("parallel_config").get("attn_tp_size") moe_tp_size = runner_config.get("parallel_config").get("moe_tp_size") embed_tp_size = runner_config.get("parallel_config").get("embed_tp_size") vocab_size = src_model.model.vocab_size // embed_tp_size embed_tp_rank = local_rank % embed_tp_size dst_model.lm_head.weight.data = src_model.lm_head.weight.data[ embed_tp_rank * vocab_size : (embed_tp_rank + 1) * vocab_size, : ] dst_model.model.embed_tokens.weight.data = src_model.model.embed_tokens.weight.data[ embed_tp_rank * vocab_size : (embed_tp_rank + 1) * vocab_size, : ] dst_model.model.norm.weight.data = src_model.model.norm.weight.data layer_num = len(src_model.model.layers) all_threads = [] for i in range(0, layer_num): block = src_model.model.layers[i] thread = Thread( target=split_layer, args=(block, dst_model, i, local_rank, attn_tp_size, moe_tp_size), ) all_threads.append(thread) thread.start() for thread in all_threads: thread.join() def copy_files_with_prefix(src_dir, dst_dir, prefix): for file in os.listdir(src_dir): if file.startswith(prefix): src_file = os.path.join(src_dir, file) dst_file = os.path.join(dst_dir, file) shutil.copy2(src_file, dst_file) def parse_args(): parser = argparse.ArgumentParser( description="Split weight parameters with tensor parallel" ) parser.add_argument("--model_path", type=str, help="Path of model weights") parser.add_argument( "--output_path", type=str, help="The output directory where the results are saved", ) parser.add_argument( "--origin_yaml_file_path", type=str, help="inference configurations" ) parser.add_argument( "--new_yaml_file_path", type=str, help="inference configurations" ) parser.add_argument( "--world_size", type=int, default=8, help="The parallel rank size of model" ) parser.add_argument("--node_num", type=int, default=1, help="The parallel node num") parser.add_argument( "--node_rank", type=int, default=0, help="The parallel node rank" ) parser_args = parser.parse_args() return parser_args def show_model_states(origin_model, model_name="src_model"): src_param_size = 0 for name, params in origin_model.named_parameters(): size_per_param = np.prod(params.size()) src_param_size += size_per_param logging.info( "Param of %s tensor parallel: %s, %s, %s", model_name, name, params.size(), params.dtype, ) logging.info( "Total param size of %s tensor parallel: %s", model_name, src_param_size ) def read_yaml(yaml_file_path): try: with open(yaml_file_path, "r", encoding="utf-8") as file: data = yaml.safe_load(file) except FileNotFoundError: logging.error(f"No such yaml file: {yaml_file_path}") except yaml.YAMLERROR as e: logging.error(f"Load yaml file failed: {e}") return data def check_vars(world_size, runner_config): attn_tp_size = runner_config.get("parallel_config").get("attn_tp_size") moe_tp_size = runner_config.get("parallel_config").get("moe_tp_size") embed_tp_size = runner_config.get("parallel_config").get("embed_tp_size") if world_size % attn_tp_size != 0: logging.error( "world_size %s mod attn_tp_size %s must be 0", world_size, attn_tp_size ) exit(1) if world_size % moe_tp_size != 0: logging.error( "world_size %s mod moe_tp_size %s must be 0", world_size, moe_tp_size ) exit(1) if world_size % embed_tp_size != 0: logging.error( "world_size %s mod embed_tp_size %s must be 0", world_size, embed_tp_size ) exit(1) if __name__ == "__main__": logging.info("Start to split weight...") args = parse_args() output_path = args.output_path old_runner_config = read_yaml(args.origin_yaml_file_path) new_runner_config = read_yaml(args.new_yaml_file_path) world_size = args.world_size if not os.path.exists(output_path): os.makedirs(output_path) origin_model = AutoModelForCausalLM.from_pretrained( args.model_path, trust_remote_code=True, local_files_only=True, ignore_mismatched_sizes=True, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="eager", ) show_model_states(origin_model, "origin_model") node_rank_id = args.node_rank rank_num_per_node = world_size // args.node_num start_rank = rank_num_per_node * node_rank_id end_rank = rank_num_per_node * (node_rank_id + 1) for rank_id in range(start_rank, end_rank): logging.info("rank_id={} / rank_size={}".format(rank_id, world_size)) os.environ["LOCAL_RANK"] = str(rank_id) save_path = os.path.join(output_path, f"rank_{rank_id}") logging.info( "Split weight for rank %s start, save path is: %s", rank_id, save_path ) config = origin_model.config part_model = PanguUltraMoEForCausalLM(config, new_runner_config) split_w(origin_model, part_model, rank_id, new_runner_config) show_model_states(part_model, "dst_model") part_model.save_pretrained(save_path) copy_files_with_prefix(args.model_path, save_path, "tokenizer") copy_files_with_prefix(args.model_path, save_path, "tokenization") logging.info( "Split weight for rank %s finished, save path is: %s", rank_id, save_path ) del part_model