File size: 6,246 Bytes
432d085
 
 
13a6781
432d085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d12f37
 
 
 
432d085
6d12f37
 
979c812
6d12f37
979c812
432d085
 
 
6d12f37
432d085
 
6d12f37
 
 
432d085
 
6d12f37
 
 
 
432d085
6d12f37
 
 
432d085
6d12f37
432d085
 
6d12f37
432d085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13a6781
 
 
432d085
 
 
 
 
 
 
 
 
6d12f37
432d085
13a6781
 
432d085
 
13a6781
432d085
 
13a6781
 
6d12f37
13a6781
 
 
 
 
 
 
 
 
 
 
 
 
 
432d085
 
 
6d12f37
 
e1604fb
6d12f37
c68da56
6d12f37
e1604fb
38a2079
432d085
6d12f37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432d085
 
6d12f37
c68da56
432d085
 
6d12f37
432d085
6d12f37
 
432d085
6d12f37
432d085
6d12f37
185cac5
432d085
 
6d12f37
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import gradio as gr
from transformers import AutoProcessor, Idefics3ForConditionalGeneration, TextIteratorStreamer
from threading import Thread
from queue import Queue
import re
import time
from PIL import Image
import torch
import spaces
from tinymind import *

tokenizer_path = "./custom_tokenizer"
tokenizer = load_tokenizer(tokenizer_path)
preprocess = build_image_preprocess(DEFAULT_IMAGE_SIZE)

special_tokens = prepare_special_tokens(tokenizer, max_rows=4, max_cols=4)

vision_session = create_onnx_session("./onnx_model/vision_encoder.onnx", intra_threads=2)
embed_tokens_session = create_onnx_session("./onnx_model/embed_tokens.onnx", intra_threads=2)
llm_session = create_onnx_session("./onnx_model/llm.onnx", intra_threads=2)
freqs_cos, freqs_sin = precompute_freqs_cis(dim=64, end=32768, rope_base=1e6)

@spaces.GPU
def model_inference(input_dict, history):
    """简化版推理函数,只需图片和提示词"""
    text = input_dict.get("text", "")
    files = input_dict.get("files", [])
   
    # 处理上传的图片
    if len(files) > 1:
        images = [Image.open(f).convert("RGB") for f in files]
    elif len(files) == 1:
        images = [Image.open(files[0]).convert("RGB")]
    else:
        images = []

    # 如果没有图片,尝试从历史记录中获取
    if not images and history:
        for turn in reversed(history):
            files_in_history, _ = turn
            if isinstance(files_in_history, list) and len(files_in_history) > 0:
                images = [Image.open(f).convert("RGB") for f in files_in_history]
                break

    # 输入验证
    if not images:
        yield "❌ 错误:请上传图片"
        return

    if text.strip() == "":
        yield "❌ 错误:请输入提示词"
        return

    # 处理第一张图片
    pixel_values, mask_positions = prepare_image_patches(images[0], preprocess, max_rows=4, max_cols=4)

    # 构造 prompt + image placeholders
    messages = [
        {"role": "system", "content": "你是一个多模态AI助手,能够理解图片和文本信息."},
        {"role": "user", "content": text + construct_image_placeholders(special_tokens)}
    ]
    inputs_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(inputs_text, return_tensors="pt", truncation=True)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    # prefill
    seqlen = input_ids.shape[1]
    prefill_out = prefill_llm(
        vision_session=vision_session,
        embed_tokens_session=embed_tokens_session,
        llm_session=llm_session,
        pixel_values=pixel_values,
        input_ids=input_ids,
        attention_mask=attention_mask,
        freqs_cos=freqs_cos,
        freqs_sin=freqs_sin,
        special_tokens=special_tokens,
        seqlen=seqlen
    )

    # start token id = argmax last logit
    start_token_id = int(np.argmax(prefill_out["logits"][:, -1, :], axis=-1)[0])

    # 创建输出队列用于线程间通信
    output_queue = Queue()
    
    generation_args = {
        "llm_session" : llm_session,
        "embed_tokens_session": embed_tokens_session,
        "tokenizer": tokenizer,
        "initial_present" :{"present_keys": prefill_out["present_keys"], "present_values": prefill_out["present_values"]},
        "start_token_id": start_token_id,
        "freqs_cos": freqs_cos,
        "freqs_sin": freqs_sin,
        "attention_mask": attention_mask.numpy(),
        "max_new_tokens": 128,
        "eos_token_id": 2,
        "start_pos": seqlen,
        "output_queue": output_queue
    }
    
    # 在后台线程启动生成
    thread = Thread(target=generate_autoregressive, kwargs=generation_args)
    thread.start()
    
    # 从队列中读取生成的文本并 yield
    yield "🤔 正在生成中..."
    buffer = ""
    
    while True:
        text_chunk = output_queue.get()  # 阻塞等待队列中的数据
        
        if text_chunk is None:  # 生成完成信号
            break
        
        buffer += text_chunk
        time.sleep(0.01)
        yield buffer
    
    # 等待线程完成
    thread.join()


examples = [
    [{"text": "描述图片的内容",
      "files": ["example_images/objects365_v1_00322846.jpg"]}],
    [{"text": "简要概括这张图片的核心内容",
      "files": ["example_images/objects365_v1_00361740.jpg"]}],
    [{"text": "你从图片中看到了什么",
      "files": ["example_images/objects365_v1_00357438.jpg"]}],
    [{"text": "描述这张图片中的内容",
      "files": ["example_images/objects365_v1_00323167.jpg"]}]
]

# 美化 CSS
custom_css = """
.container {
    max-width: 1200px;
    margin: 0 auto;
}

.title {
    text-align: center;
    font-size: 2.5em;
    font-weight: bold;
    margin-bottom: 0.5em;
    background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
    -webkit-background-clip: text;
    -webkit-text-fill-color: transparent;
    background-clip: text;
}

.description {
    text-align: center;
    font-size: 1.1em;
    color: #666;
    margin-bottom: 2em;
}

.input-section {
    background: #f8f9fa;
    padding: 2em;
    border-radius: 12px;
    border: 1px solid #e0e0e0;
    margin-bottom: 2em;
}

.output-section {
    background: #ffffff;
    padding: 2em;
    border-radius: 12px;
    border: 2px solid #667eea;
    min-height: 200px;
}

.example-text {
    font-size: 0.95em;
    color: #555;
}
"""

demo = gr.ChatInterface(
    fn=model_inference,
    title="🎨 TinyMind 多模态AI助手",
    description="基于ONNX的高效多模态模型,支持图像理解与文本生成。上传图片并输入提示词,获得智能回答。模型只有90M,效果仅供娱乐(优化学习中)[详见知乎](https://zhuanlan.zhihu.com/p/1982031267370389720)",
    examples=examples,
    textbox=gr.MultimodalTextbox(
        label="📝 输入提示词(建议以图片描述为主)",
        file_types=["image"],
        file_count="multiple",
        placeholder="例如:描述图片内容 / 这是什么? / 图中有哪些物体?"
    ),
    stop_btn="⏹️ 停止生成",
    multimodal=True,
    cache_examples=False,
    #css=custom_css,
)

demo.launch(debug=True, share=False)