Jiaqi-hkust commited on
Commit
16db65e
·
verified ·
1 Parent(s): 685215b

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +33 -49
app.py CHANGED
@@ -1,10 +1,8 @@
1
  import gradio as gr
2
- print(f"当前使用的 Gradio 版本是: {gr.__version__}")
3
  import os
4
  import torch
5
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
6
  from qwen_vl_utils import process_vision_info
7
- import html
8
 
9
  # 导入 spaces 模块用于 GPU 检测
10
  is_spaces = os.getenv("SPACE_ID") is not None
@@ -25,10 +23,7 @@ def gpu_decorator(func):
25
  return GPU(func)
26
  return func
27
 
28
- # 条件安装 flash-attn(延迟到模型加载时,避免启动时 CUDA 检查)
29
- # 注意:在 ZeroGPU 环境中,启动时 CUDA 可能还不可用
30
- # flash-attn 将在模型加载时根据实际 CUDA 可用性决定是否使用
31
-
32
  sys_prompt = """First output the types of degradations in image briefly in <TYPE> <TYPE_END> tags,
33
  and then output what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags,
34
  then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END> tags,
@@ -44,8 +39,14 @@ if not is_spaces:
44
 
45
  MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1-RL")
46
 
 
 
 
 
 
 
47
  print(f"==========================================")
48
- print(f"Initializing application...")
49
  print(f"==========================================")
50
 
51
  class ModelHandler:
@@ -61,20 +62,21 @@ class ModelHandler:
61
 
62
  self.processor = AutoProcessor.from_pretrained(self.model_path)
63
 
 
 
64
  if torch.cuda.is_available():
65
  device_capability = torch.cuda.get_device_capability()
66
- use_flash_attention = device_capability[0] >= 8
67
- print(f"🔧 CUDA available, device capability: {device_capability}")
 
68
  else:
69
- use_flash_attention = False
70
  print(f"🔧 Using CPU or non-CUDA device")
71
 
72
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
73
  self.model_path,
74
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
75
  device_map="auto",
76
- # attn_implementation="flash_attention_2" if use_flash_attention else "sdpa",
77
- attn_implementation="sdpa",
78
  trust_remote_code=True
79
  )
80
  print("✅ Model loaded successfully!")
@@ -82,11 +84,8 @@ class ModelHandler:
82
  print(f"❌ Model loading failed: {e}")
83
  raise e
84
 
85
-
86
  def predict(self, messages, temperature, max_tokens):
87
- # 注意:这里接收到的 messages 已经是标准的 [{'role': 'user', 'content': [...]}, ...]
88
-
89
- # 我们需要做一个深拷贝,避免修改 UI 上的 history 显示 System Prompt
90
  import copy
91
  messages_payload = copy.deepcopy(messages)
92
 
@@ -95,7 +94,6 @@ class ModelHandler:
95
  content = messages_payload[-1]["content"]
96
  sys_prompt_fmt = "\n" + " ".join(sys_prompt.split())
97
 
98
- # 现在的 content 肯定是 list (因为我们上面的 respond 函数构建的是 list)
99
  if isinstance(content, list):
100
  text_found = False
101
  for item in content:
@@ -108,7 +106,6 @@ class ModelHandler:
108
  elif isinstance(content, str):
109
  messages_payload[-1]["content"] += sys_prompt_fmt
110
 
111
- # 后续逻辑保持不变 ...
112
  text_prompt = self.processor.apply_chat_template(
113
  messages_payload, tokenize=False, add_generation_prompt=True
114
  )
@@ -169,37 +166,31 @@ def get_model_handler():
169
  @gpu_decorator
170
  def respond(user_msg, history, temp, tokens):
171
  """
