ZK-DeepSeek / inference /convert2.py
asd52403's picture
init commit
a96927f
import os
import shutil
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm, trange
import torch
import ctypes
from safetensors.torch import safe_open, save_file
from kernel import weight_dequant
mapping = {
"embed_tokens": ("embed", 0),
"input_layernorm": ("attn_norm", None),
"post_attention_layernorm": ("ffn_norm", None),
"q_proj": ("wq", 0),
"q_a_proj": ("wq_a", None),
"q_a_layernorm": ("q_norm", None),
"q_b_proj": ("wq_b", 0),
"kv_a_proj_with_mqa": ("wkv_a", None),
"kv_a_layernorm": ("kv_norm", None),
"kv_b_proj": ("wkv_b", 0),
"o_proj": ("wo", 1),
"gate": ("gate", None),
"gate_proj": ("w1", 0),
"down_proj": ("w2", 1),
"up_proj": ("w3", 0),
"norm": ("norm", None),
"lm_head": ("head", 0),
"scale": ("scale", None),
}
EmbedsInOneFile = 256
EmbedsZKDir = "../zkdata/embeds/"
wkv_b_1_rescales = [32, 34, 37, 36, 33, 32, 33, 33, 30, 32,
32, 30, 31, 30, 29, 30, 29, 30, 29, 29,
29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
29, 29, 29, 29, 29, 29, 29, 29, 30, 30,
29, 29, 30, 30, 30, 30, 29, 30, 30, 29, 30]
wkv_b_2_rescales = [31, 32, 32, 31, 32, 30, 30, 30, 30, 30,
30, 30, 30, 29, 29, 29, 29, 30, 29, 29,
29, 29, 29, 29, 30, 30, 30, 29, 29, 29,
29, 29, 30, 29, 30, 29, 30, 29, 29, 29,
30, 29, 29, 29, 29, 30, 29, 30, 30, 30,
29, 29, 29, 30, 30, 29, 29, 29, 30, 30, 30]
wo_rescales = [31, 32, 32, 32, 32, 31, 32, 31, 31, 31,
31, 31, 31, 31, 30, 31, 31, 32, 31, 31,
31, 30, 30, 30, 30, 30, 30, 30, 30, 30,
30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
30, 30, 30, 31, 30, 31, 30, 30, 31, 31,
31, 30, 31, 31, 31, 30, 31, 31, 31, 31, 32 ]
gate_rescales = [0, 0, 0, 33, 32, 32, 32, 31, 32, 31, 30,
32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
32, 31, 32, 31, 32, 32, 32, 32, 31, 32,
32, 31, 32, 32, 32, 32, 32, 32, 32, 32,
32, 32, 32, 32, 32, 32, 32, 32, 32, 32,
32, 32, 32, 33, 33, 33, 33, 33, 32, 32 ]
w1_rescales = [32, 32, 32]
w2_rescales = [31, 32, 31]
w3_rescales = [32, 33, 32]
shared_w1_rescales = [0, 0, 0, 30, 30, 29, 29, 29, 28, 29,
29, 28, 29, 29, 29, 29, 29, 29, 29, 29,
29, 29, 29, 30, 30, 30, 30, 30, 30, 30,
30, 30, 30, 30, 29, 29, 30, 29, 29, 30,
29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29]
shared_w2_rescales = [0, 0, 0, 30, 30, 30, 30, 30, 29, 29,
30, 29, 29, 29, 30, 30, 30, 30, 30, 29,
29, 29, 29, 29, 29, 29, 29, 30, 30, 29,
29, 29, 29, 29, 29, 29, 29, 30, 29, 29,
29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
29, 29, 29, 29, 29, 29, 30, 30, 30, 30, 30]
shared_w3_rescales = [0, 0, 0, 30, 30, 30, 30, 30, 29, 29,
30, 29, 29, 29, 30, 30, 30, 29, 30, 29,
29, 29, 29, 29, 29, 29, 30, 30, 30, 30,
29, 29, 29, 29, 29, 29, 29, 30, 30, 29,
30, 29, 29, 29, 29, 30, 29, 29, 30, 30,
29, 30, 30, 30, 29, 29, 30, 30, 30, 29, 28]
layer_state_dict0 = [{} for _ in range(61)]
layer_state_dict = [{} for _ in range(61)]
experts = [ [{} for _j in range(256)] for _i in range(61)]
def getF32PrintStr(ele):
v = int(ele.cpu().view(torch.uint32).item())
ex = str((v >> 23 & 0xFF) - 127)
r = '(1+' + str(v & 0x7FFFFF) + '/8388608)'
if v & 0x80000000:
vstr = '-' + r + '*2^' + ex
else:
vstr = r + '*2^' + ex
return vstr
def getBF16PrintStr(ele):
v = int(ele.cpu().view(torch.uint16).item())
ex = v >> 7 & 0xFF
r = '(1+' + str(v & 0x7F) + '/128)'
rraw = v & 0x7F
if v & 0x8000:
vstr = '-' + r + '*2^' + str(ex - 127)
else:
vstr = r + '*2^' + str(ex - 127)
return vstr
def getBF8PrintStr(ele):
v = int(ele.cpu().view(torch.uint8).item())
ex = str((v >> 3 & 0xF) - 7)
r = '(1+' + str(v & 0x7) + '/8)'
if v & 0x80:
vstr = '-' + r + '*2^' + ex
else:
vstr = r + '*2^' + ex
if ex == -7 or ex == 8:
print(vstr)
return vstr
def mem(i):
a = torch.cuda.memory_allocated()/1024**2
r = torch.cuda.memory_reserved()/1024**2
m = torch.cuda.max_memory_allocated()/1024**2
print(f"{i} allocated={a:.1f}MB, reserved={r:.1f}MB, max={m:.1f}MB", flush=True)
def handle_expert_w(layer_id, expert_id, idx, param_weight, weight_name, scale, typ, shape, experts_save_path):
global layer_state_dict0
global experts
scale_name = weight_name.replace('weight', 'scale')
param_scale = layer_state_dict0[layer_id][scale_name]
weight = weight_dequant(param_weight.cuda(), param_scale.cuda())
# scale = experts_w3_rescales[layer_id][expert_id]
rescale = 2 ** scale
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
# layer_state_dict[layer_id][weight_name] = param_int.cpu()
# layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32)
weight_name2 = f'w{idx}.weight'
scale_name2 = f'w{idx}.scale'
experts[layer_id][expert_id][weight_name2] = param_int
experts[layer_id][expert_id][scale_name2] = torch.tensor(scale, dtype=torch.int32)
if len(experts[layer_id][expert_id]) == 6: # w1, w2, w3 以及对应的 scale
save_file(experts[layer_id][expert_id], os.path.join(experts_save_path, f"{expert_id}.safetensors"))
experts[layer_id][expert_id] = {}
print(f'layer {layer_id} expert {expert_id} w{idx} type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {scale_name}')
def saveTensor(fileName, t):
with open(fileName, "w", encoding="utf-8") as f:
t = t.detach()
if t.device.type != "cpu":
t = t.cpu()
t = t.contiguous()
with open(fileName, "wb") as f:
f.write(t.numpy().tobytes(order="C"))
def main(hf_ckpt_path, save_path, n_experts, mp):
"""
Converts and saves model checkpoint files into a specified format.
Args:
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
save_path (str): Path to the directory where the converted checkpoint files will be saved.
n_experts (int): Total number of experts in the model.
mp (int): Model parallelism factor.
Returns:
None
"""
torch.cuda.set_device(0)
# 设置pytorch计算时的默认数据类型。这里使用的是BF16
torch.set_default_dtype(torch.bfloat16)
# 限制 PyTorch 在 CPU 计算时最多使用 8 个线程,防止过多线程竞争资源:
torch.set_num_threads(8)
# 设定随机种子,保证不同进程初始化时随机数相同。
torch.manual_seed(965)
# n_local_experts = n_experts // mp
# state_dicts = [{} for _ in range(mp)]
head_state_dict = {}
norm_state_dict = {}
embed_state_dict = {}
experts_w1_rescales = []
experts_w2_rescales = []
experts_w3_rescales = []
with open("w1.txt", "r", encoding="utf-8") as f1:
for line in f1:
layer_line = line.strip().split()
int_list = [int(s) for s in layer_line]
experts_w1_rescales.append(int_list)
with open("w2.txt", "r", encoding="utf-8") as f2:
for line in f2:
layer_line = line.strip().split()
int_list = [int(s) for s in layer_line]
experts_w2_rescales.append(int_list)
with open("w3.txt", "r", encoding="utf-8") as f3:
for line in f3:
layer_line = line.strip().split()
int_list = [int(s) for s in layer_line]
experts_w3_rescales.append(int_list)
# Tqdm 是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。
# glob是python自己带的一个文件操作相关模块,用它可以查找符合自己目的的文件,类似于Windows下的文件搜索
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
with safe_open(file_path, framework="pt", device="cpu") as f:
print('Opening ' + file_path, flush=True)
for name in f.keys():
# print('name 1: ', name)
if "model.layers.61" in name:
continue
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = name.split(".")[-2]
assert key in mapping, f"Key {key} not found in mapping"
# print('key::: ' + key)
new_key, dim = mapping[key]
# print('dim::: ' + str(dim))
name = name.replace(key, new_key)
ns = name.split(".")
comp = ns[0]
if comp == 'head':
name2 = name[len('head.'):]
print('head: ' + name2)
param_int = (param.to(torch.float32) * (2 ** 43)).round().to(torch.int64)
head_state_dict[name2] = param_int
elif comp == 'norm':
name2 = name[len('norm.'):]
print('norm: ' + name2)
param_int = (param.to(torch.float32) * (2 ** 15)).round().to(torch.int64)
norm_state_dict[name2] = param_int
elif comp == 'embed':
name2 = name[len('embed.'):]
print('embed: ' + name2)
param_int = (param.to(torch.float32) * (2 ** 31)).round().to(torch.int64)
embed_state_dict[name2] = param_int
os.makedirs(EmbedsZKDir, exist_ok=True)
fileCount = param_int.shape[0] // EmbedsInOneFile
for i in range(0, fileCount):
saveTensor(EmbedsZKDir + str(i) + '.bin', param_int[i * EmbedsInOneFile : (i+1) * EmbedsInOneFile].cpu())
elif comp == 'layers':
layer_id = int(ns[1])
name2 = '.'.join(ns[2:])
layer_state_dict0[layer_id][name2] = param
print('Finish loading state dict from disk! ++++++++++')
# for layer_id, states in enumerate(layer_state_dict0):
for layer_id in range(len(layer_state_dict0)):
os.makedirs(f'{save_path}/experts-{layer_id}', exist_ok=True)
states = layer_state_dict0[layer_id]
for name, param in states.items():
ns = name.split(".")
typ = param.type()
shape = param.shape
if ns[0] == 'attn_norm':
print(f'layer {layer_id} {name}, type: {typ}', flush=True)
if ns[1] == 'weight':
param_int = (param.to(torch.float32) * (2 ** 21)).round().to(torch.int32)
layer_state_dict[layer_id][name] = param_int
elif ns[0] == 'ffn_norm':
print(f'layer {layer_id} {name}, type: {typ}', flush=True)
if ns[1] == 'weight':
param_int2 = (param.to(torch.float32) * (2 ** 23)).round().to(torch.int32)
layer_state_dict[layer_id][name] = param_int2
elif ns[0] == 'ffn':
if len(ns) == 3:
if ns[1] == 'w1' and ns[2] == 'scale':
continue
elif ns[1] == 'w1' and ns[2] == 'weight':
param_weight = param.cuda()
weight_name = name
scale_name = name.replace('weight', 'scale')
param_scale = states[scale_name]
weight = weight_dequant(param_weight, param_scale.cuda())
scale = w1_rescales[layer_id]
rescale = 2 ** scale
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
layer_state_dict[layer_id][weight_name] = param_int.cpu()
layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32)
print(f'layer {layer_id} w1 weight, type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {name}', flush=True)
elif ns[1] == 'w2' and ns[2] == 'scale':
continue
elif ns[1] == 'w2' and ns[2] == 'weight':
param_weight = param.cuda()
weight_name = name
scale_name = name.replace('weight', 'scale')
param_scale = states[scale_name]
weight = weight_dequant(param_weight, param_scale.cuda())
scale = w2_rescales[layer_id]
rescale = 2 ** scale
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
layer_state_dict[layer_id][weight_name] = param_int.cpu()
layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32)
print(f'layer {layer_id} w2 weight, type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {name}', flush=True)
elif ns[1] == 'w3' and ns[2] == 'scale':
continue
elif ns[1] == 'w3' and ns[2] == 'weight':
param_weight = param.cuda()
weight_name = name
scale_name = name.replace('weight', 'scale')
param_scale = states[scale_name]
weight = weight_dequant(param_weight, param_scale.cuda())
scale = w3_rescales[layer_id]
rescale = 2 ** scale
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
layer_state_dict[layer_id][weight_name] = param_int.cpu()
layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32)
print(f'layer {layer_id} w3 weight, type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {name}', flush=True)
elif ns[1] == 'gate' and ns[2] == 'weight':
gate_rescale = 2 ** gate_rescales[layer_id]
gate_int = (param.to(torch.float32) * gate_rescale).round().to(torch.int32)
layer_state_dict[layer_id][name] = gate_int.cpu()
rescale_name = name.replace('weight', 'scale')
layer_state_dict[layer_id][rescale_name] = torch.tensor(gate_rescales[layer_id], dtype=torch.int32)
print(f'layer {layer_id}: gate_weight_name: {name}, gate_scale_name: {rescale_name}')
elif ns[1] == 'gate' and ns[2] == 'bias':
bias_int = (param.to(torch.float32) * (2 ** 23)).round().to(torch.int32)
layer_state_dict[layer_id][name] = bias_int.cpu()
print(f'layer {layer_id} bias: {name}')
else:
layer_state_dict[layer_id][name] = param
elif len(ns) == 4:
if ns[1] == 'shared_experts':
if (ns[2] == 'w1' or ns[2] == 'w2' or ns[2] == 'w3') and ns[3] == 'scale':
continue
elif ns[2] == 'w1' and ns[3] == 'weight':
param_weight = param.cuda()
weight_name = name
scale_name = name.replace('weight', 'scale')
param_scale = states[scale_name]
weight = weight_dequant(param_weight, param_scale.cuda())
scale = shared_w1_rescales[layer_id]
rescale = 2 ** scale
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
layer_state_dict[layer_id][weight_name] = param_int.cpu()
layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32)
print(f'layer {layer_id} shared_expert w1 type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {scale_name}')
elif ns[2] == 'w2' and ns[3] == 'weight':
param_weight = param.cuda()
weight_name = name
scale_name = name.replace('weight', 'scale')
param_scale = states[scale_name]
weight = weight_dequant(param_weight, param_scale.cuda())
scale = shared_w2_rescales[layer_id]
rescale = 2 ** scale
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
layer_state_dict[layer_id][weight_name] = param_int.cpu()
layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32)
print(f'layer {layer_id} shared_expert w2 type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {scale_name}')
elif ns[2] == 'w3' and ns[3] == 'weight':
param_weight = param.cuda()
weight_name = name
scale_name = name.replace('weight', 'scale')
param_scale = states[scale_name]
weight = weight_dequant(param_weight, param_scale.cuda())
scale = shared_w3_rescales[layer_id]
rescale = 2 ** scale
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
layer_state_dict[layer_id][weight_name] = param_int.cpu()
layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32)
print(f'layer {layer_id} shared_expert w3 type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {scale_name}')
else:
layer_state_dict[layer_id][name] = param
else:
layer_state_dict[layer_id][name] = param
elif len(ns) == 5:
if ns[1] == 'experts':
expert_id = int(ns[2])
if (ns[3] == 'w1' or ns[3] == 'w2' or ns[3] == 'w3') and ns[4] == 'scale':
continue
elif ns[3] == 'w1' and ns[4] == 'weight':
scale = experts_w1_rescales[layer_id][expert_id]
handle_expert_w(layer_id, expert_id, 1, param, name, scale, typ, shape, f'{save_path}/experts-{layer_id}')
elif ns[3] == 'w2' and ns[4] == 'weight':
scale = experts_w2_rescales[layer_id][expert_id]
handle_expert_w(layer_id, expert_id, 2, param, name, scale, typ, shape, f'{save_path}/experts-{layer_id}')
elif ns[3] == 'w3' and ns[4] == 'weight':
scale = experts_w3_rescales[layer_id][expert_id]
handle_expert_w(layer_id, expert_id, 3, param, name, scale, typ, shape, f'{save_path}/experts-{layer_id}')
else:
layer_state_dict[layer_id][name] = param
else:
layer_state_dict[layer_id][name] = param
elif ns[0] == 'attn':
if len(ns) == 3:
if ns[1] == 'wq_a' and ns[2] == 'scale':
continue
elif ns[1] == 'wq_a' and ns[2] == 'weight':
param_weight = param.cuda()
weight_name = name
scale_name = name.replace('weight', 'scale')
param_scale = states[scale_name]
weight = weight_dequant(param_weight, param_scale.cuda())
weight_int = (weight.to(torch.float32) * (2 ** 30)).round().to(torch.int32)
layer_state_dict[layer_id][weight_name] = weight_int.cpu()
print(f'layer {layer_id} wq_a weight, type: {typ}, shape: {shape}', flush=True)
elif ns[1] == 'q_norm':
print(f'layer {layer_id} q_norm, type: {typ}, shape: {shape}', flush=True)
param_int3 = (param.to(torch.float32) * (2 ** 19)).round().to(torch.int32)
layer_state_dict[layer_id][name] = param_int3
elif ns[1] == 'kv_norm':
print(f'layer {layer_id} kv_norm, type: {typ}, shape: {shape}', flush=True)
param_int4 = (param.to(torch.float32) * (2 ** 23)).round().to(torch.int32)
layer_state_dict[layer_id][name] = param_int4
elif ns[1] == 'wq_b' and ns[2] == 'scale':
continue
elif ns[1] == 'wq_b' and ns[2] == 'weight':
param_weight = param.cuda()
weight_name = name
scale_name = name.replace('weight', 'scale')
param_scale = states[scale_name]
weight = weight_dequant(param_weight, param_scale.cuda())
weight_int = (weight.to(torch.float32) * (2 ** 30)).round().to(torch.int32)
weight_int = weight_int.view(128, 192, 1536)
wq_b1, wq_b2 = torch.split(weight_int, [128, 64], dim=-2)
print(f'layer {layer_id} wq_b1 weight, shape: {wq_b1.shape}, wq_b2 weight, shape: {wq_b2.shape}', flush=True)
wq_b1 = wq_b1.reshape(128 * 128, 1536)
wq_b2 = wq_b2.reshape(128 * 64, 1536)
wq_b1_name = weight_name.replace('wq_b', 'wq_b1')
wq_b2_name = weight_name.replace('wq_b', 'wq_b2')
# layer_state_dict[layer_id][weight_name] = weight_int.cpu()
layer_state_dict[layer_id][wq_b1_name] = wq_b1.cpu()
layer_state_dict[layer_id][wq_b2_name] = wq_b2.cpu()
print(f'layer {layer_id} wq_b weight, type: {typ}, shape: {shape}', flush=True)
elif ns[1] == 'wkv_a' and ns[2] == 'scale':
continue
elif ns[1] == 'wkv_a' and ns[2] == 'weight':
param_weight = param.cuda()
weight_name = name
scale_name = name.replace('weight', 'scale')
param_scale = states[scale_name]
weight = weight_dequant(param_weight, param_scale.cuda())
weight_int = (weight.to(torch.float32) * (2 ** 29)).round().to(torch.int32)
# layer_state_dict[layer_id][weight_name] = weight_int.cpu()
weight_int = weight_int.view(576, 7168)
wkv_a1, wkv_a2 = torch.split(weight_int, [512, 64], dim=-2)
print(f'layer {layer_id} wkv_a1 weight, shape: {wkv_a1.shape}, wkv_a2 weight, shape: {wkv_a2.shape}', flush=True)
wkv_a1_name = weight_name.replace('wkv_a', 'wkv_a1')
wkv_a2_name = weight_name.replace('wkv_a', 'wkv_a2')
# layer_state_dict[layer_id][weight_name] = weight_int.cpu()
layer_state_dict[layer_id][wkv_a1_name] = wkv_a1.cpu()
layer_state_dict[layer_id][wkv_a2_name] = wkv_a2.cpu()
print(f'layer {layer_id} wkv_a weight, type: {typ}, shape: {shape}', flush=True)
elif ns[1] == 'wkv_b' and ns[2] == 'scale':
continue
elif ns[1] == 'wkv_b' and ns[2] == 'weight':
param_weight = param.cuda()
weight_name = name
scale_name = name.replace('weight', 'scale')
param_scale = states[scale_name]
weight = weight_dequant(param_weight, param_scale.cuda())
wkv_b = weight.view(128, 256, 512)
wkv_b_1 = wkv_b[:, :128]
wkv_b_1 = wkv_b_1.reshape(128 * 128, 512)
scale1 = wkv_b_1_rescales[layer_id]
wkv_b_1_rescale = 2 ** scale1
wkv_b_1_int = torch.round(wkv_b_1.to(torch.float32) * wkv_b_1_rescale).to(torch.int32)
wkv_b_2 = wkv_b[:, -128:]
wkv_b_2 = wkv_b_2.reshape(128 * 128, 512)
scale2 = wkv_b_2_rescales[layer_id]
wkv_b_2_rescale = 2 ** scale2
wkv_b_2_int = torch.round(wkv_b_2.to(torch.float32) * wkv_b_2_rescale).to(torch.int32)
wkv_b_1_name = weight_name.replace("wkv_b", "wkv_b_1")
wkv_b_1_scale_name = scale_name.replace("wkv_b", "wkv_b_1")
layer_state_dict[layer_id][wkv_b_1_name] = wkv_b_1_int.cpu()
layer_state_dict[layer_id][wkv_b_1_scale_name] = torch.tensor(scale1, dtype=torch.int32)
wkv_b_2_name = weight_name.replace("wkv_b", "wkv_b_2")
wkv_b_2_scale_name = scale_name.replace("wkv_b", "wkv_b_2")
layer_state_dict[layer_id][wkv_b_2_name] = wkv_b_2_int.cpu()
layer_state_dict[layer_id][wkv_b_2_scale_name] = torch.tensor(scale2, dtype=torch.int32)
print(f'layer {layer_id} wkv_b, type: {typ}, shape: {shape}, wkv_b_1 weight: {wkv_b_1_name}, wkv_b_1 scale: {wkv_b_1_scale_name}, wkv_b_2 weight: {wkv_b_2_name}, wkv_b_2 scale: {wkv_b_2_scale_name}', flush=True)
elif ns[1] == 'wo' and ns[2] == 'scale':
continue
elif ns[1] == 'wo' and ns[2] == 'weight':
param_weight = param.cuda()
weight_name = name
scale_name = name.replace('weight', 'scale')
param_scale = states[scale_name]
weight = weight_dequant(param_weight, param_scale.cuda())
scale = wo_rescales[layer_id]
rescale = 2 ** scale
if layer_id != 58:
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
else:
wo_abs = weight.abs().cpu()
maxpos = wo_abs.argmax()
row, col = divmod(maxpos.item(), weight.size(1))
print(f'maxpos: {maxpos}, {row} {col}', flush=True)
vstr = getBF16PrintStr(weight[row][col])
print(f'weight[{row}][{col}]: {vstr}', flush=True)
weight[row][col] = 0
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32)
param_int[row][col] = -(2 ** 31)
layer_state_dict[layer_id][weight_name] = param_int.cpu()
layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32)
print(f'layer {layer_id} wo weight, type: {typ}, shape: {shape}, weight: {weight_name}, scale: {scale_name}', flush=True)
else:
layer_state_dict[layer_id][name] = param
else:
layer_state_dict[layer_id][name] = param
else:
layer_state_dict[layer_id][name] = param
save_file(layer_state_dict[layer_id], os.path.join(save_path, f"layer-{layer_id}.safetensors"))
print(f'Finish saving layer {layer_id}', flush=True)
layer_state_dict0[layer_id] = {}
layer_state_dict[layer_id] = {}
print('Finish opening')
os.makedirs(save_path, exist_ok=True)
print(layer_state_dict)
print(experts)
save_file(head_state_dict, os.path.join(save_path, f"head_int.safetensors"))
save_file(norm_state_dict, os.path.join(save_path, f"norm_int.safetensors"))
save_file(embed_state_dict, os.path.join(save_path, f"embed_int.safetensors"))
# for i, st in enumerate(layer_state_dict):
# # print(f'{i} {st['attn_norm.weight']}', flush=True)
# # print(f'{i} {st['ffn_norm.weight']}', flush=True)
# save_file(st, os.path.join(save_path, f"layer-{i}.safetensors"))
# print(f'Finish saving layer {i}', flush=True)
# for i in trange(mp):
# save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
# print('Finish saving files')
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
new_file_path = os.path.join(save_path, os.path.basename(file_path))
shutil.copyfile(file_path, new_file_path)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--hf-ckpt-path", type=str, required=True)
parser.add_argument("--save-path", type=str, required=True)
parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True)
args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)