wkplhc commited on
Commit
a354de5
·
verified ·
1 Parent(s): b63da4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -78
app.py CHANGED
@@ -6,81 +6,73 @@ from transformers import AutoTokenizer, TextIteratorStreamer
6
  from threading import Thread
7
  import time
8
 
9
- # --- 模型配置区 ---
10
- # 8B 主模型 (INT4 量化)
11
  MAIN_MODEL_ID = "OpenVINO/Qwen2.5-7B-Instruct-int4-ov"
12
- # 0.5B 助手模型 (用于投机采样加速)
13
  DRAFT_MODEL_ID = "hsuwill000/Qwen2.5-0.5B-Instruct-openvino-4bit"
14
 
15
- print("🚀 初始化引擎中...")
16
 
17
- # --- 1. 加载模型 (OpenVINO + 投机采样) ---
18
  try:
19
  tokenizer = AutoTokenizer.from_pretrained(MAIN_MODEL_ID)
20
-
21
- print(f"Loading Main: {MAIN_MODEL_ID}...")
22
  model = OVModelForCausalLM.from_pretrained(
23
  MAIN_MODEL_ID,
24
  ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""},
25
  )
26
 
27
- print(f"Loading Draft: {DRAFT_MODEL_ID}...")
28
  try:
29
  draft_model = OVModelForCausalLM.from_pretrained(
30
  DRAFT_MODEL_ID,
31
  ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""},
32
  )
33
  print("✅ 投机采样 (Speculative Decoding) 已激活")
34
- except Exception as e:
35
- print(f"⚠️ 助手模型加载失败,将使用普通模式: {e}")
36
  draft_model = None
 
37
 
38
  except Exception as e:
39
- print(f"❌ 模型加载严重错误: {e}")
40
  model = None
41
- tokenizer = None
42
 
43
- # --- 2. 辅助工具:解析 Prompt ---
44
  def parse_system_prompt(mode, text_content, json_file):
45
  if mode == "文本模式":
46
  return text_content
47
  elif mode == "JSON模式":
48
- if json_file is None:
49
- return "You are a helpful assistant."
50
  try:
51
- with open(json_file, 'r', encoding='utf-8') as f:
52
  data = json.load(f)
53
- # 兼容多种 JSON 格式
54
  if isinstance(data, str): return data
55
  return data.get("system_prompt") or data.get("system") or data.get("prompt") or str(data)
56
  except:
57
- return "Error parsing JSON file."
58
  return "You are a helpful assistant."
59
 
60
- # --- 3. 核心生成逻辑 (适配 Messages 格式) ---
61
- def chat_response(history, mode, prompt_text, prompt_json):
 
 
 
62
  if model is None:
63
- history.append({"role": "assistant", "content": "模型加载失败,请检查 Logs。"})
64
- yield history
65
  return
66
 
67
- # history 现在的格式是:
68
- # [{'role': 'user', 'content': '你好'}, {'role': 'assistant', 'content': '...'}]
69
 
70
- # 1. 获取用户最新的输入 (最后一条 user 消息)
71
- # Gradio type="messages" 会自动把用户输入加到 history 里传进来
72
- # 所以我们不需要手动 history.append(user_input)
73
 
74
- # 2. 构建推理用的 Prompt (在最前面插入 System Prompt)
75
- system_prompt_content = parse_system_prompt(mode, prompt_text, prompt_json)
 
 
 
76
 
77
- # 构建给模型看的 messages (临时列表,不影响 UI 显示)
78
- model_messages = [{"role": "system", "content": system_prompt_content}]
79
- model_messages.extend(history)
80
-
81
- # 3. 准备推理
82
- input_text = tokenizer.apply_chat_template(model_messages, tokenize=False, add_generation_prompt=True)
83
- inputs = tokenizer(input_text, return_tensors="pt")
84
 
85
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
86
 
@@ -96,65 +88,60 @@ def chat_response(history, mode, prompt_text, prompt_json):
96
  if draft_model is not None:
97
  gen_kwargs["assistant_model"] = draft_model
98
 
99
- # 4. 启动生成
100
- thread = Thread(target=model.generate, kwargs=gen_kwargs)
101
- thread.start()
102
 
103
- # 5. UI 更新 (流式)
104
- # 先添加一个空的 assistant 消息占位
105
- history.append({"role": "assistant", "content": ""})
106
-
107
  partial_text = ""
