Jiaqi-hkust commited on
Commit
685215b
·
verified ·
1 Parent(s): 46b4550

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +56 -99
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  import os
3
  import torch
4
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
@@ -83,18 +84,19 @@ class ModelHandler:
83
 
84
 
85
  def predict(self, messages, temperature, max_tokens):
86
- # 【修复3】简化 predict 逻辑
87
- # 不要在函数内部重新解析 history,而是直接接收处理好的标准 messages 列表
88
 
89
- # system prompt 拼接到最后一条用户消息中
90
- if messages and messages[-1]["role"] == "user":
91
- content = messages[-1]["content"]
 
 
 
 
92
  sys_prompt_fmt = "\n" + " ".join(sys_prompt.split())
93
 
94
- if isinstance(content, str):
95
- messages[-1]["content"] += sys_prompt_fmt
96
- elif isinstance(content, list):
97
- # 如果是多模态列表,找到 text 部分追加
98
  text_found = False
99
  for item in content:
100
  if item.get("type") == "text":
@@ -103,14 +105,14 @@ class ModelHandler:
103
  break
104
  if not text_found:
105
  content.append({"type": "text", "text": sys_prompt_fmt})
 
 
106
 
107
- print(f"Total messages for model: {len(messages)}")
108
-
109
  text_prompt = self.processor.apply_chat_template(
110
- messages, tokenize=False, add_generation_prompt=True
111
  )
112
-
113
- image_inputs, video_inputs = process_vision_info(messages)
114
 
115
  inputs = self.processor(
116
  text=[text_prompt],
@@ -141,7 +143,7 @@ class ModelHandler:
141
  generated_ids,
142
  skip_special_tokens=True
143
  )
144
-
145
  if generated_text:
146
  yield generated_text
147
  else:
@@ -164,106 +166,59 @@ def get_model_handler():
164
  model_handler = ModelHandler(MODEL_PATH)
165
  return model_handler
166
 
