guanwenyu1995 commited on
Commit
2efbb40
·
verified ·
1 Parent(s): 68a215a

Upload qat-convert.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. qat-convert.py +176 -0
qat-convert.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from tqdm import tqdm
4
+ import os
5
+ import safetensors
6
+
7
+ class SteTernaryQuantizer(nn.Module):
8
+ def __init__(self, group_size):
9
+ super().__init__()
10
+ self.group_size = group_size
11
+
12
+ def forward(self, x):
13
+ org_w_shape = x.shape
14
+ if self.group_size > 0:
15
+ assert x.shape[-1] % self.group_size == 0
16
+ x = x.reshape(-1, self.group_size)
17
+ elif self.group_size == -1:
18
+ x = x.reshape(-1, x.shape[-1])
19
+ assert x.dim() == 2
20
+ scales = 1.0 / (x.abs().mean(dim=1, keepdim=True).clamp_(min=1e-5))
21
+ x_q = (torch.clamp(torch.round(x * scales),-1,1) / scales)
22
+ assert torch.isnan(x_q).sum() == 0
23
+ x = x.reshape(org_w_shape)
24
+ x_q = x_q.reshape(org_w_shape)
25
+ return x_q
26
+
27
+ class SteIntQuantizer(nn.Module):
28
+ def __init__(self, bit, group_size):
29
+ super().__init__()
30
+ self.bit = bit
31
+ self.group_size = group_size
32
+
33
+ def forward(self, x):
34
+ org_w_shape = x.shape
35
+ if self.group_size > 0:
36
+ assert org_w_shape[-1] % self.group_size == 0
37
+ x = x.reshape(-1, self.group_size)
38
+ elif self.group_size == -1:
39
+ x = x.reshape(-1, x.shape[-1])
40
+
41
+ assert x.dim() == 2
42
+
43
+ abs_max_val = x.abs().amax(dim=1, keepdim=True)
44
+ max_int = 2 ** (self.bit - 1) - 1
45
+ min_int = - (2 ** (self.bit - 1))
46
+ scales = abs_max_val.clamp(min=1e-5) / max_int
47
+
48
+ assert torch.isnan(scales).sum() == 0
49
+
50
+ x_q = (torch.clamp(torch.round(x / scales), min_int, max_int)) * scales
51
+
52
+ assert torch.isnan(x_q).sum() == 0
53
+
54
+ x = x.reshape(org_w_shape)
55
+ x_q = x_q.reshape(org_w_shape)
56
+
57
+ return x_q
58
+
59
+ class SteInt2Quantizer(nn.Module):
60
+ def __init__(self, group_size):
61
+ super().__init__()
62
+ self.group_size = group_size
63
+
64
+ def forward(self, x):
65
+ org_w_shape = x.shape
66
+ if self.group_size > 0:
67
+ assert x.shape[-1] % self.group_size == 0
68
+ x = x.reshape(-1, self.group_size)
69
+ elif self.group_size == -1:
70
+ x = x.reshape(-1, x.shape[-1])
71
+
72
+ assert x.dim() == 2
73
+
74
+ scales = 1.0 / (x.abs().mean(dim=1, keepdim=True).clamp_(min=1e-5) * 1)
75
+ x_q = (torch.clamp(torch.round(x * scales),-2,1) / scales)
76
+
77
+ assert torch.isnan(x_q).sum() == 0
78
+
79
+ x = x.reshape(org_w_shape)
80
+ x_q = x_q.reshape(org_w_shape)
81
+
82
+ return x_q
83
+
84
+ def quantize_model_bin(input_bin_path, output_bin_path, quant_type="ternary", bit=2, group_size=128, device="cuda" if torch.cuda.is_available() else "cpu"):
85
+ """
86
+ 直接对PyTorch模型bin文件进行量化。
87
+
88
+ Args:
89
+ input_bin_path: 输入模型bin文件路径
90
+ output_bin_path: 输出量化后的模型bin文件路径
91
+ quant_type: 量化类型 ("ternary" 或 "int")
92
+ bit: 整数量化的位数 (仅在 quant_type="int" 时使用)
93
+ group_size: 量化分组大小
94
+ device: 运行设备
95
+ """
96
+ print(f"加载模型文件: {input_bin_path}...")
97
+ if input_bin_path.endswith(".bin"):
98
+ state_dict = torch.load(input_bin_path, map_location=device)
99
+ elif input_bin_path.endswith(".safetensors"):
100
+ state_dict = safetensors.load_file(input_bin_path)
101
+ elif os.path.isdir(input_bin_path) and os.path.exists(os.path.join(input_bin_path, "pytorch_model.bin")):
102
+ state_dict = torch.load(os.path.join(input_bin_path, "pytorch_model.bin"), map_location=device)
103
+ elif os.path.isdir(input_bin_path) and os.path.exists(os.path.join(input_bin_path, "model.safetensors")):
104
+ state_dict = safetensors.load_file(os.path.join(input_bin_path, "model.safetensors"))
105
+ else:
106
+ raise ValueError(f"不支持的模型文件类型: {input_bin_path}")
107
+
108
+ print(f"应用 {quant_type} 量化...")
109
+ if quant_type == "ternary":
110
+ quantizer = SteTernaryQuantizer(group_size=group_size)
111
+ elif quant_type == "int":
112
+ quantizer = SteIntQuantizer(bit=bit, group_size=group_size)
113
+ elif quant_type == "int2":
114
+ quantizer = SteInt2Quantizer(group_size=group_size)
115
+ else:
116
+ raise ValueError(f"不支持的量化类型: {quant_type}")
117
+
118
+ # 统计需要量化的参数数量
119
+ total_params = sum(1 for k, v in state_dict.items() if ("weight" in k and "layer" in k) or ("fc" in k))
120
+
121
+ # 应用量化
122
+ with torch.no_grad():
123
+ for name, param in tqdm(state_dict.items(), total=total_params, desc="量化中"):
124
+ if (("weight" in name and "layer" in name and param.dim() == 2) or ("fc" in name and param.dim() == 2)):
125
+ # 对权重进行量化
126
+ original_weight = param.data.clone()
127
+ quantized_weight = quantizer(original_weight)
128
+ state_dict[name] = quantized_weight
129
+
130
+ # 打印前几个层的统计信息
131
+ if total_params > 0:
132
+ total_params -= 1
133
+ if total_params > total_params - 5:
134
+ print(f"层: {name}")
135
+ print(f" 原始范围: {original_weight.min():.4f} 到 {original_weight.max():.4f}")
136
+ print(f" 量化后范围: {quantized_weight.min():.4f} 到 {quantized_weight.max():.4f}")
137
+ print(f" 均方误差: {((original_weight - quantized_weight)**2).mean():.8f}")
138
+
139
+ # 保存量化后的模型
140
+ print(f"保存量化后的模型到: {output_bin_path}...")
141
+ if output_bin_path.endswith(".bin"):
142
+ torch.save(state_dict, output_bin_path)
143
+ elif output_bin_path.endswith(".safetensors"):
144
+ safetensors.save_file(state_dict, output_bin_path)
145
+ else:
146
+ os.makedirs(os.path.dirname(output_bin_path), exist_ok=True)
147
+ output_bin_path = os.path.join(output_bin_path, "pytorch_model.bin")
148
+ torch.save(state_dict, output_bin_path)
149
+ print("完成!")
150
+
151
+ def main():
152
+ import argparse
153
+ parser = argparse.ArgumentParser(description="量化PyTorch模型bin文件")
154
+ parser.add_argument("--input_bin", type=str, required=True, help="输入模型bin文件路径")
155
+ parser.add_argument("--output", type=str, required=True, help="输出量化后的模型bin文件路径")
156
+ parser.add_argument("--quant_type", type=str, default="ternary", choices=["ternary", "int", "int2"], help="量化类型")
157
+ parser.add_argument("--bit", type=int, default=2, help="整数量化的位数")
158
+ parser.add_argument("--group_size", type=int, default=-1, help="量化分组大小")
159
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="运行设备")
160
+ parser.add_argument("--config", type=str, default="", help="model config file")
161
+
162
+ args = parser.parse_args()
163
+ os.makedirs(args.output, exist_ok=True)
164
+ quantize_model_bin(
165
+ input_bin_path=args.input_bin,
166
+ output_bin_path=os.path.join(args.output, "pytorch_model.bin"),
167
+ quant_type=args.quant_type,
168
+ bit=args.bit,
169
+ group_size=args.group_size,
170
+ device=args.device
171
+ )
172
+ if args.config:
173
+ os.system(f"cp {args.config}/* {args.output}")
174
+ print(f"复制{args.config}文件到{args.output}")
175
+ if __name__ == "__main__":
176
+ main()