108
  for new_text in streamer:
109
  partial_text += new_text
110
- # 更新 history 的最后一条消息
111
- history[-1]['content'] = partial_text
112
- yield history
113
 
114
- # --- 4. 构建界面 ---
115
- with gr.Blocks(title="Qwen Turbo") as demo:
116
  gr.Markdown("## ⚡ Qwen OpenVINO + Speculative Decoding")
117
 
118
  with gr.Row():
119
  with gr.Column(scale=1):
120
- with gr.Accordion("🛠️ 设置提示词", open=True):
121
- mode_radio = gr.Radio(["文本模式", "JSON模式"], label="模式", value="文本模式")
122
- sys_text = gr.Textbox(label="System Prompt", value="You are a helpful assistant.", lines=3)
123
- sys_json = gr.File(label="JSON Config", file_types=[".json"], visible=False)
124
 
125
- def update_vis(m):
126
- return {sys_text: gr.update(visible=(m=="文本模式")), sys_json: gr.update(visible=(m!="文本模式"))}
127
- mode_radio.change(update_vis, [mode_radio], [sys_text, sys_json])
128
 
129
  with gr.Column(scale=3):
130
- # 关键点:这里显式指定 type="messages"
131
- chatbot = gr.Chatbot(height=600, type="messages", label="Qwen2.5-7B (Accel)")
132
- msg = gr.Textbox(label="输入消息", placeholder="Enter 发送...")
133
-
134
  with gr.Row():
135
- submit_btn = gr.Button("发送", variant="primary")
136
- clear_btn = gr.ClearButton([msg, chatbot])
137
-
138
- # --- 事件绑定 (核心修正) ---
139
-
140
- # 1. 用户输入处理:直接把用户消息加到 history,并清空输入框
141
- def user_turn(user_message, history):
142
- return "", history + [{"role": "user", "content": user_message}]
143
-
144
- # 2. 机器人回复处理:调用生成函数
145
- # 注意:generate_response 会 yield 更新后的 history
146
 
147
- msg.submit(
148
- user_turn, [msg, chatbot], [msg, chatbot], queue=False
149
- ).then(
150
- chat_response, [chatbot, mode_radio, sys_text, sys_json], [chatbot]
151
  )
 
152
 
153
- submit_btn.click(
154
- user_turn, [msg, chatbot], [msg, chatbot], queue=False
155
- ).then(
156
- chat_response, [chatbot, mode_radio, sys_text, sys_json], [chatbot]
157
  )
 
158
 
159
  if __name__ == "__main__":
160
  demo.queue().launch()
 
6
  from threading import Thread
7
  import time
8
 
9
+ # --- 模型配置 (保持不变,因为日志显示加载成功了) ---
 
10
  MAIN_MODEL_ID = "OpenVINO/Qwen2.5-7B-Instruct-int4-ov"
 
11
  DRAFT_MODEL_ID = "hsuwill000/Qwen2.5-0.5B-Instruct-openvino-4bit"
12
 
13
+ print("🚀 启动引擎...")
14
 
15
+ # --- 1. 加载模型 ---
16
  try:
17
  tokenizer = AutoTokenizer.from_pretrained(MAIN_MODEL_ID)
 
 
18
  model = OVModelForCausalLM.from_pretrained(
19
  MAIN_MODEL_ID,
20
  ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""},
21
  )
22
 
 
23
  try:
24
  draft_model = OVModelForCausalLM.from_pretrained(
25
  DRAFT_MODEL_ID,
26
  ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""},
27
  )
28
  print("✅ 投机采样 (Speculative Decoding) 已激活")
29
+ except:
 
30
  draft_model = None
31
+ print("⚠️ 仅使用主模型推理")
32
 
33
  except Exception as e:
34
+ print(f"❌ 加载失败: {e}")
35
  model = None
 
36
 
37
+ # --- 2. 辅助函数 ---
38
  def parse_system_prompt(mode, text_content, json_file):
39
  if mode == "文本模式":
40
  return text_content
41
  elif mode == "JSON模式":
42
+ if json_file is None: return "You are a helpful assistant."
 
43
  try:
44
+ with open(json_file.name, 'r', encoding='utf-8') as f:
45
  data = json.load(f)
 
46
  if isinstance(data, str): return data
