import torch import torch.nn as nn from tqdm import tqdm import os import safetensors class SteTernaryQuantizer(nn.Module): def __init__(self, group_size): super().__init__() self.group_size = group_size def forward(self, x): org_w_shape = x.shape if self.group_size > 0: assert x.shape[-1] % self.group_size == 0 x = x.reshape(-1, self.group_size) elif self.group_size == -1: x = x.reshape(-1, x.shape[-1]) assert x.dim() == 2 scales = 1.0 / (x.abs().mean(dim=1, keepdim=True).clamp_(min=1e-5)) x_q = (torch.clamp(torch.round(x * scales),-1,1) / scales) assert torch.isnan(x_q).sum() == 0 x = x.reshape(org_w_shape) x_q = x_q.reshape(org_w_shape) return x_q class SteIntQuantizer(nn.Module): def __init__(self, bit, group_size): super().__init__() self.bit = bit self.group_size = group_size def forward(self, x): org_w_shape = x.shape if self.group_size > 0: assert org_w_shape[-1] % self.group_size == 0 x = x.reshape(-1, self.group_size) elif self.group_size == -1: x = x.reshape(-1, x.shape[-1]) assert x.dim() == 2 abs_max_val = x.abs().amax(dim=1, keepdim=True) max_int = 2 ** (self.bit - 1) - 1 min_int = - (2 ** (self.bit - 1)) scales = abs_max_val.clamp(min=1e-5) / max_int assert torch.isnan(scales).sum() == 0 x_q = (torch.clamp(torch.round(x / scales), min_int, max_int)) * scales assert torch.isnan(x_q).sum() == 0 x = x.reshape(org_w_shape) x_q = x_q.reshape(org_w_shape) return x_q class SteInt2Quantizer(nn.Module): def __init__(self, group_size): super().__init__() self.group_size = group_size def forward(self, x): org_w_shape = x.shape if self.group_size > 0: assert x.shape[-1] % self.group_size == 0 x = x.reshape(-1, self.group_size) elif self.group_size == -1: x = x.reshape(-1, x.shape[-1]) assert x.dim() == 2 scales = 1.0 / (x.abs().mean(dim=1, keepdim=True).clamp_(min=1e-5) * 1) x_q = (torch.clamp(torch.round(x * scales),-2,1) / scales) assert torch.isnan(x_q).sum() == 0 x = x.reshape(org_w_shape) x_q = x_q.reshape(org_w_shape) return x_q def quantize_model_bin(input_bin_path, output_bin_path, quant_type="ternary", bit=2, group_size=128, device="cuda" if torch.cuda.is_available() else "cpu"): """ 直接对PyTorch模型bin文件进行量化。 Args: input_bin_path: 输入模型bin文件路径 output_bin_path: 输出量化后的模型bin文件路径 quant_type: 量化类型 ("ternary" 或 "int") bit: 整数量化的位数 (仅在 quant_type="int" 时使用) group_size: 量化分组大小 device: 运行设备 """ print(f"加载模型文件: {input_bin_path}...") if input_bin_path.endswith(".bin"): state_dict = torch.load(input_bin_path, map_location=device) elif input_bin_path.endswith(".safetensors"): state_dict = safetensors.load_file(input_bin_path) elif os.path.isdir(input_bin_path) and os.path.exists(os.path.join(input_bin_path, "pytorch_model.bin")): state_dict = torch.load(os.path.join(input_bin_path, "pytorch_model.bin"), map_location=device) elif os.path.isdir(input_bin_path) and os.path.exists(os.path.join(input_bin_path, "model.safetensors")): state_dict = safetensors.load_file(os.path.join(input_bin_path, "model.safetensors")) else: raise ValueError(f"不支持的模型文件类型: {input_bin_path}") print(f"应用 {quant_type} 量化...") if quant_type == "ternary": quantizer = SteTernaryQuantizer(group_size=group_size) elif quant_type == "int": quantizer = SteIntQuantizer(bit=bit, group_size=group_size) elif quant_type == "int2": quantizer = SteInt2Quantizer(group_size=group_size) else: raise ValueError(f"不支持的量化类型: {quant_type}") # 统计需要量化的参数数量 total_params = sum(1 for k, v in state_dict.items() if ("weight" in k and "layer" in k) or ("fc" in k)) # 应用量化 with torch.no_grad(): for name, param in tqdm(state_dict.items(), total=total_params, desc="量化中"): if (("weight" in name and "layer" in name and param.dim() == 2) or ("fc" in name and param.dim() == 2)): # 对权重进行量化 original_weight = param.data.clone() quantized_weight = quantizer(original_weight) state_dict[name] = quantized_weight # 打印前几个层的统计信息 if total_params > 0: total_params -= 1 if total_params > total_params - 5: print(f"层: {name}") print(f" 原始范围: {original_weight.min():.4f} 到 {original_weight.max():.4f}") print(f" 量化后范围: {quantized_weight.min():.4f} 到 {quantized_weight.max():.4f}") print(f" 均方误差: {((original_weight - quantized_weight)**2).mean():.8f}") # 保存量化后的模型 print(f"保存量化后的模型到: {output_bin_path}...") if output_bin_path.endswith(".bin"): torch.save(state_dict, output_bin_path) elif output_bin_path.endswith(".safetensors"): safetensors.save_file(state_dict, output_bin_path) else: os.makedirs(os.path.dirname(output_bin_path), exist_ok=True) output_bin_path = os.path.join(output_bin_path, "pytorch_model.bin") torch.save(state_dict, output_bin_path) print("完成!") def main(): import argparse parser = argparse.ArgumentParser(description="量化PyTorch模型bin文件") parser.add_argument("--input_bin", type=str, required=True, help="输入模型bin文件路径") parser.add_argument("--output", type=str, required=True, help="输出量化后的模型bin文件路径") parser.add_argument("--quant_type", type=str, default="ternary", choices=["ternary", "int", "int2"], help="量化类型") parser.add_argument("--bit", type=int, default=2, help="整数量化的位数") parser.add_argument("--group_size", type=int, default=-1, help="量化分组大小") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="运行设备") parser.add_argument("--config", type=str, default="", help="model config file") args = parser.parse_args() os.makedirs(args.output, exist_ok=True) quantize_model_bin( input_bin_path=args.input_bin, output_bin_path=os.path.join(args.output, "pytorch_model.bin"), quant_type=args.quant_type, bit=args.bit, group_size=args.group_size, device=args.device ) if args.config: os.system(f"cp {args.config}/* {args.output}") print(f"复制{args.config}文件到{args.output}") if __name__ == "__main__": main()