172
- 针对 type="messages" 的 Chatbot 重写的响应函数
173
- history 现在的格式直接是: [{'role': 'user', 'content': ...}, {'role': 'assistant', ...}]
174
  """
175
 
176
- # 1. 构建当前用户的消息内容 (OpenAI 多模态格式)
177
  user_content = []
178
 
179
- # 处理图片/文件
180
  files = user_msg.get("files", [])
181
  for f in files:
182
- # qwen_vl_utils 识别 "image" 字段作为本地路径
183
  user_content.append({"type": "image", "image": f})
184
 
185
- # 处理文本
186
  text = user_msg.get("text", "")
187
  if text:
188
  user_content.append({"type": "text", "text": text})
189
 
190
- # 如果既没图也没��,直接返回
191
  if not user_content:
192
  yield history, gr.MultimodalTextbox(value=None, interactive=True)
193
  return
194
 
195
  # 2. 将用户消息加入历史
196
- # 注意:这里直接 append 一个 dict,而不是 tuple
197
  history.append({
198
  "role": "user",
199
  "content": user_content
200
  })
201
 
202
- # 立即更新 UI,让用户看到自己的输入(图文会在同一个气泡里)
203
  yield history, gr.MultimodalTextbox(value=None, interactive=False)
204
 
205
  # 3. 调用模型
@@ -209,8 +200,8 @@ def respond(user_msg, history, temp, tokens):
209
  history.append({"role": "assistant", "content": ""})
210
 
211
  full_response = ""
212
- # 调用你的 handler.predict (注意:你需要稍微调整 handler.predict 里sys_prompt 处理逻辑,见下文建议)
213
- for chunk in handler.predict(history[:-1], temp, tokens): # 传入除最后一条空回复外的历史
214
  full_response += chunk
215
  history[-1]["content"] = full_response
216
  yield history, gr.MultimodalTextbox(interactive=False)
@@ -218,32 +209,31 @@ def respond(user_msg, history, temp, tokens):
218
  except Exception as e:
219
  import traceback
220
  traceback.print_exc()
221
- history.append({"role": "assistant", "content": f"❌ Error: {str(e)}"})
 
 
 
222
  yield history, gr.MultimodalTextbox(interactive=True)
223
 
224
  # 恢复输入框
225
  yield history, gr.MultimodalTextbox(interactive=True)
226
 
227
  def create_chat_ui():
228
- custom_css = """
229
- .gradio-container { font-family: 'Inter', sans-serif; }
230
- #chatbot { height: 650px !important; overflow-y: auto; }
231
- """
232
-
233
- with gr.Blocks(title="Robust-R1", css=custom_css) as demo:
234
 
235
  with gr.Row():
236
  gr.Markdown("# 🤖 Robust-R1: Degradation-Aware Reasoning")
237
 
238
  with gr.Row():
239
  with gr.Column(scale=4):
240
- # 【关键修改】添加 type="messages"
241
  chatbot = gr.Chatbot(
242
  elem_id="chatbot",
243
  label="Chat",
244
  avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Qwen"),
245
  height=650,
246
- type="messages" # <--- 这里是重点!
247
  )
248
 
249
  chat_input = gr.MultimodalTextbox(
@@ -269,7 +259,6 @@ def create_chat_ui():
269
 
270
  gr.Markdown("---")
271
  gr.Markdown("### 📚 Examples")
272
- gr.Markdown("Click the examples below to quickly fill the input box and start a conversation")
273
 
274
  example_images_dir = os.path.join(project_dir, "assets")
275
 
@@ -291,7 +280,7 @@ def create_chat_ui():
291
  examples_per_page=3
292
  )
293
  else:
294
- gr.Markdown("*No example images available, please manually upload images for testing*")
295
 
296
  chat_input.submit(
297
  respond,
@@ -299,7 +288,6 @@ def create_chat_ui():
299
  outputs=[chatbot, chat_input]
300
  )
301
 
302
- # 清空历史只需要返回空列表 []
303
  clear_btn.click(lambda: ([], None), outputs=[chatbot, chat_input])
304
 
305
  return demo
@@ -307,16 +295,12 @@ def create_chat_ui():
307
  if __name__ == "__main__":
308
  demo = create_chat_ui()
309
 
310
- custom_css = """
311
- .gradio-container { font-family: 'Inter', sans-serif; }
312
- #chatbot { height: 650px !important; overflow-y: auto; }
313
- """
314
-
315
  if is_spaces:
316
  print(f"🚀 Running on Hugging Face Spaces: {os.getenv('SPACE_ID')}")
 
317
  demo.launch(
318
  theme=gr.themes.Soft(),
319
- css=custom_css,
320
  show_error=True,
321
  allowed_paths=[project_dir] if project_dir else None
322
  )
@@ -324,10 +308,10 @@ if __name__ == "__main__":
324
  print(f"🚀 Service is starting, please visit: http://localhost:7860")
325
  demo.launch(
326
  theme=gr.themes.Soft(),
327
- css=custom_css,
328
  server_name="0.0.0.0",
329
  server_port=7860,
330
  share=False,
331
  show_error=True,
332
  allowed_paths=[project_dir]
333
- )
 
1
  import gradio as gr
 
2
  import os
3
  import torch
4
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
  from qwen_vl_utils import process_vision_info
 
6
 
7
  # 导入 spaces 模块用于 GPU 检测
8
  is_spaces = os.getenv("SPACE_ID") is not None
 
23
  return GPU(func)
24
  return func
25
 
26
+ # 系统提示词
 
 
 
27
  sys_prompt = """First output the types of degradations in image briefly in <TYPE> <TYPE_END> tags,
