| |
| |
|
|
| import argparse |
| import logging |
| import os |
| import shutil |
| from threading import Thread |
|
|
| import numpy as np |
| import torch |
| import yaml |
| from torch import nn |
| from transformers import AutoModelForCausalLM |
|
|
| from model import PanguUltraMoEForCausalLM |
|
|
| root_logger = logging.getLogger() |
| root_logger.handlers.clear() |
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s", |
| level=logging.INFO, |
| ) |
|
|
|
|
| def _to_parameter(data): |
| return nn.Parameter(data, requires_grad=False) |
|
|
|
|
| def split_w_dense(block, dst_model, i, local_rank): |
| up_weight_list = [] |
| ffn_dim = dst_model.model.layers[i].mlp.intermediate_size_per_rank |
| gate_weight = block.mlp.gate_proj.weight[ |
| local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : |
| ].contiguous() |
| up_weight = block.mlp.up_proj.weight[ |
| local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : |
| ].contiguous() |
| up_weight_list.append(_to_parameter(torch.cat([gate_weight, up_weight], axis=0))) |
|
|
| if len(up_weight_list) == 1: |
| dst_model.model.layers[i].mlp.merge_up_gate_proj.weight = up_weight_list[0] |
| else: |
| dst_model.model.layers[i].mlp.merge_up_gate_proj.weight = _to_parameter( |
| torch.cat(up_weight_list, axis=0) |
| ) |
| dst_model.model.layers[i].mlp.down_proj.weight.data = ( |
| block.mlp.down_proj.weight.data[ |
| :, local_rank * ffn_dim : (local_rank + 1) * ffn_dim |
| ].contiguous() |
| ) |
|
|
|
|
| def split_w_moe(block, dst_model, i, local_rank): |
| shared_up_weight_list = [] |
| ffn_dim = dst_model.model.layers[i].mlp.shared_experts.intermediate_size_per_rank |
| gate_weight = block.mlp.shared_experts.gate_proj.weight[ |
| local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : |
| ].contiguous() |
| up_weight = block.mlp.shared_experts.up_proj.weight[ |
| local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : |
| ].contiguous() |
| shared_up_weight_list.append( |
| _to_parameter(torch.cat([gate_weight, up_weight], axis=0)) |
| ) |
| if len(shared_up_weight_list) == 1: |
| dst_model.model.layers[i].mlp.shared_experts.merge_up_gate_proj.weight = ( |
| shared_up_weight_list[0] |
| ) |
| else: |
| dst_model.model.layers[i].mlp.shared_experts.merge_up_gate_proj.weight = ( |
| _to_parameter(torch.cat(shared_up_weight_list, axis=0)) |
| ) |
| dst_model.model.layers[i].mlp.shared_experts.down_proj.weight.data = ( |
| block.mlp.shared_experts.down_proj.weight.data[ |
| :, local_rank * ffn_dim : (local_rank + 1) * ffn_dim |
| ].contiguous() |
| ) |
| dst_model.model.layers[i].mlp.gate.weight.data = block.mlp.gate.weight.data |
|
|
| expert_num = block.mlp.num_routed_experts |
| gate_proj_list, down_proj_list, up_proj_list = [], [], [] |
| for _, src_expert in enumerate(block.mlp.experts): |
| ffn_dim = dst_model.model.layers[i].mlp.experts.intermediate_size_per_rank |
| gate_proj_list.append( |
| src_expert.gate_proj.weight.data[ |
| local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : |
| ].contiguous() |
| ) |
| up_proj_list.append( |
| src_expert.up_proj.weight.data[ |
| local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : |
| ].contiguous() |
| ) |
| down_proj_list.append( |
| src_expert.down_proj.weight.data[ |
| :, local_rank * ffn_dim : (local_rank + 1) * ffn_dim |
| ].contiguous() |
| ) |
|
|
| dst_model.model.layers[i].mlp.experts.group_w2.data = ( |
| torch.cat(down_proj_list, dim=0).view(expert_num, -1, ffn_dim).contiguous() |
| ) |
| group_gate_proj = ( |
| torch.cat(gate_proj_list, dim=0).view(expert_num, ffn_dim, -1).contiguous() |
| ) |
| group_up_proj = ( |
| torch.cat(up_proj_list, dim=0).view(expert_num, ffn_dim, -1).contiguous() |
| ) |
| dst_model.model.layers[i].mlp.experts.group_w1_w3.data = torch.cat( |
| [group_gate_proj, group_up_proj], dim=1 |
| ) |
|
|
|
|
| def split_w_attn(block, dst_model, i, local_rank): |
| q_dim = ( |
| dst_model.model.layers[0].self_attn.num_heads_per_rank |
| * dst_model.model.layers[0].self_attn.q_head_dim |
| ) |
| o_dim = ( |
| dst_model.model.layers[0].self_attn.num_heads_per_rank |
| * dst_model.model.layers[0].self_attn.attention_v_dim |
| ) |
|
|
| if dst_model.model.layers[i].self_attn.attention_q_lora_dim is None: |
| dst_model.model.layers[i].self_attn.q_proj.weight.data = ( |
| block.self_attn.q_proj.weight.data[ |
| local_rank * q_dim : (local_rank + 1) * q_dim, : |
| ].contiguous() |
| ) |
| else: |
| dst_model.model.layers[i].self_attn.q_a_proj.weight.data = ( |
| block.self_attn.q_a_proj.weight.data |
| ) |
| dst_model.model.layers[i].self_attn.q_a_layernorm.weight.data = ( |
| block.self_attn.q_a_layernorm.weight.data |
| ) |
| dst_model.model.layers[i].self_attn.q_b_proj.weight.data = ( |
| block.self_attn.q_b_proj.weight.data[ |
| local_rank * q_dim : (local_rank + 1) * q_dim, : |
| ].contiguous() |
| ) |
|
|
| dst_model.model.layers[i].self_attn.kv_a_proj_with_mqa.weight.data = ( |
| block.self_attn.kv_a_proj_with_mqa.weight.data |
| ) |
|
|
| dst_model.model.layers[i].self_attn.kv_a_layernorm.weight.data = ( |
| block.self_attn.kv_a_layernorm.weight.data |
| ) |
| dst_model.model.layers[i].self_attn.o_proj.weight.data = ( |
| block.self_attn.o_proj.weight.data[ |
| :, local_rank * o_dim : (local_rank + 1) * o_dim |
| ].contiguous() |
| ) |
| dst_model.model.layers[i].input_layernorm.weight.data = ( |
| block.input_layernorm.weight.data |
| ) |
| dst_model.model.layers[i].post_attention_layernorm.weight.data = ( |
| block.post_attention_layernorm.weight.data |
| ) |
| dst_model.model.layers[i].pre_mlp_layernorm.weight.data = ( |
| block.pre_mlp_layernorm.weight.data |
| ) |
| dst_model.model.layers[i].post_mlp_layernorm.weight.data = ( |
| block.post_mlp_layernorm.weight.data |
| ) |
|
|
|
|
| def kv_low_rank_split(block, dst_model, i, local_rank): |
| k_dim = dst_model.model.layers[0].self_attn.num_heads_per_rank * ( |
| dst_model.model.layers[0].self_attn.attention_qk_dim |
| + dst_model.model.layers[0].self_attn.attention_v_dim |
| ) |
| kv_b_proj_weight_data = block.self_attn.kv_b_proj.weight.data[ |
| local_rank * k_dim : (local_rank + 1) * k_dim, : |
| ].contiguous() |
| attention_qk_dim = dst_model.model.layers[i].self_attn.attention_qk_dim |
| num_heads_per_rank = dst_model.model.layers[i].self_attn.num_heads_per_rank |
| attention_kv_lora_dim = dst_model.model.layers[i].self_attn.attention_kv_lora_dim |
| attention_v_dim = dst_model.model.layers[i].self_attn.attention_v_dim |
|
|
| index_tensor = torch.arange(attention_qk_dim).repeat( |
| num_heads_per_rank |
| ) + torch.arange(num_heads_per_rank).repeat_interleave(attention_qk_dim) * ( |
| attention_qk_dim + attention_v_dim |
| ) |
| kv_b_proj_w_k = torch.index_select(kv_b_proj_weight_data, dim=0, index=index_tensor) |
| dst_model.model.layers[i].self_attn.kv_b_proj_w_k.data = kv_b_proj_w_k.view( |
| num_heads_per_rank, attention_qk_dim, attention_kv_lora_dim |
| ).contiguous() |
| index_tensor = torch.arange( |
| attention_qk_dim, attention_qk_dim + attention_v_dim |
| ).repeat(num_heads_per_rank) + torch.arange(num_heads_per_rank).repeat_interleave( |
| attention_v_dim |
| ) * ( |
| attention_qk_dim + attention_v_dim |
| ) |
| kv_b_proj_w_v = torch.index_select(kv_b_proj_weight_data, dim=0, index=index_tensor) |
| dst_model.model.layers[i].self_attn.kv_b_proj_w_v.data = ( |
| kv_b_proj_w_v.view(num_heads_per_rank, attention_v_dim, attention_kv_lora_dim) |
| .transpose(1, 2) |
| .contiguous() |
| ) |
|
|
|
|
| def split_layer(block, dst_model, i, local_rank, attn_tp_size, moe_tp_size): |
| |
| local_rank_tp_attn = local_rank % attn_tp_size |
| split_w_attn(block, dst_model, i, local_rank_tp_attn) |
| kv_low_rank_split(block, dst_model, i, local_rank_tp_attn) |
|
|
| |
| local_rank_tp_moe = local_rank % moe_tp_size |
| if i >= dst_model.config.num_dense_layers: |
| split_w_moe(block, dst_model, i, local_rank_tp_moe) |
| else: |
| split_w_dense(block, dst_model, i, local_rank_tp_moe) |
|
|
|
|
| def split_w(src_model, dst_model, local_rank, runner_config): |
| attn_tp_size = runner_config.get("parallel_config").get("attn_tp_size") |
| moe_tp_size = runner_config.get("parallel_config").get("moe_tp_size") |
| embed_tp_size = runner_config.get("parallel_config").get("embed_tp_size") |
|
|
| vocab_size = src_model.model.vocab_size // embed_tp_size |
| embed_tp_rank = local_rank % embed_tp_size |
|
|
| dst_model.lm_head.weight.data = src_model.lm_head.weight.data[ |
| embed_tp_rank * vocab_size : (embed_tp_rank + 1) * vocab_size, : |
| ] |
| dst_model.model.embed_tokens.weight.data = src_model.model.embed_tokens.weight.data[ |
| embed_tp_rank * vocab_size : (embed_tp_rank + 1) * vocab_size, : |
| ] |
|
|
| dst_model.model.norm.weight.data = src_model.model.norm.weight.data |
|
|
| layer_num = len(src_model.model.layers) |
|
|
| all_threads = [] |
| for i in range(0, layer_num): |
| block = src_model.model.layers[i] |
| thread = Thread( |
| target=split_layer, |
| args=(block, dst_model, i, local_rank, attn_tp_size, moe_tp_size), |
| ) |
| all_threads.append(thread) |
| thread.start() |
| for thread in all_threads: |
| thread.join() |
|
|
|
|
| def copy_files_with_prefix(src_dir, dst_dir, prefix): |
| for file in os.listdir(src_dir): |
| if file.startswith(prefix): |
| src_file = os.path.join(src_dir, file) |
| dst_file = os.path.join(dst_dir, file) |
| shutil.copy2(src_file, dst_file) |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser( |
| description="Split weight parameters with tensor parallel" |
| ) |
| parser.add_argument("--model_path", type=str, help="Path of model weights") |
| parser.add_argument( |
| "--output_path", |
| type=str, |
| help="The output directory where the results are saved", |
| ) |
| parser.add_argument( |
| "--origin_yaml_file_path", type=str, help="inference configurations" |
| ) |
| parser.add_argument( |
| "--new_yaml_file_path", type=str, help="inference configurations" |
| ) |
| parser.add_argument( |
| "--world_size", type=int, default=8, help="The parallel rank size of model" |
| ) |
| parser.add_argument("--node_num", type=int, default=1, help="The parallel node num") |
| parser.add_argument( |
| "--node_rank", type=int, default=0, help="The parallel node rank" |
| ) |
| parser_args = parser.parse_args() |
| return parser_args |
|
|
|
|
| def show_model_states(origin_model, model_name="src_model"): |
| src_param_size = 0 |
| for name, params in origin_model.named_parameters(): |
| size_per_param = np.prod(params.size()) |
| src_param_size += size_per_param |
| logging.info( |
| "Param of %s tensor parallel: %s, %s, %s", |
| model_name, |
| name, |
| params.size(), |
| params.dtype, |
| ) |
| logging.info( |
| "Total param size of %s tensor parallel: %s", model_name, src_param_size |
| ) |
|
|
|
|
| def read_yaml(yaml_file_path): |
| try: |
| with open(yaml_file_path, "r", encoding="utf-8") as file: |
| data = yaml.safe_load(file) |
| except FileNotFoundError: |
| logging.error(f"No such yaml file: {yaml_file_path}") |
| except yaml.YAMLERROR as e: |
| logging.error(f"Load yaml file failed: {e}") |
| return data |
|
|
|
|
| def check_vars(world_size, runner_config): |
| attn_tp_size = runner_config.get("parallel_config").get("attn_tp_size") |
| moe_tp_size = runner_config.get("parallel_config").get("moe_tp_size") |
| embed_tp_size = runner_config.get("parallel_config").get("embed_tp_size") |
| if world_size % attn_tp_size != 0: |
| logging.error( |
| "world_size %s mod attn_tp_size %s must be 0", world_size, attn_tp_size |
| ) |
| exit(1) |
| if world_size % moe_tp_size != 0: |
| logging.error( |
| "world_size %s mod moe_tp_size %s must be 0", world_size, moe_tp_size |
| ) |
| exit(1) |
| if world_size % embed_tp_size != 0: |
| logging.error( |
| "world_size %s mod embed_tp_size %s must be 0", world_size, embed_tp_size |
| ) |
| exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| logging.info("Start to split weight...") |
| args = parse_args() |
| output_path = args.output_path |
|
|
| old_runner_config = read_yaml(args.origin_yaml_file_path) |
| new_runner_config = read_yaml(args.new_yaml_file_path) |
| world_size = args.world_size |
|
|
| if not os.path.exists(output_path): |
| os.makedirs(output_path) |
| origin_model = AutoModelForCausalLM.from_pretrained( |
| args.model_path, |
| trust_remote_code=True, |
| local_files_only=True, |
| ignore_mismatched_sizes=True, |
| low_cpu_mem_usage=True, |
| torch_dtype=torch.bfloat16, |
| attn_implementation="eager", |
| ) |
| show_model_states(origin_model, "origin_model") |
|
|
| node_rank_id = args.node_rank |
| rank_num_per_node = world_size // args.node_num |
| start_rank = rank_num_per_node * node_rank_id |
| end_rank = rank_num_per_node * (node_rank_id + 1) |
|
|
| for rank_id in range(start_rank, end_rank): |
| logging.info("rank_id={} / rank_size={}".format(rank_id, world_size)) |
| os.environ["LOCAL_RANK"] = str(rank_id) |
|
|
| save_path = os.path.join(output_path, f"rank_{rank_id}") |
| logging.info( |
| "Split weight for rank %s start, save path is: %s", rank_id, save_path |
| ) |
|
|
| config = origin_model.config |
| part_model = PanguUltraMoEForCausalLM(config, new_runner_config) |
|
|
| split_w(origin_model, part_model, rank_id, new_runner_config) |
|
|
| show_model_states(part_model, "dst_model") |
|
|
| part_model.save_pretrained(save_path) |
| copy_files_with_prefix(args.model_path, save_path, "tokenizer") |
| copy_files_with_prefix(args.model_path, save_path, "tokenization") |
| logging.info( |
| "Split weight for rank %s finished, save path is: %s", rank_id, save_path |
| ) |
|
|
| del part_model |
|
|