wkplhc commited on
Commit
b63da4e
·
verified ·
1 Parent(s): bf18031

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -56
app.py CHANGED
@@ -4,81 +4,83 @@ import os
4
  from optimum.intel import OVModelForCausalLM
5
  from transformers import AutoTokenizer, TextIteratorStreamer
6
  from threading import Thread
7
- import gc
8
 
9
- # --- 配置区 ---
10
- # 8B 主模型
11
  MAIN_MODEL_ID = "OpenVINO/Qwen2.5-7B-Instruct-int4-ov"
12
- # 0.5B 助手模型 (投机采样核心)
13
  DRAFT_MODEL_ID = "hsuwill000/Qwen2.5-0.5B-Instruct-openvino-4bit"
14
 
15
- print("🚀 正在初始化极速推理引擎...")
16
 
17
- # --- 1. 加载模型 ---
18
  try:
19
  tokenizer = AutoTokenizer.from_pretrained(MAIN_MODEL_ID)
20
 
21
- print(f"Loading Main Model: {MAIN_MODEL_ID}...")
22
  model = OVModelForCausalLM.from_pretrained(
23
  MAIN_MODEL_ID,
24
  ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""},
25
  )
26
 
27
- print(f"Loading Draft Model: {DRAFT_MODEL_ID}...")
28
  try:
29
  draft_model = OVModelForCausalLM.from_pretrained(
30
  DRAFT_MODEL_ID,
31
  ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""},
32
  )
33
- print("✅ 投机采样加速已激活 (Main + Draft)")
34
  except Exception as e:
35
- print(f"⚠️ 助手模型加载失败,降级为普通推理: {e}")
36
  draft_model = None
37
 
38
  except Exception as e:
39
- print(f"❌ 致命错误: {e}")
40
  model = None
41
  tokenizer = None
42
 
43
- # --- 2. 辅助函数 ---
44
  def parse_system_prompt(mode, text_content, json_file):
45
- if mode == "文本模式 (Text)":
46
  return text_content
47
- elif mode == "JSON模式 (File)":
48
  if json_file is None:
49
  return "You are a helpful assistant."
50
  try:
51
  with open(json_file, 'r', encoding='utf-8') as f:
52
  data = json.load(f)
 
53
  if isinstance(data, str): return data
54
  return data.get("system_prompt") or data.get("system") or data.get("prompt") or str(data)
55
  except:
56
- return "Error parsing JSON"
57
  return "You are a helpful assistant."
58
 
59
- # --- 3. 核心生成逻辑 (适配 Tuple 历史格式) ---
60
- def generate_response(history, mode, prompt_text, prompt_json):
61
  if model is None:
62
- yield history + [["", "模型加载失败"]]
 
63
  return
64
 
65
- # 1. 提取当前问题和历史
66
- # Gradio Tuple 格式: [[q1, a1], [q2, a2], [curr_q, None]]
67
- user_message = history[-1][0]
68
- past_history = history[:-1]
69
-
70
- # 2. 构建 Prompt
71
- system_prompt = parse_system_prompt(mode, prompt_text, prompt_json)
72
- messages = [{"role": "system", "content": system_prompt}]
 
 
 
 
 
73
 
74
- for user_msg, bot_msg in past_history:
75
- messages.append({"role": "user", "content": user_msg})
76
- messages.append({"role": "assistant", "content": bot_msg})
77
- messages.append({"role": "user", "content": user_message})
78
-
79
  # 3. 准备推理
80
- text_input = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
81
- inputs = tokenizer(text_input, return_tensors="pt")
82
 
83
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
84
 
@@ -90,58 +92,68 @@ def generate_response(history, mode, prompt_text, prompt_json):
90
  do_sample=True,
91
  top_p=0.9,
92
  )
93
-
94
- # 投机采样注入
95
  if draft_model is not None:
96
  gen_kwargs["assistant_model"] = draft_model
97
 
 
98
  thread = Thread(target=model.generate, kwargs=gen_kwargs)
99
  thread.start()
100
 
101
- # 4. 流式更新
 
 
 
102
  partial_text = ""
103
  for new_text in streamer:
104
  partial_text += new_text
105
- # 更新 history 的最后一条记录
106
- history[-1][1] = partial_text
107
  yield history
108
 
109
- # --- 4. 构建 UI (修复版) ---
110
- # 移除 theme 参数放在 Blocks 初始化里,部分旧版本不支持
111
- with gr.Blocks(title="Qwen Turbo CPU") as demo:
112
  gr.Markdown("## ⚡ Qwen OpenVINO + Speculative Decoding")
