|
|
import os |
|
|
import shutil |
|
|
from argparse import ArgumentParser |
|
|
from glob import glob |
|
|
from tqdm import tqdm, trange |
|
|
|
|
|
import torch |
|
|
import ctypes |
|
|
from safetensors.torch import safe_open, save_file |
|
|
from kernel import weight_dequant |
|
|
|
|
|
|
|
|
mapping = { |
|
|
"embed_tokens": ("embed", 0), |
|
|
"input_layernorm": ("attn_norm", None), |
|
|
"post_attention_layernorm": ("ffn_norm", None), |
|
|
"q_proj": ("wq", 0), |
|
|
"q_a_proj": ("wq_a", None), |
|
|
"q_a_layernorm": ("q_norm", None), |
|
|
"q_b_proj": ("wq_b", 0), |
|
|
"kv_a_proj_with_mqa": ("wkv_a", None), |
|
|
"kv_a_layernorm": ("kv_norm", None), |
|
|
"kv_b_proj": ("wkv_b", 0), |
|
|
"o_proj": ("wo", 1), |
|
|
"gate": ("gate", None), |
|
|
"gate_proj": ("w1", 0), |
|
|
"down_proj": ("w2", 1), |
|
|
"up_proj": ("w3", 0), |
|
|
"norm": ("norm", None), |
|
|
"lm_head": ("head", 0), |
|
|
"scale": ("scale", None), |
|
|
} |
|
|
|
|
|
EmbedsInOneFile = 256 |
|
|
EmbedsZKDir = "../zkdata/embeds/" |
|
|
|
|
|
wkv_b_1_rescales = [32, 34, 37, 36, 33, 32, 33, 33, 30, 32, |
|
|
32, 30, 31, 30, 29, 30, 29, 30, 29, 29, |
|
|
29, 29, 29, 29, 29, 29, 29, 29, 29, 29, |
|
|
29, 29, 29, 29, 29, 29, 29, 29, 29, 29, |
|
|
29, 29, 29, 29, 29, 29, 29, 29, 30, 30, |
|
|
29, 29, 30, 30, 30, 30, 29, 30, 30, 29, 30] |
|
|
|
|
|
wkv_b_2_rescales = [31, 32, 32, 31, 32, 30, 30, 30, 30, 30, |
|
|
30, 30, 30, 29, 29, 29, 29, 30, 29, 29, |
|
|
29, 29, 29, 29, 30, 30, 30, 29, 29, 29, |
|
|
29, 29, 30, 29, 30, 29, 30, 29, 29, 29, |
|
|
30, 29, 29, 29, 29, 30, 29, 30, 30, 30, |
|
|
29, 29, 29, 30, 30, 29, 29, 29, 30, 30, 30] |
|
|
|
|
|
wo_rescales = [31, 32, 32, 32, 32, 31, 32, 31, 31, 31, |
|
|
31, 31, 31, 31, 30, 31, 31, 32, 31, 31, |
|
|
31, 30, 30, 30, 30, 30, 30, 30, 30, 30, |
|
|
30, 30, 30, 30, 30, 30, 30, 30, 30, 30, |
|
|
30, 30, 30, 31, 30, 31, 30, 30, 31, 31, |
|
|
31, 30, 31, 31, 31, 30, 31, 31, 31, 31, 32 ] |
|
|
|
|
|
gate_rescales = [0, 0, 0, 33, 32, 32, 32, 31, 32, 31, 30, |
|
|
32, 32, 32, 32, 32, 32, 32, 32, 32, 32, |
|
|
32, 31, 32, 31, 32, 32, 32, 32, 31, 32, |
|
|
32, 31, 32, 32, 32, 32, 32, 32, 32, 32, |
|
|
32, 32, 32, 32, 32, 32, 32, 32, 32, 32, |
|
|
32, 32, 32, 33, 33, 33, 33, 33, 32, 32 ] |
|
|
|
|
|
w1_rescales = [32, 32, 32] |
|
|
w2_rescales = [31, 32, 31] |
|
|
w3_rescales = [32, 33, 32] |
|
|
|
|
|
shared_w1_rescales = [0, 0, 0, 30, 30, 29, 29, 29, 28, 29, |
|
|
29, 28, 29, 29, 29, 29, 29, 29, 29, 29, |
|
|
29, 29, 29, 30, 30, 30, 30, 30, 30, 30, |
|
|
30, 30, 30, 30, 29, 29, 30, 29, 29, 30, |
|
|
29, 29, 29, 29, 29, 29, 29, 29, 29, 29, |
|
|
29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29] |
|
|
|
|
|
shared_w2_rescales = [0, 0, 0, 30, 30, 30, 30, 30, 29, 29, |
|
|
30, 29, 29, 29, 30, 30, 30, 30, 30, 29, |
|
|
29, 29, 29, 29, 29, 29, 29, 30, 30, 29, |
|
|
29, 29, 29, 29, 29, 29, 29, 30, 29, 29, |
|
|
29, 29, 29, 29, 29, 29, 29, 29, 29, 29, |
|
|
29, 29, 29, 29, 29, 29, 30, 30, 30, 30, 30] |
|
|
|
|
|
shared_w3_rescales = [0, 0, 0, 30, 30, 30, 30, 30, 29, 29, |
|
|
30, 29, 29, 29, 30, 30, 30, 29, 30, 29, |
|
|
29, 29, 29, 29, 29, 29, 30, 30, 30, 30, |
|
|
29, 29, 29, 29, 29, 29, 29, 30, 30, 29, |
|
|
30, 29, 29, 29, 29, 30, 29, 29, 30, 30, |
|
|
29, 30, 30, 30, 29, 29, 30, 30, 30, 29, 28] |
|
|
|
|
|
layer_state_dict0 = [{} for _ in range(61)] |
|
|
layer_state_dict = [{} for _ in range(61)] |
|
|
|
|
|
experts = [ [{} for _j in range(256)] for _i in range(61)] |
|
|
|
|
|
def getF32PrintStr(ele): |
|
|
v = int(ele.cpu().view(torch.uint32).item()) |
|
|
ex = str((v >> 23 & 0xFF) - 127) |
|
|
r = '(1+' + str(v & 0x7FFFFF) + '/8388608)' |
|
|
if v & 0x80000000: |
|
|
vstr = '-' + r + '*2^' + ex |
|
|
else: |
|
|
vstr = r + '*2^' + ex |
|
|
return vstr |
|
|
|
|
|
def getBF16PrintStr(ele): |
|
|
v = int(ele.cpu().view(torch.uint16).item()) |
|
|
ex = v >> 7 & 0xFF |
|
|
r = '(1+' + str(v & 0x7F) + '/128)' |
|
|
rraw = v & 0x7F |
|
|
|
|
|
if v & 0x8000: |
|
|
vstr = '-' + r + '*2^' + str(ex - 127) |
|
|
else: |
|
|
vstr = r + '*2^' + str(ex - 127) |
|
|
return vstr |
|
|
|
|
|
def getBF8PrintStr(ele): |
|
|
v = int(ele.cpu().view(torch.uint8).item()) |
|
|
ex = str((v >> 3 & 0xF) - 7) |
|
|
r = '(1+' + str(v & 0x7) + '/8)' |
|
|
|
|
|
if v & 0x80: |
|
|
vstr = '-' + r + '*2^' + ex |
|
|
else: |
|
|
vstr = r + '*2^' + ex |
|
|
|
|
|
if ex == -7 or ex == 8: |
|
|
print(vstr) |
|
|
return vstr |
|
|
|
|
|
def mem(i): |
|
|
a = torch.cuda.memory_allocated()/1024**2 |
|
|
r = torch.cuda.memory_reserved()/1024**2 |
|
|
m = torch.cuda.max_memory_allocated()/1024**2 |
|
|
print(f"{i} allocated={a:.1f}MB, reserved={r:.1f}MB, max={m:.1f}MB", flush=True) |
|
|
|
|
|
def handle_expert_w(layer_id, expert_id, idx, param_weight, weight_name, scale, typ, shape, experts_save_path): |
|
|
global layer_state_dict0 |
|
|
global experts |
|
|
|
|
|
scale_name = weight_name.replace('weight', 'scale') |
|
|
param_scale = layer_state_dict0[layer_id][scale_name] |
|
|
|
|
|
weight = weight_dequant(param_weight.cuda(), param_scale.cuda()) |
|
|
|
|
|
rescale = 2 ** scale |
|
|
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32) |
|
|
|
|
|
|
|
|
weight_name2 = f'w{idx}.weight' |
|
|
scale_name2 = f'w{idx}.scale' |
|
|
experts[layer_id][expert_id][weight_name2] = param_int |
|
|
experts[layer_id][expert_id][scale_name2] = torch.tensor(scale, dtype=torch.int32) |
|
|
|
|
|
if len(experts[layer_id][expert_id]) == 6: |
|
|
save_file(experts[layer_id][expert_id], os.path.join(experts_save_path, f"{expert_id}.safetensors")) |
|
|
experts[layer_id][expert_id] = {} |
|
|
|
|
|
print(f'layer {layer_id} expert {expert_id} w{idx} type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {scale_name}') |
|
|
|
|
|
def saveTensor(fileName, t): |
|
|
with open(fileName, "w", encoding="utf-8") as f: |
|
|
t = t.detach() |
|
|
if t.device.type != "cpu": |
|
|
t = t.cpu() |
|
|
t = t.contiguous() |
|
|
with open(fileName, "wb") as f: |
|
|
f.write(t.numpy().tobytes(order="C")) |
|
|
|
|
|
def main(hf_ckpt_path, save_path, n_experts, mp): |
|
|
""" |
|
|
Converts and saves model checkpoint files into a specified format. |
|
|
|
|
|
Args: |
|
|
hf_ckpt_path (str): Path to the directory containing the input checkpoint files. |
|
|
save_path (str): Path to the directory where the converted checkpoint files will be saved. |
|
|
n_experts (int): Total number of experts in the model. |
|
|
mp (int): Model parallelism factor. |
|
|
|
|
|
Returns: |
|
|
None |
|
|
""" |
|
|
torch.cuda.set_device(0) |
|
|
|
|
|
torch.set_default_dtype(torch.bfloat16) |
|
|
|
|
|
torch.set_num_threads(8) |
|
|
|
|
|
torch.manual_seed(965) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
head_state_dict = {} |
|
|
norm_state_dict = {} |
|
|
embed_state_dict = {} |
|
|
|
|
|
experts_w1_rescales = [] |
|
|
experts_w2_rescales = [] |
|
|
experts_w3_rescales = [] |
|
|
|
|
|
with open("w1.txt", "r", encoding="utf-8") as f1: |
|
|
for line in f1: |
|
|
layer_line = line.strip().split() |
|
|
int_list = [int(s) for s in layer_line] |
|
|
experts_w1_rescales.append(int_list) |
|
|
|
|
|
with open("w2.txt", "r", encoding="utf-8") as f2: |
|
|
for line in f2: |
|
|
layer_line = line.strip().split() |
|
|
int_list = [int(s) for s in layer_line] |
|
|
experts_w2_rescales.append(int_list) |
|
|
|
|
|
with open("w3.txt", "r", encoding="utf-8") as f3: |
|
|
for line in f3: |
|
|
layer_line = line.strip().split() |
|
|
int_list = [int(s) for s in layer_line] |
|
|
experts_w3_rescales.append(int_list) |
|
|
|
|
|
|
|
|
|
|
|
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): |
|
|
with safe_open(file_path, framework="pt", device="cpu") as f: |
|
|
print('Opening ' + file_path, flush=True) |
|
|
for name in f.keys(): |
|
|
|
|
|
if "model.layers.61" in name: |
|
|
continue |
|
|
|
|
|
param: torch.Tensor = f.get_tensor(name) |
|
|
if name.startswith("model."): |
|
|
name = name[len("model."):] |
|
|
name = name.replace("self_attn", "attn") |
|
|
name = name.replace("mlp", "ffn") |
|
|
name = name.replace("weight_scale_inv", "scale") |
|
|
name = name.replace("e_score_correction_bias", "bias") |
|
|
key = name.split(".")[-2] |
|
|
assert key in mapping, f"Key {key} not found in mapping" |
|
|
|
|
|
new_key, dim = mapping[key] |
|
|
|
|
|
name = name.replace(key, new_key) |
|
|
|
|
|
ns = name.split(".") |
|
|
comp = ns[0] |
|
|
if comp == 'head': |
|
|
name2 = name[len('head.'):] |
|
|
print('head: ' + name2) |
|
|
|
|
|
param_int = (param.to(torch.float32) * (2 ** 43)).round().to(torch.int64) |
|
|
head_state_dict[name2] = param_int |
|
|
elif comp == 'norm': |
|
|
name2 = name[len('norm.'):] |
|
|
print('norm: ' + name2) |
|
|
|
|
|
param_int = (param.to(torch.float32) * (2 ** 15)).round().to(torch.int64) |
|
|
norm_state_dict[name2] = param_int |
|
|
elif comp == 'embed': |
|
|
name2 = name[len('embed.'):] |
|
|
print('embed: ' + name2) |
|
|
|
|
|
param_int = (param.to(torch.float32) * (2 ** 31)).round().to(torch.int64) |
|
|
embed_state_dict[name2] = param_int |
|
|
|
|
|
os.makedirs(EmbedsZKDir, exist_ok=True) |
|
|
fileCount = param_int.shape[0] // EmbedsInOneFile |
|
|
for i in range(0, fileCount): |
|
|
saveTensor(EmbedsZKDir + str(i) + '.bin', param_int[i * EmbedsInOneFile : (i+1) * EmbedsInOneFile].cpu()) |
|
|
elif comp == 'layers': |
|
|
layer_id = int(ns[1]) |
|
|
name2 = '.'.join(ns[2:]) |
|
|
layer_state_dict0[layer_id][name2] = param |
|
|
|
|
|
print('Finish loading state dict from disk! ++++++++++') |
|
|
|
|
|
|
|
|
for layer_id in range(len(layer_state_dict0)): |
|
|
os.makedirs(f'{save_path}/experts-{layer_id}', exist_ok=True) |
|
|
|
|
|
states = layer_state_dict0[layer_id] |
|
|
|
|
|
for name, param in states.items(): |
|
|
ns = name.split(".") |
|
|
typ = param.type() |
|
|
shape = param.shape |
|
|
|
|
|
if ns[0] == 'attn_norm': |
|
|
print(f'layer {layer_id} {name}, type: {typ}', flush=True) |
|
|
if ns[1] == 'weight': |
|
|
param_int = (param.to(torch.float32) * (2 ** 21)).round().to(torch.int32) |
|
|
layer_state_dict[layer_id][name] = param_int |
|
|
elif ns[0] == 'ffn_norm': |
|
|
print(f'layer {layer_id} {name}, type: {typ}', flush=True) |
|
|
if ns[1] == 'weight': |
|
|
param_int2 = (param.to(torch.float32) * (2 ** 23)).round().to(torch.int32) |
|
|
layer_state_dict[layer_id][name] = param_int2 |
|
|
elif ns[0] == 'ffn': |
|
|
if len(ns) == 3: |
|
|
if ns[1] == 'w1' and ns[2] == 'scale': |
|
|
continue |
|
|
elif ns[1] == 'w1' and ns[2] == 'weight': |
|
|
param_weight = param.cuda() |
|
|
weight_name = name |
|
|
|
|
|
scale_name = name.replace('weight', 'scale') |
|
|
param_scale = states[scale_name] |
|
|
|
|
|
weight = weight_dequant(param_weight, param_scale.cuda()) |
|
|
scale = w1_rescales[layer_id] |
|
|
rescale = 2 ** scale |
|
|
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32) |
|
|
layer_state_dict[layer_id][weight_name] = param_int.cpu() |
|
|
layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32) |
|
|
|
|
|
print(f'layer {layer_id} w1 weight, type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {name}', flush=True) |
|
|
elif ns[1] == 'w2' and ns[2] == 'scale': |
|
|
continue |
|
|
elif ns[1] == 'w2' and ns[2] == 'weight': |
|
|
param_weight = param.cuda() |
|
|
weight_name = name |
|
|
|
|
|
scale_name = name.replace('weight', 'scale') |
|
|
param_scale = states[scale_name] |
|
|
|
|
|
weight = weight_dequant(param_weight, param_scale.cuda()) |
|
|
scale = w2_rescales[layer_id] |
|
|
rescale = 2 ** scale |
|
|
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32) |
|
|
layer_state_dict[layer_id][weight_name] = param_int.cpu() |
|
|
layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32) |
|
|
|
|
|
print(f'layer {layer_id} w2 weight, type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {name}', flush=True) |
|
|
elif ns[1] == 'w3' and ns[2] == 'scale': |
|
|
continue |
|
|
elif ns[1] == 'w3' and ns[2] == 'weight': |
|
|
param_weight = param.cuda() |
|
|
weight_name = name |
|
|
|
|
|
scale_name = name.replace('weight', 'scale') |
|
|
param_scale = states[scale_name] |
|
|
|
|
|
weight = weight_dequant(param_weight, param_scale.cuda()) |
|
|
scale = w3_rescales[layer_id] |
|
|
rescale = 2 ** scale |
|
|
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32) |
|
|
layer_state_dict[layer_id][weight_name] = param_int.cpu() |
|
|
layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32) |
|
|
|
|
|
print(f'layer {layer_id} w3 weight, type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {name}', flush=True) |
|
|
|
|
|
elif ns[1] == 'gate' and ns[2] == 'weight': |
|
|
gate_rescale = 2 ** gate_rescales[layer_id] |
|
|
gate_int = (param.to(torch.float32) * gate_rescale).round().to(torch.int32) |
|
|
layer_state_dict[layer_id][name] = gate_int.cpu() |
|
|
rescale_name = name.replace('weight', 'scale') |
|
|
layer_state_dict[layer_id][rescale_name] = torch.tensor(gate_rescales[layer_id], dtype=torch.int32) |
|
|
print(f'layer {layer_id}: gate_weight_name: {name}, gate_scale_name: {rescale_name}') |
|
|
elif ns[1] == 'gate' and ns[2] == 'bias': |
|
|
bias_int = (param.to(torch.float32) * (2 ** 23)).round().to(torch.int32) |
|
|
layer_state_dict[layer_id][name] = bias_int.cpu() |
|
|
print(f'layer {layer_id} bias: {name}') |
|
|
else: |
|
|
layer_state_dict[layer_id][name] = param |
|
|
elif len(ns) == 4: |
|
|
if ns[1] == 'shared_experts': |
|
|
if (ns[2] == 'w1' or ns[2] == 'w2' or ns[2] == 'w3') and ns[3] == 'scale': |
|
|
continue |
|
|
elif ns[2] == 'w1' and ns[3] == 'weight': |
|
|
param_weight = param.cuda() |
|
|
weight_name = name |
|
|
|
|
|
scale_name = name.replace('weight', 'scale') |
|
|
param_scale = states[scale_name] |
|
|
|
|
|
weight = weight_dequant(param_weight, param_scale.cuda()) |
|
|
scale = shared_w1_rescales[layer_id] |
|
|
rescale = 2 ** scale |
|
|
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32) |
|
|
layer_state_dict[layer_id][weight_name] = param_int.cpu() |
|
|
layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32) |
|
|
print(f'layer {layer_id} shared_expert w1 type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {scale_name}') |
|
|
elif ns[2] == 'w2' and ns[3] == 'weight': |
|
|
param_weight = param.cuda() |
|
|
weight_name = name |
|
|
|
|
|
scale_name = name.replace('weight', 'scale') |
|
|
param_scale = states[scale_name] |
|
|
|
|
|
weight = weight_dequant(param_weight, param_scale.cuda()) |
|
|
scale = shared_w2_rescales[layer_id] |
|
|
rescale = 2 ** scale |
|
|
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32) |
|
|
layer_state_dict[layer_id][weight_name] = param_int.cpu() |
|
|
layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32) |
|
|
print(f'layer {layer_id} shared_expert w2 type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {scale_name}') |
|
|
elif ns[2] == 'w3' and ns[3] == 'weight': |
|
|
param_weight = param.cuda() |
|
|
weight_name = name |
|
|
|
|
|
scale_name = name.replace('weight', 'scale') |
|
|
param_scale = states[scale_name] |
|
|
|
|
|
weight = weight_dequant(param_weight, param_scale.cuda()) |
|
|
scale = shared_w3_rescales[layer_id] |
|
|
rescale = 2 ** scale |
|
|
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32) |
|
|
layer_state_dict[layer_id][weight_name] = param_int.cpu() |
|
|
layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32) |
|
|
print(f'layer {layer_id} shared_expert w3 type: {typ}, shape: {shape}, weight_name: {weight_name}, scale_name: {scale_name}') |
|
|
else: |
|
|
layer_state_dict[layer_id][name] = param |
|
|
else: |
|
|
layer_state_dict[layer_id][name] = param |
|
|
elif len(ns) == 5: |
|
|
if ns[1] == 'experts': |
|
|
expert_id = int(ns[2]) |
|
|
if (ns[3] == 'w1' or ns[3] == 'w2' or ns[3] == 'w3') and ns[4] == 'scale': |
|
|
continue |
|
|
elif ns[3] == 'w1' and ns[4] == 'weight': |
|
|
scale = experts_w1_rescales[layer_id][expert_id] |
|
|
handle_expert_w(layer_id, expert_id, 1, param, name, scale, typ, shape, f'{save_path}/experts-{layer_id}') |
|
|
elif ns[3] == 'w2' and ns[4] == 'weight': |
|
|
scale = experts_w2_rescales[layer_id][expert_id] |
|
|
handle_expert_w(layer_id, expert_id, 2, param, name, scale, typ, shape, f'{save_path}/experts-{layer_id}') |
|
|
elif ns[3] == 'w3' and ns[4] == 'weight': |
|
|
scale = experts_w3_rescales[layer_id][expert_id] |
|
|
handle_expert_w(layer_id, expert_id, 3, param, name, scale, typ, shape, f'{save_path}/experts-{layer_id}') |
|
|
else: |
|
|
layer_state_dict[layer_id][name] = param |
|
|
else: |
|
|
layer_state_dict[layer_id][name] = param |
|
|
elif ns[0] == 'attn': |
|
|
if len(ns) == 3: |
|
|
if ns[1] == 'wq_a' and ns[2] == 'scale': |
|
|
continue |
|
|
elif ns[1] == 'wq_a' and ns[2] == 'weight': |
|
|
param_weight = param.cuda() |
|
|
weight_name = name |
|
|
|
|
|
scale_name = name.replace('weight', 'scale') |
|
|
param_scale = states[scale_name] |
|
|
|
|
|
weight = weight_dequant(param_weight, param_scale.cuda()) |
|
|
|
|
|
weight_int = (weight.to(torch.float32) * (2 ** 30)).round().to(torch.int32) |
|
|
|
|
|
layer_state_dict[layer_id][weight_name] = weight_int.cpu() |
|
|
|
|
|
print(f'layer {layer_id} wq_a weight, type: {typ}, shape: {shape}', flush=True) |
|
|
elif ns[1] == 'q_norm': |
|
|
print(f'layer {layer_id} q_norm, type: {typ}, shape: {shape}', flush=True) |
|
|
|
|
|
param_int3 = (param.to(torch.float32) * (2 ** 19)).round().to(torch.int32) |
|
|
layer_state_dict[layer_id][name] = param_int3 |
|
|
elif ns[1] == 'kv_norm': |
|
|
print(f'layer {layer_id} kv_norm, type: {typ}, shape: {shape}', flush=True) |
|
|
|
|
|
param_int4 = (param.to(torch.float32) * (2 ** 23)).round().to(torch.int32) |
|
|
layer_state_dict[layer_id][name] = param_int4 |
|
|
elif ns[1] == 'wq_b' and ns[2] == 'scale': |
|
|
continue |
|
|
elif ns[1] == 'wq_b' and ns[2] == 'weight': |
|
|
param_weight = param.cuda() |
|
|
weight_name = name |
|
|
|
|
|
scale_name = name.replace('weight', 'scale') |
|
|
param_scale = states[scale_name] |
|
|
|
|
|
weight = weight_dequant(param_weight, param_scale.cuda()) |
|
|
|
|
|
weight_int = (weight.to(torch.float32) * (2 ** 30)).round().to(torch.int32) |
|
|
|
|
|
weight_int = weight_int.view(128, 192, 1536) |
|
|
wq_b1, wq_b2 = torch.split(weight_int, [128, 64], dim=-2) |
|
|
|
|
|
print(f'layer {layer_id} wq_b1 weight, shape: {wq_b1.shape}, wq_b2 weight, shape: {wq_b2.shape}', flush=True) |
|
|
|
|
|
wq_b1 = wq_b1.reshape(128 * 128, 1536) |
|
|
wq_b2 = wq_b2.reshape(128 * 64, 1536) |
|
|
wq_b1_name = weight_name.replace('wq_b', 'wq_b1') |
|
|
wq_b2_name = weight_name.replace('wq_b', 'wq_b2') |
|
|
|
|
|
|
|
|
layer_state_dict[layer_id][wq_b1_name] = wq_b1.cpu() |
|
|
layer_state_dict[layer_id][wq_b2_name] = wq_b2.cpu() |
|
|
|
|
|
print(f'layer {layer_id} wq_b weight, type: {typ}, shape: {shape}', flush=True) |
|
|
elif ns[1] == 'wkv_a' and ns[2] == 'scale': |
|
|
continue |
|
|
elif ns[1] == 'wkv_a' and ns[2] == 'weight': |
|
|
param_weight = param.cuda() |
|
|
weight_name = name |
|
|
|
|
|
scale_name = name.replace('weight', 'scale') |
|
|
param_scale = states[scale_name] |
|
|
|
|
|
weight = weight_dequant(param_weight, param_scale.cuda()) |
|
|
|
|
|
weight_int = (weight.to(torch.float32) * (2 ** 29)).round().to(torch.int32) |
|
|
|
|
|
|
|
|
|
|
|
weight_int = weight_int.view(576, 7168) |
|
|
wkv_a1, wkv_a2 = torch.split(weight_int, [512, 64], dim=-2) |
|
|
|
|
|
print(f'layer {layer_id} wkv_a1 weight, shape: {wkv_a1.shape}, wkv_a2 weight, shape: {wkv_a2.shape}', flush=True) |
|
|
|
|
|
wkv_a1_name = weight_name.replace('wkv_a', 'wkv_a1') |
|
|
wkv_a2_name = weight_name.replace('wkv_a', 'wkv_a2') |
|
|
|
|
|
|
|
|
layer_state_dict[layer_id][wkv_a1_name] = wkv_a1.cpu() |
|
|
layer_state_dict[layer_id][wkv_a2_name] = wkv_a2.cpu() |
|
|
|
|
|
print(f'layer {layer_id} wkv_a weight, type: {typ}, shape: {shape}', flush=True) |
|
|
elif ns[1] == 'wkv_b' and ns[2] == 'scale': |
|
|
continue |
|
|
elif ns[1] == 'wkv_b' and ns[2] == 'weight': |
|
|
param_weight = param.cuda() |
|
|
weight_name = name |
|
|
|
|
|
scale_name = name.replace('weight', 'scale') |
|
|
param_scale = states[scale_name] |
|
|
|
|
|
weight = weight_dequant(param_weight, param_scale.cuda()) |
|
|
|
|
|
wkv_b = weight.view(128, 256, 512) |
|
|
|
|
|
wkv_b_1 = wkv_b[:, :128] |
|
|
wkv_b_1 = wkv_b_1.reshape(128 * 128, 512) |
|
|
scale1 = wkv_b_1_rescales[layer_id] |
|
|
wkv_b_1_rescale = 2 ** scale1 |
|
|
wkv_b_1_int = torch.round(wkv_b_1.to(torch.float32) * wkv_b_1_rescale).to(torch.int32) |
|
|
|
|
|
wkv_b_2 = wkv_b[:, -128:] |
|
|
wkv_b_2 = wkv_b_2.reshape(128 * 128, 512) |
|
|
scale2 = wkv_b_2_rescales[layer_id] |
|
|
wkv_b_2_rescale = 2 ** scale2 |
|
|
wkv_b_2_int = torch.round(wkv_b_2.to(torch.float32) * wkv_b_2_rescale).to(torch.int32) |
|
|
|
|
|
wkv_b_1_name = weight_name.replace("wkv_b", "wkv_b_1") |
|
|
wkv_b_1_scale_name = scale_name.replace("wkv_b", "wkv_b_1") |
|
|
layer_state_dict[layer_id][wkv_b_1_name] = wkv_b_1_int.cpu() |
|
|
layer_state_dict[layer_id][wkv_b_1_scale_name] = torch.tensor(scale1, dtype=torch.int32) |
|
|
|
|
|
wkv_b_2_name = weight_name.replace("wkv_b", "wkv_b_2") |
|
|
wkv_b_2_scale_name = scale_name.replace("wkv_b", "wkv_b_2") |
|
|
layer_state_dict[layer_id][wkv_b_2_name] = wkv_b_2_int.cpu() |
|
|
layer_state_dict[layer_id][wkv_b_2_scale_name] = torch.tensor(scale2, dtype=torch.int32) |
|
|
|
|
|
print(f'layer {layer_id} wkv_b, type: {typ}, shape: {shape}, wkv_b_1 weight: {wkv_b_1_name}, wkv_b_1 scale: {wkv_b_1_scale_name}, wkv_b_2 weight: {wkv_b_2_name}, wkv_b_2 scale: {wkv_b_2_scale_name}', flush=True) |
|
|
elif ns[1] == 'wo' and ns[2] == 'scale': |
|
|
continue |
|
|
elif ns[1] == 'wo' and ns[2] == 'weight': |
|
|
param_weight = param.cuda() |
|
|
weight_name = name |
|
|
|
|
|
scale_name = name.replace('weight', 'scale') |
|
|
param_scale = states[scale_name] |
|
|
|
|
|
weight = weight_dequant(param_weight, param_scale.cuda()) |
|
|
|
|
|
scale = wo_rescales[layer_id] |
|
|
rescale = 2 ** scale |
|
|
|
|
|
if layer_id != 58: |
|
|
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32) |
|
|
else: |
|
|
wo_abs = weight.abs().cpu() |
|
|
maxpos = wo_abs.argmax() |
|
|
row, col = divmod(maxpos.item(), weight.size(1)) |
|
|
print(f'maxpos: {maxpos}, {row} {col}', flush=True) |
|
|
|
|
|
vstr = getBF16PrintStr(weight[row][col]) |
|
|
print(f'weight[{row}][{col}]: {vstr}', flush=True) |
|
|
weight[row][col] = 0 |
|
|
param_int = (weight.to(torch.float32) * rescale).round().to(torch.int32) |
|
|
param_int[row][col] = -(2 ** 31) |
|
|
|
|
|
layer_state_dict[layer_id][weight_name] = param_int.cpu() |
|
|
layer_state_dict[layer_id][scale_name] = torch.tensor(scale, dtype=torch.int32) |
|
|
|
|
|
print(f'layer {layer_id} wo weight, type: {typ}, shape: {shape}, weight: {weight_name}, scale: {scale_name}', flush=True) |
|
|
else: |
|
|
layer_state_dict[layer_id][name] = param |
|
|
else: |
|
|
layer_state_dict[layer_id][name] = param |
|
|
else: |
|
|
layer_state_dict[layer_id][name] = param |
|
|
|
|
|
save_file(layer_state_dict[layer_id], os.path.join(save_path, f"layer-{layer_id}.safetensors")) |
|
|
print(f'Finish saving layer {layer_id}', flush=True) |
|
|
layer_state_dict0[layer_id] = {} |
|
|
layer_state_dict[layer_id] = {} |
|
|
|
|
|
print('Finish opening') |
|
|
|
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
|
|
print(layer_state_dict) |
|
|
print(experts) |
|
|
|
|
|
save_file(head_state_dict, os.path.join(save_path, f"head_int.safetensors")) |
|
|
save_file(norm_state_dict, os.path.join(save_path, f"norm_int.safetensors")) |
|
|
save_file(embed_state_dict, os.path.join(save_path, f"embed_int.safetensors")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): |
|
|
new_file_path = os.path.join(save_path, os.path.basename(file_path)) |
|
|
shutil.copyfile(file_path, new_file_path) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = ArgumentParser() |
|
|
parser.add_argument("--hf-ckpt-path", type=str, required=True) |
|
|
parser.add_argument("--save-path", type=str, required=True) |
|
|
parser.add_argument("--n-experts", type=int, required=True) |
|
|
parser.add_argument("--model-parallel", type=int, required=True) |
|
|
args = parser.parse_args() |
|
|
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism" |
|
|
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel) |
|
|
|