Spaces:
Sleeping
Sleeping
File size: 6,246 Bytes
432d085 13a6781 432d085 6d12f37 432d085 6d12f37 979c812 6d12f37 979c812 432d085 6d12f37 432d085 6d12f37 432d085 6d12f37 432d085 6d12f37 432d085 6d12f37 432d085 6d12f37 432d085 13a6781 432d085 6d12f37 432d085 13a6781 432d085 13a6781 432d085 13a6781 6d12f37 13a6781 432d085 6d12f37 e1604fb 6d12f37 c68da56 6d12f37 e1604fb 38a2079 432d085 6d12f37 432d085 6d12f37 c68da56 432d085 6d12f37 432d085 6d12f37 432d085 6d12f37 432d085 6d12f37 185cac5 432d085 6d12f37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import gradio as gr
from transformers import AutoProcessor, Idefics3ForConditionalGeneration, TextIteratorStreamer
from threading import Thread
from queue import Queue
import re
import time
from PIL import Image
import torch
import spaces
from tinymind import *
tokenizer_path = "./custom_tokenizer"
tokenizer = load_tokenizer(tokenizer_path)
preprocess = build_image_preprocess(DEFAULT_IMAGE_SIZE)
special_tokens = prepare_special_tokens(tokenizer, max_rows=4, max_cols=4)
vision_session = create_onnx_session("./onnx_model/vision_encoder.onnx", intra_threads=2)
embed_tokens_session = create_onnx_session("./onnx_model/embed_tokens.onnx", intra_threads=2)
llm_session = create_onnx_session("./onnx_model/llm.onnx", intra_threads=2)
freqs_cos, freqs_sin = precompute_freqs_cis(dim=64, end=32768, rope_base=1e6)
@spaces.GPU
def model_inference(input_dict, history):
"""简化版推理函数,只需图片和提示词"""
text = input_dict.get("text", "")
files = input_dict.get("files", [])
# 处理上传的图片
if len(files) > 1:
images = [Image.open(f).convert("RGB") for f in files]
elif len(files) == 1:
images = [Image.open(files[0]).convert("RGB")]
else:
images = []
# 如果没有图片,尝试从历史记录中获取
if not images and history:
for turn in reversed(history):
files_in_history, _ = turn
if isinstance(files_in_history, list) and len(files_in_history) > 0:
images = [Image.open(f).convert("RGB") for f in files_in_history]
break
# 输入验证
if not images:
yield "❌ 错误:请上传图片"
return
if text.strip() == "":
yield "❌ 错误:请输入提示词"
return
# 处理第一张图片
pixel_values, mask_positions = prepare_image_patches(images[0], preprocess, max_rows=4, max_cols=4)
# 构造 prompt + image placeholders
messages = [
{"role": "system", "content": "你是一个多模态AI助手,能够理解图片和文本信息."},
{"role": "user", "content": text + construct_image_placeholders(special_tokens)}
]
inputs_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(inputs_text, return_tensors="pt", truncation=True)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# prefill
seqlen = input_ids.shape[1]
prefill_out = prefill_llm(
vision_session=vision_session,
embed_tokens_session=embed_tokens_session,
llm_session=llm_session,
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
freqs_cos=freqs_cos,
freqs_sin=freqs_sin,
special_tokens=special_tokens,
seqlen=seqlen
)
# start token id = argmax last logit
start_token_id = int(np.argmax(prefill_out["logits"][:, -1, :], axis=-1)[0])
# 创建输出队列用于线程间通信
output_queue = Queue()
generation_args = {
"llm_session" : llm_session,
"embed_tokens_session": embed_tokens_session,
"tokenizer": tokenizer,
"initial_present" :{"present_keys": prefill_out["present_keys"], "present_values": prefill_out["present_values"]},
"start_token_id": start_token_id,
"freqs_cos": freqs_cos,
"freqs_sin": freqs_sin,
"attention_mask": attention_mask.numpy(),
"max_new_tokens": 128,
"eos_token_id": 2,
"start_pos": seqlen,
"output_queue": output_queue
}
# 在后台线程启动生成
thread = Thread(target=generate_autoregressive, kwargs=generation_args)
thread.start()
# 从队列中读取生成的文本并 yield
yield "🤔 正在生成中..."
buffer = ""
while True:
text_chunk = output_queue.get() # 阻塞等待队列中的数据
if text_chunk is None: # 生成完成信号
break
buffer += text_chunk
time.sleep(0.01)
yield buffer
# 等待线程完成
thread.join()
examples = [
[{"text": "描述图片的内容",
"files": ["example_images/objects365_v1_00322846.jpg"]}],
[{"text": "简要概括这张图片的核心内容",
"files": ["example_images/objects365_v1_00361740.jpg"]}],
[{"text": "你从图片中看到了什么",
"files": ["example_images/objects365_v1_00357438.jpg"]}],
[{"text": "描述这张图片中的内容",
"files": ["example_images/objects365_v1_00323167.jpg"]}]
]
# 美化 CSS
custom_css = """
.container {
max-width: 1200px;
margin: 0 auto;
}
.title {
text-align: center;
font-size: 2.5em;
font-weight: bold;
margin-bottom: 0.5em;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
.description {
text-align: center;
font-size: 1.1em;
color: #666;
margin-bottom: 2em;
}
.input-section {
background: #f8f9fa;
padding: 2em;
border-radius: 12px;
border: 1px solid #e0e0e0;
margin-bottom: 2em;
}
.output-section {
background: #ffffff;
padding: 2em;
border-radius: 12px;
border: 2px solid #667eea;
min-height: 200px;
}
.example-text {
font-size: 0.95em;
color: #555;
}
"""
demo = gr.ChatInterface(
fn=model_inference,
title="🎨 TinyMind 多模态AI助手",
description="基于ONNX的高效多模态模型,支持图像理解与文本生成。上传图片并输入提示词,获得智能回答。模型只有90M,效果仅供娱乐(优化学习中)[详见知乎](https://zhuanlan.zhihu.com/p/1982031267370389720)",
examples=examples,
textbox=gr.MultimodalTextbox(
label="📝 输入提示词(建议以图片描述为主)",
file_types=["image"],
file_count="multiple",
placeholder="例如:描述图片内容 / 这是什么? / 图中有哪些物体?"
),
stop_btn="⏹️ 停止生成",
multimodal=True,
cache_examples=False,
#css=custom_css,
)
demo.launch(debug=True, share=False)
|