|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
import os |
|
|
import shutil |
|
|
from threading import Thread |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import yaml |
|
|
from torch import nn |
|
|
from transformers import AutoModelForCausalLM |
|
|
|
|
|
from model import PanguUltraMoEForCausalLM |
|
|
|
|
|
root_logger = logging.getLogger() |
|
|
root_logger.handlers.clear() |
|
|
logging.basicConfig( |
|
|
format="%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s", |
|
|
level=logging.INFO, |
|
|
) |
|
|
|
|
|
|
|
|
def _to_parameter(data): |
|
|
return nn.Parameter(data, requires_grad=False) |
|
|
|
|
|
|
|
|
def split_w_dense(block, dst_model, i, local_rank): |
|
|
up_weight_list = [] |
|
|
ffn_dim = dst_model.model.layers[i].mlp.intermediate_size_per_rank |
|
|
gate_weight = block.mlp.gate_proj.weight[ |
|
|
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : |
|
|
].contiguous() |
|
|
up_weight = block.mlp.up_proj.weight[ |
|
|
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : |
|
|
].contiguous() |
|
|
up_weight_list.append(_to_parameter(torch.cat([gate_weight, up_weight], axis=0))) |
|
|
|
|
|
if len(up_weight_list) == 1: |
|
|
dst_model.model.layers[i].mlp.merge_up_gate_proj.weight = up_weight_list[0] |
|
|
else: |
|
|
dst_model.model.layers[i].mlp.merge_up_gate_proj.weight = _to_parameter( |
|
|
torch.cat(up_weight_list, axis=0) |
|
|
) |
|
|
dst_model.model.layers[i].mlp.down_proj.weight.data = ( |
|
|
block.mlp.down_proj.weight.data[ |
|
|
:, local_rank * ffn_dim : (local_rank + 1) * ffn_dim |
|
|
].contiguous() |
|
|
) |
|
|
|
|
|
|
|
|
def split_w_moe(block, dst_model, i, local_rank): |
|
|
shared_up_weight_list = [] |
|
|
ffn_dim = dst_model.model.layers[i].mlp.shared_experts.intermediate_size_per_rank |
|
|
gate_weight = block.mlp.shared_experts.gate_proj.weight[ |
|
|
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : |
|
|
].contiguous() |
|
|
up_weight = block.mlp.shared_experts.up_proj.weight[ |
|
|
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : |
|
|
].contiguous() |
|
|
shared_up_weight_list.append( |
|
|
_to_parameter(torch.cat([gate_weight, up_weight], axis=0)) |
|
|
) |
|
|
if len(shared_up_weight_list) == 1: |
|
|
dst_model.model.layers[i].mlp.shared_experts.merge_up_gate_proj.weight = ( |
|
|
shared_up_weight_list[0] |
|
|
) |
|
|
else: |
|
|
dst_model.model.layers[i].mlp.shared_experts.merge_up_gate_proj.weight = ( |
|
|
_to_parameter(torch.cat(shared_up_weight_list, axis=0)) |
|
|
) |
|
|
dst_model.model.layers[i].mlp.shared_experts.down_proj.weight.data = ( |
|
|
block.mlp.shared_experts.down_proj.weight.data[ |
|
|
:, local_rank * ffn_dim : (local_rank + 1) * ffn_dim |
|
|
].contiguous() |
|
|
) |
|
|
dst_model.model.layers[i].mlp.gate.weight.data = block.mlp.gate.weight.data |
|
|
|
|
|
expert_num = block.mlp.num_routed_experts |
|
|
gate_proj_list, down_proj_list, up_proj_list = [], [], [] |
|
|
for _, src_expert in enumerate(block.mlp.experts): |
|
|
ffn_dim = dst_model.model.layers[i].mlp.experts.intermediate_size_per_rank |
|
|
gate_proj_list.append( |
|
|
src_expert.gate_proj.weight.data[ |
|
|
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : |
|
|
].contiguous() |
|
|
) |
|
|
up_proj_list.append( |
|
|
src_expert.up_proj.weight.data[ |
|
|
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, : |
|
|
].contiguous() |
|
|
) |
|
|
down_proj_list.append( |
|
|
src_expert.down_proj.weight.data[ |
|
|
:, local_rank * ffn_dim : (local_rank + 1) * ffn_dim |
|
|
].contiguous() |
|
|
) |
|
|
|
|
|
dst_model.model.layers[i].mlp.experts.group_w2.data = ( |
|
|
torch.cat(down_proj_list, dim=0).view(expert_num, -1, ffn_dim).contiguous() |
|
|
) |
|
|
group_gate_proj = ( |
|
|
torch.cat(gate_proj_list, dim=0).view(expert_num, ffn_dim, -1).contiguous() |
|
|
) |
|
|
group_up_proj = ( |
|
|
torch.cat(up_proj_list, dim=0).view(expert_num, ffn_dim, -1).contiguous() |
|
|
) |
|
|
dst_model.model.layers[i].mlp.experts.group_w1_w3.data = torch.cat( |
|
|
[group_gate_proj, group_up_proj], dim=1 |
|
|
) |
|
|
|
|
|
|
|
|
def split_w_attn(block, dst_model, i, local_rank): |
|
|
q_dim = ( |
|
|
dst_model.model.layers[0].self_attn.num_heads_per_rank |
|
|
* dst_model.model.layers[0].self_attn.q_head_dim |
|
|
) |
|
|
o_dim = ( |
|
|
dst_model.model.layers[0].self_attn.num_heads_per_rank |
|
|
* dst_model.model.layers[0].self_attn.attention_v_dim |
|
|
) |
|
|
|
|
|
if dst_model.model.layers[i].self_attn.attention_q_lora_dim is None: |
|
|
dst_model.model.layers[i].self_attn.q_proj.weight.data = ( |
|
|
block.self_attn.q_proj.weight.data[ |
|
|
local_rank * q_dim : (local_rank + 1) * q_dim, : |
|
|
].contiguous() |
|
|
) |
|
|
else: |
|
|
dst_model.model.layers[i].self_attn.q_a_proj.weight.data = ( |
|
|
block.self_attn.q_a_proj.weight.data |
|
|
) |
|
|
dst_model.model.layers[i].self_attn.q_a_layernorm.weight.data = ( |
|
|
block.self_attn.q_a_layernorm.weight.data |
|
|
) |
|
|
dst_model.model.layers[i].self_attn.q_b_proj.weight.data = ( |
|
|
block.self_attn.q_b_proj.weight.data[ |
|
|
local_rank * q_dim : (local_rank + 1) * q_dim, : |
|
|
].contiguous() |
|
|
) |
|
|
|
|
|
dst_model.model.layers[i].self_attn.kv_a_proj_with_mqa.weight.data = ( |
|
|
block.self_attn.kv_a_proj_with_mqa.weight.data |
|
|
) |
|
|
|
|
|
dst_model.model.layers[i].self_attn.kv_a_layernorm.weight.data = ( |
|
|
block.self_attn.kv_a_layernorm.weight.data |
|
|
) |
|
|
dst_model.model.layers[i].self_attn.o_proj.weight.data = ( |
|
|
block.self_attn.o_proj.weight.data[ |
|
|
:, local_rank * o_dim : (local_rank + 1) * o_dim |
|
|
].contiguous() |
|
|
) |
|
|
dst_model.model.layers[i].input_layernorm.weight.data = ( |
|
|
block.input_layernorm.weight.data |
|
|
) |
|
|
dst_model.model.layers[i].post_attention_layernorm.weight.data = ( |
|
|
block.post_attention_layernorm.weight.data |
|
|
) |
|
|
dst_model.model.layers[i].pre_mlp_layernorm.weight.data = ( |
|
|
block.pre_mlp_layernorm.weight.data |
|
|
) |
|
|
dst_model.model.layers[i].post_mlp_layernorm.weight.data = ( |
|
|
block.post_mlp_layernorm.weight.data |
|
|
) |
|
|
|
|
|
|
|
|
def kv_low_rank_split(block, dst_model, i, local_rank): |
|
|
k_dim = dst_model.model.layers[0].self_attn.num_heads_per_rank * ( |
|
|
dst_model.model.layers[0].self_attn.attention_qk_dim |
|
|
+ dst_model.model.layers[0].self_attn.attention_v_dim |
|
|
) |
|
|
kv_b_proj_weight_data = block.self_attn.kv_b_proj.weight.data[ |
|
|
local_rank * k_dim : (local_rank + 1) * k_dim, : |
|
|
].contiguous() |
|
|
attention_qk_dim = dst_model.model.layers[i].self_attn.attention_qk_dim |
|
|
num_heads_per_rank = dst_model.model.layers[i].self_attn.num_heads_per_rank |
|
|
attention_kv_lora_dim = dst_model.model.layers[i].self_attn.attention_kv_lora_dim |
|
|
attention_v_dim = dst_model.model.layers[i].self_attn.attention_v_dim |
|
|
|
|
|
index_tensor = torch.arange(attention_qk_dim).repeat( |
|
|
num_heads_per_rank |
|
|
) + torch.arange(num_heads_per_rank).repeat_interleave(attention_qk_dim) * ( |
|
|
attention_qk_dim + attention_v_dim |
|
|
) |
|
|
kv_b_proj_w_k = torch.index_select(kv_b_proj_weight_data, dim=0, index=index_tensor) |
|
|
dst_model.model.layers[i].self_attn.kv_b_proj_w_k.data = kv_b_proj_w_k.view( |
|
|
num_heads_per_rank, attention_qk_dim, attention_kv_lora_dim |
|
|
).contiguous() |
|
|
index_tensor = torch.arange( |
|
|
attention_qk_dim, attention_qk_dim + attention_v_dim |
|
|
).repeat(num_heads_per_rank) + torch.arange(num_heads_per_rank).repeat_interleave( |
|
|
attention_v_dim |
|
|
) * ( |
|
|
attention_qk_dim + attention_v_dim |
|
|
) |
|
|
kv_b_proj_w_v = torch.index_select(kv_b_proj_weight_data, dim=0, index=index_tensor) |
|
|
dst_model.model.layers[i].self_attn.kv_b_proj_w_v.data = ( |
|
|
kv_b_proj_w_v.view(num_heads_per_rank, attention_v_dim, attention_kv_lora_dim) |
|
|
.transpose(1, 2) |
|
|
.contiguous() |
|
|
) |
|
|
|
|
|
|
|
|
def split_layer(block, dst_model, i, local_rank, attn_tp_size, moe_tp_size): |
|
|
|
|
|
local_rank_tp_attn = local_rank % attn_tp_size |
|
|
split_w_attn(block, dst_model, i, local_rank_tp_attn) |
|
|
kv_low_rank_split(block, dst_model, i, local_rank_tp_attn) |
|
|
|
|
|
|
|
|
local_rank_tp_moe = local_rank % moe_tp_size |
|
|
if i >= dst_model.config.num_dense_layers: |
|
|
split_w_moe(block, dst_model, i, local_rank_tp_moe) |
|
|
else: |
|
|
split_w_dense(block, dst_model, i, local_rank_tp_moe) |
|
|
|
|
|
|
|
|
def split_w(src_model, dst_model, local_rank, runner_config): |
|
|
attn_tp_size = runner_config.get("parallel_config").get("attn_tp_size") |
|
|
moe_tp_size = runner_config.get("parallel_config").get("moe_tp_size") |
|
|
embed_tp_size = runner_config.get("parallel_config").get("embed_tp_size") |
|
|
|
|
|
vocab_size = src_model.model.vocab_size // embed_tp_size |
|
|
embed_tp_rank = local_rank % embed_tp_size |
|
|
|
|
|
dst_model.lm_head.weight.data = src_model.lm_head.weight.data[ |
|
|
embed_tp_rank * vocab_size : (embed_tp_rank + 1) * vocab_size, : |
|
|
] |
|
|
dst_model.model.embed_tokens.weight.data = src_model.model.embed_tokens.weight.data[ |
|
|
embed_tp_rank * vocab_size : (embed_tp_rank + 1) * vocab_size, : |
|
|
] |
|
|
|
|
|
dst_model.model.norm.weight.data = src_model.model.norm.weight.data |
|
|
|
|
|
layer_num = len(src_model.model.layers) |
|
|
|
|
|
all_threads = [] |
|
|
for i in range(0, layer_num): |
|
|
block = src_model.model.layers[i] |
|
|
thread = Thread( |
|
|
target=split_layer, |
|
|
args=(block, dst_model, i, local_rank, attn_tp_size, moe_tp_size), |
|
|
) |
|
|
all_threads.append(thread) |
|
|
thread.start() |
|
|
for thread in all_threads: |
|
|
thread.join() |
|
|
|
|
|
|
|
|
def copy_files_with_prefix(src_dir, dst_dir, prefix): |
|
|
for file in os.listdir(src_dir): |
|
|
if file.startswith(prefix): |
|
|
src_file = os.path.join(src_dir, file) |
|
|
dst_file = os.path.join(dst_dir, file) |
|
|
shutil.copy2(src_file, dst_file) |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Split weight parameters with tensor parallel" |
|
|
) |
|
|
parser.add_argument("--model_path", type=str, help="Path of model weights") |
|
|
parser.add_argument( |
|
|
"--output_path", |
|
|
type=str, |
|
|
help="The output directory where the results are saved", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--origin_yaml_file_path", type=str, help="inference configurations" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--new_yaml_file_path", type=str, help="inference configurations" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--world_size", type=int, default=8, help="The parallel rank size of model" |
|
|
) |
|
|
parser.add_argument("--node_num", type=int, default=1, help="The parallel node num") |
|
|
parser.add_argument( |
|
|
"--node_rank", type=int, default=0, help="The parallel node rank" |
|
|
) |
|
|
parser_args = parser.parse_args() |
|
|
return parser_args |
|
|
|
|
|
|
|
|
def show_model_states(origin_model, model_name="src_model"): |
|
|
src_param_size = 0 |
|
|
for name, params in origin_model.named_parameters(): |
|
|
size_per_param = np.prod(params.size()) |
|
|
src_param_size += size_per_param |
|
|
logging.info( |
|
|
"Param of %s tensor parallel: %s, %s, %s", |
|
|
model_name, |
|
|
name, |
|
|
params.size(), |
|
|
params.dtype, |
|
|
) |
|
|
logging.info( |
|
|
"Total param size of %s tensor parallel: %s", model_name, src_param_size |
|
|
) |
|
|
|
|
|
|
|
|
def read_yaml(yaml_file_path): |
|
|
try: |
|
|
with open(yaml_file_path, "r", encoding="utf-8") as file: |
|
|
data = yaml.safe_load(file) |
|
|
except FileNotFoundError: |
|
|
logging.error(f"No such yaml file: {yaml_file_path}") |
|
|
except yaml.YAMLERROR as e: |
|
|
logging.error(f"Load yaml file failed: {e}") |
|
|
return data |
|
|
|
|
|
|
|
|
def check_vars(world_size, runner_config): |
|
|
attn_tp_size = runner_config.get("parallel_config").get("attn_tp_size") |
|
|
moe_tp_size = runner_config.get("parallel_config").get("moe_tp_size") |
|
|
embed_tp_size = runner_config.get("parallel_config").get("embed_tp_size") |
|
|
if world_size % attn_tp_size != 0: |
|
|
logging.error( |
|
|
"world_size %s mod attn_tp_size %s must be 0", world_size, attn_tp_size |
|
|
) |
|
|
exit(1) |
|
|
if world_size % moe_tp_size != 0: |
|
|
logging.error( |
|
|
"world_size %s mod moe_tp_size %s must be 0", world_size, moe_tp_size |
|
|
) |
|
|
exit(1) |
|
|
if world_size % embed_tp_size != 0: |
|
|
logging.error( |
|
|
"world_size %s mod embed_tp_size %s must be 0", world_size, embed_tp_size |
|
|
) |
|
|
exit(1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
logging.info("Start to split weight...") |
|
|
args = parse_args() |
|
|
output_path = args.output_path |
|
|
|
|
|
old_runner_config = read_yaml(args.origin_yaml_file_path) |
|
|
new_runner_config = read_yaml(args.new_yaml_file_path) |
|
|
world_size = args.world_size |
|
|
|
|
|
if not os.path.exists(output_path): |
|
|
os.makedirs(output_path) |
|
|
origin_model = AutoModelForCausalLM.from_pretrained( |
|
|
args.model_path, |
|
|
trust_remote_code=True, |
|
|
local_files_only=True, |
|
|
ignore_mismatched_sizes=True, |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=torch.bfloat16, |
|
|
attn_implementation="eager", |
|
|
) |
|
|
show_model_states(origin_model, "origin_model") |
|
|
|
|
|
node_rank_id = args.node_rank |
|
|
rank_num_per_node = world_size // args.node_num |
|
|
start_rank = rank_num_per_node * node_rank_id |
|
|
end_rank = rank_num_per_node * (node_rank_id + 1) |
|
|
|
|
|
for rank_id in range(start_rank, end_rank): |
|
|
logging.info("rank_id={} / rank_size={}".format(rank_id, world_size)) |
|
|
os.environ["LOCAL_RANK"] = str(rank_id) |
|
|
|
|
|
save_path = os.path.join(output_path, f"rank_{rank_id}") |
|
|
logging.info( |
|
|
"Split weight for rank %s start, save path is: %s", rank_id, save_path |
|
|
) |
|
|
|
|
|
config = origin_model.config |
|
|
part_model = PanguUltraMoEForCausalLM(config, new_runner_config) |
|
|
|
|
|
split_w(origin_model, part_model, rank_id, new_runner_config) |
|
|
|
|
|
show_model_states(part_model, "dst_model") |
|
|
|
|
|
part_model.save_pretrained(save_path) |
|
|
copy_files_with_prefix(args.model_path, save_path, "tokenizer") |
|
|
copy_files_with_prefix(args.model_path, save_path, "tokenization") |
|
|
logging.info( |
|
|
"Split weight for rank %s finished, save path is: %s", rank_id, save_path |
|
|
) |
|
|
|
|
|
del part_model |
|
|
|