47
  return data.get("system_prompt") or data.get("system") or data.get("prompt") or str(data)
48
  except:
49
+ return "Error parsing JSON."
50
  return "You are a helpful assistant."
51
 
52
+ # --- 3. 核心逻辑 (兼容旧版 Gradio 的 Tuple 格式) ---
53
+ def predict(message, history, mode, prompt_text, prompt_json):
54
+ # history 格式: [[User1, Bot1], [User2, Bot2]]
55
+ # message: 当前用户输入 (Str)
56
+
57
  if model is None:
58
+ yield history + [[message, "模型加载失败"]]
 
59
  return
60
 
61
+ # 1. 解析系统提示词
62
+ sys_prompt = parse_system_prompt(mode, prompt_text, prompt_json)
63
 
64
+ # 2. Tuple 历史转换为模型需要的 List of Dicts
65
+ model_inputs = [{"role": "system", "content": sys_prompt}]
 
66
 
67
+ for user_msg, bot_msg in history:
68
+ model_inputs.append({"role": "user", "content": user_msg})
69
+ model_inputs.append({"role": "assistant", "content": bot_msg})
70
+
71
+ model_inputs.append({"role": "user", "content": message})
72
 
73
+ # 3. 构建输入
74
+ text = tokenizer.apply_chat_template(model_inputs, tokenize=False, add_generation_prompt=True)
75
+ inputs = tokenizer(text, return_tensors="pt")
 
 
 
 
76
 
77
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
78
 
 
88
  if draft_model is not None:
89
  gen_kwargs["assistant_model"] = draft_model
90
 
91
+ # 4. 线程生成
92
+ t = Thread(target=model.generate, kwargs=gen_kwargs)
93
+ t.start()
94
 
95
+ # 5. 流式输出,适配 Chatbot 格式
 
 
 
96
  partial_text = ""
97
  for new_text in streamer:
98
  partial_text += new_text
99
+ # yield 的格式必须是: history_list
100
+ # 即: [[old_u, old_b], ..., [current_u, current_partial_b]]
101
+ yield history + [[message, partial_text]]
102
 
103
+ # --- 4. 界面构建 ---
104
+ with gr.Blocks(title="Qwen Extreme") as demo:
105
  gr.Markdown("## ⚡ Qwen OpenVINO + Speculative Decoding")
106
 
107
  with gr.Row():
108
  with gr.Column(scale=1):
109
+ with gr.Accordion("设置", open=True):
110
+ mode = gr.Radio(["文本模式", "JSON模式"], value="文本模式", label="Prompt模式")
111
+ p_text = gr.Textbox(value="You are a helpful assistant.", lines=3, label="System Prompt")
112
+ p_json = gr.File(label="JSON文件", file_types=[".json"], visible=False)
113
 
114
+ def toggle(m):
115
+ return {p_text: gr.update(visible=m=="文本模式"), p_json: gr.update(visible=m=="JSON模式")}
116
+ mode.change(toggle, mode, [p_text, p_json])
117
 
118
  with gr.Column(scale=3):
119
+ # 关键修改:移除了 type="messages",默认就是 tuple 格式,绝对安全
120
+ chatbot = gr.Chatbot(height=600, label="Qwen2.5-7B")
121
+ msg = gr.Textbox(label="输入")
 
122
  with gr.Row():
123
+ btn = gr.Button("发送", variant="primary")
124
+ clear = gr.ClearButton([msg, chatbot])
125
+
126
+ # 事件绑定 (简单粗暴版)
127
+ # 当点击发送时:
128
+ # 1. 调用 predict,传入 msg 和 chatbot(也就是history)
129
+ # 2. 将 predict 的输出(新的history) 更新给 chatbot
130
+ # 3. 清空 msg
 
 
 
131
 
132
+ submit_event = msg.submit(
133
+ predict,
134
+ inputs=[msg, chatbot, mode, p_text, p_json],
135
+ outputs=[chatbot]
136
  )
137
+ msg.submit(lambda: "", None, msg) # 清空输入框
138
 
139
+ btn_event = btn.click(
140
+ predict,
141
+ inputs=[msg, chatbot, mode, p_text, p_json],
142
+ outputs=[chatbot]
143
  )
144
+ btn.click(lambda: "", None, msg) # 清空输入框
145
 
146
  if __name__ == "__main__":
147
  demo.queue().launch()