113
- gr.Markdown("OpenVINO INT4 量化 + 投机采样 (Draft Model) 加速版")
114
 
115
  with gr.Row():
116
  with gr.Column(scale=1):
117
- with gr.Accordion("🛠️ 提示词设置", open=True):
118
- mode_radio = gr.Radio(["文本模式 (Text)", "JSON模式 (File)"], label="模式", value="文本模式 (Text)")
119
  sys_text = gr.Textbox(label="System Prompt", value="You are a helpful assistant.", lines=3)
120
  sys_json = gr.File(label="JSON Config", file_types=[".json"], visible=False)
121
 
122
  def update_vis(m):
123
- return {sys_text: gr.update(visible=(m=="文本模式 (Text)")), sys_json: gr.update(visible=(m!="文本模式 (Text)"))}
124
  mode_radio.change(update_vis, [mode_radio], [sys_text, sys_json])
125
 
126
  with gr.Column(scale=3):
127
- # 关键修复:移除 type="messages",使用默认的 Tuple 格式
128
- chatbot = gr.Chatbot(height=600, label="Qwen2.5-7B (Accel)")
129
  msg = gr.Textbox(label="输入消息", placeholder="Enter 发送...")
 
130
  with gr.Row():
131
  submit_btn = gr.Button("发送", variant="primary")
132
  clear_btn = gr.ClearButton([msg, chatbot])
133
 
134
- # 事件处理 logic (适配 Tuple)
135
- def user_fn(user_message, history):
136
- # 用户发言时,追加 [msg, None] 到历史
137
- return "", history + [[user_message, None]]
 
138
 
139
- # 绑定回车和点击
140
- msg.submit(user_fn, [msg, chatbot], [msg, chatbot], queue=False).then(
141
- generate_response, [chatbot, mode_radio, sys_text, sys_json], [chatbot]
 
 
 
 
142
  )
143
- submit_btn.click(user_fn, [msg, chatbot], [msg, chatbot], queue=False).then(
144
- generate_response, [chatbot, mode_radio, sys_text, sys_json], [chatbot]
 
 
 
145
  )
146
 
147
  if __name__ == "__main__":
 
4
  from optimum.intel import OVModelForCausalLM
5
  from transformers import AutoTokenizer, TextIteratorStreamer
6
  from threading import Thread
7
+ import time
8
 
9
+ # --- 模型配置区 ---
10
+ # 8B 主模型 (INT4 量化)
11
  MAIN_MODEL_ID = "OpenVINO/Qwen2.5-7B-Instruct-int4-ov"
12
+ # 0.5B 助手模型 (用于投机采样加速)
13
  DRAFT_MODEL_ID = "hsuwill000/Qwen2.5-0.5B-Instruct-openvino-4bit"
14
 
15
+ print("🚀 初始化引擎中...")
16
 
17
+ # --- 1. 加载模型 (OpenVINO + 投机采样) ---
18
  try:
19
  tokenizer = AutoTokenizer.from_pretrained(MAIN_MODEL_ID)
20
 
21
+ print(f"Loading Main: {MAIN_MODEL_ID}...")
22
  model = OVModelForCausalLM.from_pretrained(
23
  MAIN_MODEL_ID,
24
  ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""},
25
  )
26
 
27
+ print(f"Loading Draft: {DRAFT_MODEL_ID}...")
28
  try:
29
  draft_model = OVModelForCausalLM.from_pretrained(
30
  DRAFT_MODEL_ID,
31
  ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""},
32
  )
33
+ print("✅ 投机采样 (Speculative Decoding) 已激活")
34
  except Exception as e:
35
+ print(f"⚠️ 助手模型加载失败,将使用普通模式: {e}")
36
  draft_model = None
37
 
38
  except Exception as e:
39
+ print(f"❌ 模型加载严重错误: {e}")
40
  model = None
41
  tokenizer = None
42
 
43
+ # --- 2. 辅助工具:解析 Prompt ---
44
  def parse_system_prompt(mode, text_content, json_file):
45
+ if mode == "文本模式":
46
  return text_content
47
+ elif mode == "JSON模式":
48
  if json_file is None:
49
  return "You are a helpful assistant."
50
  try:
51
  with open(json_file, 'r', encoding='utf-8') as f:
52
  data = json.load(f)
53
+ # 兼容多种 JSON 格式
54
  if isinstance(data, str): return data
55
  return data.get("system_prompt") or data.get("system") or data.get("prompt") or str(data)
56
  except:
57
+ return "Error parsing JSON file."
58
  return "You are a helpful assistant."
