TalkUHulk's picture
Update app.py
c68da56 verified
import gradio as gr
from transformers import AutoProcessor, Idefics3ForConditionalGeneration, TextIteratorStreamer
from threading import Thread
from queue import Queue
import re
import time
from PIL import Image
import torch
import spaces
from tinymind import *
tokenizer_path = "./custom_tokenizer"
tokenizer = load_tokenizer(tokenizer_path)
preprocess = build_image_preprocess(DEFAULT_IMAGE_SIZE)
special_tokens = prepare_special_tokens(tokenizer, max_rows=4, max_cols=4)
vision_session = create_onnx_session("./onnx_model/vision_encoder.onnx", intra_threads=2)
embed_tokens_session = create_onnx_session("./onnx_model/embed_tokens.onnx", intra_threads=2)
llm_session = create_onnx_session("./onnx_model/llm.onnx", intra_threads=2)
freqs_cos, freqs_sin = precompute_freqs_cis(dim=64, end=32768, rope_base=1e6)
@spaces.GPU
def model_inference(input_dict, history):
"""简化版推理函数,只需图片和提示词"""
text = input_dict.get("text", "")
files = input_dict.get("files", [])
# 处理上传的图片
if len(files) > 1:
images = [Image.open(f).convert("RGB") for f in files]
elif len(files) == 1:
images = [Image.open(files[0]).convert("RGB")]
else:
images = []
# 如果没有图片,尝试从历史记录中获取
if not images and history:
for turn in reversed(history):
files_in_history, _ = turn
if isinstance(files_in_history, list) and len(files_in_history) > 0:
images = [Image.open(f).convert("RGB") for f in files_in_history]
break
# 输入验证
if not images:
yield "❌ 错误:请上传图片"
return
if text.strip() == "":
yield "❌ 错误:请输入提示词"
return
# 处理第一张图片
pixel_values, mask_positions = prepare_image_patches(images[0], preprocess, max_rows=4, max_cols=4)
# 构造 prompt + image placeholders
messages = [
{"role": "system", "content": "你是一个多模态AI助手,能够理解图片和文本信息."},
{"role": "user", "content": text + construct_image_placeholders(special_tokens)}
]
inputs_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(inputs_text, return_tensors="pt", truncation=True)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# prefill
seqlen = input_ids.shape[1]
prefill_out = prefill_llm(
vision_session=vision_session,
embed_tokens_session=embed_tokens_session,
llm_session=llm_session,
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
freqs_cos=freqs_cos,
freqs_sin=freqs_sin,
special_tokens=special_tokens,
seqlen=seqlen
)
# start token id = argmax last logit
start_token_id = int(np.argmax(prefill_out["logits"][:, -1, :], axis=-1)[0])
# 创建输出队列用于线程间通信
output_queue = Queue()
generation_args = {
"llm_session" : llm_session,
"embed_tokens_session": embed_tokens_session,
"tokenizer": tokenizer,
"initial_present" :{"present_keys": prefill_out["present_keys"], "present_values": prefill_out["present_values"]},
"start_token_id": start_token_id,
"freqs_cos": freqs_cos,
"freqs_sin": freqs_sin,
"attention_mask": attention_mask.numpy(),
"max_new_tokens": 128,
"eos_token_id": 2,
"start_pos": seqlen,
"output_queue": output_queue
}
# 在后台线程启动生成
thread = Thread(target=generate_autoregressive, kwargs=generation_args)
thread.start()
# 从队列中读取生成的文本并 yield
yield "🤔 正在生成中..."
buffer = ""
while True:
text_chunk = output_queue.get() # 阻塞等待队列中的数据
if text_chunk is None: # 生成完成信号
break
buffer += text_chunk
time.sleep(0.01)
yield buffer
# 等待线程完成
thread.join()
examples = [
[{"text": "描述图片的内容",
"files": ["example_images/objects365_v1_00322846.jpg"]}],
[{"text": "简要概括这张图片的核心内容",
"files": ["example_images/objects365_v1_00361740.jpg"]}],
[{"text": "你从图片中看到了什么",
"files": ["example_images/objects365_v1_00357438.jpg"]}],
[{"text": "描述这张图片中的内容",
"files": ["example_images/objects365_v1_00323167.jpg"]}]
]
# 美化 CSS
custom_css = """
.container {
max-width: 1200px;
margin: 0 auto;
}
.title {
text-align: center;
font-size: 2.5em;
font-weight: bold;
margin-bottom: 0.5em;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
.description {
text-align: center;
font-size: 1.1em;
color: #666;
margin-bottom: 2em;
}
.input-section {
background: #f8f9fa;
padding: 2em;
border-radius: 12px;
border: 1px solid #e0e0e0;
margin-bottom: 2em;
}
.output-section {
background: #ffffff;
padding: 2em;
border-radius: 12px;
border: 2px solid #667eea;
min-height: 200px;
}
.example-text {
font-size: 0.95em;
color: #555;
}
"""
demo = gr.ChatInterface(
fn=model_inference,
title="🎨 TinyMind 多模态AI助手",
description="基于ONNX的高效多模态模型,支持图像理解与文本生成。上传图片并输入提示词,获得智能回答。模型只有90M,效果仅供娱乐(优化学习中)[详见知乎](https://zhuanlan.zhihu.com/p/1982031267370389720)",
examples=examples,
textbox=gr.MultimodalTextbox(
label="📝 输入提示词(建议以图片描述为主)",
file_types=["image"],
file_count="multiple",
placeholder="例如:描述图片内容 / 这是什么? / 图中有哪些物体?"
),
stop_btn="⏹️ 停止生成",
multimodal=True,
cache_examples=False,
#css=custom_css,
)
demo.launch(debug=True, share=False)