BoruiXu commited on
Commit
fc3d3de
·
verified ·
1 Parent(s): 077d2cb

Update run_awq.py

Browse files
Files changed (1) hide show
  1. run_awq.py +195 -195
run_awq.py CHANGED
@@ -1,195 +1,195 @@
1
- #
2
- # Copyright © 2023 Advanced Micro Devices, Inc. All rights reserved.
3
- #
4
-
5
- import torch
6
- import logging
7
- import time
8
- import argparse
9
- import os
10
- import psutil
11
- from transformers import set_seed
12
- from transformers import LlamaTokenizer,AutoTokenizer
13
-
14
- import qlinear
15
- from utils import Utils
16
- from model_utils import (
17
- warmup,
18
- decode_prompt,
19
- decode_prompts,
20
- get_wikitext2,
21
- perplexity,
22
- )
23
- from profiler import ProfileAIE
24
- import gc
25
-
26
-
27
- from phi3_mini.modeling_phi3 import Phi3ForCausalLM
28
-
29
- from pre_quant import run_awq, apply_awq
30
- from quantizer import real_quantize_model_weight
31
- from qmodule import WQLinear
32
-
33
- set_seed(123)
34
-
35
-
36
- def load_model(args):
37
-
38
- # tokenizer = LlamaTokenizer.from_pretrained("./Phi-3-mini-4k-instruct-AWQ")
39
- tokenizer = AutoTokenizer.from_pretrained("./phi3_mini")
40
- if args.awq == "none":
41
- model = Phi3ForCausalLM.from_pretrained("./phi3_mini", torch_dtype=torch.bfloat16)
42
-
43
- else:
44
- # ckpt = "pytorch_phi3_mini_w_bit_{}_awq{}_{}amd.pt".format(args.w_bit, "_fa" if args.flash_attention else "", "lm_" if args.lm_head else "")
45
- ckpt = "./phi3_mini_awq_4bit_no_flash_attention.pt"
46
- if args.task == "quantize":
47
- model = Phi3ForCausalLM.from_pretrained("./phi3_mini", torch_dtype=torch.bfloat16)
48
- print(model)
49
-
50
- Utils.print_model_size(model)
51
-
52
- q_config = {
53
- "zero_point": True,
54
- "q_group_size": 128, } # whether to use group quantization
55
-
56
- if args.awq == 'load':
57
- print("Loading pre-computed AWQ results from", os.getenv("AWQ_CACHE"))
58
- awq_results = torch.load( "./phi-3-chat-w4-g128_awq.pt", map_location="cpu")
59
- apply_awq(model, awq_results)
60
- print("Quantization config:", q_config)
61
- real_quantize_model_weight(
62
- model, w_bit=args.w_bit, q_config=q_config
63
- )
64
-
65
- Utils.print_model_size(model)
66
-
67
- #for n, m in model.named_modules():
68
- # if isinstance(m, WQLinear):
69
- # print(f"AWQ Model load : {n} : {m.qweight.data.min()} {m.qweight.data.max()} {m.qweight.data.shape} {m.scales.shape} qzeros: {m.qzeros.shape} {m.qzeros.min()} {m.qzeros.max()}")
70
-
71
- elif args.awq == 'run':
72
- awq_results = run_awq(
73
- model, tokenizer,
74
- w_bit=args.w_bit, q_config=q_config,
75
- n_samples=128, seqlen=512,
76
- )
77
- torch.save(awq_results, "./phi3-mini-w%d-g128-generated.pt"%args.w_bit)
78
- print(model)
79
- print("Saved AWQ results in ./phi3-mini-w%d-g128-generated.pt"%args.w_bit)
80
- raise SystemExit
81
-
82
-
83
- Utils.replace_node( model,
84
- WQLinear,
85
- qlinear.QLinearPerGrp,
86
- (), {'device':'cpu', 'w_bit':args.w_bit, 'group_size':128} )
87
- print(model)
88
- gc.collect()
89
-
90
- Utils.print_model_size(model)
91
- if args.lm_head: # Quantize lm_head
92
- Utils.replace_node( model,
93
- torch.nn.Linear,
94
- qlinear.QLinearPerGrp,
95
- (), {'device':'cpu', 'w_bit':args.w_bit, 'group_size':32} )
96
- print(model)
97
- gc.collect()
98
-
99
- torch.save(model, ckpt)
100
- print(f"Quantized and saved model: {ckpt}")
101
- raise SystemExit
102
- else:
103
- print(f"Loading from ckpt: {ckpt}")
104
- if not os.path.exists(ckpt):
105
- print(f"\n\n ***** Run --task quantize (with/without lm_head) first to save quantized model ...!!! \n\n")
106
- raise SystemExit
107
- model = torch.load(ckpt)
108
-
109
- Utils.print_model_size(model)
110
- _ = gc.collect()
111
- model.eval()
112
- model = model.to(torch.bfloat16)
113
- print(model)
114
- return model, tokenizer
115
-
116
-
117
- if __name__ == "__main__":
118
- parser = argparse.ArgumentParser()
119
- parser.add_argument('--dataset', help="Dataset - wikitext2-raw-v1, wikitext2-v1", type=str, default="raw", choices=["non-raw", "raw"])
120
- parser.add_argument('--w_bit', help="weight bit size", type=int, default=3, choices=[3, 4])
121
- parser.add_argument('--awq', help="load awq scales, clips from pt or run awq", type=str, default="load", choices=["load", "run", "none"])
122
- parser.add_argument("--target", help="cpu, aie, aie_emu", type=str, default="cpu", choices=["cpu", "aie_emu", "aie"])
123
- parser.add_argument('--task', help="quantize: Apply AWQ and save ckpt; perplexity: Measure perplexity on wikitext2 dataset; benchmark: Benchmark latency w.r.t prompt length; benchmark_long: Benchmark long sequences (compare with flash attn); decode: Decode set of prompts;", type=str, default="decode", choices=["quantize", "decode", "benchmark", "benchmark_long", "perplexity"] )
124
- parser.add_argument('--flash_attention', help="Enable flash attention", action='store_true')
125
- parser.add_argument('--lm_head', help="Enable PerGrp quantization of lm_head layer", action='store_true')
126
- parser.add_argument('--num_torch_threads', help="Number of torch threads", type=int, default=8, choices=[1, 2, 3, 4, 5, 6, 7, 8])
127
- args = parser.parse_args()
128
- print(f"{args}")
129
- dev = os.getenv("DEVICE")
130
- print(f'DEVICE varibale is {dev}')
131
-
132
- if dev == "stx":
133
- p = psutil.Process()
134
- p.cpu_affinity([0, 1, 2, 3])
135
- torch.set_num_threads(args.num_torch_threads)
136
-
137
- log_dir = "./logs_awq_phi3_chat"
138
- if not os.path.exists(log_dir):
139
- os.makedirs(log_dir)
140
- log_file = log_dir + "/log_awq.log"
141
-
142
- logging.basicConfig(filename=log_file,
143
- filemode='w',
144
- format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
145
- datefmt='%H:%M:%S',
146
- level=logging.CRITICAL)
147
-
148
- model, tokenizer = load_model(args)
149
-
150
- if args.awq != "none":
151
- for n, m in model.named_modules():
152
- print(n)
153
- if isinstance(m, qlinear.QLinearPerGrp):
154
- print(f"Preparing weights of layer : {n}")
155
- m.device = "aie"
156
- m.quantize_weights()
157
-
158
- print(model)
159
- Utils.print_model_size(model)
160
-
161
- warmup(model, tokenizer)
162
-
163
- if (args.task == "decode"):
164
- decode_prompts(model, tokenizer, max_new_tokens=11)
165
- logging.shutdown()
166
- out_file = log_file.replace(".log", "_profile.csv")
167
- out_file = open(out_file, "w")
168
- ProfileAIE.analyze_profiling(False, True, log_file, out_file)
169
- out_file.close()
170
-
171
- elif (args.task == "benchmark") or (args.task == "benchmark_long"):
172
- #print(model.config.max_position_embeddings) # 2048
173
- trainloader, testenc = get_wikitext2(tokenizer, nsamples=2, seqlen=4096)
174
- if (args.task == "benchmark"):
175
- seqlens = [1,2,3,4,5,6,7, 8,9,10,60,61,62,63,64,65,510,512,513,514,515]
176
- else:
177
- seqlens = [512, 1024, 1536]
178
- input_ids = next(iter(trainloader))[0][:, :4096]
179
- for seqlen in seqlens:
180
- logging.critical("*"*40)
181
- print("*"*40)
182
- print(f"Benchmarking for {seqlen} tokens ...")
183
- input_ids_test = input_ids[:, :seqlen]
184
- decode_prompt(model, tokenizer, prompt=None, input_ids = input_ids_test, max_new_tokens=11)
185
-
186
- logging.shutdown()
187
- out_file = log_file.replace(".log", "_profile.csv")
188
- out_file = open(out_file, "w")
189
- ProfileAIE.analyze_profiling(False, True, log_file, out_file)
190
- out_file.close()
191
-
192
- elif (args.task == "perplexity"):
193
- start = time.time()
194
- perplexity(model, tokenizer, dataset=args.dataset)
195
- print(f"Time taken to measure ppl on RyzenAI: {time.time() - start}s")
 
