Jiaqi-hkust commited on
Commit
78ed009
·
verified ·
1 Parent(s): 8aa41c0

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +50 -38
app.py CHANGED
@@ -28,7 +28,7 @@ def gpu_decorator(func):
28
  # 注意:在 ZeroGPU 环境中,启动时 CUDA 可能还不可用
29
  # flash-attn 将在模型加载时根据实际 CUDA 可用性决定是否使用
30
 
31
- sys_prompt = """First output the the types of degradations in image briefly in <TYPE> <TYPE_END> tags,
32
  and then output what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags,
33
  then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END> tags,
34
  and then summarize the content of reasoning and the give the answer in <CONCLUSION> <CONCLUSION_END> tags,
@@ -188,7 +188,6 @@ class ModelHandler:
188
 
189
  model_handler = None
190
 
191
- @gpu_decorator # 标记此函数需要 GPU
192
  def get_model_handler():
193
  """Get model handler with lazy loading"""
194
  global model_handler
@@ -197,6 +196,53 @@ def get_model_handler():
197
  model_handler = ModelHandler(MODEL_PATH)
198
  return model_handler
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  def create_chat_ui():
201
  custom_css = """
202
  .gradio-container { font-family: 'Inter', sans-serif; }
@@ -214,7 +260,8 @@ def create_chat_ui():
214
  elem_id="chatbot",
215
  label="Chat",
216
  avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Qwen"),
217
- height=650
 
218
  )
219
 
220
  chat_input = gr.MultimodalTextbox(
@@ -264,41 +311,6 @@ def create_chat_ui():
264
  else:
265
  gr.Markdown("*No example images available, please manually upload images for testing*")
266
 
267
- async def respond(user_msg, history, temp, tokens):
268
- text = user_msg.get("text", "").strip()
269
- files = user_msg.get("files", [])
270
- user_content = list(files)
271
- if text: user_content.append(text)
272
-
273
- if not files and text: user_message = {"role": "user", "content": text}
274
- else: user_message = {"role": "user", "content": user_content}
275
-
276
- history.append(user_message)
277
- yield history, gr.MultimodalTextbox(value=None, interactive=False)
278
-
279
- history.append({"role": "assistant", "content": ""})
280
-
281
- try:
282
- previous_history = history[:-2] if len(history) >= 2 else []
283
-
284
- handler = get_model_handler()
285
- generated_text = ""
286
- for chunk in handler.predict(user_msg, previous_history, temp, tokens):
287
- generated_text = chunk
288
-
289
- safe_text = generated_text.replace("<", "&lt;").replace(">", "&gt;")
290
-
291
- history[-1]["content"] = safe_text
292
- yield history, gr.MultimodalTextbox(interactive=False)
293
-
294
- except Exception as e:
295
- import traceback
296
- traceback.print_exc()
297
- history[-1]["content"] = f"❌ Inference error: {str(e)}"
298
- yield history, gr.MultimodalTextbox(interactive=True)
299
-
300
- yield history, gr.MultimodalTextbox(value=None, interactive=True)
301
-
302
  chat_input.submit(
303
  respond,
304
  inputs=[chat_input, chatbot, temperature, max_tokens],
 
28
  # 注意:在 ZeroGPU 环境中,启动时 CUDA 可能还不可用
29
  # flash-attn 将在模型加载时根据实际 CUDA 可用性决定是否使用
30
 
31
+ sys_prompt = """First output the types of degradations in image briefly in <TYPE> <TYPE_END> tags,
32
  and then output what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags,
33
  then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END> tags,
34
  and then summarize the content of reasoning and the give the answer in <CONCLUSION> <CONCLUSION_END> tags,
 
188
 
189
  model_handler = None
190
 
 
191
  def get_model_handler():
192
  """Get model handler with lazy loading"""
193
  global model_handler
 
196
  model_handler = ModelHandler(MODEL_PATH)
197
  return model_handler
198
 
199
+ @gpu_decorator
200
+ async def respond(user_msg, history, temp, tokens):
201
+ text = user_msg.get("text", "").strip()
202
+ files = user_msg.get("files", [])
203
+
204
+ # ### <<< 修改点 3:构建正确的多模态消息格式
205
+ # 不能直接 append 路径字符串,要用字典 {"type": "image", "image": path}
206
+ user_content = []
207
+ for file_path in files:
208
+ user_content.append({"type": "image", "image": file_path})
209
+
210
+ if text:
211
+ user_content.append({"type": "text", "text": text})
212
+
213
+ # 构建符合 type="messages" 的用户消息
214
+ user_message = {"role": "user", "content": user_content}
215
+
216
+ history.append(user_message)
217
+ # 此时先 yield 一次,让用户看到自己的输入
218
+ yield history, gr.MultimodalTextbox(value=None, interactive=False)
219
+
220
+ history.append({"role": "assistant", "content": ""})
221
+
222
+ try:
223
+ # 截取历史记录(只取之前的对话,不包含当前这一轮,避免重复)
224
+ previous_history = history[:-2] if len(history) >= 2 else []
225
+
226
+ # 在这里调用 handler,此时我们在 @gpu_decorator 的保护下,可以访问 GPU
227
+ handler = get_model_handler()
228
+
229
+ generated_text = ""
230
+ # 传递原始的 user_msg 字典给 predict,或者根据需要调整 predict 的输入
231
+ # 注意:你的 predict 函数解析逻辑需要适配
232
+ for chunk in handler.predict(user_msg, previous_history, temp, tokens):
233
+ generated_text = chunk
234
+ safe_text = generated_text.replace("<", "&lt;").replace(">", "&gt;")
235
+ history[-1]["content"] = safe_text
236
+ yield history, gr.MultimodalTextbox(interactive=False)
237
+
238
+ except Exception as e:
239
+ import traceback
240
+ traceback.print_exc()
241
+ history[-1]["content"] = f"❌ Error: {str(e)}"
242
+ yield history, gr.MultimodalTextbox(interactive=True)
243
+
244
+ yield history, gr.MultimodalTextbox(value=None, interactive=True)
245
+
246
  def create_chat_ui():
247
  custom_css = """
248
  .gradio-container { font-family: 'Inter', sans-serif; }
 
260
  elem_id="chatbot",
261
  label="Chat",
262
  avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Qwen"),
263
+ height=650,
264
+ type="messages"
265
  )
266
 
267
  chat_input = gr.MultimodalTextbox(
 
311
  else:
312
  gr.Markdown("*No example images available, please manually upload images for testing*")
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  chat_input.submit(
315
  respond,
316
  inputs=[chat_input, chatbot, temperature, max_tokens],