59
 
60
+ # --- 3. 核心生成逻辑 (适配 Messages 格式) ---
61
+ def chat_response(history, mode, prompt_text, prompt_json):
62
  if model is None:
63
+ history.append({"role": "assistant", "content": "模型加载失败,请检查 Logs。"})
64
+ yield history
65
  return
66
 
67
+ # history 现在的格式是:
68
+ # [{'role': 'user', 'content': '你好'}, {'role': 'assistant', 'content': '...'}]
69
+
70
+ # 1. 获取用户最新的输入 (最后一条 user 消息)
71
+ # Gradio 的 type="messages" 会自动把用户输入加到 history 里传进来
72
+ # 所以我们不需要手动 history.append(user_input)
73
+
74
+ # 2. 构建推理用的 Prompt (在最前面插入 System Prompt)
75
+ system_prompt_content = parse_system_prompt(mode, prompt_text, prompt_json)
76
+
77
+ # 构建给模型看的 messages (临时列表,不影响 UI 显示)
78
+ model_messages = [{"role": "system", "content": system_prompt_content}]
79
+ model_messages.extend(history)
80
 
 
 
 
 
 
81
  # 3. 准备推理
82
+ input_text = tokenizer.apply_chat_template(model_messages, tokenize=False, add_generation_prompt=True)
83
+ inputs = tokenizer(input_text, return_tensors="pt")
84
 
85
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
86
 
 
92
  do_sample=True,
93
  top_p=0.9,
94
  )
95
+
 
96
  if draft_model is not None:
97
  gen_kwargs["assistant_model"] = draft_model
98
 
99
+ # 4. 启动生成
100
  thread = Thread(target=model.generate, kwargs=gen_kwargs)
101
  thread.start()
102
 
103
+ # 5. UI 更新 (流式)
104
+ # 先添加一个空的 assistant 消息占位
105
+ history.append({"role": "assistant", "content": ""})
106
+
107
  partial_text = ""
108
  for new_text in streamer:
109
  partial_text += new_text
110
+ # 更新 history 的最后一条消息
111
+ history[-1]['content'] = partial_text
112
  yield history
113
 
114
+ # --- 4. 构建界面 ---
115
+ with gr.Blocks(title="Qwen Turbo") as demo:
 
116
  gr.Markdown("## ⚡ Qwen OpenVINO + Speculative Decoding")
 
117
 
118
  with gr.Row():
119
  with gr.Column(scale=1):
120
+ with gr.Accordion("🛠️ 设置提示词", open=True):
121
+ mode_radio = gr.Radio(["文本模式", "JSON模式"], label="模式", value="文本模式")
122
  sys_text = gr.Textbox(label="System Prompt", value="You are a helpful assistant.", lines=3)
123
  sys_json = gr.File(label="JSON Config", file_types=[".json"], visible=False)
124
 
125
  def update_vis(m):
126
+ return {sys_text: gr.update(visible=(m=="文本模式")), sys_json: gr.update(visible=(m!="文本模式"))}
127
  mode_radio.change(update_vis, [mode_radio], [sys_text, sys_json])
128
 
129
  with gr.Column(scale=3):
130
+ # 关键点:这里显式指定 type="messages"
131
+ chatbot = gr.Chatbot(height=600, type="messages", label="Qwen2.5-7B (Accel)")
132
  msg = gr.Textbox(label="输入消息", placeholder="Enter 发送...")
133
+
134
  with gr.Row():
135
  submit_btn = gr.Button("发送", variant="primary")
136
  clear_btn = gr.ClearButton([msg, chatbot])
137
 
138
+ # --- 事件绑定 (核心修正) ---
139
+
140
+ # 1. 用户输入处理:直接把用户消息加到 history,并清空输入框
141
+ def user_turn(user_message, history):
142
+ return "", history + [{"role": "user", "content": user_message}]
143
 
144
+ # 2. 机器人回复处理:调用生成函数
145
+ # 注意:generate_response yield 更新后的 history
146
+
147
+ msg.submit(
148
+ user_turn, [msg, chatbot], [msg, chatbot], queue=False
149
+ ).then(
150
+ chat_response, [chatbot, mode_radio, sys_text, sys_json], [chatbot]
151
  )
152
+
153
+ submit_btn.click(
154
+ user_turn, [msg, chatbot], [msg, chatbot], queue=False
155
+ ).then(
156
+ chat_response, [chatbot, mode_radio, sys_text, sys_json], [chatbot]
157
  )
158
 
159
  if __name__ == "__main__":