1
+ #
2
+ # Copyright © 2023 Advanced Micro Devices, Inc. All rights reserved.
3
+ #
4
+
5
+ import torch
6
+ import logging
7
+ import time
8
+ import argparse
9
+ import os
10
+ import psutil
11
+ from transformers import set_seed
12
+ from transformers import LlamaTokenizer,AutoTokenizer
13
+
14
+ import qlinear
15
+ from utils import Utils
16
+ from model_utils import (
17
+ warmup,
18
+ decode_prompt,
19
+ decode_prompts,
20
+ get_wikitext2,
21
+ perplexity,
22
+ )
23
+ from profiler import ProfileAIE
24
+ import gc
25
+
26
+ #need to modify to the phi3 folder
27
+ from phi3_mini.modeling_phi3 import Phi3ForCausalLM
28
+
29
+ from pre_quant import run_awq, apply_awq
30
+ from quantizer import real_quantize_model_weight
31
+ from qmodule import WQLinear
32
+
33
+ set_seed(123)
34
+
35
+
36
+ def load_model(args):
37
+
38
+ # tokenizer = LlamaTokenizer.from_pretrained("./Phi-3-mini-4k-instruct-AWQ")
39
+ tokenizer = AutoTokenizer.from_pretrained("./phi3_mini")
40
+ if args.awq == "none":
41
+ model = Phi3ForCausalLM.from_pretrained("./phi3_mini", torch_dtype=torch.bfloat16)
42
+
43
+ else:
44
+ # ckpt = "pytorch_phi3_mini_w_bit_{}_awq{}_{}amd.pt".format(args.w_bit, "_fa" if args.flash_attention else "", "lm_" if args.lm_head else "")
45
+ ckpt = "./phi3_mini_awq_4bit_no_flash_attention.pt"
46
+ if args.task == "quantize":
47
+ model = Phi3ForCausalLM.from_pretrained("./phi3_mini", torch_dtype=torch.bfloat16)
48
+ print(model)
49
+
50
+ Utils.print_model_size(model)
51
+
52
+ q_config = {
53
+ "zero_point": True,
54
+ "q_group_size": 128, } # whether to use group quantization
55
+
56
+ if args.awq == 'load':
57
+ print("Loading pre-computed AWQ results from", os.getenv("AWQ_CACHE"))
58
+ awq_results = torch.load( "./phi-3-chat-w4-g128_awq.pt", map_location="cpu")
59
+ apply_awq(model, awq_results)
60
+ print("Quantization config:", q_config)
61
+ real_quantize_model_weight(
62
+ model, w_bit=args.w_bit, q_config=q_config
63
+ )
64
+
65
+ Utils.print_model_size(model)
66
+
67
+ #for n, m in model.named_modules():
68
+ # if isinstance(m, WQLinear):
69
+ # print(f"AWQ Model load : {n} : {m.qweight.data.min()} {m.qweight.data.max()} {m.qweight.data.shape} {m.scales.shape} qzeros: {m.qzeros.shape} {m.qzeros.min()} {m.qzeros.max()}")
70
+
71
+ elif args.awq == 'run':
72
+ awq_results = run_awq(
73
+ model, tokenizer,
74
+ w_bit=args.w_bit, q_config=q_config,
75
+ n_samples=128, seqlen=512,
76
+ )
77
+ torch.save(awq_results, "./phi3-mini-w%d-g128-generated.pt"%args.w_bit)
78
+ print(model)
79
+ print("Saved AWQ results in ./phi3-mini-w%d-g128-generated.pt"%args.w_bit)
80
+ raise SystemExit
81
+
82
+
83
+ Utils.replace_node( model,
84
+ WQLinear,
85
+ qlinear.QLinearPerGrp,
86
+ (), {'device':'cpu', 'w_bit':args.w_bit, 'group_size':128} )
87
+ print(model)
88
+ gc.collect()
89
+
90
+ Utils.print_model_size(model)
91
+ if args.lm_head: # Quantize lm_head
92
+ Utils.replace_node( model,
93
+ torch.nn.Linear,
94
+ qlinear.QLinearPerGrp,
95
+ (), {'device':'cpu', 'w_bit':args.w_bit, 'group_size':32} )
96
+ print(model)
97
+ gc.collect()
98
+
99
+ torch.save(model, ckpt)
100
+ print(f"Quantized and saved model: {ckpt}")
101
+ raise SystemExit
102
+ else:
103
+ print(f"Loading from ckpt: {ckpt}")
104
+ if not os.path.exists(ckpt):
105
+ print(f"\n\n ***** Run --task quantize (with/without lm_head) first to save quantized model ...!!! \n\n")
106
+ raise SystemExit
107
+ model = torch.load(ckpt)
108
+
109
+ Utils.print_model_size(model)
110
+ _ = gc.collect()
111
+ model.eval()
112
+ model = model.to(torch.bfloat16)
113
+ print(model)
114
+ return model, tokenizer
115
+
116
+
117
+ if __name__ == "__main__":
118
+ parser = argparse.ArgumentParser()
119
+ parser.add_argument('--dataset', help="Dataset - wikitext2-raw-v1, wikitext2-v1", type=str, default="raw", choices=["non-raw", "raw"])
120
+ parser.add_argument('--w_bit', help="weight bit size", type=int, default=3, choices=[3, 4])
121
+ parser.add_argument('--awq', help="load awq scales, clips from pt or run awq", type=str, default="load", choices=["load", "run", "none"])
122
+ parser.add_argument("--target", help="cpu, aie, aie_emu", type=str, default="cpu", choices=["cpu", "aie_emu", "aie"])
123
+ parser.add_argument('--task', help="quantize: Apply AWQ and save ckpt; perplexity: Measure perplexity on wikitext2 dataset; benchmark: Benchmark latency w.r.t prompt length; benchmark_long: Benchmark long sequences (compare with flash attn); decode: Decode set of prompts;", type=str, default="decode", choices=["quantize", "decode", "benchmark", "benchmark_long", "perplexity"] )
124
+ parser.add_argument('--flash_attention', help="Enable flash attention", action='store_true')
125
+ parser.add_argument('--lm_head', help="Enable PerGrp quantization of lm_head layer", action='store_true')
126
+ parser.add_argument('--num_torch_threads', help="Number of torch threads", type=int, default=8, choices=[1, 2, 3, 4, 5, 6, 7, 8])
127
+ args = parser.parse_args()
128
+ print(f"{args}")
129
+ dev = os.getenv("DEVICE")
130
+ print(f'DEVICE varibale is {dev}')
131
+
132
+ if dev == "stx":
133
+ p = psutil.Process()
134
+ p.cpu_affinity([0, 1, 2, 3])
135
+ torch.set_num_threads(args.num_torch_threads)
136
+
137
+ log_dir = "./logs_awq_phi3_chat"
138
+ if not os.path.exists(log_dir):
139
+ os.makedirs(log_dir)
140
+ log_file = log_dir + "/log_awq.log"
141
+
142
+ logging.basicConfig(filename=log_file,
143
+ filemode='w',
144
+ format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
145
+ datefmt='%H:%M:%S',
146
+ level=logging.CRITICAL)
147
+
148
+ model, tokenizer = load_model(args)
149
+
150
+ if args.awq != "none":
151
+ for n, m in model.named_modules():
152
+ print(n)
153
+ if isinstance(m, qlinear.QLinearPerGrp):
154
+ print(f"Preparing weights of layer : {n}")
155
+ m.device = "aie"
156
+ m.quantize_weights()
157
+
158
+ print(model)
159
+ Utils.print_model_size(model)
160
+
161
+ warmup(model, tokenizer)
162
+
163
+ if (args.task == "decode"):
164
+ decode_prompts(model, tokenizer, max_new_tokens=11)
165
+ logging.shutdown()
166
+ out_file = log_file.replace(".log", "_profile.csv")
167
+ out_file = open(out_file, "w")
168
+ ProfileAIE.analyze_profiling(False, True, log_file, out_file)
169
+ out_file.close()
170
+
171
+ elif (args.task == "benchmark") or (args.task == "benchmark_long"):
172
+ #print(model.config.max_position_embeddings) # 2048
173
+ trainloader, testenc = get_wikitext2(tokenizer, nsamples=2, seqlen=4096)
174
+ if (args.task == "benchmark"):
175
+ seqlens = [1,2,3,4,5,6,7, 8,9,10,60,61,62,63,64,65,510,512,513,514,515]
176
+ else:
177
+ seqlens = [512, 1024, 1536]
178
+ input_ids = next(iter(trainloader))[0][:, :4096]
179
+ for seqlen in seqlens:
180
+ logging.critical("*"*40)
181
+ print("*"*40)
182
+ print(f"Benchmarking for {seqlen} tokens ...")
183
+ input_ids_test = input_ids[:, :seqlen]
184
+ decode_prompt(model, tokenizer, prompt=None, input_ids = input_ids_test, max_new_tokens=11)
185
+
186
+ logging.shutdown()
187
+ out_file = log_file.replace(".log", "_profile.csv")
188
+ out_file = open(out_file, "w")
189
+ ProfileAIE.analyze_profiling(False, True, log_file, out_file)
190
+ out_file.close()
191
+
192
+ elif (args.task == "perplexity"):
193
+ start = time.time()
194
+ perplexity(model, tokenizer, dataset=args.dataset)
195
+ print(f"Time taken to measure ppl on RyzenAI: {time.time() - start}s")