TalkUHulk commited on
Commit
432d085
·
verified ·
1 Parent(s): 25008f7

Upload 24 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example_images/objects365_v1_00322597.jpg filter=lfs diff=lfs merge=lfs -text
37
+ example_images/objects365_v1_00322772.jpg filter=lfs diff=lfs merge=lfs -text
38
+ example_images/objects365_v1_00322846.jpg filter=lfs diff=lfs merge=lfs -text
39
+ example_images/objects365_v1_00322901.jpg filter=lfs diff=lfs merge=lfs -text
40
+ example_images/objects365_v1_00323167.jpg filter=lfs diff=lfs merge=lfs -text
41
+ example_images/objects365_v1_00326764.jpg filter=lfs diff=lfs merge=lfs -text
42
+ example_images/objects365_v1_00357438.jpg filter=lfs diff=lfs merge=lfs -text
43
+ example_images/objects365_v1_00358590.jpg filter=lfs diff=lfs merge=lfs -text
44
+ example_images/objects365_v1_00363692.jpg filter=lfs diff=lfs merge=lfs -text
45
+ example_images/objects365_v1_00367221.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, Idefics3ForConditionalGeneration, TextIteratorStreamer
3
+ from threading import Thread
4
+ import re
5
+ import time
6
+ from PIL import Image
7
+ import torch
8
+ import spaces
9
+ from tinymind import *
10
+
11
+ tokenizer_path = "./custom_tokenizer"
12
+ tokenizer = load_tokenizer(tokenizer_path)
13
+ preprocess = build_image_preprocess(DEFAULT_IMAGE_SIZE)
14
+
15
+ special_tokens = prepare_special_tokens(tokenizer, max_rows=4, max_cols=4)
16
+
17
+ vision_session = create_onnx_session("./onnx_model/vision_encoder.onnx", intra_threads=2)
18
+ embed_tokens_session = create_onnx_session("./onnx_model/embed_tokens.onnx", intra_threads=2)
19
+ llm_session = create_onnx_session("./onnx_model/llm.onnx", intra_threads=2)
20
+ freqs_cos, freqs_sin = precompute_freqs_cis(dim=64, end=32768, rope_base=1e6)
21
+
22
+ @spaces.GPU
23
+ def model_inference(
24
+ input_dict, history, decoding_strategy, temperature, max_new_tokens,
25
+ repetition_penalty, top_p
26
+ ):
27
+ \
28
+ text = input_dict["text"]
29
+
30
+ if len(input_dict["files"]) > 1:
31
+ images = [Image.open(image).convert("RGB") for image in input_dict["files"]]
32
+ elif len(input_dict["files"]) == 1:
33
+ images = [Image.open(input_dict["files"][0]).convert("RGB")]
34
+ else:
35
+ images = []
36
+
37
+ if not images and history:
38
+ for turn in reversed(history):
39
+ files, _ = turn # user text, assistant text
40
+ if isinstance(files, tuple) and len(files) > 0:
41
+ images = [Image.open(image).convert("RGB") for image in files]
42
+ break
43
+
44
+ if text == "" and not images:
45
+ gr.Error("Please input a query and optionally image(s).")
46
+
47
+ if text == "" and images:
48
+ gr.Error("Please input a text query along the image(s).")
49
+
50
+ pixel_values, mask_positions = prepare_image_patches(images[0], preprocess, max_rows=4, max_cols=4)
51
+
52
+ # 构造 prompt + image placeholders(假设 tokenizer 支持 apply_chat_template)
53
+
54
+ messages = [
55
+ {"role": "system", "content": "你是一个多模态AI助手,能够理解图片和文本信息."},
56
+ {"role": "user", "content": text + construct_image_placeholders(special_tokens)}
57
+ ]
58
+ inputs_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
59
+ inputs = tokenizer(inputs_text, return_tensors="pt", truncation=True)
60
+ input_ids = inputs["input_ids"]
61
+ attention_mask = inputs["attention_mask"]
62
+
63
+ # prefill
64
+ seqlen = input_ids.shape[1]
65
+ prefill_out = prefill_llm(
66
+ vision_session=vision_session,
67
+ embed_tokens_session=embed_tokens_session,
68
+ llm_session=llm_session,
69
+ pixel_values=pixel_values,
70
+ input_ids=input_ids,
71
+ attention_mask=attention_mask,
72
+ freqs_cos=freqs_cos,
73
+ freqs_sin=freqs_sin,
74
+ special_tokens=special_tokens,
75
+ seqlen=seqlen
76
+ )
77
+
78
+ # start token id = argmax last logit
79
+ start_token_id = int(np.argmax(prefill_out["logits"][:, -1, :], axis=-1)[0])
80
+
81
+
82
+ generated_text = ""
83
+ generation_args = {
84
+ "llm_session" : llm_session,
85
+ "embed_tokens_session": embed_tokens_session,
86
+ "tokenizer": tokenizer,
87
+ "initial_present" :{"present_keys": prefill_out["present_keys"], "present_values": prefill_out["present_values"]},
88
+ "start_token_id": start_token_id,
89
+ "freqs_cos": freqs_cos,
90
+ "freqs_sin": freqs_sin,
91
+ "attention_mask": attention_mask.numpy(),
92
+ "max_new_tokens": 128,
93
+ "eos_token_id": 2,
94
+ "start_pos": seqlen
95
+ }
96
+
97
+
98
+ thread = Thread(target=generate_autoregressive, kwargs=generation_args)
99
+ thread.start()
100
+
101
+
102
+ examples = [
103
+ [{"text": "描述下图片的内容",
104
+ "files": ["example_images/objects365_v1_00322846.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
105
+ [{"text": "请描述这张图片的内容,并检测其中的苹果",
106
+ "files": ["example_images/objects365_v1_00361740.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
107
+ [{"text": "图中是什么交通工具?",
108
+ "files": ["example_images/objects365_v1_00357438.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
109
+ [{"text": "图中有几只鸭子?",
110
+ "files": ["example_images/objects365_v1_00323167.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
111
+ [{"text": "这是在哪?",
112
+ "files": ["example_images/objects365_v1_00363692.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
113
+ ]
114
+ demo = gr.ChatInterface(
115
+ fn=model_inference,
116
+ title="SmolVLM2-256M-Married-Qwen3-0.6B: SmolVLM拥抱Qwen3,支持中文问答🤖",
117
+ description="[TalkUHulk/SmolVLM2-256M-Married-Qwen3-0.6B](https://huggingface.co/TalkUHulk/SmolVLM2-256M-Married-Qwen3-0.6B) 演示。请上传图片和文本,或尝试下方示例。",
118
+ examples=examples,
119
+ textbox=gr.MultimodalTextbox(
120
+ label="请输入查询文本(附带图片)",
121
+ file_types=["image"],
122
+ file_count="multiple"
123
+ ),
124
+ stop_btn="停止生成",
125
+ multimodal=True,
126
+ additional_inputs=[
127
+ gr.Radio(
128
+ ["Top P Sampling", "Greedy"],
129
+ value="Greedy",
130
+ label="解码策略",
131
+ info="选择生成文本的方式:采样更随机,贪心更确定。"
132
+ ),
133
+ gr.Slider(
134
+ minimum=0.0,
135
+ maximum=5.0,
136
+ value=0.4,
137
+ step=0.1,
138
+ interactive=True,
139
+ label="采样温度 (Temperature)",
140
+ info="数值越高,输出越多样化;越低则更保守。"
141
+ ),
142
+ gr.Slider(
143
+ minimum=8,
144
+ maximum=1024,
145
+ value=512,
146
+ step=1,
147
+ interactive=True,
148
+ label="最大生成 Token 数",
149
+ ),
150
+ gr.Slider(
151
+ minimum=0.01,
152
+ maximum=5.0,
153
+ value=1.2,
154
+ step=0.01,
155
+ interactive=True,
156
+ label="重复惩罚 (Repetition penalty)",
157
+ info="1.0 表示不做惩罚;数值越大越避免重复。"
158
+ ),
159
+ gr.Slider(
160
+ minimum=0.01,
161
+ maximum=0.99,
162
+ value=0.8,
163
+ step=0.01,
164
+ interactive=True,
165
+ label="Top P",
166
+ info="数值越高,表示会采样更多低概率的 token。"
167
+ ),
168
+ ],
169
+ cache_examples=False
170
+ )
171
+
172
+ demo.launch(debug=True)
custom_tokenizer/chat_template.jinja ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0]['role'] == 'system' -%}
14
+ {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
15
+ {%- else -%}
16
+ {{- '<|im_start|>system\n你是一个多模态AI助手,能够理解图片和文本信息。<|im_end|>\n' }}
17
+ {%- endif %}
18
+ {%- endif %}
19
+
20
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
21
+ {%- for message in messages[::-1] %}
22
+ {%- set index = (messages|length - 1) - loop.index0 %}
23
+ {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
24
+ {%- set ns.multi_step_tool = false %}
25
+ {%- set ns.last_query_index = index %}
26
+ {%- endif %}
27
+ {%- endfor %}
28
+
29
+ {%- for message in messages %}
30
+ {#- 处理消息内容:支持字符串、列表、图像等多种格式 #}
31
+ {%- if message.content is string %}
32
+ {%- set content = message.content %}
33
+ {%- elif message.content is iterable %}
34
+ {#- 处理多部分内容(文本+图像) #}
35
+ {%- set content_parts = [] %}
36
+ {%- for part in message.content %}
37
+ {%- if part.type == 'text' %}
38
+ {%- set _ = content_parts.append(part.text) %}
39
+ {%- elif part.type == 'image' %}
40
+ {#- 图像占位符,实际图像数据会在processor中处理 #}
41
+ {%- set _ = content_parts.append('<image>') %}
42
+ {%- endif %}
43
+ {%- endfor %}
44
+ {%- set content = content_parts | join('\n') %}
45
+ {%- else %}
46
+ {%- set content = '' %}
47
+ {%- endif %}
48
+
49
+ {#- 用户消息或系统消息 #}
50
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
51
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
52
+
53
+ {#- 助手消息 #}
54
+ {%- elif message.role == "assistant" %}
55
+ {{- '<|im_start|>' + message.role + '\n' + content }}
56
+ {%- if message.tool_calls %}
57
+ {%- for tool_call in message.tool_calls %}
58
+ {%- if (loop.first and content) or (not loop.first) %}
59
+ {{- '\n' }}
60
+ {%- endif %}
61
+ {%- if tool_call.function %}
62
+ {%- set tool_call = tool_call.function %}
63
+ {%- endif %}
64
+ {{- '<tool_call>\n{\"name\": \"' }}
65
+ {{- tool_call.name }}
66
+ {{- '\", \"arguments\": ' }}
67
+ {%- if tool_call.arguments is string %}
68
+ {{- tool_call.arguments }}
69
+ {%- else %}
70
+ {{- tool_call.arguments | tojson }}
71
+ {%- endif %}
72
+ {{- '}\n</tool_call>' }}
73
+ {%- endfor %}
74
+ {%- endif %}
75
+ {{- '<|im_end|>\n' }}
76
+
77
+ {#- 工具消息 #}
78
+ {%- elif message.role == "tool" %}
79
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
80
+ {{- '<|im_start|>user' }}
81
+ {%- endif %}
82
+ {{- '\n<tool_response>\n' }}
83
+ {{- content }}
84
+ {{- '\n</tool_response>' }}
85
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
86
+ {{- '<|im_end|>\n' }}
87
+ {%- endif %}
88
+ {%- endif %}
89
+ {%- endfor %}
90
+
91
+ {%- if add_generation_prompt %}
92
+ {{- '<|im_start|>assistant\n' }}
93
+ {%- if enable_thinking is defined and enable_thinking is false %}
94
+ {{- '<think>\n\n</think>\n\n' }}
95
+ {%- endif %}
96
+ {%- endif %}
custom_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|im_start|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|im_end|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
custom_tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
custom_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<|im_start|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "<|im_end|>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "3": {
31
+ "content": "<fake_token_around_image>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ },
38
+ "4": {
39
+ "content": "<global-img>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": true
45
+ },
46
+ "5": {
47
+ "content": "<image>",
48
+ "lstrip": false,
49
+ "normalized": false,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": true
53
+ },
54
+ "6": {
55
+ "content": "<row_1_col_1>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": true
61
+ },
62
+ "7": {
63
+ "content": "<row_1_col_2>",
64
+ "lstrip": false,
65
+ "normalized": false,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": true
69
+ },
70
+ "8": {
71
+ "content": "<row_1_col_3>",
72
+ "lstrip": false,
73
+ "normalized": false,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": true
77
+ },
78
+ "9": {
79
+ "content": "<row_1_col_4>",
80
+ "lstrip": false,
81
+ "normalized": false,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": true
85
+ },
86
+ "10": {
87
+ "content": "<row_2_col_1>",
88
+ "lstrip": false,
89
+ "normalized": false,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": true
93
+ },
94
+ "11": {
95
+ "content": "<row_2_col_2>",
96
+ "lstrip": false,
97
+ "normalized": false,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": true
101
+ },
102
+ "12": {
103
+ "content": "<row_2_col_3>",
104
+ "lstrip": false,
105
+ "normalized": false,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": true
109
+ },
110
+ "13": {
111
+ "content": "<row_2_col_4>",
112
+ "lstrip": false,
113
+ "normalized": false,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": true
117
+ },
118
+ "14": {
119
+ "content": "<row_3_col_1>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": true
125
+ },
126
+ "15": {
127
+ "content": "<row_3_col_2>",
128
+ "lstrip": false,
129
+ "normalized": false,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": true
133
+ },
134
+ "16": {
135
+ "content": "<row_3_col_3>",
136
+ "lstrip": false,
137
+ "normalized": false,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": true
141
+ },
142
+ "17": {
143
+ "content": "<row_3_col_4>",
144
+ "lstrip": false,
145
+ "normalized": false,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": true
149
+ },
150
+ "18": {
151
+ "content": "<row_4_col_1>",
152
+ "lstrip": false,
153
+ "normalized": false,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": true
157
+ },
158
+ "19": {
159
+ "content": "<row_4_col_2>",
160
+ "lstrip": false,
161
+ "normalized": false,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": true
165
+ },
166
+ "20": {
167
+ "content": "<row_4_col_3>",
168
+ "lstrip": false,
169
+ "normalized": false,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": true
173
+ },
174
+ "21": {
175
+ "content": "<row_4_col_4>",
176
+ "lstrip": false,
177
+ "normalized": false,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": true
181
+ }
182
+ },
183
+ "additional_special_tokens": [],
184
+ "bos_token": "<|im_start|>",
185
+ "clean_up_tokenization_spaces": false,
186
+ "eos_token": "<|im_end|>",
187
+ "extra_special_tokens": {},
188
+ "legacy": true,
189
+ "model_max_length": 32768,
190
+ "pad_token": "<|endoftext|>",
191
+ "sp_model_kwargs": {},
192
+ "spaces_between_special_tokens": false,
193
+ "tokenizer_class": "PreTrainedTokenizerFast",
194
+ "unk_token": "<|endoftext|>"
195
+ }
example_images/objects365_v1_00322597.jpg ADDED

Git LFS Details

  • SHA256: dfbc7e2998294d23000530a5ecd3cfdc956d41f80f036214b1485249b86c3304
  • Pointer size: 131 Bytes
  • Size of remote file: 110 kB
example_images/objects365_v1_00322772.jpg ADDED

Git LFS Details

  • SHA256: bc5e3358c4d508544dd9cd70476d7fe84acfe17e9b2c472ddad05d232fe120e6
  • Pointer size: 131 Bytes
  • Size of remote file: 188 kB
example_images/objects365_v1_00322846.jpg ADDED

Git LFS Details

  • SHA256: ec0d97d15a3a0fb2cd1fe8027c0754c0fe2bdbe85a4d7ee23f970ca49e5fe624
  • Pointer size: 131 Bytes
  • Size of remote file: 102 kB
example_images/objects365_v1_00322901.jpg ADDED

Git LFS Details

  • SHA256: 80a60ed10df005c29fd99c0eaa8f248e901dd96e3726c4949611f099e7b03e4e
  • Pointer size: 131 Bytes
  • Size of remote file: 140 kB
example_images/objects365_v1_00323167.jpg ADDED

Git LFS Details

  • SHA256: c896af496b1128fc6917d58777d0b8a1301a8fe20528c92245b36ef001e146e2
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB
example_images/objects365_v1_00324105.jpg ADDED
example_images/objects365_v1_00324441.jpg ADDED
example_images/objects365_v1_00326764.jpg ADDED

Git LFS Details

  • SHA256: 3f9bc79f5eb2582a57ac3f89b6af9d44d58cadcaa447b77c29138b7994b16717
  • Pointer size: 131 Bytes
  • Size of remote file: 116 kB
example_images/objects365_v1_00336365.jpg ADDED
example_images/objects365_v1_00357438.jpg ADDED

Git LFS Details

  • SHA256: 421b704b0583140db781d2e2d1127daca907109be26d1fbd841d3be38c3f2a8f
  • Pointer size: 131 Bytes
  • Size of remote file: 120 kB
example_images/objects365_v1_00358590.jpg ADDED

Git LFS Details

  • SHA256: 0e928a76a67d701a07177cace504fdd999d4b52e64844898a0659b2cd7e69610
  • Pointer size: 131 Bytes
  • Size of remote file: 126 kB
example_images/objects365_v1_00361740.jpg ADDED
example_images/objects365_v1_00363692.jpg ADDED

Git LFS Details

  • SHA256: bffd9d842a677d6c7680a212e4d1a69d1c14e66d8163e7157a360566d23cb4f4
  • Pointer size: 131 Bytes
  • Size of remote file: 120 kB
example_images/objects365_v1_00367221.jpg ADDED

Git LFS Details

  • SHA256: cec012fc2d0fb5ebcca7aa3e47d5cc2228adab1e8475d5c03911690fe8165c5f
  • Pointer size: 131 Bytes
  • Size of remote file: 135 kB
onnx_model/embed_tokens.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69098c878e4ff056b0724e327f139981b63116f425d8dd94b28b7ce79fccaf8c
3
+ size 13107407
onnx_model/llm.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a50d292a275361d4e19d54a4bb59fd27efe55434958b0f5a007538b07d3bc43
3
+ size 103619420
onnx_model/vision_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:718953645f203de3d12396ca071f5fb1af2c54fa94ca21926235e797b965e807
3
+ size 254698566
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ huggingface_hub
3
+ transformers == 4.51.3
4
+ spaces
5
+ onnxruntime
6
+ torchvision
7
+ torch
8
+ numpy
9
+ Pillow
tinymind.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from typing import Optional, Tuple, List, Dict, Any
4
+ import os
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ import torchvision
9
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10
+ import onnxruntime
11
+
12
+ # 如果你用 transformers 的 AutoTokenizer(推荐)
13
+ from transformers import AutoTokenizer
14
+
15
+ # ---------------------------
16
+ # Config / 默认参数
17
+ # ---------------------------
18
+
19
+ DEFAULT_IMAGE_SIZE = 224
20
+ DEFAULT_MAX_ROWS = 4
21
+ DEFAULT_MAX_COLS = 4
22
+ MIN_BLOCK_SIZE = 16
23
+
24
+
25
+ # ---------------------------
26
+ # Tokenizer / Preprocess
27
+ # ---------------------------
28
+
29
+ def load_tokenizer(tokenizer_path: str):
30
+ """加载 tokenizer(AutoTokenizer)"""
31
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
32
+ if tokenizer.chat_template is None:
33
+ # 新版本transformer可自动加载,训练环境版本:4.51.3支持
34
+ with open(os.path.join(tokenizer_path, "chat_template.jinja"), "r") as f:
35
+ tokenizer.chat_template = f.read()
36
+ return tokenizer
37
+
38
+
39
+ def build_image_preprocess(image_size: int = DEFAULT_IMAGE_SIZE):
40
+ """返回 torchvision.transforms.Compose 的预处理 callable"""
41
+ return Compose([
42
+ Resize(size=image_size, interpolation=torchvision.transforms.InterpolationMode.BICUBIC, max_size=None, antialias=True),
43
+ CenterCrop(size=(image_size, image_size)),
44
+ lambda img: img.convert("RGB"),
45
+ ToTensor(),
46
+ Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
47
+ std=(0.26862954, 0.26130258, 0.27577711))
48
+ ])
49
+
50
+
51
+ # ---------------------------
52
+ # RoPE频率预计算(precompute_freqs_cis)
53
+ # ---------------------------
54
+
55
+ def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
56
+ rope_scaling: Optional[dict] = None) -> Tuple[torch.Tensor, torch.Tensor]:
57
+ """
58
+ 计算 RoPE 的 cos 和 sin 表
59
+
60
+ 返回:
61
+ freqs_cos: (end, dim)
62
+ freqs_sin: (end, dim)
63
+ """
64
+ freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
65
+ if rope_scaling is not None:
66
+ orig_max = rope_scaling.get("original_max_position_embeddings", 2048)
67
+ factor = rope_scaling.get("factor", 4)
68
+ beta_fast = rope_scaling.get("beta_fast", 4.0)
69
+ beta_slow = rope_scaling.get("beta_slow", 1.0)
70
+
71
+ if end / orig_max > 1.0:
72
+ corr_dim = next((i for i in range(dim // 2) if 2 * math.pi / freqs[i] > orig_max), dim // 2)
73
+ power = torch.arange(0, dim // 2, device=freqs.device).float() / max(dim // 2 - 1, 1)
74
+ beta = beta_slow + (beta_fast - beta_slow) * power
75
+ scale = torch.where(torch.arange(dim // 2, device=freqs.device) < corr_dim,
76
+ (beta * factor - beta + 1) / (beta * factor),
77
+ 1.0 / factor)
78
+ freqs = freqs * scale
79
+
80
+ t = torch.arange(end, device=freqs.device)
81
+ freqs = torch.outer(t, freqs).float() # (end, dim/2)
82
+ freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
83
+ freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
84
+ return freqs_cos, freqs_sin
85
+
86
+
87
+ # ---------------------------
88
+ # 图像自适应切分(adaptive_square_split)
89
+ # ---------------------------
90
+
91
+ def calculate_optimal_split_with_fixed_max(width: int, height: int, max_rows: int, max_cols: int) -> Tuple[int, int, int]:
92
+ """
93
+ 计算最佳切分(返回 rows, cols, block_size)
94
+ block_size 会向下取整到 16 的倍数,最小为 MIN_BLOCK_SIZE。
95
+ """
96
+ best_rows = 1
97
+ best_cols = 1
98
+ best_block_size = 0
99
+ best_coverage = 0.0
100
+
101
+ # 方案1: 固定行数为 max_rows,自适应列数
102
+ rows_fixed = max_rows
103
+ for cols in range(1, max_cols + 1):
104
+ block_width = width // cols
105
+ block_height = height // rows_fixed
106
+ square_size = min(block_width, block_height)
107
+ if square_size <= 0:
108
+ continue
109
+ coverage = (cols * square_size) * (rows_fixed * square_size) / (width * height)
110
+ if coverage > best_coverage or (coverage == best_coverage and square_size > best_block_size):
111
+ best_rows, best_cols, best_block_size, best_coverage = rows_fixed, cols, square_size, coverage
112
+
113
+ # 方案2: 固定列数为 max_cols,自适应行数
114
+ cols_fixed = max_cols
115
+ for rows in range(1, max_rows + 1):
116
+ block_width = width // cols_fixed
117
+ block_height = height // rows
118
+ square_size = min(block_width, block_height)
119
+ if square_size <= 0:
120
+ continue
121
+ coverage = (cols_fixed * square_size) * (rows * square_size) / (width * height)
122
+ if coverage > best_coverage or (coverage == best_coverage and square_size > best_block_size):
123
+ best_rows, best_cols, best_block_size, best_coverage = rows, cols_fixed, square_size, coverage
124
+
125
+ # 方案3: 两者都达到最大
126
+ block_width = width // max_cols
127
+ block_height = height // max_rows
128
+ square_size = min(block_width, block_height)
129
+ if square_size > 0:
130
+ coverage = (max_cols * square_size) * (max_rows * square_size) / (width * height)
131
+ if coverage > best_coverage or (coverage == best_coverage and square_size > best_block_size):
132
+ best_rows, best_cols, best_block_size, best_coverage = max_rows, max_cols, square_size, coverage
133
+
134
+ # 对齐到 16 的倍数并保证最小值
135
+ best_block_size = max(MIN_BLOCK_SIZE, (best_block_size // 16) * 16)
136
+ return best_rows, best_cols, best_block_size
137
+
138
+
139
+ def adaptive_square_split(image: Image.Image, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS
140
+ ) -> Tuple[List[Image.Image], int, int, int]:
141
+ """
142
+ 将 PIL Image 自适应切分为正方形块,返回 (blocks_list, rows, cols, block_size)
143
+ blocks_list 是按行主序的块列表(可能少于 max_rows*max_cols)
144
+ """
145
+ width, height = image.size
146
+ rows, cols, block_size = calculate_optimal_split_with_fixed_max(width, height, max_rows, max_cols)
147
+
148
+ blocks = []
149
+ for r in range(rows):
150
+ for c in range(cols):
151
+ left = c * block_size
152
+ upper = r * block_size
153
+ right = left + block_size
154
+ lower = upper + block_size
155
+ blocks.append(image.crop((left, upper, right, lower)))
156
+
157
+ return blocks, rows, cols, block_size
158
+
159
+
160
+ # ---------------------------
161
+ # 特殊 token 准备
162
+ # ---------------------------
163
+
164
+ def prepare_special_tokens(tokenizer, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS) -> Dict[str, int]:
165
+ """
166
+ 返回特殊 token id 的 dict,包含 <global-img>, <fake_token_around_image>, <image>, 以及 <row_i_col_j>
167
+ """
168
+ special = {
169
+ "<global-img>": tokenizer.convert_tokens_to_ids("<global-img>"),
170
+ "<fake_token_around_image>": tokenizer.convert_tokens_to_ids("<fake_token_around_image>"),
171
+ "<image>": tokenizer.convert_tokens_to_ids("<image>"),
172
+ }
173
+ for i in range(max_rows):
174
+ for j in range(max_cols):
175
+ special[f"<row_{i + 1}_col_{j + 1}>"] = tokenizer.convert_tokens_to_ids(f"<row_{i + 1}_col_{j + 1}>")
176
+ return special
177
+
178
+
179
+ # ---------------------------
180
+ # 将图像切块、填充、stack 为模型输入张量
181
+ # ---------------------------
182
+
183
+ def prepare_image_patches(image: Image.Image, preprocess_fn, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS
184
+ ) -> Tuple[torch.Tensor, List[int]]:
185
+ """
186
+ image: PIL.Image
187
+ preprocess_fn: callable that maps PIL.Image -> tensor(C,H,W)
188
+ 返回:
189
+ pixel_values: torch.Tensor, shape (num_patches + 1, C, H, W) -- 最后一个是 full image 原图
190
+ mask_token_ids: list[int] -- 当某个位置为空时,对应的 row_col token id 列表(未去重)
191
+ """
192
+ blocks, rows, cols, block_size = adaptive_square_split(image, max_rows=max_rows, max_cols=max_cols)
193
+ patch_num = len(blocks)
194
+ pad_num = max_rows * max_cols - patch_num
195
+ mask_token_id_list = []
196
+ patch_tensors = []
197
+
198
+ if pad_num > 0:
199
+ # 以行主序填充: 若某个位置超出 rows 或 cols,则用零张量并记录对应的 row_col token id(由调用者映射)
200
+ for i in range(max_rows):
201
+ for j in range(max_cols):
202
+ if i >= rows or j >= cols:
203
+ patch_tensors.append(torch.zeros_like(preprocess_fn(image)))
204
+ # mask token id 由调用者生成/映射,这里只记录一个占位(具体 id 值需外部映射)
205
+ # 返回时,调用者会在文本中找到对应的 special token 的位置并进行 attention mask 操作
206
+ mask_token_id_list.append((i, j))
207
+ else:
208
+ patch_tensors.append(preprocess_fn(blocks[i * cols + j]))
209
+ else:
210
+ patch_tensors = [preprocess_fn(b) for b in blocks]
211
+
212
+ # 最后附加 full image 的 pixel_values(和你原来逻辑一致)
213
+ full_image_tensor = preprocess_fn(image)
214
+ pixel_values = torch.stack(patch_tensors + [full_image_tensor], dim=0) # (N_patches+1, C, H, W)
215
+
216
+ return pixel_values, mask_token_id_list
217
+
218
+
219
+ # ---------------------------
220
+ # 在 token 流中构建 image placeholder(原始的占位 token 串)
221
+ # ---------------------------
222
+
223
+ def construct_image_placeholders(special_tokens: Dict[str, int], max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS,
224
+ n_image_tokens_per_patch: int = 49) -> str:
225
+ """
226
+ 生成一个示例占位字符串,便于拼接到 prompt 中。
227
+ 返回一个包含多个占位符的字符串 (str)。
228
+ """
229
+ image_place_holder = random.choice(["图片如下:", "如下所示的图片:", "请见下面这张图:", "如下图显示:", "参考下方图片:", "图示如下:"])
230
+ for row in range(max_rows):
231
+ for col in range(max_cols):
232
+ image_place_holder += f"<fake_token_around_image><row_{row + 1}_col_{col + 1}>"
233
+ image_place_holder += "<image>" * n_image_tokens_per_patch
234
+ # 全局图像块(最后)
235
+ image_place_holder += f"<fake_token_around_image><global-img>{'<image>' * n_image_tokens_per_patch}<fake_token_around_image>"
236
+ return image_place_holder
237
+
238
+
239
+ # ---------------------------
240
+ # 寻找 token 序列中 image 标记出现的位置(用于 attention mask 修改)
241
+ # ---------------------------
242
+
243
+ def find_indices(tokens: torch.Tensor) -> Optional[Dict[int, Dict[int, List[Tuple[int, int]]]]]:
244
+ """
245
+ 输入 tokens: shape (B, T) 的 tensor
246
+ 返回结构:
247
+ results = { batch_index: { k: [(start_idx, end_idx), ...], ... }, ... }
248
+ 其中 k 对应 image token 的索引(函数里预设 image_id 列表),返回的 start_idx/end_idx 为占位段在 tokens 中的 start/end(包含)
249
+ 说明:此方法沿用了你原来的匹配模式(匹配 [<fake>, <row_i_col_j>] 以及 [<fake>, <global-img>])
250
+ """
251
+ B, T = tokens.size()
252
+ # 这里使用与原代码一致的 id 序列(如果 tokenizer 中不同,请改这里)
253
+ image_ids = [[3, i] for i in range(6, 22)] + [[3, 4]] # 预设 pattern
254
+ image_ids_tensor = torch.tensor(image_ids, device=tokens.device)
255
+ len_image_ids = image_ids_tensor.size(1)
256
+ if len_image_ids > tokens.size(1):
257
+ return None
258
+ tokens_view = tokens.unfold(1, len_image_ids, 1) # (B, T - len_image_ids +1, len_image_ids)
259
+ matches = []
260
+ for image_id_tensor in image_ids_tensor:
261
+ match = (tokens_view == image_id_tensor).all(dim=2) # (B, T-len+1)
262
+ matches.append(match)
263
+ results = {}
264
+ for b in range(B):
265
+ batch_res = {}
266
+ for k, m in enumerate(matches):
267
+ idxs = m[b].nonzero(as_tuple=True)[0]
268
+ if len(idxs) > 0:
269
+ batch_res[k] = [(i.item() + 2, i.item() + 50) for i in idxs]
270
+ if batch_res:
271
+ results[b] = batch_res
272
+ return results or None
273
+
274
+
275
+ # ---------------------------
276
+ # ONNX Session helpers
277
+ # ---------------------------
278
+
279
+ def create_onnx_session(path: str, intra_threads: int = 1) -> onnxruntime.InferenceSession:
280
+ opts = onnxruntime.SessionOptions()
281
+ opts.intra_op_num_threads = intra_threads
282
+ opts.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
283
+ return onnxruntime.InferenceSession(path, sess_options=opts)
284
+
285
+
286
+ # ---------------------------
287
+ # Prefill 阶段(将视觉嵌入插入并运行一次 LLM)
288
+ # ---------------------------
289
+
290
+ def prefill_llm(vision_session: onnxruntime.InferenceSession,
291
+ embed_tokens_session: onnxruntime.InferenceSession,
292
+ llm_session: onnxruntime.InferenceSession,
293
+ pixel_values: torch.Tensor,
294
+ input_ids: torch.Tensor,
295
+ attention_mask: torch.Tensor,
296
+ freqs_cos: torch.Tensor,
297
+ freqs_sin: torch.Tensor,
298
+ special_tokens: Dict[str, int],
299
+ seqlen: int,
300
+ device: str = "cpu") -> Dict[str, Any]:
301
+ """
302
+ 完成 prefill 步骤:
303
+ 1) 通过 vision_session 获得视觉嵌入 deepstack_embeds
304
+ 2) 通过 embed_tokens_session 获得 token embedding(或直接使用输入 hidden states)
305
+ 3) 将视觉嵌入插入 hidden stream(替换占位 token 段)
306
+ 4) 调用 llm_session.run 一次,得到 logits、hidden_states、present_keys、present_values
307
+
308
+ 返回 dict:
309
+ {
310
+ "logits": np.ndarray,
311
+ "hidden_states": np.ndarray,
312
+ "present_keys": np.ndarray,
313
+ "present_values": np.ndarray
314
+ }
315
+ """
316
+ # 1) vision embed
317
+ ort_inputs_vis = {"inputs": pixel_values.numpy()}
318
+ deepstack_embeds = vision_session.run(["deepstack_embeds"], ort_inputs_vis)[0] # e.g. (B, P, L_patch, D)
319
+
320
+ # 2) token embed
321
+ ort_inputs_emb = {"input_ids": input_ids.numpy()}
322
+ embed_tokens = embed_tokens_session.run(["embed_tokens"], ort_inputs_emb)[0] # (B, T, D)
323
+
324
+ # 3) 找到 image placeholder 在 token 中的位置并替换
325
+ image_batch_indices = find_indices(input_ids)
326
+ B = input_ids.shape[0]
327
+ seqlen = seqlen
328
+ new_h = []
329
+
330
+ for i in range(B):
331
+ h_i = embed_tokens[i] # np array (T, D)
332
+ image_indices = image_batch_indices.get(i, {}) if image_batch_indices else {}
333
+ # image_indices: {k: [(start,end), ...], ...}
334
+ # deepstack_embeds: assume shape (B, P, L_patch, D), P = number_of_image_patches + global
335
+ for tki, index_list in image_indices.items():
336
+ # tki 对应 deepstack_embeds 第二维索引
337
+ vision_proj_i = deepstack_embeds[i][tki] # (L_patch, D)
338
+ # 取第一个匹配段
339
+ start_idx, end_idx = index_list[0]
340
+ # 将 h_i 中 start_idx..end_idx 替换为 vision_proj_i,并截断到 seqlen
341
+ # 注意这里我们使用 numpy concat(h_i 是 numpy)
342
+ h_i = np.concatenate((h_i[:start_idx], vision_proj_i, h_i[end_idx + 1:]), axis=0)[:seqlen]
343
+ new_h.append(h_i)
344
+
345
+ hidden_states = np.stack(new_h, axis=0) # (B, seqlen, D)
346
+
347
+ # 4) 呼叫 llm.onnx 做一次前向(prefill)
348
+ # past_keys/past_values 用空的 shape(按模型要求)
349
+ # 这里 past keys/values 的 shape 需与模型期望一致,示例用随机的 0 长度数组作为占位
350
+ # 如模型要求具体形状,请在调用方准备
351
+ past_keys = np.zeros([8, 0, 2, 64], dtype=np.float32)
352
+ past_values = np.zeros([8, 0, 2, 64], dtype=np.float32)
353
+ cos_pe = freqs_cos[0: seqlen].numpy()
354
+ sin_pe = freqs_sin[0: seqlen].numpy()
355
+
356
+ ort_inputs_llm = {
357
+ "input_ids": hidden_states.astype(np.float32),
358
+ "attention_mask": attention_mask.numpy(),
359
+ "cos_pe": cos_pe.astype(np.float32),
360
+ "sin_pe": sin_pe.astype(np.float32),
361
+ "past_keys": past_keys,
362
+ "past_values": past_values
363
+ }
364
+
365
+ logits, hidden_states_out, present_keys, present_values = llm_session.run(
366
+ ["logits", "hidden_states", "present_keys", "present_values"], ort_inputs_llm
367
+ )
368
+
369
+ return {
370
+ "logits": logits,
371
+ "hidden_states": hidden_states_out,
372
+ "present_keys": present_keys,
373
+ "present_values": present_values
374
+ }
375
+
376
+
377
+ # ---------------------------
378
+ # Next-token 自回归生成(基于 present keys/values)
379
+ # ---------------------------
380
+
381
+ def generate_autoregressive(llm_session: onnxruntime.InferenceSession,
382
+ embed_tokens_session: onnxruntime.InferenceSession,
383
+ tokenizer,
384
+ initial_present: Dict[str, np.ndarray],
385
+ start_token_id: int,
386
+ freqs_cos: torch.Tensor,
387
+ freqs_sin: torch.Tensor,
388
+ attention_mask: np.ndarray,
389
+ max_new_tokens: int = 128,
390
+ eos_token_id: int = 2,
391
+ start_pos: int = None):
392
+ """
393
+ 基于 prefill 返回的 present_keys/present_values 进行自回归生成。
394
+ 每一步:
395
+ - 用 embed_tokens_session 获取新 token 的 embedding
396
+ - 用 llm_session 传入 present keys/values 并得到新的 present keys/values 与 logits
397
+ - 选取最大 logit(argmax)作为下一个 token(你可替换为 sampling 策略)
398
+
399
+ 注意:present keys/values 的名称与 shape 与模型实现相关,确保和模型一致。
400
+ """
401
+ present_keys = initial_present["present_keys"]
402
+ present_values = initial_present["present_values"]
403
+ present_keys = present_keys
404
+ present_values = present_values
405
+
406
+ token_id = int(start_token_id)
407
+ if start_pos is None:
408
+ # start_pos = attention_mask.shape[1] # 如果是 numpy
409
+ start_pos = attention_mask.shape[1]
410
+
411
+ generated_ids = []
412
+ buffer = ""
413
+ for step in range(max_new_tokens):
414
+ # 打印已生成字符(decode)
415
+ decoded = tokenizer.decode(token_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
416
+ # print(decoded, end="", flush=True)
417
+ buffer += decoded
418
+ yield buffer
419
+ # 更新 attention mask
420
+ attention_mask = np.concatenate([attention_mask, np.array([[1]], dtype=np.int64)], axis=1)
421
+
422
+ # embed 当前 token
423
+ embed_tokens = embed_tokens_session.run(["embed_tokens"], {"input_ids": np.array([[token_id]], dtype=np.int64)})[0]
424
+
425
+ cos_pe = freqs_cos[start_pos: start_pos + 1].numpy()
426
+ sin_pe = freqs_sin[start_pos: start_pos + 1].numpy()
427
+
428
+ ort_inputs = {
429
+ "input_ids": embed_tokens.astype(np.float32),
430
+ "attention_mask": attention_mask,
431
+ "cos_pe": cos_pe.astype(np.float32),
432
+ "sin_pe": sin_pe.astype(np.float32),
433
+ "past_keys": present_keys,
434
+ "past_values": present_values
435
+ }
436
+
437
+ logits, hidden_states, present_keys, present_values = llm_session.run(
438
+ ["logits", "hidden_states", "present_keys", "present_values"], ort_inputs
439
+ )
440
+
441
+ token_id = int(np.argmax(logits[:, -1, :], axis=-1)[0])
442
+ generated_ids.append(token_id)
443
+
444
+ if token_id == eos_token_id:
445
+ break
446
+
447
+ start_pos += 1
448
+
449
+ return generated_ids
450
+
451
+
452
+ def main_example():
453
+
454
+ tokenizer_path = "./custom_tokenizer"
455
+ tokenizer = load_tokenizer(tokenizer_path)
456
+ preprocess = build_image_preprocess(DEFAULT_IMAGE_SIZE)
457
+
458
+ # image
459
+ image_path = "/Users/hulk/Downloads/coco128/images/train2017/000000000165.jpg"
460
+ image = Image.open(image_path).convert("RGB")
461
+
462
+ # special tokens
463
+ special_tokens = prepare_special_tokens(tokenizer, max_rows=4, max_cols=4)
464
+
465
+ pixel_values, mask_positions = prepare_image_patches(image, preprocess, max_rows=4, max_cols=4)
466
+
467
+ # 构造 prompt + image placeholders(假设 tokenizer 支持 apply_chat_template)
468
+ query = "图片中的人在做什么。"
469
+ messages = [
470
+ {"role": "system", "content": "你是一个多模态AI助手,能够理解图片和文本信息."},
471
+ {"role": "user", "content": query + construct_image_placeholders(special_tokens)}
472
+ ]
473
+ inputs_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
474
+ inputs = tokenizer(inputs_text, return_tensors="pt", truncation=True)
475
+ input_ids = inputs["input_ids"]
476
+ attention_mask = inputs["attention_mask"]
477
+
478
+ # precompute RoPE
479
+ freqs_cos, freqs_sin = precompute_freqs_cis(dim=64, end=32768, rope_base=1e6)
480
+
481
+ # create onnx sessions
482
+ vision_session = create_onnx_session("./onnx_model/vision_encoder.onnx", intra_threads=2)
483
+ embed_tokens_session = create_onnx_session("./onnx_model/embed_tokens.onnx", intra_threads=2)
484
+ llm_session = create_onnx_session("./onnx_model/llm.onnx", intra_threads=2)
485
+
486
+ # prefill
487
+ seqlen = input_ids.shape[1]
488
+ prefill_out = prefill_llm(
489
+ vision_session=vision_session,
490
+ embed_tokens_session=embed_tokens_session,
491
+ llm_session=llm_session,
492
+ pixel_values=pixel_values,
493
+ input_ids=input_ids,
494
+ attention_mask=attention_mask,
495
+ freqs_cos=freqs_cos,
496
+ freqs_sin=freqs_sin,
497
+ special_tokens=special_tokens,
498
+ seqlen=seqlen
499
+ )
500
+
501
+ # start token id = argmax last logit
502
+ start_token_id = int(np.argmax(prefill_out["logits"][:, -1, :], axis=-1)[0])
503
+
504
+ generated = generate_autoregressive(
505
+ llm_session=llm_session,
506
+ embed_tokens_session=embed_tokens_session,
507
+ tokenizer=tokenizer,
508
+ initial_present={"present_keys": prefill_out["present_keys"], "present_values": prefill_out["present_values"]},
509
+ start_token_id=start_token_id,
510
+ freqs_cos=freqs_cos,
511
+ freqs_sin=freqs_sin,
512
+ attention_mask=attention_mask.numpy(),
513
+ max_new_tokens=128,
514
+ eos_token_id=2,
515
+ start_pos=seqlen
516
+ )
517
+
518
+
519
+
520
+ if __name__ == "__main__":
521
+ # main_example()
522
+ pass