Update app.py
Browse files
app.py
CHANGED
|
@@ -4,81 +4,83 @@ import os
|
|
| 4 |
from optimum.intel import OVModelForCausalLM
|
| 5 |
from transformers import AutoTokenizer, TextIteratorStreamer
|
| 6 |
from threading import Thread
|
| 7 |
-
import
|
| 8 |
|
| 9 |
-
# ---
|
| 10 |
-
# 8B 主模型
|
| 11 |
MAIN_MODEL_ID = "OpenVINO/Qwen2.5-7B-Instruct-int4-ov"
|
| 12 |
-
# 0.5B 助手模型 (
|
| 13 |
DRAFT_MODEL_ID = "hsuwill000/Qwen2.5-0.5B-Instruct-openvino-4bit"
|
| 14 |
|
| 15 |
-
print("🚀
|
| 16 |
|
| 17 |
-
# --- 1. 加载模型 ---
|
| 18 |
try:
|
| 19 |
tokenizer = AutoTokenizer.from_pretrained(MAIN_MODEL_ID)
|
| 20 |
|
| 21 |
-
print(f"Loading Main
|
| 22 |
model = OVModelForCausalLM.from_pretrained(
|
| 23 |
MAIN_MODEL_ID,
|
| 24 |
ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""},
|
| 25 |
)
|
| 26 |
|
| 27 |
-
print(f"Loading Draft
|
| 28 |
try:
|
| 29 |
draft_model = OVModelForCausalLM.from_pretrained(
|
| 30 |
DRAFT_MODEL_ID,
|
| 31 |
ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""},
|
| 32 |
)
|
| 33 |
-
print("✅
|
| 34 |
except Exception as e:
|
| 35 |
-
print(f"⚠️
|
| 36 |
draft_model = None
|
| 37 |
|
| 38 |
except Exception as e:
|
| 39 |
-
print(f"❌
|
| 40 |
model = None
|
| 41 |
tokenizer = None
|
| 42 |
|
| 43 |
-
# --- 2.
|
| 44 |
def parse_system_prompt(mode, text_content, json_file):
|
| 45 |
-
if mode == "文本模式
|
| 46 |
return text_content
|
| 47 |
-
elif mode == "JSON模式
|
| 48 |
if json_file is None:
|
| 49 |
return "You are a helpful assistant."
|
| 50 |
try:
|
| 51 |
with open(json_file, 'r', encoding='utf-8') as f:
|
| 52 |
data = json.load(f)
|
|
|
|
| 53 |
if isinstance(data, str): return data
|
| 54 |
return data.get("system_prompt") or data.get("system") or data.get("prompt") or str(data)
|
| 55 |
except:
|
| 56 |
-
return "Error parsing JSON"
|
| 57 |
return "You are a helpful assistant."
|
| 58 |
|
| 59 |
-
# --- 3. 核心生成逻辑 (适配
|
| 60 |
-
def
|
| 61 |
if model is None:
|
| 62 |
-
|
|
|
|
| 63 |
return
|
| 64 |
|
| 65 |
-
#
|
| 66 |
-
#
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
#
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
for user_msg, bot_msg in past_history:
|
| 75 |
-
messages.append({"role": "user", "content": user_msg})
|
| 76 |
-
messages.append({"role": "assistant", "content": bot_msg})
|
| 77 |
-
messages.append({"role": "user", "content": user_message})
|
| 78 |
-
|
| 79 |
# 3. 准备推理
|
| 80 |
-
|
| 81 |
-
inputs = tokenizer(
|
| 82 |
|
| 83 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 84 |
|
|
@@ -90,58 +92,68 @@ def generate_response(history, mode, prompt_text, prompt_json):
|
|
| 90 |
do_sample=True,
|
| 91 |
top_p=0.9,
|
| 92 |
)
|
| 93 |
-
|
| 94 |
-
# 投机采样注入
|
| 95 |
if draft_model is not None:
|
| 96 |
gen_kwargs["assistant_model"] = draft_model
|
| 97 |
|
|
|
|
| 98 |
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
| 99 |
thread.start()
|
| 100 |
|
| 101 |
-
#
|
|
|
|
|
|
|
|
|
|
| 102 |
partial_text = ""
|
| 103 |
for new_text in streamer:
|
| 104 |
partial_text += new_text
|
| 105 |
-
# 更新 history
|
| 106 |
-
history[-1][
|
| 107 |
yield history
|
| 108 |
|
| 109 |
-
# --- 4.
|
| 110 |
-
|
| 111 |
-
with gr.Blocks(title="Qwen Turbo CPU") as demo:
|
| 112 |
gr.Markdown("## ⚡ Qwen OpenVINO + Speculative Decoding")
|
| 113 |
-
gr.Markdown("OpenVINO INT4 量化 + 投机采样 (Draft Model) 加速版")
|
| 114 |
|
| 115 |
with gr.Row():
|
| 116 |
with gr.Column(scale=1):
|
| 117 |
-
with gr.Accordion("🛠️
|
| 118 |
-
mode_radio = gr.Radio(["文本模式
|
| 119 |
sys_text = gr.Textbox(label="System Prompt", value="You are a helpful assistant.", lines=3)
|
| 120 |
sys_json = gr.File(label="JSON Config", file_types=[".json"], visible=False)
|
| 121 |
|
| 122 |
def update_vis(m):
|
| 123 |
-
return {sys_text: gr.update(visible=(m=="文本模式
|
| 124 |
mode_radio.change(update_vis, [mode_radio], [sys_text, sys_json])
|
| 125 |
|
| 126 |
with gr.Column(scale=3):
|
| 127 |
-
#
|
| 128 |
-
chatbot = gr.Chatbot(height=600, label="Qwen2.5-7B (Accel)")
|
| 129 |
msg = gr.Textbox(label="输入消息", placeholder="Enter 发送...")
|
|
|
|
| 130 |
with gr.Row():
|
| 131 |
submit_btn = gr.Button("发送", variant="primary")
|
| 132 |
clear_btn = gr.ClearButton([msg, chatbot])
|
| 133 |
|
| 134 |
-
#
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
| 138 |
|
| 139 |
-
#
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
)
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
| 145 |
)
|
| 146 |
|
| 147 |
if __name__ == "__main__":
|
|
|
|
| 4 |
from optimum.intel import OVModelForCausalLM
|
| 5 |
from transformers import AutoTokenizer, TextIteratorStreamer
|
| 6 |
from threading import Thread
|
| 7 |
+
import time
|
| 8 |
|
| 9 |
+
# --- 模型配置区 ---
|
| 10 |
+
# 8B 主模型 (INT4 量化)
|
| 11 |
MAIN_MODEL_ID = "OpenVINO/Qwen2.5-7B-Instruct-int4-ov"
|
| 12 |
+
# 0.5B 助手模型 (用于投机采样加速)
|
| 13 |
DRAFT_MODEL_ID = "hsuwill000/Qwen2.5-0.5B-Instruct-openvino-4bit"
|
| 14 |
|
| 15 |
+
print("🚀 初始化引擎中...")
|
| 16 |
|
| 17 |
+
# --- 1. 加载模型 (OpenVINO + 投机采样) ---
|
| 18 |
try:
|
| 19 |
tokenizer = AutoTokenizer.from_pretrained(MAIN_MODEL_ID)
|
| 20 |
|
| 21 |
+
print(f"Loading Main: {MAIN_MODEL_ID}...")
|
| 22 |
model = OVModelForCausalLM.from_pretrained(
|
| 23 |
MAIN_MODEL_ID,
|
| 24 |
ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""},
|
| 25 |
)
|
| 26 |
|
| 27 |
+
print(f"Loading Draft: {DRAFT_MODEL_ID}...")
|
| 28 |
try:
|
| 29 |
draft_model = OVModelForCausalLM.from_pretrained(
|
| 30 |
DRAFT_MODEL_ID,
|
| 31 |
ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""},
|
| 32 |
)
|
| 33 |
+
print("✅ 投机采样 (Speculative Decoding) 已激活")
|
| 34 |
except Exception as e:
|
| 35 |
+
print(f"⚠️ 助手模型加载失败,将使用普通模式: {e}")
|
| 36 |
draft_model = None
|
| 37 |
|
| 38 |
except Exception as e:
|
| 39 |
+
print(f"❌ 模型加载严重错误: {e}")
|
| 40 |
model = None
|
| 41 |
tokenizer = None
|
| 42 |
|
| 43 |
+
# --- 2. 辅助工具:解析 Prompt ---
|
| 44 |
def parse_system_prompt(mode, text_content, json_file):
|
| 45 |
+
if mode == "文本模式":
|
| 46 |
return text_content
|
| 47 |
+
elif mode == "JSON模式":
|
| 48 |
if json_file is None:
|
| 49 |
return "You are a helpful assistant."
|
| 50 |
try:
|
| 51 |
with open(json_file, 'r', encoding='utf-8') as f:
|
| 52 |
data = json.load(f)
|
| 53 |
+
# 兼容多种 JSON 格式
|
| 54 |
if isinstance(data, str): return data
|
| 55 |
return data.get("system_prompt") or data.get("system") or data.get("prompt") or str(data)
|
| 56 |
except:
|
| 57 |
+
return "Error parsing JSON file."
|
| 58 |
return "You are a helpful assistant."
|
| 59 |
|
| 60 |
+
# --- 3. 核心生成逻辑 (适配 Messages 格式) ---
|
| 61 |
+
def chat_response(history, mode, prompt_text, prompt_json):
|
| 62 |
if model is None:
|
| 63 |
+
history.append({"role": "assistant", "content": "模型加载失败,请检查 Logs。"})
|
| 64 |
+
yield history
|
| 65 |
return
|
| 66 |
|
| 67 |
+
# history 现在的格式是:
|
| 68 |
+
# [{'role': 'user', 'content': '你好'}, {'role': 'assistant', 'content': '...'}]
|
| 69 |
+
|
| 70 |
+
# 1. 获取用户最新的输入 (最后一条 user 消息)
|
| 71 |
+
# Gradio 的 type="messages" 会自动把用户输入加到 history 里传进来
|
| 72 |
+
# 所以我们不需要手动 history.append(user_input)
|
| 73 |
+
|
| 74 |
+
# 2. 构建推理用的 Prompt (在最前面插入 System Prompt)
|
| 75 |
+
system_prompt_content = parse_system_prompt(mode, prompt_text, prompt_json)
|
| 76 |
+
|
| 77 |
+
# 构建给模型看的 messages (临时列表,不影响 UI 显示)
|
| 78 |
+
model_messages = [{"role": "system", "content": system_prompt_content}]
|
| 79 |
+
model_messages.extend(history)
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
# 3. 准备推理
|
| 82 |
+
input_text = tokenizer.apply_chat_template(model_messages, tokenize=False, add_generation_prompt=True)
|
| 83 |
+
inputs = tokenizer(input_text, return_tensors="pt")
|
| 84 |
|
| 85 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 86 |
|
|
|
|
| 92 |
do_sample=True,
|
| 93 |
top_p=0.9,
|
| 94 |
)
|
| 95 |
+
|
|
|
|
| 96 |
if draft_model is not None:
|
| 97 |
gen_kwargs["assistant_model"] = draft_model
|
| 98 |
|
| 99 |
+
# 4. 启动生成
|
| 100 |
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
| 101 |
thread.start()
|
| 102 |
|
| 103 |
+
# 5. UI 更新 (流式)
|
| 104 |
+
# 先添加一个空的 assistant 消息占位
|
| 105 |
+
history.append({"role": "assistant", "content": ""})
|
| 106 |
+
|
| 107 |
partial_text = ""
|
| 108 |
for new_text in streamer:
|
| 109 |
partial_text += new_text
|
| 110 |
+
# 更新 history 的最后一条消息
|
| 111 |
+
history[-1]['content'] = partial_text
|
| 112 |
yield history
|
| 113 |
|
| 114 |
+
# --- 4. 构建界面 ---
|
| 115 |
+
with gr.Blocks(title="Qwen Turbo") as demo:
|
|
|
|
| 116 |
gr.Markdown("## ⚡ Qwen OpenVINO + Speculative Decoding")
|
|
|
|
| 117 |
|
| 118 |
with gr.Row():
|
| 119 |
with gr.Column(scale=1):
|
| 120 |
+
with gr.Accordion("🛠️ 设置提示词", open=True):
|
| 121 |
+
mode_radio = gr.Radio(["文本模式", "JSON模式"], label="模式", value="文本模式")
|
| 122 |
sys_text = gr.Textbox(label="System Prompt", value="You are a helpful assistant.", lines=3)
|
| 123 |
sys_json = gr.File(label="JSON Config", file_types=[".json"], visible=False)
|
| 124 |
|
| 125 |
def update_vis(m):
|
| 126 |
+
return {sys_text: gr.update(visible=(m=="文本模式")), sys_json: gr.update(visible=(m!="文本模式"))}
|
| 127 |
mode_radio.change(update_vis, [mode_radio], [sys_text, sys_json])
|
| 128 |
|
| 129 |
with gr.Column(scale=3):
|
| 130 |
+
# 关键点:这里显式指定 type="messages"
|
| 131 |
+
chatbot = gr.Chatbot(height=600, type="messages", label="Qwen2.5-7B (Accel)")
|
| 132 |
msg = gr.Textbox(label="输入消息", placeholder="Enter 发送...")
|
| 133 |
+
|
| 134 |
with gr.Row():
|
| 135 |
submit_btn = gr.Button("发送", variant="primary")
|
| 136 |
clear_btn = gr.ClearButton([msg, chatbot])
|
| 137 |
|
| 138 |
+
# --- 事件绑定 (核心修正) ---
|
| 139 |
+
|
| 140 |
+
# 1. 用户输入处理:直接把用户消息加到 history,并清空输入框
|
| 141 |
+
def user_turn(user_message, history):
|
| 142 |
+
return "", history + [{"role": "user", "content": user_message}]
|
| 143 |
|
| 144 |
+
# 2. 机器人回复处理:调用生成函数
|
| 145 |
+
# 注意:generate_response 会 yield 更新后的 history
|
| 146 |
+
|
| 147 |
+
msg.submit(
|
| 148 |
+
user_turn, [msg, chatbot], [msg, chatbot], queue=False
|
| 149 |
+
).then(
|
| 150 |
+
chat_response, [chatbot, mode_radio, sys_text, sys_json], [chatbot]
|
| 151 |
)
|
| 152 |
+
|
| 153 |
+
submit_btn.click(
|
| 154 |
+
user_turn, [msg, chatbot], [msg, chatbot], queue=False
|
| 155 |
+
).then(
|
| 156 |
+
chat_response, [chatbot, mode_radio, sys_text, sys_json], [chatbot]
|
| 157 |
)
|
| 158 |
|
| 159 |
if __name__ == "__main__":
|