28
  and then output what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags,
29
  then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END> tags,
 
39
 
40
  MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1-RL")
41
 
42
+ # 定义 CSS (移到全局,方便管理)
43
+ CUSTOM_CSS = """
44
+ .gradio-container { font-family: 'Inter', sans-serif; }
45
+ #chatbot { height: 650px !important; overflow-y: auto; }
46
+ """
47
+
48
  print(f"==========================================")
49
+ print(f"Initializing application (Gradio {gr.__version__})...")
50
  print(f"==========================================")
51
 
52
  class ModelHandler:
 
62
 
63
  self.processor = AutoProcessor.from_pretrained(self.model_path)
64
 
65
+ # 智能判断 Flash Attention
66
+ use_flash_attention = False
67
  if torch.cuda.is_available():
68
  device_capability = torch.cuda.get_device_capability()
69
+ if device_capability[0] >= 8:
70
+ use_flash_attention = True
71
+ print(f"🔧 CUDA available with Ampere+, utilizing Flash Attention 2")
72
  else:
 
73
  print(f"🔧 Using CPU or non-CUDA device")
74
 
75
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
76
  self.model_path,
77
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
78
  device_map="auto",
79
+ attn_implementation="flash_attention_2" if use_flash_attention else "sdpa",
 
80
  trust_remote_code=True
81
  )
82
  print("✅ Model loaded successfully!")
 
84
  print(f"❌ Model loading failed: {e}")
85
  raise e
86
 
 
87
  def predict(self, messages, temperature, max_tokens):
88
+ # 深拷贝消息,避免修改 UI 历史
 
 
89
  import copy
90
  messages_payload = copy.deepcopy(messages)
91
 
 
94
  content = messages_payload[-1]["content"]
95
  sys_prompt_fmt = "\n" + " ".join(sys_prompt.split())
96
 
 
97
  if isinstance(content, list):
98
  text_found = False
99
  for item in content:
 
106
  elif isinstance(content, str):
107
  messages_payload[-1]["content"] += sys_prompt_fmt
108
 
 
109
  text_prompt = self.processor.apply_chat_template(
110
  messages_payload, tokenize=False, add_generation_prompt=True
111
  )
 
166
  @gpu_decorator
167
  def respond(user_msg, history, temp, tokens):
168
  """
169
+ 针对 type="messages" 的 Chatbot 响应函数
 
170
  """
