BitCPM-CANN-1B-unquantized / qat-convert.py
guanwenyu1995's picture
Upload qat-convert.py with huggingface_hub
689ba6e verified
import torch
import torch.nn as nn
from tqdm import tqdm
import os
import safetensors
class SteTernaryQuantizer(nn.Module):
def __init__(self, group_size):
super().__init__()
self.group_size = group_size
def forward(self, x):
org_w_shape = x.shape
if self.group_size > 0:
assert x.shape[-1] % self.group_size == 0
x = x.reshape(-1, self.group_size)
elif self.group_size == -1:
x = x.reshape(-1, x.shape[-1])
assert x.dim() == 2
scales = 1.0 / (x.abs().mean(dim=1, keepdim=True).clamp_(min=1e-5))
x_q = (torch.clamp(torch.round(x * scales),-1,1) / scales)
assert torch.isnan(x_q).sum() == 0
x = x.reshape(org_w_shape)
x_q = x_q.reshape(org_w_shape)
return x_q
class SteIntQuantizer(nn.Module):
def __init__(self, bit, group_size):
super().__init__()
self.bit = bit
self.group_size = group_size
def forward(self, x):
org_w_shape = x.shape
if self.group_size > 0:
assert org_w_shape[-1] % self.group_size == 0
x = x.reshape(-1, self.group_size)
elif self.group_size == -1:
x = x.reshape(-1, x.shape[-1])
assert x.dim() == 2
abs_max_val = x.abs().amax(dim=1, keepdim=True)
max_int = 2 ** (self.bit - 1) - 1
min_int = - (2 ** (self.bit - 1))
scales = abs_max_val.clamp(min=1e-5) / max_int
assert torch.isnan(scales).sum() == 0
x_q = (torch.clamp(torch.round(x / scales), min_int, max_int)) * scales
assert torch.isnan(x_q).sum() == 0
x = x.reshape(org_w_shape)
x_q = x_q.reshape(org_w_shape)
return x_q
class SteInt2Quantizer(nn.Module):
def __init__(self, group_size):
super().__init__()
self.group_size = group_size
def forward(self, x):
org_w_shape = x.shape
if self.group_size > 0:
assert x.shape[-1] % self.group_size == 0
x = x.reshape(-1, self.group_size)
elif self.group_size == -1:
x = x.reshape(-1, x.shape[-1])
assert x.dim() == 2
scales = 1.0 / (x.abs().mean(dim=1, keepdim=True).clamp_(min=1e-5) * 1)
x_q = (torch.clamp(torch.round(x * scales),-2,1) / scales)
assert torch.isnan(x_q).sum() == 0
x = x.reshape(org_w_shape)
x_q = x_q.reshape(org_w_shape)
return x_q
def quantize_model_bin(input_bin_path, output_bin_path, quant_type="ternary", bit=2, group_size=128, device="cuda" if torch.cuda.is_available() else "cpu"):
"""
直接对PyTorch模型bin文件进行量化。
Args:
input_bin_path: 输入模型bin文件路径
output_bin_path: 输出量化后的模型bin文件路径
quant_type: 量化类型 ("ternary" 或 "int")
bit: 整数量化的位数 (仅在 quant_type="int" 时使用)
group_size: 量化分组大小
device: 运行设备
"""
print(f"加载模型文件: {input_bin_path}...")
if input_bin_path.endswith(".bin"):
state_dict = torch.load(input_bin_path, map_location=device)
elif input_bin_path.endswith(".safetensors"):
state_dict = safetensors.load_file(input_bin_path)
elif os.path.isdir(input_bin_path) and os.path.exists(os.path.join(input_bin_path, "pytorch_model.bin")):
state_dict = torch.load(os.path.join(input_bin_path, "pytorch_model.bin"), map_location=device)
elif os.path.isdir(input_bin_path) and os.path.exists(os.path.join(input_bin_path, "model.safetensors")):
state_dict = safetensors.load_file(os.path.join(input_bin_path, "model.safetensors"))
else:
raise ValueError(f"不支持的模型文件类型: {input_bin_path}")
print(f"应用 {quant_type} 量化...")
if quant_type == "ternary":
quantizer = SteTernaryQuantizer(group_size=group_size)
elif quant_type == "int":
quantizer = SteIntQuantizer(bit=bit, group_size=group_size)
elif quant_type == "int2":
quantizer = SteInt2Quantizer(group_size=group_size)
else:
raise ValueError(f"不支持的量化类型: {quant_type}")
# 统计需要量化的参数数量
total_params = sum(1 for k, v in state_dict.items() if ("weight" in k and "layer" in k) or ("fc" in k))
# 应用量化
with torch.no_grad():
for name, param in tqdm(state_dict.items(), total=total_params, desc="量化中"):
if (("weight" in name and "layer" in name and param.dim() == 2) or ("fc" in name and param.dim() == 2)):
# 对权重进行量化
original_weight = param.data.clone()
quantized_weight = quantizer(original_weight)
state_dict[name] = quantized_weight
# 打印前几个层的统计信息
if total_params > 0:
total_params -= 1
if total_params > total_params - 5:
print(f"层: {name}")
print(f" 原始范围: {original_weight.min():.4f}{original_weight.max():.4f}")
print(f" 量化后范围: {quantized_weight.min():.4f}{quantized_weight.max():.4f}")
print(f" 均方误差: {((original_weight - quantized_weight)**2).mean():.8f}")
# 保存量化后的模型
print(f"保存量化后的模型到: {output_bin_path}...")
if output_bin_path.endswith(".bin"):
torch.save(state_dict, output_bin_path)
elif output_bin_path.endswith(".safetensors"):
safetensors.save_file(state_dict, output_bin_path)
else:
os.makedirs(os.path.dirname(output_bin_path), exist_ok=True)
output_bin_path = os.path.join(output_bin_path, "pytorch_model.bin")
torch.save(state_dict, output_bin_path)
print("完成!")
def main():
import argparse
parser = argparse.ArgumentParser(description="量化PyTorch模型bin文件")
parser.add_argument("--input_bin", type=str, required=True, help="输入模型bin文件路径")
parser.add_argument("--output", type=str, required=True, help="输出量化后的模型bin文件路径")
parser.add_argument("--quant_type", type=str, default="ternary", choices=["ternary", "int", "int2"], help="量化类型")
parser.add_argument("--bit", type=int, default=2, help="整数量化的位数")
parser.add_argument("--group_size", type=int, default=-1, help="量化分组大小")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="运行设备")
parser.add_argument("--config", type=str, default="", help="model config file")
args = parser.parse_args()
os.makedirs(args.output, exist_ok=True)
quantize_model_bin(
input_bin_path=args.input_bin,
output_bin_path=os.path.join(args.output, "pytorch_model.bin"),
quant_type=args.quant_type,
bit=args.bit,
group_size=args.group_size,
device=args.device
)
if args.config:
os.system(f"cp {args.config}/* {args.output}")
print(f"复制{args.config}文件到{args.output}")
if __name__ == "__main__":
main()