167
- def _history_to_messages(history):
168
- """
169
- 【新增辅助函数】学习参考代码的逻辑:
170
- 将 Gradio 的 Tuple 历史 [[user_msg, bot_msg], ...]
171
- 转换为 Model 需要的 OpenAI 格式 [{'role': 'user', ...}, ...]
172
- """
173
- messages = []
174
-
175
- for pair in history:
176
- user_msg, bot_msg = pair
177
-
178
- # --- 1. 处理用户消息 ---
179
- if user_msg:
180
- # 判断是否为图片路径(简单判断)
181
- is_image = False
182
- if isinstance(user_msg, str):
183
- if os.path.exists(user_msg) or any(user_msg.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.webp']):
184
- is_image = True
185
-
186
- if is_image:
187
- # 图片消息
188
- messages.append({
189
- "role": "user",
190
- "content": [{"type": "image", "image": user_msg}]
191
- })
192
- else:
193
- # 文本消息
194
- messages.append({
195
- "role": "user",
196
- "content": [{"type": "text", "text": str(user_msg)}]
197
- })
198
-
199
- # --- 2. 处理机器人消息 ---
200
- if bot_msg:
201
- messages.append({
202
- "role": "assistant",
203
- "content": [{"type": "text", "text": str(bot_msg)}]
204
- })
205
-
206
- return messages
207
-
208
  @gpu_decorator
209
  def respond(user_msg, history, temp, tokens):
210
  """
211
- 【修复4】完全重写 respond 函数
212
- 解决 UI (Tuple) Model (Dict) 格式不兼容的问题
213
  """
214
 
215
- # 1. 解析 Gradio MultimodalTextbox 的输入
216
- files = user_msg.get("files", [])
217
- text = user_msg.get("text", "")
218
 
219
- # 2. 更新 UI 历史 (Tuple 格式)
220
- # -------------------------------------------------
221
- # 先把图片加进历史
222
  for f in files:
223
- history.append([f, None])
224
-
225
- # 再把文本加进历史
 
 
226
  if text:
227
- history.append([text, None])
228
- elif not text and files:
229
- # 如果只有图没字,也需要触发回复,不做操作,history 已经有了图片项
230
- pass
231
 
232
- # 立即 yield,让用户先看到自己的输入
233
- yield history, gr.MultimodalTextbox(value=None, interactive=False)
 
 
 
 
 
 
 
 
 
234
 
235
- # 3. 准备模型输入 (Dict 格式)
236
- # -------------------------------------------------
 
 
237
  try:
238
  handler = get_model_handler()
239
 
240
- # 将当前的 Tuple 历史转换为 Model 能读懂的 Messages 列表
241
- messages = _history_to_messages(history)
242
 
243
- # 在 UI 历史上预留一个空位给机器人回复
244
- # 如果最后一条只有 User 内容 (例如 [text, None]),我们填充它
245
- if history and history[-1][1] is None:
246
- history[-1][1] = ""
247
- else:
248
- history.append([None, ""]) # 万一没对应上,追加一个空回复
249
-
250
- # 4. 调用模型流式生成
251
  full_response = ""
252
- for chunk in handler.predict(messages, temp, tokens):
 
253
  full_response += chunk
254
- # 实时更新 UI 历史的最后一条
255
- history[-1][1] = full_response
256
  yield history, gr.MultimodalTextbox(interactive=False)
257
 
258
  except Exception as e:
259
  import traceback
260
  traceback.print_exc()
261
- # 错误处理
262
- err_msg = f"❌ Error: {str(e)}"
263
- if history and history[-1][1] is None:
264
- history[-1][1] = err_msg
265
- else:
266
- history.append([None, err_msg])
267
  yield history, gr.MultimodalTextbox(interactive=True)
268
 
269
  # 恢复输入框
@@ -275,18 +230,20 @@ def create_chat_ui():
275
  #chatbot { height: 650px !important; overflow-y: auto; }
276
  """
277
 
278
- with gr.Blocks(title="Robust-R1") as demo:
279
 
280
  with gr.Row():
281
- gr.Markdown("# 🤖Robust-R1:Degradation-Aware Reasoning for Robust Visual Understanding")
282
 
283
  with gr.Row():
284
  with gr.Column(scale=4):
 
285
  chatbot = gr.Chatbot(
286
  elem_id="chatbot",
287
  label="Chat",
288
  avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Qwen"),
289
- height=650
 
290
  )
291
 
292
  chat_input = gr.MultimodalTextbox(
@@ -342,8 +299,8 @@ def create_chat_ui():
342
  outputs=[chatbot, chat_input]
343
  )
344
 
345
- def clear_history(): return [], None
346
- clear_btn.click(clear_history, outputs=[chatbot, chat_input])
347
 
348
  return demo
349
 
 
1
  import gradio as gr
2
+ print(f"当前使用的 Gradio 版本是: {gr.__version__}")
3
  import os
4
  import torch
5
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
 
84
 
85
 
86
  def predict(self, messages, temperature, max_tokens):
87
+ # 注意:这里接收到的 messages 已经是标准的 [{'role': 'user', 'content': [...]}, ...]
 
88
 
89
+ # 我们需要做一个深拷贝,避免修改 UI 上的 history 显示 System Prompt
90
+ import copy
91
+ messages_payload = copy.deepcopy(messages)
92
+
93
+ # 拼接 System Prompt
94
+ if messages_payload and messages_payload[-1]["role"] == "user":
95
+ content = messages_payload[-1]["content"]
96
  sys_prompt_fmt = "\n" + " ".join(sys_prompt.split())
97
 
98
+ # 现在的 content 肯定是 list (因为我们上面的 respond 函数构建的是 list)
99
+ if isinstance(content, list):
 
 
100
  text_found = False
101
  for item in content:
102
  if item.get("type") == "text":
 
105
  break
106
  if not text_found:
107
  content.append({"type": "text", "text": sys_prompt_fmt})
108
+ elif isinstance(content, str):
109
+ messages_payload[-1]["content"] += sys_prompt_fmt
110
 
111
+ # 后续逻辑保持不变 ...
 
112
  text_prompt = self.processor.apply_chat_template(
113
+ messages_payload, tokenize=False, add_generation_prompt=True
114
  )
115
+ image_inputs, video_inputs = process_vision_info(messages_payload)
 
116
 
117
  inputs = self.processor(
118
  text=[text_prompt],
 
143
  generated_ids,
144
  skip_special_tokens=True
145
  )
146
+ print(f"Generated text: {generated_text}")
147
  if generated_text:
148
  yield generated_text
149
  else:
 
166
  model_handler = ModelHandler(MODEL_PATH)
167
  return model_handler
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  @gpu_decorator
170
  def respond(user_msg, history, temp, tokens):
171
  """
172
+ 针对 type="messages" 的 Chatbot 重写的响应函数
173
+ history 现在的格式直接是: [{'role': 'user', 'content': ...}, {'role': 'assistant', ...}]
174
  """
175
 
176
+ # 1. 构建当前用户的消息内容 (OpenAI 多模态格式)
177
+ user_content = []
 
178
 
179
+ # 处理图片/文件
180
+ files = user_msg.get("files", [])
 
181
  for f in files:
182
+ # qwen_vl_utils 识别 "image" 字段作为本地路径
183
+ user_content.append({"type": "image", "image": f})
184
+
185
+ # 处理文本
186
+ text = user_msg.get("text", "")
187
  if text:
188
+ user_content.append({"type": "text", "text": text})
 
 
 
189
 
190
+ # 如果既没图也没字,直接返回
191
+ if not user_content:
192
+ yield history, gr.MultimodalTextbox(value=None, interactive=True)
193
+ return
194
+
195
+ # 2. 将用户消息加入历史
196
+ # 注意:这里直接 append 一个 dict,而不是 tuple
197
+ history.append({
198
+ "role": "user",
199
+ "content": user_content
200
+ })
201
 
202
+ # 立即更新 UI,让用户看到自己的输入(图文会在同一个气泡里)
203
+ yield history, gr.MultimodalTextbox(value=None, interactive=False)
204
+
205
+ # 3. 调用模型
206
  try:
207
  handler = get_model_handler()
208
 
209
+ history.append({"role": "assistant", "content": ""})
 
210
 
 
 
 
 
 
 
 
 
211
  full_response = ""
212
+ # 调用你的 handler.predict (注意:你需要稍微调整 handler.predict 里的 sys_prompt 处理逻辑,见下文建议)
213
+ for chunk in handler.predict(history[:-1], temp, tokens): # 传入除最后一条空回复外的历史
214
  full_response += chunk
215
+ history[-1]["content"] = full_response
 
216
  yield history, gr.MultimodalTextbox(interactive=False)
217
 
218
  except Exception as e:
219
  import traceback
220
  traceback.print_exc()
221
+ history.append({"role": "assistant", "content": f"❌ Error: {str(e)}"})
 
 
 
 
 
222
  yield history, gr.MultimodalTextbox(interactive=True)
223
 
224
  # 恢复输入框
 
230
  #chatbot { height: 650px !important; overflow-y: auto; }
231
  """
232
 
233
+ with gr.Blocks(title="Robust-R1", css=custom_css) as demo:
234
 
235
  with gr.Row():
236
+ gr.Markdown("# 🤖 Robust-R1: Degradation-Aware Reasoning")
237
 
238
  with gr.Row():
239
  with gr.Column(scale=4):
240
+ # 【关键修改】添加 type="messages"
241
  chatbot = gr.Chatbot(
242
  elem_id="chatbot",
243
  label="Chat",
244
  avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Qwen"),
245
+ height=650,
246
+ type="messages" # <--- 这里是重点!
247
  )
248
 
249
  chat_input = gr.MultimodalTextbox(
 
299
  outputs=[chatbot, chat_input]
300
  )
301
 
302
+ # 清空历史只需要返回空列表 []
303
+ clear_btn.click(lambda: ([], None), outputs=[chatbot, chat_input])
304
 
305
  return demo
306