171
 
172
+ # 1. 构建当前用户的消息内容
173
  user_content = []
174
 
 
175
  files = user_msg.get("files", [])
176
  for f in files:
 
177
  user_content.append({"type": "image", "image": f})
178
 
 
179
  text = user_msg.get("text", "")
180
  if text:
181
  user_content.append({"type": "text", "text": text})
182
 
 
183
  if not user_content:
184
  yield history, gr.MultimodalTextbox(value=None, interactive=True)
185
  return
186
 
187
  # 2. 将用户消息加入历史
 
188
  history.append({
189
  "role": "user",
190
  "content": user_content
191
  })
192
 
193
+ # 立即更新 UI
194
  yield history, gr.MultimodalTextbox(value=None, interactive=False)
195
 
196
  # 3. 调用模型
 
200
  history.append({"role": "assistant", "content": ""})
201
 
202
  full_response = ""
203
+ # 传入 history[:-1] 避免传入空assistant 消息导致模板报错
204
+ for chunk in handler.predict(history[:-1], temp, tokens):
205
  full_response += chunk
206
  history[-1]["content"] = full_response
207
  yield history, gr.MultimodalTextbox(interactive=False)
 
209
  except Exception as e:
210
  import traceback
211
  traceback.print_exc()
212
+ # 如果还没加 assistant 消息就报错了,补一个
213
+ if not history or history[-1].get("role") != "assistant":
214
+ history.append({"role": "assistant", "content": ""})
215
+ history[-1]["content"] = f"❌ Error: {str(e)}"
216
  yield history, gr.MultimodalTextbox(interactive=True)
217
 
218
  # 恢复输入框
219
  yield history, gr.MultimodalTextbox(interactive=True)
220
 
221
  def create_chat_ui():
222
+ # 【修复点 1】: 这里不要传 css 参数
223
+ with gr.Blocks(title="Robust-R1") as demo:
 
 
 
 
224
 
225
  with gr.Row():
226
  gr.Markdown("# 🤖 Robust-R1: Degradation-Aware Reasoning")
227
 
228
  with gr.Row():
229
  with gr.Column(scale=4):
230
+ # Chatbot 设置 type="messages"
231
  chatbot = gr.Chatbot(
232
  elem_id="chatbot",
233
  label="Chat",
234
  avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Qwen"),
235
  height=650,
236
+ type="messages"
237
  )
238
 
239
  chat_input = gr.MultimodalTextbox(
 
259
 
260
  gr.Markdown("---")
261
  gr.Markdown("### 📚 Examples")
 
262
 
263
  example_images_dir = os.path.join(project_dir, "assets")
264
 
 
280
  examples_per_page=3
281
  )
282
  else:
283
+ gr.Markdown("*No example images available.*")
284
 
285
  chat_input.submit(
286
  respond,
 
288
  outputs=[chatbot, chat_input]
289
  )
290
 
 
291
  clear_btn.click(lambda: ([], None), outputs=[chatbot, chat_input])
292
 
293
  return demo
 
295
  if __name__ == "__main__":
296
  demo = create_chat_ui()
297
 
 
 
 
 
 
298
  if is_spaces:
299
  print(f"🚀 Running on Hugging Face Spaces: {os.getenv('SPACE_ID')}")
300
+ # 【修复点 2】: CSS 放在 launch 里
301
  demo.launch(
302
  theme=gr.themes.Soft(),
303
+ css=CUSTOM_CSS,
304
  show_error=True,
305
  allowed_paths=[project_dir] if project_dir else None
306
  )
 
308
  print(f"🚀 Service is starting, please visit: http://localhost:7860")
309
  demo.launch(
310
  theme=gr.themes.Soft(),
311
+ css=CUSTOM_CSS,
312
  server_name="0.0.0.0",
313
  server_port=7860,
314
  share=False,
315
  show_error=True,
316
  allowed_paths=[project_dir]
317
+ )