Spaces:
Sleeping
Sleeping
Upload 24 files
Browse files- .gitattributes +10 -0
- app.py +172 -0
- custom_tokenizer/chat_template.jinja +96 -0
- custom_tokenizer/special_tokens_map.json +30 -0
- custom_tokenizer/tokenizer.json +0 -0
- custom_tokenizer/tokenizer_config.json +195 -0
- example_images/objects365_v1_00322597.jpg +3 -0
- example_images/objects365_v1_00322772.jpg +3 -0
- example_images/objects365_v1_00322846.jpg +3 -0
- example_images/objects365_v1_00322901.jpg +3 -0
- example_images/objects365_v1_00323167.jpg +3 -0
- example_images/objects365_v1_00324105.jpg +0 -0
- example_images/objects365_v1_00324441.jpg +0 -0
- example_images/objects365_v1_00326764.jpg +3 -0
- example_images/objects365_v1_00336365.jpg +0 -0
- example_images/objects365_v1_00357438.jpg +3 -0
- example_images/objects365_v1_00358590.jpg +3 -0
- example_images/objects365_v1_00361740.jpg +0 -0
- example_images/objects365_v1_00363692.jpg +3 -0
- example_images/objects365_v1_00367221.jpg +3 -0
- onnx_model/embed_tokens.onnx +3 -0
- onnx_model/llm.onnx +3 -0
- onnx_model/vision_encoder.onnx +3 -0
- requirements.txt +9 -0
- tinymind.py +522 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
example_images/objects365_v1_00322597.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
example_images/objects365_v1_00322772.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
example_images/objects365_v1_00322846.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
example_images/objects365_v1_00322901.jpg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
example_images/objects365_v1_00323167.jpg filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
example_images/objects365_v1_00326764.jpg filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
example_images/objects365_v1_00357438.jpg filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
example_images/objects365_v1_00358590.jpg filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
example_images/objects365_v1_00363692.jpg filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
example_images/objects365_v1_00367221.jpg filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from transformers import AutoProcessor, Idefics3ForConditionalGeneration, TextIteratorStreamer
|
| 3 |
+
from threading import Thread
|
| 4 |
+
import re
|
| 5 |
+
import time
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import torch
|
| 8 |
+
import spaces
|
| 9 |
+
from tinymind import *
|
| 10 |
+
|
| 11 |
+
tokenizer_path = "./custom_tokenizer"
|
| 12 |
+
tokenizer = load_tokenizer(tokenizer_path)
|
| 13 |
+
preprocess = build_image_preprocess(DEFAULT_IMAGE_SIZE)
|
| 14 |
+
|
| 15 |
+
special_tokens = prepare_special_tokens(tokenizer, max_rows=4, max_cols=4)
|
| 16 |
+
|
| 17 |
+
vision_session = create_onnx_session("./onnx_model/vision_encoder.onnx", intra_threads=2)
|
| 18 |
+
embed_tokens_session = create_onnx_session("./onnx_model/embed_tokens.onnx", intra_threads=2)
|
| 19 |
+
llm_session = create_onnx_session("./onnx_model/llm.onnx", intra_threads=2)
|
| 20 |
+
freqs_cos, freqs_sin = precompute_freqs_cis(dim=64, end=32768, rope_base=1e6)
|
| 21 |
+
|
| 22 |
+
@spaces.GPU
|
| 23 |
+
def model_inference(
|
| 24 |
+
input_dict, history, decoding_strategy, temperature, max_new_tokens,
|
| 25 |
+
repetition_penalty, top_p
|
| 26 |
+
):
|
| 27 |
+
\
|
| 28 |
+
text = input_dict["text"]
|
| 29 |
+
|
| 30 |
+
if len(input_dict["files"]) > 1:
|
| 31 |
+
images = [Image.open(image).convert("RGB") for image in input_dict["files"]]
|
| 32 |
+
elif len(input_dict["files"]) == 1:
|
| 33 |
+
images = [Image.open(input_dict["files"][0]).convert("RGB")]
|
| 34 |
+
else:
|
| 35 |
+
images = []
|
| 36 |
+
|
| 37 |
+
if not images and history:
|
| 38 |
+
for turn in reversed(history):
|
| 39 |
+
files, _ = turn # user text, assistant text
|
| 40 |
+
if isinstance(files, tuple) and len(files) > 0:
|
| 41 |
+
images = [Image.open(image).convert("RGB") for image in files]
|
| 42 |
+
break
|
| 43 |
+
|
| 44 |
+
if text == "" and not images:
|
| 45 |
+
gr.Error("Please input a query and optionally image(s).")
|
| 46 |
+
|
| 47 |
+
if text == "" and images:
|
| 48 |
+
gr.Error("Please input a text query along the image(s).")
|
| 49 |
+
|
| 50 |
+
pixel_values, mask_positions = prepare_image_patches(images[0], preprocess, max_rows=4, max_cols=4)
|
| 51 |
+
|
| 52 |
+
# 构造 prompt + image placeholders(假设 tokenizer 支持 apply_chat_template)
|
| 53 |
+
|
| 54 |
+
messages = [
|
| 55 |
+
{"role": "system", "content": "你是一个多模态AI助手,能够理解图片和文本信息."},
|
| 56 |
+
{"role": "user", "content": text + construct_image_placeholders(special_tokens)}
|
| 57 |
+
]
|
| 58 |
+
inputs_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 59 |
+
inputs = tokenizer(inputs_text, return_tensors="pt", truncation=True)
|
| 60 |
+
input_ids = inputs["input_ids"]
|
| 61 |
+
attention_mask = inputs["attention_mask"]
|
| 62 |
+
|
| 63 |
+
# prefill
|
| 64 |
+
seqlen = input_ids.shape[1]
|
| 65 |
+
prefill_out = prefill_llm(
|
| 66 |
+
vision_session=vision_session,
|
| 67 |
+
embed_tokens_session=embed_tokens_session,
|
| 68 |
+
llm_session=llm_session,
|
| 69 |
+
pixel_values=pixel_values,
|
| 70 |
+
input_ids=input_ids,
|
| 71 |
+
attention_mask=attention_mask,
|
| 72 |
+
freqs_cos=freqs_cos,
|
| 73 |
+
freqs_sin=freqs_sin,
|
| 74 |
+
special_tokens=special_tokens,
|
| 75 |
+
seqlen=seqlen
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# start token id = argmax last logit
|
| 79 |
+
start_token_id = int(np.argmax(prefill_out["logits"][:, -1, :], axis=-1)[0])
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
generated_text = ""
|
| 83 |
+
generation_args = {
|
| 84 |
+
"llm_session" : llm_session,
|
| 85 |
+
"embed_tokens_session": embed_tokens_session,
|
| 86 |
+
"tokenizer": tokenizer,
|
| 87 |
+
"initial_present" :{"present_keys": prefill_out["present_keys"], "present_values": prefill_out["present_values"]},
|
| 88 |
+
"start_token_id": start_token_id,
|
| 89 |
+
"freqs_cos": freqs_cos,
|
| 90 |
+
"freqs_sin": freqs_sin,
|
| 91 |
+
"attention_mask": attention_mask.numpy(),
|
| 92 |
+
"max_new_tokens": 128,
|
| 93 |
+
"eos_token_id": 2,
|
| 94 |
+
"start_pos": seqlen
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
thread = Thread(target=generate_autoregressive, kwargs=generation_args)
|
| 99 |
+
thread.start()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
examples = [
|
| 103 |
+
[{"text": "描述下图片的内容",
|
| 104 |
+
"files": ["example_images/objects365_v1_00322846.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
|
| 105 |
+
[{"text": "请描述这张图片的内容,并检测其中的苹果",
|
| 106 |
+
"files": ["example_images/objects365_v1_00361740.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
|
| 107 |
+
[{"text": "图中是什么交通工具?",
|
| 108 |
+
"files": ["example_images/objects365_v1_00357438.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
|
| 109 |
+
[{"text": "图中有几只鸭子?",
|
| 110 |
+
"files": ["example_images/objects365_v1_00323167.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
|
| 111 |
+
[{"text": "这是在哪?",
|
| 112 |
+
"files": ["example_images/objects365_v1_00363692.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
|
| 113 |
+
]
|
| 114 |
+
demo = gr.ChatInterface(
|
| 115 |
+
fn=model_inference,
|
| 116 |
+
title="SmolVLM2-256M-Married-Qwen3-0.6B: SmolVLM拥抱Qwen3,支持中文问答🤖",
|
| 117 |
+
description="[TalkUHulk/SmolVLM2-256M-Married-Qwen3-0.6B](https://huggingface.co/TalkUHulk/SmolVLM2-256M-Married-Qwen3-0.6B) 演示。请上传图片和文本,或尝试下方示例。",
|
| 118 |
+
examples=examples,
|
| 119 |
+
textbox=gr.MultimodalTextbox(
|
| 120 |
+
label="请输入查询文本(附带图片)",
|
| 121 |
+
file_types=["image"],
|
| 122 |
+
file_count="multiple"
|
| 123 |
+
),
|
| 124 |
+
stop_btn="停止生成",
|
| 125 |
+
multimodal=True,
|
| 126 |
+
additional_inputs=[
|
| 127 |
+
gr.Radio(
|
| 128 |
+
["Top P Sampling", "Greedy"],
|
| 129 |
+
value="Greedy",
|
| 130 |
+
label="解码策略",
|
| 131 |
+
info="选择生成文本的方式:采样更随机,贪心更确定。"
|
| 132 |
+
),
|
| 133 |
+
gr.Slider(
|
| 134 |
+
minimum=0.0,
|
| 135 |
+
maximum=5.0,
|
| 136 |
+
value=0.4,
|
| 137 |
+
step=0.1,
|
| 138 |
+
interactive=True,
|
| 139 |
+
label="采样温度 (Temperature)",
|
| 140 |
+
info="数值越高,输出越多样化;越低则更保守。"
|
| 141 |
+
),
|
| 142 |
+
gr.Slider(
|
| 143 |
+
minimum=8,
|
| 144 |
+
maximum=1024,
|
| 145 |
+
value=512,
|
| 146 |
+
step=1,
|
| 147 |
+
interactive=True,
|
| 148 |
+
label="最大生成 Token 数",
|
| 149 |
+
),
|
| 150 |
+
gr.Slider(
|
| 151 |
+
minimum=0.01,
|
| 152 |
+
maximum=5.0,
|
| 153 |
+
value=1.2,
|
| 154 |
+
step=0.01,
|
| 155 |
+
interactive=True,
|
| 156 |
+
label="重复惩罚 (Repetition penalty)",
|
| 157 |
+
info="1.0 表示不做惩罚;数值越大越避免重复。"
|
| 158 |
+
),
|
| 159 |
+
gr.Slider(
|
| 160 |
+
minimum=0.01,
|
| 161 |
+
maximum=0.99,
|
| 162 |
+
value=0.8,
|
| 163 |
+
step=0.01,
|
| 164 |
+
interactive=True,
|
| 165 |
+
label="Top P",
|
| 166 |
+
info="数值越高,表示会采样更多低概率的 token。"
|
| 167 |
+
),
|
| 168 |
+
],
|
| 169 |
+
cache_examples=False
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
demo.launch(debug=True)
|
custom_tokenizer/chat_template.jinja
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- if tools %}
|
| 2 |
+
{{- '<|im_start|>system\n' }}
|
| 3 |
+
{%- if messages[0].role == 'system' %}
|
| 4 |
+
{{- messages[0].content + '\n\n' }}
|
| 5 |
+
{%- endif %}
|
| 6 |
+
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
| 7 |
+
{%- for tool in tools %}
|
| 8 |
+
{{- "\n" }}
|
| 9 |
+
{{- tool | tojson }}
|
| 10 |
+
{%- endfor %}
|
| 11 |
+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
| 12 |
+
{%- else %}
|
| 13 |
+
{%- if messages[0]['role'] == 'system' -%}
|
| 14 |
+
{{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
|
| 15 |
+
{%- else -%}
|
| 16 |
+
{{- '<|im_start|>system\n你是一个多模态AI助手,能够理解图片和文本信息。<|im_end|>\n' }}
|
| 17 |
+
{%- endif %}
|
| 18 |
+
{%- endif %}
|
| 19 |
+
|
| 20 |
+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
| 21 |
+
{%- for message in messages[::-1] %}
|
| 22 |
+
{%- set index = (messages|length - 1) - loop.index0 %}
|
| 23 |
+
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
| 24 |
+
{%- set ns.multi_step_tool = false %}
|
| 25 |
+
{%- set ns.last_query_index = index %}
|
| 26 |
+
{%- endif %}
|
| 27 |
+
{%- endfor %}
|
| 28 |
+
|
| 29 |
+
{%- for message in messages %}
|
| 30 |
+
{#- 处理消息内容:支持字符串、列表、图像等多种格式 #}
|
| 31 |
+
{%- if message.content is string %}
|
| 32 |
+
{%- set content = message.content %}
|
| 33 |
+
{%- elif message.content is iterable %}
|
| 34 |
+
{#- 处理多部分内容(文本+图像) #}
|
| 35 |
+
{%- set content_parts = [] %}
|
| 36 |
+
{%- for part in message.content %}
|
| 37 |
+
{%- if part.type == 'text' %}
|
| 38 |
+
{%- set _ = content_parts.append(part.text) %}
|
| 39 |
+
{%- elif part.type == 'image' %}
|
| 40 |
+
{#- 图像占位符,实际图像数据会在processor中处理 #}
|
| 41 |
+
{%- set _ = content_parts.append('<image>') %}
|
| 42 |
+
{%- endif %}
|
| 43 |
+
{%- endfor %}
|
| 44 |
+
{%- set content = content_parts | join('\n') %}
|
| 45 |
+
{%- else %}
|
| 46 |
+
{%- set content = '' %}
|
| 47 |
+
{%- endif %}
|
| 48 |
+
|
| 49 |
+
{#- 用户消息或系统消息 #}
|
| 50 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
| 51 |
+
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
| 52 |
+
|
| 53 |
+
{#- 助手消息 #}
|
| 54 |
+
{%- elif message.role == "assistant" %}
|
| 55 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 56 |
+
{%- if message.tool_calls %}
|
| 57 |
+
{%- for tool_call in message.tool_calls %}
|
| 58 |
+
{%- if (loop.first and content) or (not loop.first) %}
|
| 59 |
+
{{- '\n' }}
|
| 60 |
+
{%- endif %}
|
| 61 |
+
{%- if tool_call.function %}
|
| 62 |
+
{%- set tool_call = tool_call.function %}
|
| 63 |
+
{%- endif %}
|
| 64 |
+
{{- '<tool_call>\n{\"name\": \"' }}
|
| 65 |
+
{{- tool_call.name }}
|
| 66 |
+
{{- '\", \"arguments\": ' }}
|
| 67 |
+
{%- if tool_call.arguments is string %}
|
| 68 |
+
{{- tool_call.arguments }}
|
| 69 |
+
{%- else %}
|
| 70 |
+
{{- tool_call.arguments | tojson }}
|
| 71 |
+
{%- endif %}
|
| 72 |
+
{{- '}\n</tool_call>' }}
|
| 73 |
+
{%- endfor %}
|
| 74 |
+
{%- endif %}
|
| 75 |
+
{{- '<|im_end|>\n' }}
|
| 76 |
+
|
| 77 |
+
{#- 工具消息 #}
|
| 78 |
+
{%- elif message.role == "tool" %}
|
| 79 |
+
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
| 80 |
+
{{- '<|im_start|>user' }}
|
| 81 |
+
{%- endif %}
|
| 82 |
+
{{- '\n<tool_response>\n' }}
|
| 83 |
+
{{- content }}
|
| 84 |
+
{{- '\n</tool_response>' }}
|
| 85 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 86 |
+
{{- '<|im_end|>\n' }}
|
| 87 |
+
{%- endif %}
|
| 88 |
+
{%- endif %}
|
| 89 |
+
{%- endfor %}
|
| 90 |
+
|
| 91 |
+
{%- if add_generation_prompt %}
|
| 92 |
+
{{- '<|im_start|>assistant\n' }}
|
| 93 |
+
{%- if enable_thinking is defined and enable_thinking is false %}
|
| 94 |
+
{{- '<think>\n\n</think>\n\n' }}
|
| 95 |
+
{%- endif %}
|
| 96 |
+
{%- endif %}
|
custom_tokenizer/special_tokens_map.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<|im_start|>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "<|im_end|>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "<|endoftext|>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"unk_token": {
|
| 24 |
+
"content": "<|endoftext|>",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
}
|
| 30 |
+
}
|
custom_tokenizer/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
custom_tokenizer/tokenizer_config.json
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": false,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<|endoftext|>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<|im_start|>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "<|im_end|>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
},
|
| 30 |
+
"3": {
|
| 31 |
+
"content": "<fake_token_around_image>",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false,
|
| 36 |
+
"special": true
|
| 37 |
+
},
|
| 38 |
+
"4": {
|
| 39 |
+
"content": "<global-img>",
|
| 40 |
+
"lstrip": false,
|
| 41 |
+
"normalized": false,
|
| 42 |
+
"rstrip": false,
|
| 43 |
+
"single_word": false,
|
| 44 |
+
"special": true
|
| 45 |
+
},
|
| 46 |
+
"5": {
|
| 47 |
+
"content": "<image>",
|
| 48 |
+
"lstrip": false,
|
| 49 |
+
"normalized": false,
|
| 50 |
+
"rstrip": false,
|
| 51 |
+
"single_word": false,
|
| 52 |
+
"special": true
|
| 53 |
+
},
|
| 54 |
+
"6": {
|
| 55 |
+
"content": "<row_1_col_1>",
|
| 56 |
+
"lstrip": false,
|
| 57 |
+
"normalized": false,
|
| 58 |
+
"rstrip": false,
|
| 59 |
+
"single_word": false,
|
| 60 |
+
"special": true
|
| 61 |
+
},
|
| 62 |
+
"7": {
|
| 63 |
+
"content": "<row_1_col_2>",
|
| 64 |
+
"lstrip": false,
|
| 65 |
+
"normalized": false,
|
| 66 |
+
"rstrip": false,
|
| 67 |
+
"single_word": false,
|
| 68 |
+
"special": true
|
| 69 |
+
},
|
| 70 |
+
"8": {
|
| 71 |
+
"content": "<row_1_col_3>",
|
| 72 |
+
"lstrip": false,
|
| 73 |
+
"normalized": false,
|
| 74 |
+
"rstrip": false,
|
| 75 |
+
"single_word": false,
|
| 76 |
+
"special": true
|
| 77 |
+
},
|
| 78 |
+
"9": {
|
| 79 |
+
"content": "<row_1_col_4>",
|
| 80 |
+
"lstrip": false,
|
| 81 |
+
"normalized": false,
|
| 82 |
+
"rstrip": false,
|
| 83 |
+
"single_word": false,
|
| 84 |
+
"special": true
|
| 85 |
+
},
|
| 86 |
+
"10": {
|
| 87 |
+
"content": "<row_2_col_1>",
|
| 88 |
+
"lstrip": false,
|
| 89 |
+
"normalized": false,
|
| 90 |
+
"rstrip": false,
|
| 91 |
+
"single_word": false,
|
| 92 |
+
"special": true
|
| 93 |
+
},
|
| 94 |
+
"11": {
|
| 95 |
+
"content": "<row_2_col_2>",
|
| 96 |
+
"lstrip": false,
|
| 97 |
+
"normalized": false,
|
| 98 |
+
"rstrip": false,
|
| 99 |
+
"single_word": false,
|
| 100 |
+
"special": true
|
| 101 |
+
},
|
| 102 |
+
"12": {
|
| 103 |
+
"content": "<row_2_col_3>",
|
| 104 |
+
"lstrip": false,
|
| 105 |
+
"normalized": false,
|
| 106 |
+
"rstrip": false,
|
| 107 |
+
"single_word": false,
|
| 108 |
+
"special": true
|
| 109 |
+
},
|
| 110 |
+
"13": {
|
| 111 |
+
"content": "<row_2_col_4>",
|
| 112 |
+
"lstrip": false,
|
| 113 |
+
"normalized": false,
|
| 114 |
+
"rstrip": false,
|
| 115 |
+
"single_word": false,
|
| 116 |
+
"special": true
|
| 117 |
+
},
|
| 118 |
+
"14": {
|
| 119 |
+
"content": "<row_3_col_1>",
|
| 120 |
+
"lstrip": false,
|
| 121 |
+
"normalized": false,
|
| 122 |
+
"rstrip": false,
|
| 123 |
+
"single_word": false,
|
| 124 |
+
"special": true
|
| 125 |
+
},
|
| 126 |
+
"15": {
|
| 127 |
+
"content": "<row_3_col_2>",
|
| 128 |
+
"lstrip": false,
|
| 129 |
+
"normalized": false,
|
| 130 |
+
"rstrip": false,
|
| 131 |
+
"single_word": false,
|
| 132 |
+
"special": true
|
| 133 |
+
},
|
| 134 |
+
"16": {
|
| 135 |
+
"content": "<row_3_col_3>",
|
| 136 |
+
"lstrip": false,
|
| 137 |
+
"normalized": false,
|
| 138 |
+
"rstrip": false,
|
| 139 |
+
"single_word": false,
|
| 140 |
+
"special": true
|
| 141 |
+
},
|
| 142 |
+
"17": {
|
| 143 |
+
"content": "<row_3_col_4>",
|
| 144 |
+
"lstrip": false,
|
| 145 |
+
"normalized": false,
|
| 146 |
+
"rstrip": false,
|
| 147 |
+
"single_word": false,
|
| 148 |
+
"special": true
|
| 149 |
+
},
|
| 150 |
+
"18": {
|
| 151 |
+
"content": "<row_4_col_1>",
|
| 152 |
+
"lstrip": false,
|
| 153 |
+
"normalized": false,
|
| 154 |
+
"rstrip": false,
|
| 155 |
+
"single_word": false,
|
| 156 |
+
"special": true
|
| 157 |
+
},
|
| 158 |
+
"19": {
|
| 159 |
+
"content": "<row_4_col_2>",
|
| 160 |
+
"lstrip": false,
|
| 161 |
+
"normalized": false,
|
| 162 |
+
"rstrip": false,
|
| 163 |
+
"single_word": false,
|
| 164 |
+
"special": true
|
| 165 |
+
},
|
| 166 |
+
"20": {
|
| 167 |
+
"content": "<row_4_col_3>",
|
| 168 |
+
"lstrip": false,
|
| 169 |
+
"normalized": false,
|
| 170 |
+
"rstrip": false,
|
| 171 |
+
"single_word": false,
|
| 172 |
+
"special": true
|
| 173 |
+
},
|
| 174 |
+
"21": {
|
| 175 |
+
"content": "<row_4_col_4>",
|
| 176 |
+
"lstrip": false,
|
| 177 |
+
"normalized": false,
|
| 178 |
+
"rstrip": false,
|
| 179 |
+
"single_word": false,
|
| 180 |
+
"special": true
|
| 181 |
+
}
|
| 182 |
+
},
|
| 183 |
+
"additional_special_tokens": [],
|
| 184 |
+
"bos_token": "<|im_start|>",
|
| 185 |
+
"clean_up_tokenization_spaces": false,
|
| 186 |
+
"eos_token": "<|im_end|>",
|
| 187 |
+
"extra_special_tokens": {},
|
| 188 |
+
"legacy": true,
|
| 189 |
+
"model_max_length": 32768,
|
| 190 |
+
"pad_token": "<|endoftext|>",
|
| 191 |
+
"sp_model_kwargs": {},
|
| 192 |
+
"spaces_between_special_tokens": false,
|
| 193 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 194 |
+
"unk_token": "<|endoftext|>"
|
| 195 |
+
}
|
example_images/objects365_v1_00322597.jpg
ADDED
|
Git LFS Details
|
example_images/objects365_v1_00322772.jpg
ADDED
|
Git LFS Details
|
example_images/objects365_v1_00322846.jpg
ADDED
|
Git LFS Details
|
example_images/objects365_v1_00322901.jpg
ADDED
|
Git LFS Details
|
example_images/objects365_v1_00323167.jpg
ADDED
|
Git LFS Details
|
example_images/objects365_v1_00324105.jpg
ADDED
|
example_images/objects365_v1_00324441.jpg
ADDED
|
example_images/objects365_v1_00326764.jpg
ADDED
|
Git LFS Details
|
example_images/objects365_v1_00336365.jpg
ADDED
|
example_images/objects365_v1_00357438.jpg
ADDED
|
Git LFS Details
|
example_images/objects365_v1_00358590.jpg
ADDED
|
Git LFS Details
|
example_images/objects365_v1_00361740.jpg
ADDED
|
example_images/objects365_v1_00363692.jpg
ADDED
|
Git LFS Details
|
example_images/objects365_v1_00367221.jpg
ADDED
|
Git LFS Details
|
onnx_model/embed_tokens.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:69098c878e4ff056b0724e327f139981b63116f425d8dd94b28b7ce79fccaf8c
|
| 3 |
+
size 13107407
|
onnx_model/llm.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3a50d292a275361d4e19d54a4bb59fd27efe55434958b0f5a007538b07d3bc43
|
| 3 |
+
size 103619420
|
onnx_model/vision_encoder.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:718953645f203de3d12396ca071f5fb1af2c54fa94ca21926235e797b965e807
|
| 3 |
+
size 254698566
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate
|
| 2 |
+
huggingface_hub
|
| 3 |
+
transformers == 4.51.3
|
| 4 |
+
spaces
|
| 5 |
+
onnxruntime
|
| 6 |
+
torchvision
|
| 7 |
+
torch
|
| 8 |
+
numpy
|
| 9 |
+
Pillow
|
tinymind.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
from typing import Optional, Tuple, List, Dict, Any
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import torchvision
|
| 9 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
| 10 |
+
import onnxruntime
|
| 11 |
+
|
| 12 |
+
# 如果你用 transformers 的 AutoTokenizer(推荐)
|
| 13 |
+
from transformers import AutoTokenizer
|
| 14 |
+
|
| 15 |
+
# ---------------------------
|
| 16 |
+
# Config / 默认参数
|
| 17 |
+
# ---------------------------
|
| 18 |
+
|
| 19 |
+
DEFAULT_IMAGE_SIZE = 224
|
| 20 |
+
DEFAULT_MAX_ROWS = 4
|
| 21 |
+
DEFAULT_MAX_COLS = 4
|
| 22 |
+
MIN_BLOCK_SIZE = 16
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ---------------------------
|
| 26 |
+
# Tokenizer / Preprocess
|
| 27 |
+
# ---------------------------
|
| 28 |
+
|
| 29 |
+
def load_tokenizer(tokenizer_path: str):
|
| 30 |
+
"""加载 tokenizer(AutoTokenizer)"""
|
| 31 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
| 32 |
+
if tokenizer.chat_template is None:
|
| 33 |
+
# 新版本transformer可自动加载,训练环境版本:4.51.3支持
|
| 34 |
+
with open(os.path.join(tokenizer_path, "chat_template.jinja"), "r") as f:
|
| 35 |
+
tokenizer.chat_template = f.read()
|
| 36 |
+
return tokenizer
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def build_image_preprocess(image_size: int = DEFAULT_IMAGE_SIZE):
|
| 40 |
+
"""返回 torchvision.transforms.Compose 的预处理 callable"""
|
| 41 |
+
return Compose([
|
| 42 |
+
Resize(size=image_size, interpolation=torchvision.transforms.InterpolationMode.BICUBIC, max_size=None, antialias=True),
|
| 43 |
+
CenterCrop(size=(image_size, image_size)),
|
| 44 |
+
lambda img: img.convert("RGB"),
|
| 45 |
+
ToTensor(),
|
| 46 |
+
Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
|
| 47 |
+
std=(0.26862954, 0.26130258, 0.27577711))
|
| 48 |
+
])
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ---------------------------
|
| 52 |
+
# RoPE频率预计算(precompute_freqs_cis)
|
| 53 |
+
# ---------------------------
|
| 54 |
+
|
| 55 |
+
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
|
| 56 |
+
rope_scaling: Optional[dict] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 57 |
+
"""
|
| 58 |
+
计算 RoPE 的 cos 和 sin 表
|
| 59 |
+
|
| 60 |
+
返回:
|
| 61 |
+
freqs_cos: (end, dim)
|
| 62 |
+
freqs_sin: (end, dim)
|
| 63 |
+
"""
|
| 64 |
+
freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 65 |
+
if rope_scaling is not None:
|
| 66 |
+
orig_max = rope_scaling.get("original_max_position_embeddings", 2048)
|
| 67 |
+
factor = rope_scaling.get("factor", 4)
|
| 68 |
+
beta_fast = rope_scaling.get("beta_fast", 4.0)
|
| 69 |
+
beta_slow = rope_scaling.get("beta_slow", 1.0)
|
| 70 |
+
|
| 71 |
+
if end / orig_max > 1.0:
|
| 72 |
+
corr_dim = next((i for i in range(dim // 2) if 2 * math.pi / freqs[i] > orig_max), dim // 2)
|
| 73 |
+
power = torch.arange(0, dim // 2, device=freqs.device).float() / max(dim // 2 - 1, 1)
|
| 74 |
+
beta = beta_slow + (beta_fast - beta_slow) * power
|
| 75 |
+
scale = torch.where(torch.arange(dim // 2, device=freqs.device) < corr_dim,
|
| 76 |
+
(beta * factor - beta + 1) / (beta * factor),
|
| 77 |
+
1.0 / factor)
|
| 78 |
+
freqs = freqs * scale
|
| 79 |
+
|
| 80 |
+
t = torch.arange(end, device=freqs.device)
|
| 81 |
+
freqs = torch.outer(t, freqs).float() # (end, dim/2)
|
| 82 |
+
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
|
| 83 |
+
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
|
| 84 |
+
return freqs_cos, freqs_sin
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ---------------------------
|
| 88 |
+
# 图像自适应切分(adaptive_square_split)
|
| 89 |
+
# ---------------------------
|
| 90 |
+
|
| 91 |
+
def calculate_optimal_split_with_fixed_max(width: int, height: int, max_rows: int, max_cols: int) -> Tuple[int, int, int]:
|
| 92 |
+
"""
|
| 93 |
+
计算最佳切分(返回 rows, cols, block_size)
|
| 94 |
+
block_size 会向下取整到 16 的倍数,最小为 MIN_BLOCK_SIZE。
|
| 95 |
+
"""
|
| 96 |
+
best_rows = 1
|
| 97 |
+
best_cols = 1
|
| 98 |
+
best_block_size = 0
|
| 99 |
+
best_coverage = 0.0
|
| 100 |
+
|
| 101 |
+
# 方案1: 固定行数为 max_rows,自适应列数
|
| 102 |
+
rows_fixed = max_rows
|
| 103 |
+
for cols in range(1, max_cols + 1):
|
| 104 |
+
block_width = width // cols
|
| 105 |
+
block_height = height // rows_fixed
|
| 106 |
+
square_size = min(block_width, block_height)
|
| 107 |
+
if square_size <= 0:
|
| 108 |
+
continue
|
| 109 |
+
coverage = (cols * square_size) * (rows_fixed * square_size) / (width * height)
|
| 110 |
+
if coverage > best_coverage or (coverage == best_coverage and square_size > best_block_size):
|
| 111 |
+
best_rows, best_cols, best_block_size, best_coverage = rows_fixed, cols, square_size, coverage
|
| 112 |
+
|
| 113 |
+
# 方案2: 固定列数为 max_cols,自适应行数
|
| 114 |
+
cols_fixed = max_cols
|
| 115 |
+
for rows in range(1, max_rows + 1):
|
| 116 |
+
block_width = width // cols_fixed
|
| 117 |
+
block_height = height // rows
|
| 118 |
+
square_size = min(block_width, block_height)
|
| 119 |
+
if square_size <= 0:
|
| 120 |
+
continue
|
| 121 |
+
coverage = (cols_fixed * square_size) * (rows * square_size) / (width * height)
|
| 122 |
+
if coverage > best_coverage or (coverage == best_coverage and square_size > best_block_size):
|
| 123 |
+
best_rows, best_cols, best_block_size, best_coverage = rows, cols_fixed, square_size, coverage
|
| 124 |
+
|
| 125 |
+
# 方案3: 两者都达到最大
|
| 126 |
+
block_width = width // max_cols
|
| 127 |
+
block_height = height // max_rows
|
| 128 |
+
square_size = min(block_width, block_height)
|
| 129 |
+
if square_size > 0:
|
| 130 |
+
coverage = (max_cols * square_size) * (max_rows * square_size) / (width * height)
|
| 131 |
+
if coverage > best_coverage or (coverage == best_coverage and square_size > best_block_size):
|
| 132 |
+
best_rows, best_cols, best_block_size, best_coverage = max_rows, max_cols, square_size, coverage
|
| 133 |
+
|
| 134 |
+
# 对齐到 16 的倍数并保证最小值
|
| 135 |
+
best_block_size = max(MIN_BLOCK_SIZE, (best_block_size // 16) * 16)
|
| 136 |
+
return best_rows, best_cols, best_block_size
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def adaptive_square_split(image: Image.Image, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS
|
| 140 |
+
) -> Tuple[List[Image.Image], int, int, int]:
|
| 141 |
+
"""
|
| 142 |
+
将 PIL Image 自适应切分为正方形块,返回 (blocks_list, rows, cols, block_size)
|
| 143 |
+
blocks_list 是按行主序的块列表(可能少于 max_rows*max_cols)
|
| 144 |
+
"""
|
| 145 |
+
width, height = image.size
|
| 146 |
+
rows, cols, block_size = calculate_optimal_split_with_fixed_max(width, height, max_rows, max_cols)
|
| 147 |
+
|
| 148 |
+
blocks = []
|
| 149 |
+
for r in range(rows):
|
| 150 |
+
for c in range(cols):
|
| 151 |
+
left = c * block_size
|
| 152 |
+
upper = r * block_size
|
| 153 |
+
right = left + block_size
|
| 154 |
+
lower = upper + block_size
|
| 155 |
+
blocks.append(image.crop((left, upper, right, lower)))
|
| 156 |
+
|
| 157 |
+
return blocks, rows, cols, block_size
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# ---------------------------
|
| 161 |
+
# 特殊 token 准备
|
| 162 |
+
# ---------------------------
|
| 163 |
+
|
| 164 |
+
def prepare_special_tokens(tokenizer, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS) -> Dict[str, int]:
|
| 165 |
+
"""
|
| 166 |
+
返回特殊 token id 的 dict,包含 <global-img>, <fake_token_around_image>, <image>, 以及 <row_i_col_j>
|
| 167 |
+
"""
|
| 168 |
+
special = {
|
| 169 |
+
"<global-img>": tokenizer.convert_tokens_to_ids("<global-img>"),
|
| 170 |
+
"<fake_token_around_image>": tokenizer.convert_tokens_to_ids("<fake_token_around_image>"),
|
| 171 |
+
"<image>": tokenizer.convert_tokens_to_ids("<image>"),
|
| 172 |
+
}
|
| 173 |
+
for i in range(max_rows):
|
| 174 |
+
for j in range(max_cols):
|
| 175 |
+
special[f"<row_{i + 1}_col_{j + 1}>"] = tokenizer.convert_tokens_to_ids(f"<row_{i + 1}_col_{j + 1}>")
|
| 176 |
+
return special
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# ---------------------------
|
| 180 |
+
# 将图像切块、填充、stack 为模型输入张量
|
| 181 |
+
# ---------------------------
|
| 182 |
+
|
| 183 |
+
def prepare_image_patches(image: Image.Image, preprocess_fn, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS
|
| 184 |
+
) -> Tuple[torch.Tensor, List[int]]:
|
| 185 |
+
"""
|
| 186 |
+
image: PIL.Image
|
| 187 |
+
preprocess_fn: callable that maps PIL.Image -> tensor(C,H,W)
|
| 188 |
+
返回:
|
| 189 |
+
pixel_values: torch.Tensor, shape (num_patches + 1, C, H, W) -- 最后一个是 full image 原图
|
| 190 |
+
mask_token_ids: list[int] -- 当某个位置为空时,对应的 row_col token id 列表(未去重)
|
| 191 |
+
"""
|
| 192 |
+
blocks, rows, cols, block_size = adaptive_square_split(image, max_rows=max_rows, max_cols=max_cols)
|
| 193 |
+
patch_num = len(blocks)
|
| 194 |
+
pad_num = max_rows * max_cols - patch_num
|
| 195 |
+
mask_token_id_list = []
|
| 196 |
+
patch_tensors = []
|
| 197 |
+
|
| 198 |
+
if pad_num > 0:
|
| 199 |
+
# 以行主序填充: 若某个位置超出 rows 或 cols,则用零张量并记录对应的 row_col token id(由调用者映射)
|
| 200 |
+
for i in range(max_rows):
|
| 201 |
+
for j in range(max_cols):
|
| 202 |
+
if i >= rows or j >= cols:
|
| 203 |
+
patch_tensors.append(torch.zeros_like(preprocess_fn(image)))
|
| 204 |
+
# mask token id 由调用者生成/映射,这里只记录一个占位(具体 id 值需外部映射)
|
| 205 |
+
# 返回时,调用者会在文本中找到对应的 special token 的位置并进行 attention mask 操作
|
| 206 |
+
mask_token_id_list.append((i, j))
|
| 207 |
+
else:
|
| 208 |
+
patch_tensors.append(preprocess_fn(blocks[i * cols + j]))
|
| 209 |
+
else:
|
| 210 |
+
patch_tensors = [preprocess_fn(b) for b in blocks]
|
| 211 |
+
|
| 212 |
+
# 最后附加 full image 的 pixel_values(和你原来逻辑一致)
|
| 213 |
+
full_image_tensor = preprocess_fn(image)
|
| 214 |
+
pixel_values = torch.stack(patch_tensors + [full_image_tensor], dim=0) # (N_patches+1, C, H, W)
|
| 215 |
+
|
| 216 |
+
return pixel_values, mask_token_id_list
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
# ---------------------------
|
| 220 |
+
# 在 token 流中构建 image placeholder(原始的占位 token 串)
|
| 221 |
+
# ---------------------------
|
| 222 |
+
|
| 223 |
+
def construct_image_placeholders(special_tokens: Dict[str, int], max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX_COLS,
|
| 224 |
+
n_image_tokens_per_patch: int = 49) -> str:
|
| 225 |
+
"""
|
| 226 |
+
生成一个示例占位字符串,便于拼接到 prompt 中。
|
| 227 |
+
返回一个包含多个占位符的字符串 (str)。
|
| 228 |
+
"""
|
| 229 |
+
image_place_holder = random.choice(["图片如下:", "如下所示的图片:", "请见下面这张图:", "如下图显示:", "参考下方图片:", "图示如下:"])
|
| 230 |
+
for row in range(max_rows):
|
| 231 |
+
for col in range(max_cols):
|
| 232 |
+
image_place_holder += f"<fake_token_around_image><row_{row + 1}_col_{col + 1}>"
|
| 233 |
+
image_place_holder += "<image>" * n_image_tokens_per_patch
|
| 234 |
+
# 全局图像块(最后)
|
| 235 |
+
image_place_holder += f"<fake_token_around_image><global-img>{'<image>' * n_image_tokens_per_patch}<fake_token_around_image>"
|
| 236 |
+
return image_place_holder
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# ---------------------------
|
| 240 |
+
# 寻找 token 序列中 image 标记出现的位置(用于 attention mask 修改)
|
| 241 |
+
# ---------------------------
|
| 242 |
+
|
| 243 |
+
def find_indices(tokens: torch.Tensor) -> Optional[Dict[int, Dict[int, List[Tuple[int, int]]]]]:
|
| 244 |
+
"""
|
| 245 |
+
输入 tokens: shape (B, T) 的 tensor
|
| 246 |
+
返回结构:
|
| 247 |
+
results = { batch_index: { k: [(start_idx, end_idx), ...], ... }, ... }
|
| 248 |
+
其中 k 对应 image token 的索引(函数里预设 image_id 列表),返回的 start_idx/end_idx 为占位段在 tokens 中的 start/end(包含)
|
| 249 |
+
说明:此方法沿用了你原来的匹配模式(匹配 [<fake>, <row_i_col_j>] 以及 [<fake>, <global-img>])
|
| 250 |
+
"""
|
| 251 |
+
B, T = tokens.size()
|
| 252 |
+
# 这里使用与原代码一致的 id 序列(如果 tokenizer 中不同,请改这里)
|
| 253 |
+
image_ids = [[3, i] for i in range(6, 22)] + [[3, 4]] # 预设 pattern
|
| 254 |
+
image_ids_tensor = torch.tensor(image_ids, device=tokens.device)
|
| 255 |
+
len_image_ids = image_ids_tensor.size(1)
|
| 256 |
+
if len_image_ids > tokens.size(1):
|
| 257 |
+
return None
|
| 258 |
+
tokens_view = tokens.unfold(1, len_image_ids, 1) # (B, T - len_image_ids +1, len_image_ids)
|
| 259 |
+
matches = []
|
| 260 |
+
for image_id_tensor in image_ids_tensor:
|
| 261 |
+
match = (tokens_view == image_id_tensor).all(dim=2) # (B, T-len+1)
|
| 262 |
+
matches.append(match)
|
| 263 |
+
results = {}
|
| 264 |
+
for b in range(B):
|
| 265 |
+
batch_res = {}
|
| 266 |
+
for k, m in enumerate(matches):
|
| 267 |
+
idxs = m[b].nonzero(as_tuple=True)[0]
|
| 268 |
+
if len(idxs) > 0:
|
| 269 |
+
batch_res[k] = [(i.item() + 2, i.item() + 50) for i in idxs]
|
| 270 |
+
if batch_res:
|
| 271 |
+
results[b] = batch_res
|
| 272 |
+
return results or None
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# ---------------------------
|
| 276 |
+
# ONNX Session helpers
|
| 277 |
+
# ---------------------------
|
| 278 |
+
|
| 279 |
+
def create_onnx_session(path: str, intra_threads: int = 1) -> onnxruntime.InferenceSession:
|
| 280 |
+
opts = onnxruntime.SessionOptions()
|
| 281 |
+
opts.intra_op_num_threads = intra_threads
|
| 282 |
+
opts.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
| 283 |
+
return onnxruntime.InferenceSession(path, sess_options=opts)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# ---------------------------
|
| 287 |
+
# Prefill 阶段(将视觉嵌入插入并运行一次 LLM)
|
| 288 |
+
# ---------------------------
|
| 289 |
+
|
| 290 |
+
def prefill_llm(vision_session: onnxruntime.InferenceSession,
|
| 291 |
+
embed_tokens_session: onnxruntime.InferenceSession,
|
| 292 |
+
llm_session: onnxruntime.InferenceSession,
|
| 293 |
+
pixel_values: torch.Tensor,
|
| 294 |
+
input_ids: torch.Tensor,
|
| 295 |
+
attention_mask: torch.Tensor,
|
| 296 |
+
freqs_cos: torch.Tensor,
|
| 297 |
+
freqs_sin: torch.Tensor,
|
| 298 |
+
special_tokens: Dict[str, int],
|
| 299 |
+
seqlen: int,
|
| 300 |
+
device: str = "cpu") -> Dict[str, Any]:
|
| 301 |
+
"""
|
| 302 |
+
完成 prefill 步骤:
|
| 303 |
+
1) 通过 vision_session 获得视觉嵌入 deepstack_embeds
|
| 304 |
+
2) 通过 embed_tokens_session 获得 token embedding(或直接使用输入 hidden states)
|
| 305 |
+
3) 将视觉嵌入插入 hidden stream(替换占位 token 段)
|
| 306 |
+
4) 调用 llm_session.run 一次,得到 logits、hidden_states、present_keys、present_values
|
| 307 |
+
|
| 308 |
+
返回 dict:
|
| 309 |
+
{
|
| 310 |
+
"logits": np.ndarray,
|
| 311 |
+
"hidden_states": np.ndarray,
|
| 312 |
+
"present_keys": np.ndarray,
|
| 313 |
+
"present_values": np.ndarray
|
| 314 |
+
}
|
| 315 |
+
"""
|
| 316 |
+
# 1) vision embed
|
| 317 |
+
ort_inputs_vis = {"inputs": pixel_values.numpy()}
|
| 318 |
+
deepstack_embeds = vision_session.run(["deepstack_embeds"], ort_inputs_vis)[0] # e.g. (B, P, L_patch, D)
|
| 319 |
+
|
| 320 |
+
# 2) token embed
|
| 321 |
+
ort_inputs_emb = {"input_ids": input_ids.numpy()}
|
| 322 |
+
embed_tokens = embed_tokens_session.run(["embed_tokens"], ort_inputs_emb)[0] # (B, T, D)
|
| 323 |
+
|
| 324 |
+
# 3) 找到 image placeholder 在 token 中的位置并替换
|
| 325 |
+
image_batch_indices = find_indices(input_ids)
|
| 326 |
+
B = input_ids.shape[0]
|
| 327 |
+
seqlen = seqlen
|
| 328 |
+
new_h = []
|
| 329 |
+
|
| 330 |
+
for i in range(B):
|
| 331 |
+
h_i = embed_tokens[i] # np array (T, D)
|
| 332 |
+
image_indices = image_batch_indices.get(i, {}) if image_batch_indices else {}
|
| 333 |
+
# image_indices: {k: [(start,end), ...], ...}
|
| 334 |
+
# deepstack_embeds: assume shape (B, P, L_patch, D), P = number_of_image_patches + global
|
| 335 |
+
for tki, index_list in image_indices.items():
|
| 336 |
+
# tki 对应 deepstack_embeds 第二维索引
|
| 337 |
+
vision_proj_i = deepstack_embeds[i][tki] # (L_patch, D)
|
| 338 |
+
# 取第一个匹配段
|
| 339 |
+
start_idx, end_idx = index_list[0]
|
| 340 |
+
# 将 h_i 中 start_idx..end_idx 替换为 vision_proj_i,并截断到 seqlen
|
| 341 |
+
# 注意这里我们使用 numpy concat(h_i 是 numpy)
|
| 342 |
+
h_i = np.concatenate((h_i[:start_idx], vision_proj_i, h_i[end_idx + 1:]), axis=0)[:seqlen]
|
| 343 |
+
new_h.append(h_i)
|
| 344 |
+
|
| 345 |
+
hidden_states = np.stack(new_h, axis=0) # (B, seqlen, D)
|
| 346 |
+
|
| 347 |
+
# 4) 呼叫 llm.onnx 做一次前向(prefill)
|
| 348 |
+
# past_keys/past_values 用空的 shape(按模型要求)
|
| 349 |
+
# 这里 past keys/values 的 shape 需与模型期望一致,示例用随机的 0 长度数组作为占位
|
| 350 |
+
# 如模型要求具体形状,请在调用方准备
|
| 351 |
+
past_keys = np.zeros([8, 0, 2, 64], dtype=np.float32)
|
| 352 |
+
past_values = np.zeros([8, 0, 2, 64], dtype=np.float32)
|
| 353 |
+
cos_pe = freqs_cos[0: seqlen].numpy()
|
| 354 |
+
sin_pe = freqs_sin[0: seqlen].numpy()
|
| 355 |
+
|
| 356 |
+
ort_inputs_llm = {
|
| 357 |
+
"input_ids": hidden_states.astype(np.float32),
|
| 358 |
+
"attention_mask": attention_mask.numpy(),
|
| 359 |
+
"cos_pe": cos_pe.astype(np.float32),
|
| 360 |
+
"sin_pe": sin_pe.astype(np.float32),
|
| 361 |
+
"past_keys": past_keys,
|
| 362 |
+
"past_values": past_values
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
logits, hidden_states_out, present_keys, present_values = llm_session.run(
|
| 366 |
+
["logits", "hidden_states", "present_keys", "present_values"], ort_inputs_llm
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
return {
|
| 370 |
+
"logits": logits,
|
| 371 |
+
"hidden_states": hidden_states_out,
|
| 372 |
+
"present_keys": present_keys,
|
| 373 |
+
"present_values": present_values
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
# ---------------------------
|
| 378 |
+
# Next-token 自回归生成(基于 present keys/values)
|
| 379 |
+
# ---------------------------
|
| 380 |
+
|
| 381 |
+
def generate_autoregressive(llm_session: onnxruntime.InferenceSession,
|
| 382 |
+
embed_tokens_session: onnxruntime.InferenceSession,
|
| 383 |
+
tokenizer,
|
| 384 |
+
initial_present: Dict[str, np.ndarray],
|
| 385 |
+
start_token_id: int,
|
| 386 |
+
freqs_cos: torch.Tensor,
|
| 387 |
+
freqs_sin: torch.Tensor,
|
| 388 |
+
attention_mask: np.ndarray,
|
| 389 |
+
max_new_tokens: int = 128,
|
| 390 |
+
eos_token_id: int = 2,
|
| 391 |
+
start_pos: int = None):
|
| 392 |
+
"""
|
| 393 |
+
基于 prefill 返回的 present_keys/present_values 进行自回归生成。
|
| 394 |
+
每一步:
|
| 395 |
+
- 用 embed_tokens_session 获取新 token 的 embedding
|
| 396 |
+
- 用 llm_session 传入 present keys/values 并得到新的 present keys/values 与 logits
|
| 397 |
+
- 选取最大 logit(argmax)作为下一个 token(你可替换为 sampling 策略)
|
| 398 |
+
|
| 399 |
+
注意:present keys/values 的名称与 shape 与模型实现相关,确保和模型一致。
|
| 400 |
+
"""
|
| 401 |
+
present_keys = initial_present["present_keys"]
|
| 402 |
+
present_values = initial_present["present_values"]
|
| 403 |
+
present_keys = present_keys
|
| 404 |
+
present_values = present_values
|
| 405 |
+
|
| 406 |
+
token_id = int(start_token_id)
|
| 407 |
+
if start_pos is None:
|
| 408 |
+
# start_pos = attention_mask.shape[1] # 如果是 numpy
|
| 409 |
+
start_pos = attention_mask.shape[1]
|
| 410 |
+
|
| 411 |
+
generated_ids = []
|
| 412 |
+
buffer = ""
|
| 413 |
+
for step in range(max_new_tokens):
|
| 414 |
+
# 打印已生成字符(decode)
|
| 415 |
+
decoded = tokenizer.decode(token_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
| 416 |
+
# print(decoded, end="", flush=True)
|
| 417 |
+
buffer += decoded
|
| 418 |
+
yield buffer
|
| 419 |
+
# 更新 attention mask
|
| 420 |
+
attention_mask = np.concatenate([attention_mask, np.array([[1]], dtype=np.int64)], axis=1)
|
| 421 |
+
|
| 422 |
+
# embed 当前 token
|
| 423 |
+
embed_tokens = embed_tokens_session.run(["embed_tokens"], {"input_ids": np.array([[token_id]], dtype=np.int64)})[0]
|
| 424 |
+
|
| 425 |
+
cos_pe = freqs_cos[start_pos: start_pos + 1].numpy()
|
| 426 |
+
sin_pe = freqs_sin[start_pos: start_pos + 1].numpy()
|
| 427 |
+
|
| 428 |
+
ort_inputs = {
|
| 429 |
+
"input_ids": embed_tokens.astype(np.float32),
|
| 430 |
+
"attention_mask": attention_mask,
|
| 431 |
+
"cos_pe": cos_pe.astype(np.float32),
|
| 432 |
+
"sin_pe": sin_pe.astype(np.float32),
|
| 433 |
+
"past_keys": present_keys,
|
| 434 |
+
"past_values": present_values
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
logits, hidden_states, present_keys, present_values = llm_session.run(
|
| 438 |
+
["logits", "hidden_states", "present_keys", "present_values"], ort_inputs
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
token_id = int(np.argmax(logits[:, -1, :], axis=-1)[0])
|
| 442 |
+
generated_ids.append(token_id)
|
| 443 |
+
|
| 444 |
+
if token_id == eos_token_id:
|
| 445 |
+
break
|
| 446 |
+
|
| 447 |
+
start_pos += 1
|
| 448 |
+
|
| 449 |
+
return generated_ids
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def main_example():
|
| 453 |
+
|
| 454 |
+
tokenizer_path = "./custom_tokenizer"
|
| 455 |
+
tokenizer = load_tokenizer(tokenizer_path)
|
| 456 |
+
preprocess = build_image_preprocess(DEFAULT_IMAGE_SIZE)
|
| 457 |
+
|
| 458 |
+
# image
|
| 459 |
+
image_path = "/Users/hulk/Downloads/coco128/images/train2017/000000000165.jpg"
|
| 460 |
+
image = Image.open(image_path).convert("RGB")
|
| 461 |
+
|
| 462 |
+
# special tokens
|
| 463 |
+
special_tokens = prepare_special_tokens(tokenizer, max_rows=4, max_cols=4)
|
| 464 |
+
|
| 465 |
+
pixel_values, mask_positions = prepare_image_patches(image, preprocess, max_rows=4, max_cols=4)
|
| 466 |
+
|
| 467 |
+
# 构造 prompt + image placeholders(假设 tokenizer 支持 apply_chat_template)
|
| 468 |
+
query = "图片中的人在做什么。"
|
| 469 |
+
messages = [
|
| 470 |
+
{"role": "system", "content": "你是一个多模态AI助手,能够理解图片和文本信息."},
|
| 471 |
+
{"role": "user", "content": query + construct_image_placeholders(special_tokens)}
|
| 472 |
+
]
|
| 473 |
+
inputs_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 474 |
+
inputs = tokenizer(inputs_text, return_tensors="pt", truncation=True)
|
| 475 |
+
input_ids = inputs["input_ids"]
|
| 476 |
+
attention_mask = inputs["attention_mask"]
|
| 477 |
+
|
| 478 |
+
# precompute RoPE
|
| 479 |
+
freqs_cos, freqs_sin = precompute_freqs_cis(dim=64, end=32768, rope_base=1e6)
|
| 480 |
+
|
| 481 |
+
# create onnx sessions
|
| 482 |
+
vision_session = create_onnx_session("./onnx_model/vision_encoder.onnx", intra_threads=2)
|
| 483 |
+
embed_tokens_session = create_onnx_session("./onnx_model/embed_tokens.onnx", intra_threads=2)
|
| 484 |
+
llm_session = create_onnx_session("./onnx_model/llm.onnx", intra_threads=2)
|
| 485 |
+
|
| 486 |
+
# prefill
|
| 487 |
+
seqlen = input_ids.shape[1]
|
| 488 |
+
prefill_out = prefill_llm(
|
| 489 |
+
vision_session=vision_session,
|
| 490 |
+
embed_tokens_session=embed_tokens_session,
|
| 491 |
+
llm_session=llm_session,
|
| 492 |
+
pixel_values=pixel_values,
|
| 493 |
+
input_ids=input_ids,
|
| 494 |
+
attention_mask=attention_mask,
|
| 495 |
+
freqs_cos=freqs_cos,
|
| 496 |
+
freqs_sin=freqs_sin,
|
| 497 |
+
special_tokens=special_tokens,
|
| 498 |
+
seqlen=seqlen
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
# start token id = argmax last logit
|
| 502 |
+
start_token_id = int(np.argmax(prefill_out["logits"][:, -1, :], axis=-1)[0])
|
| 503 |
+
|
| 504 |
+
generated = generate_autoregressive(
|
| 505 |
+
llm_session=llm_session,
|
| 506 |
+
embed_tokens_session=embed_tokens_session,
|
| 507 |
+
tokenizer=tokenizer,
|
| 508 |
+
initial_present={"present_keys": prefill_out["present_keys"], "present_values": prefill_out["present_values"]},
|
| 509 |
+
start_token_id=start_token_id,
|
| 510 |
+
freqs_cos=freqs_cos,
|
| 511 |
+
freqs_sin=freqs_sin,
|
| 512 |
+
attention_mask=attention_mask.numpy(),
|
| 513 |
+
max_new_tokens=128,
|
| 514 |
+
eos_token_id=2,
|
| 515 |
+
start_pos=seqlen
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
if __name__ == "__main__":
|
| 521 |
+
# main_example()
|
| 522 |
+
pass
|