aoiandroid commited on
Commit
af9db8f
·
verified ·
1 Parent(s): d7ac469

Upload llm-export/llm_export.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. llm-export/llm_export.py +1467 -0
llm-export/llm_export.py ADDED
@@ -0,0 +1,1467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import glob
4
+ import shutil
5
+ import argparse
6
+ import torch
7
+ import numpy as np
8
+ from onnxslim import slim
9
+ import onnxruntime as ort
10
+ import sentencepiece as spm
11
+ from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
12
+ from peft import LoraConfig, TaskType, get_peft_model, PeftModel
13
+ try:
14
+ import _tools as MNNTools
15
+ except:
16
+ MNNTools = None
17
+
18
+ from llm_models.GOT.GOT_ocr_2_0 import GOTQwenForCausalLM
19
+ from llm_models.GOT.modeling_qwen2 import Qwen2MLP, Qwen2RMSNorm, Qwen2Attention
20
+ def onnx2mnn(onnx_path, mnn_dir, quant_bit = 4, asymmetric = True, external_data = False, bizCode : str= None):
21
+ model_name, model_extension = os.path.splitext(os.path.basename(onnx_path))
22
+ if model_extension != '.onnx':
23
+ return
24
+ mnn_name = model_name + '.mnn'
25
+ mnn_path = os.path.join(mnn_dir, mnn_name)
26
+ convert_args = [
27
+ '',
28
+ '-f',
29
+ 'ONNX',
30
+ '--modelFile',
31
+ str(onnx_path),
32
+ '--MNNModel',
33
+ str(mnn_path),
34
+ '--weightQuantBits',
35
+ str(quant_bit),
36
+ ]
37
+ if asymmetric:
38
+ convert_args.append("--weightQuantAsymmetric")
39
+ if external_data:
40
+ convert_args.append("--saveExternalData")
41
+ if bizCode is not None:
42
+ convert_args.append("--bizCode")
43
+ convert_args.append(str(bizCode))
44
+ MNNTools.mnnconvert(convert_args)
45
+
46
+ # some wrapper class for export
47
+ class Embedding(torch.nn.Module):
48
+ def __init__(self, embed, using_bf16: bool = False):
49
+ super().__init__()
50
+ self.bf16 = using_bf16
51
+ self.embed_dim = embed.weight.shape[-1]
52
+ if using_bf16:
53
+ # using bf16 embedding weight
54
+ self.embed = embed.bfloat16()
55
+ else:
56
+ self.embed = embed
57
+
58
+ def forward(self, input_ids):
59
+ res = self.embed(input_ids)
60
+ if self.bf16:
61
+ res = res.float()
62
+ return res.view(-1, 1, self.embed_dim)
63
+
64
+ class GOTEmbedding(torch.nn.Module):
65
+ def __init__(self, embed, using_bf16: bool = False):
66
+ super().__init__()
67
+ self.bf16 = using_bf16
68
+ self.embed_dim = embed.weight.shape[-1]
69
+ if using_bf16:
70
+ # using bf16 embedding weight
71
+ self.embed = embed.bfloat16()
72
+ else:
73
+ self.embed = embed
74
+
75
+ def forward(self, input_ids):
76
+ res = self.embed(input_ids)
77
+ if self.bf16:
78
+ res = res.float()
79
+ return res.view(1, -1, self.embed_dim)
80
+
81
+ class Lm(torch.nn.Module):
82
+ def __init__(self, lm):
83
+ super().__init__()
84
+ self.lm = lm
85
+
86
+ def forward(self, hidden_states):
87
+ m_logits = self.lm(hidden_states)
88
+ #token = torch.argmax(m_logits)
89
+ return m_logits
90
+
91
+ class LLM(torch.nn.Module):
92
+ '''
93
+ Base class for all llm model. Inherits from [`torch.nn.Module`].
94
+ '''
95
+
96
+ def __init__(self, args):
97
+ super().__init__()
98
+ self.quant_bit = 4
99
+ self.asymmetric = True
100
+ self.onnx_path = args.onnx_path
101
+ self.mnn_path = args.mnn_path
102
+ if not os.path.exists(self.onnx_path):
103
+ os.makedirs(self.onnx_path)
104
+ if not os.path.exists(self.mnn_path):
105
+ os.makedirs(self.mnn_path)
106
+ self.export_mnn = args.export_mnn
107
+ self.export_verbose = args.export_verbose
108
+ self.export_test = args.export_test
109
+ # default is False, just set True when using below command:
110
+ # `python llm_export ../path --export --embed_bin` to export single model without embedding
111
+ self.without_embed = False
112
+ self.embed_bin = args.embed_bin
113
+ if self.embed_bin:
114
+ self.embed_bf16 = True
115
+ else:
116
+ self.embed_bf16 = args.embed_bf16
117
+ self.skip_slim = args.skip_slim
118
+ tokenizer_model = os.path.join(args.path, 'tokenizer.model')
119
+ if os.path.exists(tokenizer_model):
120
+ self.sp_model = spm.SentencePieceProcessor(tokenizer_model)
121
+ else:
122
+ self.sp_model = None
123
+ merge_file = os.path.join(args.path, 'merges.txt')
124
+ if os.path.exists(merge_file):
125
+ self.merge_txt = merge_file
126
+ else:
127
+ self.merge_txt = None
128
+ self.stop_ids = []
129
+ self.max_length = 1024
130
+ self.hidden_size = 4096
131
+ self.visual = None # defualt is not visual
132
+ self.lora_path = args.lora_path
133
+ self.load_hf(args.path)
134
+ self.load_model()
135
+
136
+ def load_hf(self, model_path: str):
137
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
138
+ self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).float().eval()
139
+ self.config = self.model.config
140
+ if self.lora_path is not None:
141
+ adapter = PeftModel.from_pretrained(self.model, model_id=self.lora_path)
142
+ self.model = adapter.merge_and_unload(progressbar=True)
143
+
144
+ def load_model(self):
145
+ raise NotImplementedError
146
+
147
+ def get_attention_mask(self) -> torch.Tensor:
148
+ raise NotImplementedError
149
+
150
+ def get_position_ids(self) -> torch.Tensor:
151
+ raise NotImplementedError
152
+
153
+ def export_vocab(self):
154
+ raise NotImplementedError
155
+
156
+ def visual_embed(self, input_ids):
157
+ raise NotImplementedError
158
+
159
+ def __embedding(self, input_ids):
160
+ if self.visual is not None and self.token_len == 0:
161
+ input_embeds = self.visual_embed(input_ids)
162
+ else:
163
+ input_embeds = self.embed(input_ids)
164
+ return input_embeds
165
+
166
+ def __decode(self, hidden_states, attention_mask, position_ids, past_key_values):
167
+ presents = []
168
+ for i in range(self.block_nums):
169
+ hidden_states, kv = self.blocks[i](hidden_states, attention_mask, position_ids, past_key_values[i])
170
+ presents.append(kv)
171
+ token_id = self.lm(hidden_states).view(1)
172
+ presents = torch.stack(presents)
173
+ self.seq_len += 1
174
+ self.token_len += 1
175
+ return token_id, presents
176
+
177
+ def forward(self, input_ids, attention_mask, position_ids, past_key_values):
178
+ if self.without_embed:
179
+ return self.__decode(input_ids, attention_mask, position_ids, past_key_values)
180
+ return self.__decode(self.__embedding(input_ids), attention_mask, position_ids, past_key_values)
181
+
182
+ # some test functions
183
+ def build_prompt(self, query):
184
+ if hasattr(self.tokenizer, 'build_prompt'):
185
+ prompt = self.tokenizer.build_prompt(query)
186
+ else:
187
+ prompt = query
188
+ return prompt
189
+
190
+ def str_to_ids(self, prompt):
191
+ input_ids = self.tokenizer(prompt, return_tensors="pt")['input_ids']
192
+ return input_ids
193
+
194
+ def id_to_str(self, token_id):
195
+ word = self.tokenizer._convert_id_to_token(int(token_id))
196
+ word = self.tokenizer.convert_tokens_to_string([word])
197
+ return word
198
+
199
+ def response(self, query):
200
+ prompt = self.build_prompt(query)
201
+ input_ids = self.str_to_ids(prompt)
202
+ self.seq_len = input_ids.numel()
203
+ self.context_len = self.seq_len - 2
204
+ self.token_len = 0
205
+ past_key_values = [None for i in range(self.block_nums)]
206
+ token_id = input_ids
207
+ while self.token_len < self.max_length:
208
+ attention_mask = self.get_attention_mask()
209
+ position_ids = self.get_position_ids()
210
+ token_id, past_key_values = self.forward(token_id, attention_mask, position_ids, past_key_values)
211
+ if token_id == self.stop_id or token_id in self.stop_ids:
212
+ print("", end='\n')
213
+ break
214
+ word = self.id_to_str(token_id)
215
+ print(word, end="", flush=True)
216
+
217
+ # some export functions
218
+ def assert_equal(self, torch_outs, onnx_outs):
219
+ if type(torch_outs) not in (list, tuple):
220
+ torch_outs = (torch_outs, )
221
+ onnx_outs = (onnx_outs, )
222
+ same = True
223
+ for orig, onnx in zip(torch_outs, onnx_outs):
224
+ orig = orig.detach().numpy()
225
+ if not np.allclose(orig, onnx, rtol=1e-3, atol=1e-3):
226
+ print('Error: onnx outputs dont match original. [shape = {}] onnx: {}, original: {}'.format(onnx.shape, onnx, orig))
227
+ same = False
228
+ break
229
+ if same:
230
+ print('onnx test SUCCESS')
231
+
232
+ def export_lm(self):
233
+ model = self.lm
234
+ hidden_states = torch.randn(1, self.hidden_size)
235
+ onnx_model = f'./{self.onnx_path}/lm.onnx'
236
+ torch.onnx.export(model, (hidden_states),
237
+ onnx_model,
238
+ verbose=self.export_verbose,
239
+ input_names=['hidden_states'],
240
+ output_names=['token_id'],
241
+ do_constant_folding=True,
242
+ opset_version=15)
243
+ if not self.skip_slim:
244
+ slim(onnx_model, output_model=onnx_model)
245
+ # test lm
246
+ if self.export_test:
247
+ original_outs = model(hidden_states)
248
+ ort_session = ort.InferenceSession(onnx_model, providers=['CPUExecutionProvider'])
249
+ inputs = {
250
+ 'hidden_states' : hidden_states.numpy(),
251
+ }
252
+ onnx_outs = ort_session.run(None, inputs)
253
+ self.assert_equal(original_outs, onnx_outs)
254
+ if self.export_mnn:
255
+ onnx2mnn(onnx_model, self.mnn_path, self.quant_bit, self.asymmetric)
256
+
257
+ def export_visual(self):
258
+ if self.visual is None:
259
+ return
260
+ input_images = torch.randn((1, 3, self.image_size, self.image_size))
261
+ model = self.visual
262
+ onnx_model = f'./{self.onnx_path}/visual.onnx'
263
+ torch.onnx.export(model, (input_images),
264
+ onnx_model,
265
+ verbose=self.export_verbose,
266
+ input_names=['input_images'],
267
+ output_names=['image_embeds'],
268
+ dynamic_axes={"input_images": {
269
+ 0: "size"
270
+ }},
271
+ do_constant_folding=True,
272
+ opset_version=15)
273
+ if not self.skip_slim:
274
+ slim(onnx_model, output_model=onnx_model)
275
+ # test
276
+ if self.export_test:
277
+ original_outs = model(input_images)
278
+ ort_session = ort.InferenceSession(onnx_model, providers=['CPUExecutionProvider'])
279
+ inputs = {
280
+ 'input_images' : input_images.numpy(),
281
+ }
282
+ onnx_outs = ort_session.run(None, inputs)[0]
283
+ self.assert_equal(original_outs, onnx_outs)
284
+ if self.export_mnn:
285
+ onnx2mnn(onnx_model, self.mnn_path)
286
+
287
+ def export_embed(self):
288
+ model = self.embed
289
+ if self.embed_bin:
290
+ import ctypes
291
+ tensor_data = model.embed.weight.data
292
+ data_ptr = tensor_data.untyped_storage().data_ptr()
293
+ buffer = (ctypes.c_byte * (tensor_data.numel() * 2)).from_address(data_ptr)
294
+ with open(f'./{self.mnn_path}/embeddings_bf16.bin', 'wb') as f:
295
+ f.write(buffer)
296
+ return
297
+ input_ids = torch.arange(3, dtype=torch.long)
298
+ onnx_model = f'./{self.onnx_path}/embedding.onnx'
299
+ torch.onnx.export(model, (input_ids),
300
+ onnx_model,
301
+ verbose=self.export_verbose,
302
+ input_names=['input_ids'],
303
+ output_names=['inputs_embeds'],
304
+ dynamic_axes={"input_ids": {
305
+ 0: "length"
306
+ }},
307
+ do_constant_folding=True,
308
+ opset_version=15)
309
+ if not self.skip_slim:
310
+ slim(onnx_model, output_model=onnx_model)
311
+ # test
312
+ if self.export_test:
313
+ original_outs = model(input_ids)
314
+ ort_session = ort.InferenceSession(onnx_model, providers=['CPUExecutionProvider'])
315
+ inputs = {
316
+ 'input_ids' : input_ids.numpy(),
317
+ }
318
+ onnx_outs = ort_session.run(None, inputs)
319
+ self.assert_equal(original_outs, onnx_outs)
320
+ if self.export_mnn:
321
+ onnx2mnn(onnx_model, self.mnn_path)
322
+
323
+ def export_block(self, block_id: int):
324
+ self.seq_len = 3
325
+ self.token_len = 0
326
+ inputs_embeds = torch.randn((1, self.seq_len, self.hidden_size))
327
+ attention_mask = self.get_attention_mask()
328
+ position_ids = self.get_position_ids()
329
+ past_key_cache = torch.randn((1, self.num_key_value_heads, 0, self.hidden_size// self.num_key_value_heads)) # torch.Size([1, 16, 286, 64])
330
+ past_value_cache = torch.randn((1, self.num_key_value_heads, 0, self.hidden_size// self.num_key_value_heads))
331
+ model = self.blocks[block_id]
332
+ onnx_model = f'./{self.onnx_path}/block_{block_id}.onnx'
333
+ # 每一个 循环都有pastkv cache
334
+ torch.onnx.export(
335
+ model, (inputs_embeds, attention_mask, position_ids,past_key_cache,past_value_cache),
336
+ onnx_model,
337
+ verbose=self.export_verbose,
338
+ input_names=[
339
+ 'inputs_embeds', 'attention_mask', 'position_ids', 'past_key_cache', 'past_value_cache'
340
+ ],
341
+ output_names=['hidden_states', 'past_key_states', 'past_value_states'],
342
+ dynamic_axes= {
343
+ "inputs_embeds" : { 1: "seq_len" },
344
+ "attention_mask" : { 2: "seq_len", 3: "seq_len" },
345
+ "position_ids" : { 1: "seq_len" },
346
+ "past_key_cache" : { 2: "seq_len" },
347
+ "past_value_cache" : { 2: "seq_len" },
348
+ "hidden_states":{1: "seq_len" },
349
+ "past_key_states":{2: "seq_len" },
350
+ "past_value_states":{2: "seq_len" },
351
+ },
352
+ opset_version=17)
353
+ if not self.skip_slim:
354
+ slim(onnx_model, output_model=onnx_model)
355
+ if self.export_test:
356
+ original_outs = model(inputs_embeds, attention_mask, position_ids)
357
+ ort_session = ort.InferenceSession(onnx_model, providers=['CPUExecutionProvider'])
358
+ inputs = {
359
+ 'inputs_embeds' : inputs_embeds.detach().numpy(),
360
+ 'attention_mask' : attention_mask.numpy(),
361
+ 'position_ids' : position_ids.numpy(),
362
+ }
363
+ onnx_outs = ort_session.run(None, inputs)
364
+ self.assert_equal(original_outs, onnx_outs)
365
+ if self.export_mnn:
366
+ onnx2mnn(onnx_model, self.mnn_path, self.quant_bit, self.asymmetric)
367
+
368
+ def export_blocks(self):
369
+ for i in range(self.block_nums):
370
+ self.export_block(i)
371
+
372
+ def export(self):
373
+ model = self
374
+ self.seq_len = 3
375
+ self.token_len = 0
376
+ input_ids = torch.arange(3, dtype=torch.long)
377
+ attention_mask = self.get_attention_mask()
378
+ position_ids = self.get_position_ids()
379
+ past_key_values = torch.zeros(self.past_kv_shape)
380
+ onnx_model = f'./{self.onnx_path}/llm.onnx'
381
+ if self.embed_bin:
382
+ self.without_embed = True
383
+ input_ids = self.__embedding(input_ids)
384
+ print('export start ...')
385
+ torch.onnx.export(
386
+ model, (input_ids, attention_mask, position_ids, past_key_values),
387
+ onnx_model,
388
+ verbose=self.export_verbose,
389
+ input_names=[
390
+ 'input_ids', 'attention_mask', 'position_ids', 'past_key_values'
391
+ ],
392
+ output_names=['token_id', 'presents'],
393
+ dynamic_axes=self.model_dynamic_axes,
394
+ do_constant_folding=True,
395
+ opset_version=15)
396
+ print('export done!')
397
+ if not self.skip_slim:
398
+ slim(onnx_model, output_model=onnx_model)
399
+ if self.export_test:
400
+ # test
401
+ original_outs = model(input_ids, attention_mask, position_ids, past_key_values)
402
+ ort_session = ort.InferenceSession(onnx_model, providers=['CPUExecutionProvider'])
403
+ inputs = {
404
+ 'input_ids' : input_ids.detach().numpy(),
405
+ 'attention_mask' : attention_mask.numpy(),
406
+ 'position_ids' : position_ids.numpy(),
407
+ 'past_key_values' : past_key_values.numpy()
408
+ }
409
+ onnx_outs = ort_session.run(None, inputs)
410
+ self.assert_equal(original_outs, onnx_outs)
411
+ if self.export_mnn:
412
+ # single model is > 2G, using external_data
413
+ onnx2mnn(onnx_model, self.mnn_path, self.quant_bit, self.asymmetric, True)
414
+ if self.without_embed:
415
+ self.without_embed = False
416
+
417
+ def export_tokenizer(self):
418
+ file_path = os.path.join(self.onnx_path, "tokenizer.txt")
419
+ if self.sp_model is not None:
420
+ # senetencepiece
421
+ print('# senetencepiece tokenier')
422
+ NORMAL = 1; UNKNOWN = 2; CONTROL = 3
423
+ USER_DEFINED = 4; UNUSED = 5; BYTE = 6
424
+ fp = open(file_path, "w", encoding="utf8")
425
+ for i in range(self.sp_model.GetPieceSize()):
426
+ token = self.sp_model.IdToPiece(i)
427
+ score = self.sp_model.GetScore(i)
428
+ type = NORMAL
429
+ if self.sp_model.IsUnknown(i):
430
+ type = UNKNOWN
431
+ elif self.sp_model.IsControl(i):
432
+ type = CONTROL
433
+ elif self.sp_model.IsUnused(i):
434
+ type = UNUSED
435
+ elif self.sp_model.IsByte(i):
436
+ type = BYTE
437
+ if self.model_name == 'Chatglm_6b':
438
+ if '<n>' in token: token = '\n'
439
+ if '<|tab|>' in token: token = '\t'
440
+ if '<|blank_' in token: token = ' ' * int(token[8:token.find('|>')])
441
+ if '▁' in token: token = token.replace('▁', ' ')
442
+ token_encode = base64.b64encode(token.encode("utf-8")).decode("utf8")
443
+ fp.write(f'{token_encode} {score} {type}\n')
444
+ fp.close()
445
+ elif hasattr(self.tokenizer, 'mergeable_ranks'):
446
+ print('# tiktoken tokenier')
447
+ # tikton
448
+ with open(file_path, "w", encoding="utf8") as fp:
449
+ for k, v in self.tokenizer.mergeable_ranks.items():
450
+ line = base64.b64encode(k).decode("utf8") + "\n"
451
+ fp.write(line)
452
+ if hasattr(self.tokenizer, 'special_tokens'):
453
+ for k, v in self.tokenizer.special_tokens.items():
454
+ line = base64.b64encode(k.encode("utf-8")).decode("utf8") + "\n"
455
+ fp.write(line)
456
+ elif self.merge_txt is not None:
457
+ # huggingface tokenizer
458
+ merge_list = []
459
+ vocab = self.tokenizer.get_vocab()
460
+ vocab_list = ['<unk>' for i in range(len(vocab))]
461
+ # load vocab
462
+ for k, v in vocab.items():
463
+ vocab_list[int(v)] = k
464
+ # load merge
465
+ with open(self.merge_txt, 'rt') as merge:
466
+ for line in merge.readlines():
467
+ merge_list.append(line)
468
+ # write to tokenizer.txt
469
+ with open(file_path, "w", encoding="utf8") as fp:
470
+ fp.write(f'{len(vocab_list)} {len(merge_list)}\n')
471
+ for v in vocab_list:
472
+ fp.write(v + '\n')
473
+ for m in merge_list:
474
+ fp.write(m)
475
+ else:
476
+ # huggingface tokenizer
477
+ def unicode_to_byte(u: int):
478
+ if u >= 256 and u <= 288:
479
+ return u - 256
480
+ if u >= 289 and u <= 322:
481
+ return u - 162
482
+ if u == 323:
483
+ return 173
484
+ if u == 65372: # |
485
+ return 124
486
+ if u == 9601: # _
487
+ return 95
488
+ return u
489
+ with open(file_path, "w", encoding="utf8") as fp:
490
+ vocab = self.tokenizer.get_vocab()
491
+ vocab_list = ['<unk>' for i in range(len(vocab))]
492
+ for k, v in vocab.items():
493
+ try:
494
+ vocab_list[int(v)] = bytes([unicode_to_byte(ord(c)) for c in k]).decode('utf-8', errors='ignore')
495
+ except:
496
+ vocab_list[int(v)] = k
497
+ for v in vocab_list:
498
+ line = base64.b64encode(v.encode('utf-8')).decode("utf8") + "\n"
499
+ fp.write(line)
500
+
501
+
502
+ # chatglm
503
+ class GLMBlock(torch.nn.Module):
504
+ def __init__(self, block, block_id, final_layernorm = None):
505
+ super().__init__()
506
+ self.block = block
507
+ self.block_id = block_id
508
+ self.final_layernorm = final_layernorm
509
+
510
+ def forward(self, hidden_states, attention_mask, position_ids, past_kv):
511
+ hidden_states, presents = self.block(hidden_states,
512
+ position_ids,
513
+ attention_mask,
514
+ self.block_id,
515
+ past_kv,
516
+ use_cache=True)
517
+ if self.final_layernorm is not None:
518
+ hidden_states = self.final_layernorm(hidden_states)
519
+ hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size)
520
+ if isinstance(presents, tuple):
521
+ presents = torch.stack(presents)
522
+ return hidden_states, presents
523
+
524
+ class Chatglm_6b(LLM):
525
+ def __init__(self, args):
526
+ super().__init__(args)
527
+ self.model_name = 'Chatglm_6b'
528
+
529
+ def load_model(self):
530
+ transformer = self.model.transformer
531
+ self.lm_ = self.model.lm_head
532
+ self.embed_ = transformer.word_embeddings
533
+ self.blocks_ = transformer.layers
534
+ self.final_layernorm_ = transformer.final_layernorm
535
+ # some wrapper
536
+ self.stop_id = self.tokenizer._convert_token_to_id(self.tokenizer.eos_token)
537
+ self.block_nums = len(self.blocks_)
538
+ self.lm = Lm(self.lm_)
539
+ # chatglm embedding and lm using same param, copy embedding when using bf16
540
+ if self.embed_bf16:
541
+ import copy
542
+ embed_copy = copy.deepcopy(self.embed_)
543
+ self.embed = Embedding(embed_copy, self.embed_bf16)
544
+ else:
545
+ self.embed = Embedding(self.embed_, self.embed_bf16)
546
+ self.blocks = [GLMBlock(self.blocks_[i], i, self.final_layernorm_ if i == len(self.blocks_) - 1 else None) for i in range(self.block_nums)]
547
+ # some config for export
548
+ self.past_kv_shape = [28, 2, 0, 1, 32, 128]
549
+ self.block_dynamic_axes = {
550
+ "inputs_embeds" : { 0: "seq_len" },
551
+ "attention_mask" : { 2: "seq_len", 3: "seq_len" },
552
+ "position_ids" : { 2: "seq_len" },
553
+ "past_key_values" : { 1: "history_len" }
554
+ }
555
+ self.model_dynamic_axes = {
556
+ "input_ids" : { 0: "seq_len" },
557
+ "attention_mask" : { 2: "seq_len", 3: "seq_len" },
558
+ "position_ids" : { 2: "seq_len" },
559
+ "past_key_values" : { 2: "history_len" }
560
+ }
561
+
562
+ def get_attention_mask(self) -> torch.Tensor:
563
+ if self.token_len:
564
+ return torch.zeros([1]).bool().reshape([1, 1, 1, 1])
565
+ attention_mask = torch.zeros([self.seq_len, self.seq_len], dtype=torch.bool)
566
+ for i in range(self.seq_len):
567
+ attention_mask[i][-1] = True
568
+ attention_mask = attention_mask.reshape([1, 1, self.seq_len, self.seq_len])
569
+ return attention_mask
570
+
571
+ def get_position_ids(self) -> torch.Tensor:
572
+ if self.token_len:
573
+ return torch.tensor([1, self.seq_len - self.context_len]).reshape([1, 2, 1])
574
+ position_ids_0 = torch.arange(self.seq_len, dtype=torch.long)
575
+ position_ids_1 = torch.zeros(self.seq_len, dtype=torch.long)
576
+ position_ids_1[-1] = 1
577
+ position_ids = torch.stack([position_ids_0, position_ids_1]).view(1, 2, -1)
578
+ return position_ids
579
+
580
+ # chatglm2
581
+ class GLM2Block(torch.nn.Module):
582
+ def __init__(self, block, block_id, final_layernorm = None):
583
+ super().__init__()
584
+ self.block = block
585
+ self.block_id = block_id
586
+ self.final_layernorm = final_layernorm
587
+ self.hidden_size = 4096
588
+
589
+ def forward(self, hidden_states, attention_mask, position_ids, past_kv):
590
+ theta = 1.0 / (10000 ** (torch.arange(0, 64, 2, dtype=torch.float32) / 64))
591
+ position_ids = position_ids.float().reshape(-1, 1)
592
+ idx_theta = position_ids * theta
593
+ rotary_pos_emb = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1).unsqueeze(0).contiguous()
594
+ hidden_states, presents = self.block(hidden_states,
595
+ attention_mask,
596
+ kv_cache=past_kv,
597
+ rotary_pos_emb=rotary_pos_emb)
598
+ if self.final_layernorm is not None:
599
+ hidden_states = self.final_layernorm(hidden_states)
600
+ hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size)
601
+ if isinstance(presents, tuple):
602
+ presents = torch.stack(presents)
603
+ return hidden_states, presents
604
+
605
+ class Chatglm2_6b(LLM):
606
+ def __init__(self, args):
607
+ super().__init__(args)
608
+ self.model_name = 'Chatglm2_6b'
609
+ if 'codegeex2-6b' in args.path:
610
+ self.model_name = 'Codegeex2_6b'
611
+
612
+ def load_model(self):
613
+ transformer = self.model.transformer
614
+ self.lm_ = transformer.output_layer
615
+ self.embed_ = transformer.embedding.word_embeddings
616
+ self.blocks_ = transformer.encoder.layers
617
+ self.final_layernorm_ = transformer.encoder.final_layernorm
618
+ # some wrapper
619
+ self.stop_id = self.tokenizer.eos_token_id
620
+ if self.stop_id is None:
621
+ # codegeex2-6b
622
+ self.stop_id = self.tokenizer.tokenizer.eos_id
623
+ self.block_nums = len(self.blocks_)
624
+ self.embed = Embedding(self.embed_, self.embed_bf16)
625
+ self.lm = Lm(self.lm_)
626
+ self.blocks = [GLM2Block(self.blocks_[i], i, self.final_layernorm_ if i == len(self.blocks_) - 1 else None) for i in range(self.block_nums)]
627
+ # some config for export
628
+ self.past_kv_shape = [28, 2, 0, 1, 2, 128]
629
+ self.block_dynamic_axes = {
630
+ "inputs_embeds" : { 0: "seq_len" },
631
+ "attention_mask" : { 2: "seq_len", 3: "seq_len" },
632
+ "position_ids" : { 0: "seq_len" },
633
+ "past_key_values" : { 1: "history_len" }
634
+ }
635
+ self.model_dynamic_axes = {
636
+ "input_ids" : { 0: "seq_len" },
637
+ "attention_mask" : { 2: "seq_len", 3: "seq_len" },
638
+ "position_ids" : { 0: "seq_len" },
639
+ "past_key_values" : { 2: "history_len" }
640
+ }
641
+
642
+ def get_attention_mask(self) -> torch.Tensor:
643
+ if self.token_len:
644
+ return torch.zeros([1, 1, 1, 1]).bool()
645
+ attention_mask = ~torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]).bool())
646
+ return attention_mask
647
+
648
+ def get_position_ids(self) -> torch.Tensor:
649
+ if self.token_len:
650
+ return torch.tensor([self.token_len], dtype=torch.long)
651
+ return torch.arange(self.seq_len, dtype=torch.long)
652
+
653
+ # chatglm3
654
+ class Chatglm3_6b(Chatglm2_6b):
655
+ def __init__(self, args):
656
+ super().__init__(args)
657
+ self.model_name = 'Chatglm3_6b'
658
+
659
+ def build_prompt(self, query):
660
+ return f'<|user|>\n{query}\n<|assistant|>\n'
661
+
662
+ # qwen
663
+ class QWENBlock(torch.nn.Module):
664
+ def __init__(self, name, block, block_id, hidden_size, final_layernorm = None):
665
+ super().__init__()
666
+ self.name = name
667
+ self.block = block
668
+ self.block_id = block_id
669
+ self.final_layernorm = final_layernorm
670
+ self.hidden_size = hidden_size
671
+
672
+ def forward(self, hidden_states, attention_mask, position_ids, past_kv):
673
+ theta = 1.0 / (10000.0 ** (torch.arange(0, 128, 2, dtype=torch.float32) / 128))
674
+ position_ids = position_ids.float().reshape(-1, 1)
675
+ idx_theta = position_ids * theta
676
+ rotary_pos_emb = torch.cat((idx_theta, idx_theta), dim=-1)
677
+ rotary_pos_emb = rotary_pos_emb.unsqueeze(1).unsqueeze(0)
678
+ if self.name != 'Qwen-7B':
679
+ rotary_pos_emb = torch.stack([torch.cos(rotary_pos_emb), torch.sin(rotary_pos_emb)])
680
+ hidden_states = hidden_states.view(1, -1, self.hidden_size)
681
+ hidden_states, presents = self.block(hidden_states=hidden_states,
682
+ layer_past=past_kv,
683
+ attention_mask=attention_mask,
684
+ rotary_pos_emb=rotary_pos_emb,
685
+ use_cache=True)
686
+ if self.final_layernorm is not None:
687
+ hidden_states = self.final_layernorm(hidden_states)
688
+ hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size)
689
+ if isinstance(presents, tuple):
690
+ presents = torch.stack(presents)
691
+ return hidden_states, presents
692
+
693
+ class QWEN18Block(torch.nn.Module):
694
+ def __init__(self, block, block_id, hidden_size, final_layernorm = None):
695
+ super().__init__()
696
+ self.block = block
697
+ self.block_id = block_id
698
+ self.final_layernorm = final_layernorm
699
+ self.hidden_size = hidden_size
700
+
701
+ def forward(self, hidden_states, attention_mask, position_ids, past_kv):
702
+ theta = 1.0 / (10000.0 ** (torch.arange(0, 128, 2, dtype=torch.float32) / 128))
703
+ position_ids = position_ids.float().reshape(-1, 1)
704
+ idx_theta = position_ids * theta
705
+ rotary_pos_emb = torch.cat((idx_theta, idx_theta), dim=-1).unsqueeze(1).unsqueeze(0)
706
+ rotary_pos_emb = torch.stack([torch.cos(rotary_pos_emb), torch.sin(rotary_pos_emb)])
707
+ hidden_states = hidden_states.view(1, -1, self.hidden_size)
708
+ hidden_states, presents = self.block(hidden_states,
709
+ rotary_pos_emb,
710
+ past_kv,
711
+ attention_mask,
712
+ use_cache=True)
713
+ if self.final_layernorm is not None:
714
+ hidden_states = self.final_layernorm(hidden_states)
715
+ hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size)
716
+ if isinstance(presents, tuple):
717
+ presents = torch.stack(presents)
718
+ return hidden_states, presents
719
+
720
+ class Qwen_Chat(LLM):
721
+ def __init__(self, args):
722
+ super().__init__(args)
723
+
724
+ def load_model(self):
725
+ # Qwen models
726
+ self.model_name = 'Qwen-7B'
727
+ if '1_8' in model_path:
728
+ self.model_name = 'Qwen-1_8b'
729
+ if 'VL' in model_path:
730
+ self.model_name = 'Qwen-VL'
731
+ transformer = self.model.transformer
732
+ self.lm_ = self.model.lm_head
733
+ self.embed_ = transformer.wte
734
+ self.blocks_ = transformer.h
735
+ self.final_layernorm_ = transformer.ln_f
736
+ if hasattr(transformer, 'visual'):
737
+ self.visual = transformer.visual
738
+ self.image_start_id = transformer.config.visual['image_start_id']
739
+ self.image_size = transformer.config.visual['image_size']
740
+ # some wrapper
741
+ self.stop_id = self.tokenizer.im_end_id
742
+ self.block_nums = len(self.blocks_)
743
+ self.hidden_size = transformer.embed_dim
744
+ self.embed = Embedding(self.embed_, self.embed_bf16)
745
+ self.lm = Lm(self.lm_)
746
+ self.blocks = [QWENBlock(self.model_name, self.blocks_[i], i, self.hidden_size, self.final_layernorm_ if i == len(self.blocks_) - 1 else None) for i in range(self.block_nums)]
747
+ if self.block_nums == 32:
748
+ # qwen-7b, qwen-vl
749
+ self.past_kv_shape = [32, 2, 1, 0, 32, 128]
750
+ elif self.block_nums == 24:
751
+ # qwen-1.8b
752
+ self.past_kv_shape = [24, 2, 1, 0, 16, 128]
753
+ # some config for export
754
+ self.block_dynamic_axes = {
755
+ "inputs_embeds" : { 0: "seq_len" },
756
+ "attention_mask" : { 2: "seq_len", 3: "seq_len" },
757
+ "position_ids" : { 0: "seq_len" },
758
+ "past_key_values" : { 2: "history_len" }
759
+ }
760
+ self.model_dynamic_axes = {
761
+ "input_ids" : { 0: "seq_len" },
762
+ "attention_mask" : { 2: "seq_len", 3: "seq_len" },
763
+ "position_ids" : { 0: "seq_len" },
764
+ "past_key_values" : { 3: "history_len" }
765
+ }
766
+
767
+ def build_prompt(self, query):
768
+ return f'\n<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n'
769
+
770
+ def get_attention_mask(self) -> torch.Tensor:
771
+ if self.model_name == 'Qwen-VL':
772
+ if self.token_len:
773
+ return torch.zeros([1, 1, 1, self.seq_len], dtype=torch.float32)
774
+ return (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min
775
+ if self.token_len:
776
+ return torch.ones([1, 1, 1, 1]).bool()
777
+ return torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]).bool())
778
+
779
+ def get_position_ids(self) -> torch.Tensor:
780
+ if self.token_len:
781
+ return torch.tensor([self.seq_len - 1], dtype=torch.long)
782
+ return torch.arange(self.seq_len, dtype=torch.long)
783
+
784
+ def visual_embed(self, input_ids):
785
+ if not torch.any(input_ids == self.image_start_id):
786
+ return self.embed(input_ids)
787
+ bos_pos = torch.where(input_ids == self.image_start_id)
788
+ eos_pos = torch.where(input_ids == self.image_start_id + 1)
789
+ img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
790
+ images = []
791
+ for i, a, b in img_pos:
792
+ image = input_ids[i][a + 1 : b - 1].tolist()
793
+ image = image[ : image.index(self.image_start_id + 2)]
794
+ images.append(bytes(image).decode('utf-8'))
795
+ images = self.visual.encode(images)
796
+ hidden_states = self.embed(input_ids).view(1, -1, self.hidden_size)
797
+ for idx, (i, a, b) in enumerate(img_pos):
798
+ hidden_states[i][a + 1 : b] = images[idx]
799
+ return hidden_states.view(-1, 1, self.hidden_size)
800
+
801
+
802
+ class Qwen2DecoderLayer(torch.nn.Module):
803
+ def __init__(self, config, block, layer_idx: int):
804
+ super().__init__()
805
+ self.block = block
806
+ # self.hidden_size = config.hidden_size
807
+ self.self_attn = Qwen2Attention(config, layer_idx)
808
+ # 加载权重
809
+ self.self_attn.load_state_dict(block.self_attn.state_dict())
810
+ self.mlp = self.block.mlp
811
+ self.input_layernorm = self.block.input_layernorm
812
+ self.post_attention_layernorm = self.block.post_attention_layernorm
813
+ # self.mlp = Qwen2MLP(config)
814
+ # self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
815
+ # self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
816
+ def forward(
817
+ self,
818
+ hidden_states: torch.Tensor,
819
+ attention_mask,
820
+ position_ids,
821
+ past_key_cache=None,
822
+ past_value_cache=None
823
+ ):
824
+ residual = hidden_states
825
+
826
+ hidden_states = self.input_layernorm(hidden_states)
827
+
828
+ # Self Attention
829
+ hidden_states,past_key_states,past_value_states = self.self_attn(
830
+ hidden_states=hidden_states,
831
+ attention_mask=attention_mask,
832
+ position_ids=position_ids,
833
+ past_key_cache=past_key_cache,
834
+ past_value_cache=past_value_cache
835
+ )
836
+ hidden_states = residual + hidden_states
837
+
838
+ # Fully Connected
839
+ residual = hidden_states
840
+ hidden_states = self.post_attention_layernorm(hidden_states)
841
+ hidden_states = self.mlp(hidden_states)
842
+ hidden_states = residual + hidden_states
843
+
844
+ return hidden_states,past_key_states,past_value_states
845
+ #hidden_states = self.block(hidden_states, attention_mask, position_ids)
846
+ #return hidden_states
847
+
848
+
849
+ class QWEN2Block(torch.nn.Module):
850
+ def __init__(self, name, block, block_id, config, final_layernorm = None):
851
+ super().__init__()
852
+ self.name = name
853
+ self.block = block
854
+ self.block_id = block_id
855
+ self.final_layernorm = final_layernorm
856
+ self.hidden_size = config.hidden_size
857
+ self.head_dim = config.hidden_size // config.num_attention_heads
858
+ self.rope_theta = config.rope_theta
859
+
860
+ def forward(self, hidden_states, attention_mask, position_ids):
861
+ theta = 1.0 / (self.rope_theta ** (torch.arange(0, self.head_dim, 2, dtype=torch.float32) / self.head_dim))
862
+ position_ids = position_ids.float().reshape(-1, 1)
863
+ idx_theta = position_ids * theta
864
+ rotary_pos_emb = torch.cat((idx_theta, idx_theta), dim=-1)
865
+ rotary_pos_emb = rotary_pos_emb.unsqueeze(1).unsqueeze(0)
866
+ rotary_pos_emb = torch.stack([torch.cos(rotary_pos_emb), torch.sin(rotary_pos_emb)])
867
+ hidden_states = hidden_states.view(1, -1, self.hidden_size)
868
+ hidden_states = self.block(hidden_states=hidden_states,
869
+ attention_mask=attention_mask,
870
+ #past_key_value=past_kv,
871
+ rotary_pos_emb=rotary_pos_emb,
872
+ #use_cache=True
873
+ )
874
+ if self.final_layernorm is not None:
875
+ hidden_states = self.final_layernorm(hidden_states[0])
876
+ hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size)
877
+ # print('###', presents.shape)
878
+ return hidden_states
879
+
880
+ class Qwen2_Chat(LLM):
881
+ def __init__(self, args):
882
+ super().__init__(args)
883
+
884
+ def load_model(self):
885
+ # Qwen2 models
886
+ self.model_name = 'Qwen2'
887
+ transformer = self.model.model
888
+ self.lm_ = self.model.lm_head
889
+ self.embed_ = transformer.embed_tokens
890
+ self.blocks_ = transformer.layers
891
+ self.final_layernorm_ = transformer.norm
892
+ # some wrapper
893
+ self.stop_id = self.tokenizer.eos_token_id
894
+ if hasattr(model, 'generation_config'):
895
+ self.stop_ids.append(self.stop_id)
896
+ for id in self.model.generation_config.eos_token_id:
897
+ self.stop_ids.append(id)
898
+ self.block_nums = self.config.num_hidden_layers
899
+ self.hidden_size = self.config.hidden_size
900
+ self.num_heads = self.config.num_attention_heads
901
+ self.rope_theta = self.config.rope_theta
902
+ self.head_dim = self.hidden_size // self.num_heads
903
+ if self.embed_.weight is self.lm_.weight:
904
+ import copy
905
+ embed_copy = copy.deepcopy(self.embed_)
906
+ self.embed = Embedding(embed_copy, self.embed_bf16)
907
+ else:
908
+ self.embed = Embedding(self.embed_, self.embed_bf16)
909
+ self.lm = Lm(self.lm_)
910
+ self.past_kv_shape = [self.block_nums, 2, 1, 0, self.num_heads, self.head_dim]
911
+ self.blocks = [QWEN2Block(self.model_name, self.blocks_[i], i, self.config, self.final_layernorm_ if i == len(self.blocks_) - 1 else None) for i in range(self.block_nums)]
912
+ # some config for export
913
+ self.block_dynamic_axes = {
914
+ "inputs_embeds" : { 0: "seq_len" },
915
+ "attention_mask" : { 2: "seq_len", 3: "seq_len" },
916
+ "position_ids" : { 0: "seq_len" },
917
+ "past_key_values" : { 1: "history_len" }
918
+ }
919
+ self.model_dynamic_axes = {
920
+ "input_ids" : { 0: "seq_len" },
921
+ "attention_mask" : { 2: "seq_len", 3: "seq_len" },
922
+ "position_ids" : { 0: "seq_len" },
923
+ "past_key_values" : { 2: "history_len" }
924
+ }
925
+
926
+ def build_prompt(self, query):
927
+ return f'<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n'
928
+
929
+ def get_attention_mask(self) -> torch.Tensor:
930
+ if self.token_len:
931
+ return torch.zeros([1, 1, 1, self.seq_len], dtype=torch.float32)
932
+ return (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min
933
+
934
+
935
+ def get_position_ids(self) -> torch.Tensor:
936
+ if self.token_len:
937
+ return torch.tensor([[self.seq_len - 1]], dtype=torch.long)
938
+ return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0)
939
+
940
+ def visual_embed(self, input_ids):
941
+ if not torch.any(input_ids == self.image_start_id):
942
+ return self.embed(input_ids)
943
+ bos_pos = torch.where(input_ids == self.image_start_id)
944
+ eos_pos = torch.where(input_ids == self.image_start_id + 1)
945
+ img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
946
+ images = []
947
+ for i, a, b in img_pos:
948
+ image = input_ids[i][a + 1 : b - 1].tolist()
949
+ image = image[ : image.index(self.image_start_id + 2)]
950
+ images.append(bytes(image).decode('utf-8'))
951
+ images = self.visual.encode(images)
952
+ hidden_states = self.embed(input_ids).view(1, -1, self.hidden_size)
953
+ for idx, (i, a, b) in enumerate(img_pos):
954
+ hidden_states[i][a + 1 : b] = images[idx]
955
+ return hidden_states.view(-1, 1, self.hidden_size)
956
+
957
+ # llama2
958
+ class LLAMA2Block(torch.nn.Module):
959
+ def __init__(self, block, block_id, hidden_size, final_layernorm = None):
960
+ super().__init__()
961
+ self.block = block
962
+ self.block_id = block_id
963
+ self.final_layernorm = final_layernorm
964
+ self.hidden_size = hidden_size
965
+
966
+ def forward(self, hidden_states, attention_mask, position_ids, past_kv):
967
+ hidden_states = hidden_states.view(1, -1, self.hidden_size)
968
+ hidden_states, presents = self.block(hidden_states,
969
+ attention_mask,
970
+ position_ids,
971
+ past_kv,
972
+ use_cache=True)
973
+ if self.final_layernorm is not None:
974
+ hidden_states = self.final_layernorm(hidden_states)
975
+ hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size)
976
+ if isinstance(presents, tuple):
977
+ presents = torch.stack(presents)
978
+ return hidden_states, presents
979
+
980
+ class Llama2_7b_Chat(LLM):
981
+ def __init__(self, args):
982
+ self.model_name = 'Llama2_7b'
983
+ if 'Baichuan2' in args.path:
984
+ self.model_name = 'Baichuan2_7B'
985
+ if 'internlm' in args.path:
986
+ self.model_name = 'Internlm_7b'
987
+ if 'TinyLlama' in args.path:
988
+ self.model_name = 'TinyLlama'
989
+ if 'Yi' in args.path:
990
+ self.model_name = 'Yi'
991
+ if 'deepseek' in args.path:
992
+ self.model_name = 'deepseek'
993
+ if 'Llama-3' in args.path:
994
+ self.model_name = 'Llama3_8B'
995
+ super().__init__(args)
996
+
997
+ def load_model(self):
998
+ self.config = self.model.config
999
+ transformer = self.model.model
1000
+ self.lm_ = self.model.lm_head
1001
+ self.embed_ = transformer.embed_tokens
1002
+ self.blocks_ = transformer.layers
1003
+ self.final_layernorm_ = transformer.norm
1004
+ # some wrapper
1005
+ self.hidden_size = self.embed_.weight.shape[-1]
1006
+ self.stop_id = self.tokenizer.eos_token_id
1007
+ if hasattr(model, 'generation_config'):
1008
+ self.stop_ids.append(self.stop_id)
1009
+ self.stop_ids.append(self.model.generation_config.eos_token_id)
1010
+ if self.model_name == 'Llama3_8B':
1011
+ self.stop_ids.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))
1012
+ self.block_nums = len(self.blocks_)
1013
+ self.embed = Embedding(self.embed_, self.embed_bf16)
1014
+ self.lm = Lm(self.lm_)
1015
+ self.blocks = [LLAMA2Block(self.blocks_[i], i, self.hidden_size, self.final_layernorm_ if i == len(self.blocks_) - 1 else None) for i in range(self.block_nums)]
1016
+ self.block_nums = self.config.num_hidden_layers
1017
+ self.hidden_size = self.config.hidden_size
1018
+ self.num_attention_heads = self.config.num_attention_heads
1019
+ self.head_dim = self.hidden_size // self.num_attention_heads
1020
+ self.num_key_value_heads = self.config.num_key_value_heads
1021
+ self.past_kv_shape = [self.block_nums, 2, 1, self.num_key_value_heads, 0, self.head_dim]
1022
+ self.block_dynamic_axes = {
1023
+ "inputs_embeds" : { 0: "seq_len" },
1024
+ "attention_mask" : { 2: "seq_len", 3: "seq_len" },
1025
+ "position_ids" : { 0: "seq_len" },
1026
+ "past_key_values" : { 3: "history_len" }
1027
+ }
1028
+ self.model_dynamic_axes = {
1029
+ "input_ids" : { 0: "seq_len" },
1030
+ "attention_mask" : { 2: "seq_len", 3: "seq_len" },
1031
+ "position_ids" : { 0: "seq_len" },
1032
+ "past_key_values" : { 4: "history_len" }
1033
+ }
1034
+
1035
+ def build_prompt(self, query):
1036
+ if 'Baichuan2' in self.model_name:
1037
+ return f'<reserved_106>{query}<reserved_107>'
1038
+ if 'Internlm_7b' in self.model_name:
1039
+ return f'<|User|>:{query}<eoh>\n<|Bot|>:'
1040
+ if 'TinyLlama' in self.model_name:
1041
+ return f'<s><|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s>\n<|user|>\n{query}</s>\n<|assistant|>\n'
1042
+ if 'Yi' in self.model_name:
1043
+ return f'<|im_start|> user\n{query}<|im_end|>\n<|im_start|> assistant\n'
1044
+ if 'deepseek' in self.model_name:
1045
+ return f'<|begin▁of▁sentence|>User: {query}\nAssistant:'
1046
+ if 'Llama3' in self.model_name:
1047
+ return f'<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{query}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'
1048
+ return f'[INST]{query}[/INST]'
1049
+
1050
+ def get_attention_mask(self) -> torch.Tensor:
1051
+ if self.token_len:
1052
+ return torch.zeros([1, 1, 1, self.seq_len], dtype=torch.float32)
1053
+ return (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min
1054
+
1055
+ def get_position_ids(self) -> torch.Tensor:
1056
+ if self.token_len:
1057
+ return torch.tensor([[self.seq_len - 1]], dtype=torch.long)
1058
+ return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0)
1059
+
1060
+ # phi-2
1061
+ class PHI2Block(torch.nn.Module):
1062
+ def __init__(self, block, block_id, hidden_size):
1063
+ super().__init__()
1064
+ self.block = block
1065
+ self.block_id = block_id
1066
+ self.hidden_size = hidden_size
1067
+
1068
+ def forward(self, hidden_states, attention_mask, position_ids, past_kv):
1069
+ theta = 1.0 / (10000 ** (torch.arange(0, 32, 2, dtype=torch.float32) / 32))
1070
+ position_ids = position_ids.float().reshape(-1, 1)
1071
+ idx_theta = position_ids * theta
1072
+ rotary_pos_emb = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=0).contiguous()
1073
+ hidden_states = hidden_states.view(1, -1, self.hidden_size)
1074
+ hidden_states, presents = self.block(hidden_states,
1075
+ past_kv,
1076
+ rotary_pos_emb=rotary_pos_emb,
1077
+ causal_mask=attention_mask
1078
+ )
1079
+ if self.block_id == 31:
1080
+ hidden_states = hidden_states[:, -1, :]
1081
+ return hidden_states, presents
1082
+
1083
+ class phi_2(LLM):
1084
+ def __init__(self, args):
1085
+ super().__init__(args)
1086
+ self.model_name = 'phi-2'
1087
+ self.asymmetric = False # TODO: some precision bug when using asymmetric
1088
+
1089
+ def load_model(self):
1090
+ transformer = self.model.transformer
1091
+ self.lm_ = self.model.lm_head
1092
+ self.embed_ = transformer.embd.wte
1093
+ self.hidden_size = self.embed_.weight.shape[-1]
1094
+ self.blocks_ = transformer.h
1095
+ # self.final_layernorm_ = transformer.final_layernorm
1096
+ # some wrapper
1097
+ self.stop_id = self.tokenizer.eos_token_id
1098
+ self.block_nums = len(self.blocks_)
1099
+ self.embed = Embedding(self.embed_, self.embed_bf16)
1100
+ self.lm = Lm(self.lm_)
1101
+ self.blocks = [PHI2Block(self.blocks_[i], i, self.hidden_size) for i in range(self.block_nums)]
1102
+ # some config for export
1103
+ self.past_kv_shape = [len(self.blocks), 1, 0, 2, 32, 80]
1104
+ self.block_dynamic_axes = {
1105
+ "inputs_embeds" : { 0: "seq_len" },
1106
+ "attention_mask" : { 2: "seq_len", 3: "seq_len" },
1107
+ "position_ids" : { 0: "seq_len" },
1108
+ "past_key_values" : { 1: "history_len" }
1109
+ }
1110
+ self.model_dynamic_axes = {
1111
+ "input_ids" : { 0: "seq_len" },
1112
+ "attention_mask" : { 2: "seq_len", 3: "seq_len" },
1113
+ "position_ids" : { 0: "seq_len" },
1114
+ "past_key_values" : { 2: "history_len" }
1115
+ }
1116
+
1117
+ def build_prompt(self, query):
1118
+ return f'Instruct: {query}\nOutput:'
1119
+
1120
+ def get_attention_mask(self) -> torch.Tensor:
1121
+ if self.token_len:
1122
+ return torch.zeros([1, 1, 1, 1]).bool()
1123
+ attention_mask = ~torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]).bool())
1124
+ return attention_mask
1125
+
1126
+ def get_position_ids(self) -> torch.Tensor:
1127
+ if self.token_len:
1128
+ return torch.tensor([[self.seq_len - 1]], dtype=torch.long)
1129
+ return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0)
1130
+
1131
+ # BGE is Embedding Model based Bert
1132
+ class BGEBlock(torch.nn.Module):
1133
+ def __init__(self, block, block_id, hidden_size):
1134
+ super().__init__()
1135
+ self.block = block
1136
+ self.block_id = block_id
1137
+ self.hidden_size = hidden_size
1138
+
1139
+ def forward(self, hidden_states, attention_mask):
1140
+ hidden_states = self.block(hidden_states, attention_mask)[0]
1141
+ return hidden_states
1142
+
1143
+ class bge(LLM):
1144
+ def __init__(self, args):
1145
+ super().__init__(args)
1146
+ self.model_name = 'bge-large-zh'
1147
+
1148
+ def forward(self, input_ids, position_ids, attention_mask):
1149
+ input_ids = input_ids.view(1, -1)
1150
+ token_type_ids = (1 - attention_mask).view(1, -1)
1151
+ hidden_states = self.embed(input_ids, token_type_ids, position_ids)[0].unsqueeze(0)
1152
+ for i in range(self.block_nums):
1153
+ hidden_states = self.blocks[i](hidden_states, attention_mask)
1154
+ # hidden_states = self.lm(hidden_states) # sentence_embeddings not need
1155
+ sentence_embeddings = hidden_states[:, 0]
1156
+ sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
1157
+ return sentence_embeddings
1158
+
1159
+ def response(self, query):
1160
+ self.eval()
1161
+ input_ids = self.tokenizer(query)['input_ids']
1162
+ self.seq_len = len(input_ids)
1163
+ input_ids = torch.tensor(input_ids)
1164
+ position_ids = self.get_position_ids()
1165
+ attention_mask = self.get_attention_mask()
1166
+ res = self.forward(input_ids, position_ids, attention_mask)
1167
+ return res
1168
+
1169
+ def load_model(self):
1170
+ transformer = self.model.encoder
1171
+ self.lm_ = self.model.pooler
1172
+ self.embed_ = self.model.embeddings
1173
+ self.hidden_size = self.embed_.word_embeddings.weight.shape[-1]
1174
+ self.blocks_ = transformer.layer
1175
+ # some wrapper
1176
+ self.stop_id = self.tokenizer.eos_token_id
1177
+ self.block_nums = len(self.blocks_)
1178
+ self.embed = self.embed_
1179
+ self.lm = self.lm_
1180
+ self.blocks = [BGEBlock(self.blocks_[i], i, self.hidden_size) for i in range(self.block_nums)]
1181
+ # some config for export
1182
+ self.model_dynamic_axes = {
1183
+ "input_ids" : { 0: "seq_len" },
1184
+ "position_ids" : { 1: "seq_len" },
1185
+ "attention_mask" : { 3: "seq_len" }
1186
+ }
1187
+
1188
+ def export(self):
1189
+ model = self.eval()
1190
+ self.seq_len = 3
1191
+ input_ids = torch.arange(3, dtype=torch.long)
1192
+ position_ids = self.get_position_ids()
1193
+ attention_mask = self.get_attention_mask()
1194
+ onnx_model = f'./{self.onnx_path}/bge.onnx'
1195
+ torch.onnx.export(
1196
+ model, (input_ids, position_ids, attention_mask),
1197
+ onnx_model,
1198
+ verbose=self.export_verbose,
1199
+ input_names=[
1200
+ 'input_ids',
1201
+ 'position_ids',
1202
+ 'attention_mask'
1203
+ ],
1204
+ output_names=['sentence_embeddings'],
1205
+ dynamic_axes=self.model_dynamic_axes,
1206
+ do_constant_folding=True,
1207
+ opset_version=15)
1208
+ if not self.skip_slim:
1209
+ slim(onnx_model, output_model=onnx_model)
1210
+ if self.export_test:
1211
+ self.seq_len = 4
1212
+ position_ids = self.get_position_ids()
1213
+ input_ids = torch.tensor([ 101, 872, 1962, 102 ], dtype=torch.long)
1214
+ attention_mask = self.get_attention_mask()
1215
+ # test
1216
+ original_outs = model(input_ids, position_ids, attention_mask)
1217
+ ort_session = ort.InferenceSession(onnx_model, providers=['CPUExecutionProvider'])
1218
+ inputs = {
1219
+ 'input_ids' : input_ids.detach().numpy(),
1220
+ 'position_ids' : position_ids.detach().numpy(),
1221
+ 'attention_mask' : attention_mask.detach().numpy()
1222
+ }
1223
+ onnx_outs = ort_session.run(None, inputs)[0]
1224
+ self.assert_equal(original_outs, onnx_outs)
1225
+
1226
+ token_str = None
1227
+ if False: # save tokenizer in mnn
1228
+ self.export_tokenizer()
1229
+ token_path = os.path.join(self.onnx_path, "tokenizer.txt")
1230
+ token_str = open(token_path, 'rt').read()
1231
+
1232
+ if self.export_mnn:
1233
+ onnx2mnn(onnx_model, self.mnn_path, 8, True, bizCode=token_str)
1234
+
1235
+ def get_position_ids(self) -> torch.Tensor:
1236
+ return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0)
1237
+
1238
+ def get_attention_mask(self) -> torch.Tensor:
1239
+ return torch.ones([1, 1, 1, self.seq_len], dtype=torch.long)
1240
+
1241
+ class LoraModule(torch.nn.Module):
1242
+ def __init__(self, args):
1243
+ super().__init__()
1244
+ self.onnx_path = args.onnx_path
1245
+ self.mnn_path = args.mnn_path
1246
+ self.export_mnn = args.export_mnn
1247
+ import peft
1248
+ lora_weight = peft.load_peft_weights(args.path)
1249
+ for k, v in lora_weight.items():
1250
+ k = k.replace('.', '/')
1251
+ self.register_buffer(k, v.cpu())
1252
+
1253
+ def forward(self, dummpy):
1254
+ return self._buffers
1255
+
1256
+ def export(self):
1257
+ onnx_model = f'./{self.onnx_path}/lora.onnx'
1258
+ torch.onnx.export(self.eval(), torch.tensor([]), onnx_model)
1259
+ if self.export_mnn:
1260
+ onnx2mnn(onnx_model, self.mnn_path)
1261
+
1262
+ class GOT(Qwen2_Chat):
1263
+ def __init__(self, args):
1264
+ super().__init__(args)
1265
+ def load_hf(self, model_path: str):
1266
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
1267
+ self.model = GOTQwenForCausalLM.from_pretrained(model_path, trust_remote_code=True).float().eval()
1268
+ self.config = self.model.config
1269
+ if self.lora_path is not None:
1270
+ adapter = PeftModel.from_pretrained(self.model, model_id=self.lora_path)
1271
+ self.model = adapter.merge_and_unload(progressbar=True)
1272
+ def load_model(self):
1273
+ # Qwen2 models
1274
+ self.model_name = 'GOT'
1275
+ transformer = self.model.model
1276
+ self.lm_ = self.model.lm_head
1277
+ self.embed_ = transformer.embed_tokens
1278
+ self.blocks_ = transformer.layers
1279
+ self.final_layernorm_ = transformer.norm
1280
+ self.visual = transformer.vision_tower_high
1281
+ self.mm_projector_vary = transformer.mm_projector_vary
1282
+ # some wrapper
1283
+ self.stop_id = self.tokenizer.eos_token_id
1284
+ if hasattr(self.model, 'generation_config'):
1285
+ #self.stop_ids.append(self.stop_id)
1286
+ #for id in self.model.generation_config.eos_token_id:
1287
+ self.stop_ids.append(self.model.generation_config.eos_token_id)
1288
+ self.block_nums = self.config.num_hidden_layers
1289
+ self.hidden_size = self.config.hidden_size
1290
+ self.image_size = self.hidden_size
1291
+ self.image_token_len = self.config.image_token_len
1292
+ self.num_heads = self.config.num_attention_heads
1293
+ self.num_key_value_heads = self.config.num_key_value_heads
1294
+ self.rope_theta = self.config.rope_theta
1295
+ self.head_dim = self.hidden_size // self.num_heads
1296
+ if self.embed_.weight is self.lm_.weight:
1297
+ import copy
1298
+ embed_copy = copy.deepcopy(self.embed_)
1299
+ self.embed = GOTEmbedding(embed_copy, self.embed_bf16)
1300
+ else:
1301
+ self.embed = GOTEmbedding(self.embed_, self.embed_bf16)
1302
+ self.lm = Lm(self.lm_)
1303
+ self.past_kv_shape = [self.block_nums, 2, 1, 0, self.num_heads, self.head_dim]
1304
+ #self.blocks = [QWEN2Block(self.model_name, self.blocks_[i], i, self.config, None) for i in range(self.block_nums)]
1305
+ self.blocks = [Qwen2DecoderLayer(self.config,self.blocks_[i], i) for i in range(self.block_nums)]
1306
+ #self.blocks = self.blocks_
1307
+ # some config for export
1308
+ self.block_dynamic_axes = {
1309
+ "inputs_embeds" : { 0: "seq_len" },
1310
+ "attention_mask" : { 2: "seq_len", 3: "seq_len" },
1311
+ "position_ids" : { 0: "seq_len" },
1312
+ }
1313
+ self.model_dynamic_axes = {
1314
+ "input_ids" : { 0: "seq_len" },
1315
+ "attention_mask" : { 2: "seq_len", 3: "seq_len" },
1316
+ "position_ids" : { 0: "seq_len" },
1317
+ }
1318
+ def export_lm(self):
1319
+ model = self.lm
1320
+ hidden_states = torch.randn(1, self.hidden_size)
1321
+ onnx_model = f'./{self.onnx_path}/lm.onnx'
1322
+ torch.onnx.export(model, (hidden_states),
1323
+ onnx_model,
1324
+ verbose=self.export_verbose,
1325
+ input_names=['hidden_states'],
1326
+ output_names=['token_id'],
1327
+ do_constant_folding=True,
1328
+ dynamic_axes={
1329
+ "hidden_states" : { 0: "seq_len" }
1330
+ },
1331
+ opset_version=15)
1332
+ if not self.skip_slim:
1333
+ slim(onnx_model, output_model=onnx_model)
1334
+ def export_norm(self):
1335
+ model = self.final_layernorm_
1336
+ hidden_states = torch.randn(1, self.image_token_len, self.hidden_size)
1337
+ onnx_model = f'./{self.onnx_path}/norm.onnx'
1338
+ torch.onnx.export(model, (hidden_states),
1339
+ onnx_model,
1340
+ verbose=self.export_verbose,
1341
+ input_names=['hidden_in'],
1342
+ output_names=['hidden_out'],
1343
+ do_constant_folding=True,
1344
+ dynamic_axes={
1345
+ "hidden_in" : { 1: "seq_len" },
1346
+ "hidden_out" : { 1: "seq_len"},
1347
+ },
1348
+ opset_version=15)
1349
+ if not self.skip_slim:
1350
+ slim(onnx_model, output_model=onnx_model)
1351
+ def export_projector_vary(self):
1352
+ model = self.mm_projector_vary
1353
+ hidden_states = torch.randn(1, self.image_token_len, self.hidden_size)
1354
+ onnx_model = f'./{self.onnx_path}/mm_projector_vary.onnx'
1355
+ torch.onnx.export(model, (hidden_states),
1356
+ onnx_model,
1357
+ verbose=self.export_verbose,
1358
+ input_names=['cnn_features'],
1359
+ output_names=['img_features'],
1360
+ do_constant_folding=True,
1361
+ opset_version=15)
1362
+ if not self.skip_slim:
1363
+ slim(onnx_model, output_model=onnx_model)
1364
+
1365
+ if __name__ == '__main__':
1366
+ llm_models = {
1367
+ 'chatglm-6b': Chatglm_6b,
1368
+ 'chatglm2-6b': Chatglm2_6b,
1369
+ 'chatglm3-6b': Chatglm3_6b,
1370
+ 'codegeex2-6b': Chatglm2_6b,
1371
+ 'Qwen-7B-Chat': Qwen_Chat,
1372
+ 'Qwen-1_8B-Chat': Qwen_Chat,
1373
+ 'Qwen-1_8B': Qwen_Chat,
1374
+ 'Qwen-VL-Chat': Qwen_Chat,
1375
+ 'Qwen1_5-0_5B-Chat': Qwen2_Chat,
1376
+ 'Qwen1_5-1_8B-Chat': Qwen2_Chat,
1377
+ 'Qwen1_5-4B-Chat': Qwen2_Chat,
1378
+ 'Qwen1_5-7B-Chat': Qwen2_Chat,
1379
+ 'Baichuan2-7B-Chat': Llama2_7b_Chat,
1380
+ 'Llama-2-7b-chat-ms': Llama2_7b_Chat,
1381
+ 'Llama-3-8B-Instruct': Llama2_7b_Chat,
1382
+ 'internlm-chat-7b': Llama2_7b_Chat,
1383
+ 'TinyLlama-1_1B-Chat': Llama2_7b_Chat,
1384
+ 'Yi-6B-Chat': Llama2_7b_Chat,
1385
+ 'deepseek-llm-7b-chat': Llama2_7b_Chat,
1386
+ 'phi-2': phi_2,
1387
+ 'bge-large-zh': bge,
1388
+ 'lora': LoraModule,
1389
+ 'GOT':GOT
1390
+ }
1391
+ parser = argparse.ArgumentParser(description='llm_exporter', formatter_class=argparse.RawTextHelpFormatter)
1392
+ parser.add_argument('--path', type=str, default=r'D:\LearningCodes\GithubRepo\shouxieAI\GOT-OCR2.0\GOT-OCR-2.0-master\GOT_weights',
1393
+ help='path(`str` or `os.PathLike`):\nCan be either:'
1394
+ '\n\t- A string, the *model id* of a pretrained model like `THUDM/chatglm-6b`. [TODO]'
1395
+ '\n\t- A path to a *directory* clone from repo like `../chatglm-6b`.')
1396
+ parser.add_argument('--type', type=str, choices=llm_models.keys(), default="GOT",
1397
+ help='type(`str`, *optional*):'
1398
+ '\n\tThe pretrain llm model type.'
1399
+ )
1400
+ parser.add_argument('--lora_path', type=str, default=None, help='lora path, defaut is `None` mean not apply lora.')
1401
+ parser.add_argument('--onnx_path', type=str, default='./onnx', help='export onnx model path, defaut is `./onnx`.')
1402
+ parser.add_argument('--mnn_path', type=str, default='./mnn', help='export mnn model path, defaut is `./mnn`.')
1403
+ parser.add_argument('--export_mnn', action='store_true', default=False, help='Whether or not to export mnn model after onnx.')
1404
+ parser.add_argument('--export_verbose', action='store_true', default=False, help='Whether or not to export onnx with verbose.')
1405
+ parser.add_argument('--export_test', action='store_true', help='Whether or not to export onnx with test using onnxruntime.')
1406
+ parser.add_argument('--test', type=str, help='test model inference with query `TEST`.')
1407
+ parser.add_argument('--export', action='store_true', help='export model to an `onnx` model.')
1408
+ parser.add_argument('--export_split', default=True,
1409
+ help='export model split to some `onnx` models:'
1410
+ '\n\t- embedding model.'
1411
+ '\n\t- block models.'
1412
+ '\n\t- lm_head model.'
1413
+ )
1414
+ parser.add_argument('--export_token', action='store_true', help='export llm tokenizer to a txt file.')
1415
+ parser.add_argument('--export_embed', action='store_true', help='export llm embedding to an `onnx` model.')
1416
+ parser.add_argument('--export_visual', action='store_true', help='export llm visual model to an `onnx` model.')
1417
+ parser.add_argument('--export_lm', action='store_true', help='export llm lm_head to an `onnx` model.')
1418
+ parser.add_argument('--export_block', type=int, help='export llm block [id] to an `onnx` model.')
1419
+ parser.add_argument('--export_blocks', action='store_true', help='export llm all blocks to `onnx` models.')
1420
+ parser.add_argument('--embed_bin', action='store_true', help='export embedding weight as bin file with dtype `bfloat16`')
1421
+ parser.add_argument('--embed_bf16', action='store_true', help='using `bfloat16` replace `float32` in embedding.')
1422
+ parser.add_argument('--skip_slim', action='store_true', help='Whether or not to skip onnx-slim.')
1423
+
1424
+
1425
+ args = parser.parse_args()
1426
+ model_path = args.path
1427
+ model_type = args.type
1428
+ # not sepcify model type, using path
1429
+ if model_type is None:
1430
+ for model in llm_models:
1431
+ if model in model_path:
1432
+ model_type = model
1433
+ if model_type is None:
1434
+ raise RuntimeError('Please specify model type.')
1435
+
1436
+ # # copy modeling py file to pretrain model for export
1437
+ # for file in glob.glob(f'./llm_models/{model_type}/*'):
1438
+ # shutil.copy2(file, model_path)
1439
+
1440
+ llm_exporter = llm_models[model_type](args)
1441
+
1442
+ # some actions
1443
+ if args.test is not None:
1444
+ llm_exporter.response(args.test)
1445
+
1446
+ if args.export:
1447
+ llm_exporter.export()
1448
+
1449
+ if args.export_token:
1450
+ llm_exporter.export_tokenizer()
1451
+
1452
+ if args.export_embed or args.export_split:
1453
+ llm_exporter.export_embed()
1454
+
1455
+ if args.export_visual or args.export_split:
1456
+ llm_exporter.export_visual()
1457
+
1458
+ if args.export_lm or args.export_split:
1459
+ llm_exporter.export_lm()
1460
+ llm_exporter.export_projector_vary()
1461
+ llm_exporter.export_norm()
1462
+
1463
+ if args.export_blocks or args.export_split:
1464
+ llm_exporter.export_blocks()
1465
+
1466
+ if args.export_block is not None:
1467
+ llm_exporter.export_block(args.export_block)