Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoProcessor, Idefics3ForConditionalGeneration, TextIteratorStreamer | |
| from threading import Thread | |
| from queue import Queue | |
| import re | |
| import time | |
| from PIL import Image | |
| import torch | |
| import spaces | |
| from tinymind import * | |
| tokenizer_path = "./custom_tokenizer" | |
| tokenizer = load_tokenizer(tokenizer_path) | |
| preprocess = build_image_preprocess(DEFAULT_IMAGE_SIZE) | |
| special_tokens = prepare_special_tokens(tokenizer, max_rows=4, max_cols=4) | |
| vision_session = create_onnx_session("./onnx_model/vision_encoder.onnx", intra_threads=2) | |
| embed_tokens_session = create_onnx_session("./onnx_model/embed_tokens.onnx", intra_threads=2) | |
| llm_session = create_onnx_session("./onnx_model/llm.onnx", intra_threads=2) | |
| freqs_cos, freqs_sin = precompute_freqs_cis(dim=64, end=32768, rope_base=1e6) | |
| def model_inference(input_dict, history): | |
| """简化版推理函数,只需图片和提示词""" | |
| text = input_dict.get("text", "") | |
| files = input_dict.get("files", []) | |
| # 处理上传的图片 | |
| if len(files) > 1: | |
| images = [Image.open(f).convert("RGB") for f in files] | |
| elif len(files) == 1: | |
| images = [Image.open(files[0]).convert("RGB")] | |
| else: | |
| images = [] | |
| # 如果没有图片,尝试从历史记录中获取 | |
| if not images and history: | |
| for turn in reversed(history): | |
| files_in_history, _ = turn | |
| if isinstance(files_in_history, list) and len(files_in_history) > 0: | |
| images = [Image.open(f).convert("RGB") for f in files_in_history] | |
| break | |
| # 输入验证 | |
| if not images: | |
| yield "❌ 错误:请上传图片" | |
| return | |
| if text.strip() == "": | |
| yield "❌ 错误:请输入提示词" | |
| return | |
| # 处理第一张图片 | |
| pixel_values, mask_positions = prepare_image_patches(images[0], preprocess, max_rows=4, max_cols=4) | |
| # 构造 prompt + image placeholders | |
| messages = [ | |
| {"role": "system", "content": "你是一个多模态AI助手,能够理解图片和文本信息."}, | |
| {"role": "user", "content": text + construct_image_placeholders(special_tokens)} | |
| ] | |
| inputs_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(inputs_text, return_tensors="pt", truncation=True) | |
| input_ids = inputs["input_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| # prefill | |
| seqlen = input_ids.shape[1] | |
| prefill_out = prefill_llm( | |
| vision_session=vision_session, | |
| embed_tokens_session=embed_tokens_session, | |
| llm_session=llm_session, | |
| pixel_values=pixel_values, | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| freqs_cos=freqs_cos, | |
| freqs_sin=freqs_sin, | |
| special_tokens=special_tokens, | |
| seqlen=seqlen | |
| ) | |
| # start token id = argmax last logit | |
| start_token_id = int(np.argmax(prefill_out["logits"][:, -1, :], axis=-1)[0]) | |
| # 创建输出队列用于线程间通信 | |
| output_queue = Queue() | |
| generation_args = { | |
| "llm_session" : llm_session, | |
| "embed_tokens_session": embed_tokens_session, | |
| "tokenizer": tokenizer, | |
| "initial_present" :{"present_keys": prefill_out["present_keys"], "present_values": prefill_out["present_values"]}, | |
| "start_token_id": start_token_id, | |
| "freqs_cos": freqs_cos, | |
| "freqs_sin": freqs_sin, | |
| "attention_mask": attention_mask.numpy(), | |
| "max_new_tokens": 128, | |
| "eos_token_id": 2, | |
| "start_pos": seqlen, | |
| "output_queue": output_queue | |
| } | |
| # 在后台线程启动生成 | |
| thread = Thread(target=generate_autoregressive, kwargs=generation_args) | |
| thread.start() | |
| # 从队列中读取生成的文本并 yield | |
| yield "🤔 正在生成中..." | |
| buffer = "" | |
| while True: | |
| text_chunk = output_queue.get() # 阻塞等待队列中的数据 | |
| if text_chunk is None: # 生成完成信号 | |
| break | |
| buffer += text_chunk | |
| time.sleep(0.01) | |
| yield buffer | |
| # 等待线程完成 | |
| thread.join() | |
| examples = [ | |
| [{"text": "描述图片的内容", | |
| "files": ["example_images/objects365_v1_00322846.jpg"]}], | |
| [{"text": "简要概括这张图片的核心内容", | |
| "files": ["example_images/objects365_v1_00361740.jpg"]}], | |
| [{"text": "你从图片中看到了什么", | |
| "files": ["example_images/objects365_v1_00357438.jpg"]}], | |
| [{"text": "描述这张图片中的内容", | |
| "files": ["example_images/objects365_v1_00323167.jpg"]}] | |
| ] | |
| # 美化 CSS | |
| custom_css = """ | |
| .container { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| } | |
| .title { | |
| text-align: center; | |
| font-size: 2.5em; | |
| font-weight: bold; | |
| margin-bottom: 0.5em; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| background-clip: text; | |
| } | |
| .description { | |
| text-align: center; | |
| font-size: 1.1em; | |
| color: #666; | |
| margin-bottom: 2em; | |
| } | |
| .input-section { | |
| background: #f8f9fa; | |
| padding: 2em; | |
| border-radius: 12px; | |
| border: 1px solid #e0e0e0; | |
| margin-bottom: 2em; | |
| } | |
| .output-section { | |
| background: #ffffff; | |
| padding: 2em; | |
| border-radius: 12px; | |
| border: 2px solid #667eea; | |
| min-height: 200px; | |
| } | |
| .example-text { | |
| font-size: 0.95em; | |
| color: #555; | |
| } | |
| """ | |
| demo = gr.ChatInterface( | |
| fn=model_inference, | |
| title="🎨 TinyMind 多模态AI助手", | |
| description="基于ONNX的高效多模态模型,支持图像理解与文本生成。上传图片并输入提示词,获得智能回答。模型只有90M,效果仅供娱乐(优化学习中)[详见知乎](https://zhuanlan.zhihu.com/p/1982031267370389720)", | |
| examples=examples, | |
| textbox=gr.MultimodalTextbox( | |
| label="📝 输入提示词(建议以图片描述为主)", | |
| file_types=["image"], | |
| file_count="multiple", | |
| placeholder="例如:描述图片内容 / 这是什么? / 图中有哪些物体?" | |
| ), | |
| stop_btn="⏹️ 停止生成", | |
| multimodal=True, | |
| cache_examples=False, | |
| #css=custom_css, | |
| ) | |
| demo.launch(debug=True, share=False) | |