Jiaqi-hkust commited on
Commit
dfe5589
·
verified ·
1 Parent(s): 16db65e

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +142 -221
  2. requirements.txt +1 -1
app.py CHANGED
@@ -3,8 +3,11 @@ import os
3
  import torch
4
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
  from qwen_vl_utils import process_vision_info
 
6
 
7
- # 导入 spaces 模块用于 GPU 检测
 
 
8
  is_spaces = os.getenv("SPACE_ID") is not None
9
  spaces_available = False
10
  GPU = None
@@ -16,39 +19,33 @@ if is_spaces:
16
  except ImportError:
17
  print("⚠️ spaces module not available, GPU detection may not work")
18
 
19
- # 创建条件装饰器
20
  def gpu_decorator(func):
21
  """条件应用 GPU 装饰器"""
22
  if spaces_available and GPU is not None:
23
  return GPU(func)
24
  return func
25
 
 
 
 
 
 
 
26
  # 系统提示词
27
- sys_prompt = """First output the types of degradations in image briefly in <TYPE> <TYPE_END> tags,
28
  and then output what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags,
29
  then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END> tags,
30
  and then summarize the content of reasoning and the give the answer in <CONCLUSION> <CONCLUSION_END> tags,
31
  provides the user with the answer briefly in <ANSWER> <ANSWER_END>."""
32
 
33
- project_dir = os.path.dirname(os.path.abspath(__file__))
34
-
35
- if not is_spaces:
36
- temp_dir = os.path.join(project_dir, ".gradio_temp")
37
- os.makedirs(temp_dir, exist_ok=True)
38
- os.environ["GRADIO_TEMP_DIR"] = temp_dir
39
-
40
- MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1-RL")
41
-
42
- # 定义 CSS (移到全局,方便管理)
43
  CUSTOM_CSS = """
44
  .gradio-container { font-family: 'Inter', sans-serif; }
45
- #chatbot { height: 650px !important; overflow-y: auto; }
46
  """
47
 
48
- print(f"==========================================")
49
- print(f"Initializing application (Gradio {gr.__version__})...")
50
- print(f"==========================================")
51
-
52
  class ModelHandler:
53
  def __init__(self, model_path):
54
  self.model_path = model_path
@@ -58,8 +55,7 @@ class ModelHandler:
58
 
59
  def _load_model(self):
60
  try:
61
- print(f"⏳ Loading model weights, this may take a few minutes...")
62
-
63
  self.processor = AutoProcessor.from_pretrained(self.model_path)
64
 
65
  # 智能判断 Flash Attention
@@ -69,8 +65,6 @@ class ModelHandler:
69
  if device_capability[0] >= 8:
70
  use_flash_attention = True
71
  print(f"🔧 CUDA available with Ampere+, utilizing Flash Attention 2")
72
- else:
73
- print(f"🔧 Using CPU or non-CUDA device")
74
 
75
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
76
  self.model_path,
@@ -84,234 +78,161 @@ class ModelHandler:
84
  print(f"❌ Model loading failed: {e}")
85
  raise e
86
 
87
- def predict(self, messages, temperature, max_tokens):
88
- # 深拷贝消息,避免修改 UI 历史
89
- import copy
90
- messages_payload = copy.deepcopy(messages)
91
-
92
- # 拼接 System Prompt
93
- if messages_payload and messages_payload[-1]["role"] == "user":
94
- content = messages_payload[-1]["content"]
95
- sys_prompt_fmt = "\n" + " ".join(sys_prompt.split())
96
-
97
- if isinstance(content, list):
98
- text_found = False
99
- for item in content:
100
- if item.get("type") == "text":
101
- item["text"] += sys_prompt_fmt
102
- text_found = True
103
- break
104
- if not text_found:
105
- content.append({"type": "text", "text": sys_prompt_fmt})
106
- elif isinstance(content, str):
107
- messages_payload[-1]["content"] += sys_prompt_fmt
108
-
109
- text_prompt = self.processor.apply_chat_template(
110
- messages_payload, tokenize=False, add_generation_prompt=True
111
- )
112
- image_inputs, video_inputs = process_vision_info(messages_payload)
113
-
114
- inputs = self.processor(
115
- text=[text_prompt],
116
- images=image_inputs,
117
- videos=video_inputs,
118
- padding=True,
119
- return_tensors="pt"
120
- )
121
-
122
- inputs = inputs.to(self.model.device)
123
-
124
- generation_kwargs = dict(
125
- **inputs,
126
- max_new_tokens=max_tokens,
127
- temperature=temperature,
128
- do_sample=True if temperature > 0 else False,
129
- )
130
-
131
- try:
132
- print("Starting model generation...")
133
- with torch.no_grad():
134
- generated_ids = self.model.generate(**generation_kwargs)
135
-
136
- input_length = inputs['input_ids'].shape[1]
137
- generated_ids = generated_ids[0][input_length:]
138
-
139
- generated_text = self.processor.tokenizer.decode(
140
- generated_ids,
141
- skip_special_tokens=True
142
- )
143
- print(f"Generated text: {generated_text}")
144
- if generated_text:
145
- yield generated_text
146
- else:
147
- yield "⚠️ No output generated."
148
-
149
- except Exception as e:
150
- import traceback
151
- error_details = traceback.format_exc()
152
- print(f"Error in model.generate: {error_details}")
153
- yield f"❌ Generation error: {str(e)}"
154
- return
155
-
156
  model_handler = None
157
 
158
  def get_model_handler():
159
- """Get model handler with lazy loading"""
160
  global model_handler
161
  if model_handler is None:
162
- print("🔄 Initializing model handler...")
163
  model_handler = ModelHandler(MODEL_PATH)
164
  return model_handler
165
 
 
 
 
166
  @gpu_decorator
167
- def respond(user_msg, history, temp, tokens):
168
  """
169
- 针对 type="messages" 的 Chatbot 响应函数
 
 
170
  """
 
171
 
172
- # 1. 构建当前用户的消息内容
173
- user_content = []
 
174
 
175
- files = user_msg.get("files", [])
176
- for f in files:
177
- user_content.append({"type": "image", "image": f})
 
 
 
 
 
 
178
 
179
- text = user_msg.get("text", "")
180
- if text:
181
- user_content.append({"type": "text", "text": text})
182
-
183
- if not user_content:
184
- yield history, gr.MultimodalTextbox(value=None, interactive=True)
185
- return
186
-
187
- # 2. 将用户消息加入历史
188
- history.append({
189
- "role": "user",
190
- "content": user_content
191
- })
192
 
193
- # 立即更新 UI
194
- yield history, gr.MultimodalTextbox(value=None, interactive=False)
195
-
196
- # 3. 调用模型
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  try:
198
- handler = get_model_handler()
 
 
 
 
 
 
199
 
200
- history.append({"role": "assistant", "content": ""})
 
 
 
 
201
 
202
- full_response = ""
203
- # 传入 history[:-1] 避免传入空的 assistant 消息导致模板报错
204
- for chunk in handler.predict(history[:-1], temp, tokens):
205
- full_response += chunk
206
- history[-1]["content"] = full_response
207
- yield history, gr.MultimodalTextbox(interactive=False)
208
 
209
  except Exception as e:
210
  import traceback
211
  traceback.print_exc()
212
- # 如果还没加 assistant 消息就报错了,补一个
213
- if not history or history[-1].get("role") != "assistant":
214
- history.append({"role": "assistant", "content": ""})
215
- history[-1]["content"] = f"❌ Error: {str(e)}"
216
- yield history, gr.MultimodalTextbox(interactive=True)
217
-
218
- # 恢复输入框
219
- yield history, gr.MultimodalTextbox(interactive=True)
220
-
221
- def create_chat_ui():
222
- # 【修复点 1】: 这里不要传 css 参数
223
- with gr.Blocks(title="Robust-R1") as demo:
224
-
225
- with gr.Row():
226
- gr.Markdown("# 🤖 Robust-R1: Degradation-Aware Reasoning")
227
-
228
- with gr.Row():
229
- with gr.Column(scale=4):
230
- # Chatbot 设置 type="messages"
231
- chatbot = gr.Chatbot(
232
- elem_id="chatbot",
233
- label="Chat",
234
- avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Qwen"),
235
- height=650,
236
- type="messages"
237
- )
238
-
239
- chat_input = gr.MultimodalTextbox(
240
- interactive=True,
241
- file_types=["image"],
242
- placeholder="Enter your question or upload an image...",
243
- show_label=False
244
- )
245
-
246
- with gr.Column(scale=1):
247
- with gr.Group():
248
- gr.Markdown("### ⚙️ Generation Config")
249
- temperature = gr.Slider(
250
- minimum=0.01, maximum=1.0, value=0.6, step=0.05,
251
- label="Temperature"
252
- )
253
- max_tokens = gr.Slider(
254
- minimum=128, maximum=4096, value=1024, step=128,
255
- label="Max New Tokens"
256
- )
257
-
258
- clear_btn = gr.Button("🗑️ Clear Context", variant="stop")
259
-
260
- gr.Markdown("---")
261
- gr.Markdown("### 📚 Examples")
262
-
263
- example_images_dir = os.path.join(project_dir, "assets")
264
-
265
- examples_config = [
266
- ("What type of vehicles are the people riding?\n0. trucks\n1. wagons\n2. jeeps\n3. cars\n", os.path.join(example_images_dir, "1.jpg")),
267
- ("What is the giant fish in the air?\n0. blimp\n1. balloon\n2. kite\n3. sculpture\n", os.path.join(example_images_dir, "2.jpg")),
268
- ]
269
-
270
- example_data = []
271
- for text, img_path in examples_config:
272
- if os.path.exists(img_path):
273
- example_data.append({"text": text, "files": [img_path]})
274
-
275
- if example_data:
276
- gr.Examples(
277
- examples=example_data,
278
- inputs=chat_input,
279
- label="",
280
- examples_per_page=3
281
- )
282
- else:
283
- gr.Markdown("*No example images available.*")
284
-
285
- chat_input.submit(
286
- respond,
287
- inputs=[chat_input, chatbot, temperature, max_tokens],
288
- outputs=[chatbot, chat_input]
289
- )
290
-
291
- clear_btn.click(lambda: ([], None), outputs=[chatbot, chat_input])
292
-
293
- return demo
294
-
295
  if __name__ == "__main__":
296
- demo = create_chat_ui()
297
-
298
  if is_spaces:
299
  print(f"🚀 Running on Hugging Face Spaces: {os.getenv('SPACE_ID')}")
300
- # 【修复点 2】: CSS 放在 launch 里
301
  demo.launch(
302
- theme=gr.themes.Soft(),
303
- css=CUSTOM_CSS,
304
- show_error=True,
305
- allowed_paths=[project_dir] if project_dir else None
306
  )
307
  else:
308
- print(f"🚀 Service is starting, please visit: http://localhost:7860")
309
  demo.launch(
310
- theme=gr.themes.Soft(),
311
- css=CUSTOM_CSS,
312
  server_name="0.0.0.0",
313
  server_port=7860,
314
- share=False,
315
- show_error=True,
316
- allowed_paths=[project_dir]
317
  )
 
3
  import torch
4
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
  from qwen_vl_utils import process_vision_info
6
+ import copy
7
 
8
+ # ==========================================
9
+ # 1. 环境与检测 Setup
10
+ # ==========================================
11
  is_spaces = os.getenv("SPACE_ID") is not None
12
  spaces_available = False
13
  GPU = None
 
19
  except ImportError:
20
  print("⚠️ spaces module not available, GPU detection may not work")
21
 
 
22
  def gpu_decorator(func):
23
  """条件应用 GPU 装饰器"""
24
  if spaces_available and GPU is not None:
25
  return GPU(func)
26
  return func
27
 
28
+ # ==========================================
29
+ # 2. 常量与配置 Constants
30
+ # ==========================================
31
+ MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1-RL")
32
+ PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))
33
+
34
  # 系统提示词
35
+ SYS_PROMPT = """First output the types of degradations in image briefly in <TYPE> <TYPE_END> tags,
36
  and then output what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags,
37
  then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END> tags,
38
  and then summarize the content of reasoning and the give the answer in <CONCLUSION> <CONCLUSION_END> tags,
39
  provides the user with the answer briefly in <ANSWER> <ANSWER_END>."""
40
 
41
+ # CSS 样式
 
 
 
 
 
 
 
 
 
42
  CUSTOM_CSS = """
43
  .gradio-container { font-family: 'Inter', sans-serif; }
 
44
  """
45
 
46
+ # ==========================================
47
+ # 3. 模型处理类 Model Handler
48
+ # ==========================================
 
49
  class ModelHandler:
50
  def __init__(self, model_path):
51
  self.model_path = model_path
 
55
 
56
  def _load_model(self):
57
  try:
58
+ print(f"⏳ Loading model weights from {self.model_path}...")
 
59
  self.processor = AutoProcessor.from_pretrained(self.model_path)
60
 
61
  # 智能判断 Flash Attention
 
65
  if device_capability[0] >= 8:
66
  use_flash_attention = True
67
  print(f"🔧 CUDA available with Ampere+, utilizing Flash Attention 2")
 
 
68
 
69
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
70
  self.model_path,
 
78
  print(f"❌ Model loading failed: {e}")
79
  raise e
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  model_handler = None
82
 
83
  def get_model_handler():
84
+ """懒加载模型句柄"""
85
  global model_handler
86
  if model_handler is None:
 
87
  model_handler = ModelHandler(MODEL_PATH)
88
  return model_handler
89
 
90
+ # ==========================================
91
+ # 4. 聊天生成函数 Chat Function
92
+ # ==========================================
93
  @gpu_decorator
94
+ def respond(message, history, temperature, max_tokens):
95
  """
96
+ 符合 gr.ChatInterface 标准的生成函数
97
+ message: dict (multimodal=True时) -> {'text': str, 'files': list} [cite: 140]
98
+ history: list of dicts -> OpenAI 风格历史记录 [cite: 24]
99
  """
100
+ handler = get_model_handler()
101
 
102
+ # 1. 构建当前用户消息 (转换为 OpenAI/Qwen 格式)
103
+ # message['files'] 包含文件路径列表
104
+ current_user_content = []
105
 
106
+ # 处理图片
107
+ if message.get("files"):
108
+ for file_path in message["files"]:
109
+ current_user_content.append({"type": "image", "image": file_path})
110
+
111
+ # 处理文本
112
+ user_text = message.get("text", "")
113
+ if user_text:
114
+ current_user_content.append({"type": "text", "text": user_text})
115
 
116
+ # 2. 构建完整的对话列表 (History + Current Message)
117
+ # 注意:ChatInterface 的 history 包含之前的内容,不包含当前这一条
118
+ conversation = copy.deepcopy(history)
119
+ conversation.append({"role": "user", "content": current_user_content})
120
+
121
+ # 3. 注入 System Prompt (添加到最后一条用户消息的文本中)
122
+ # 保持您原有的逻辑:将 prompt 拼接到最后一条消息
123
+ last_content = conversation[-1]["content"]
124
+ sys_prompt_fmt = "\n" + " ".join(SYS_PROMPT.split())
 
 
 
 
125
 
126
+ text_injected = False
127
+ for item in last_content:
128
+ if item.get("type") == "text":
129
+ item["text"] += sys_prompt_fmt
130
+ text_injected = True
131
+ break
132
+ if not text_injected:
133
+ last_content.append({"type": "text", "text": sys_prompt_fmt})
134
+
135
+ # 4. 预处理输入
136
+ text_prompt = handler.processor.apply_chat_template(
137
+ conversation, tokenize=False, add_generation_prompt=True
138
+ )
139
+ image_inputs, video_inputs = process_vision_info(conversation)
140
+
141
+ inputs = handler.processor(
142
+ text=[text_prompt],
143
+ images=image_inputs,
144
+ videos=video_inputs,
145
+ padding=True,
146
+ return_tensors="pt"
147
+ )
148
+ inputs = inputs.to(handler.model.device)
149
+
150
+ # 5. 生成参数
151
+ generation_kwargs = dict(
152
+ **inputs,
153
+ max_new_tokens=max_tokens,
154
+ temperature=temperature,
155
+ do_sample=True if temperature > 0 else False,
156
+ )
157
+
158
+ # 6. 流式生成 (Yielding response) [cite: 85]
159
  try:
160
+ input_length = inputs['input_ids'].shape[1]
161
+ # 注意:这里为了简化演示,使用非流式的 generate,然后模拟流式输出
162
+ # 如果需要真正的 token 级流式,需要使用 TextIteratorStreamer
163
+ # 但为了保持您原有逻辑的稳定性,我们先获取结果再 yield
164
+
165
+ with torch.no_grad():
166
+ generated_ids = handler.model.generate(**generation_kwargs)
167
 
168
+ generated_ids = generated_ids[0][input_length:]
169
+ generated_text = handler.processor.tokenizer.decode(
170
+ generated_ids,
171
+ skip_special_tokens=True
172
+ )
173
 
174
+ # 简单模拟流式效果(或直接返回)
175
+ yield generated_text
 
 
 
 
176
 
177
  except Exception as e:
178
  import traceback
179
  traceback.print_exc()
180
+ yield f"❌ Generation error: {str(e)}"
181
+
182
+ # ==========================================
183
+ # 5. 构建 UI (ChatInterface)
184
+ # ==========================================
185
+
186
+ # 准备 Examples 数据
187
+ example_images_dir = os.path.join(PROJECT_DIR, "assets")
188
+ examples_data = []
189
+
190
+ # 定义示例数据源
191
+ raw_examples = [
192
+ ("What type of vehicles are the people riding?\n0. trucks\n1. wagons\n2. jeeps\n3. cars\n", "1.jpg"),
193
+ ("What is the giant fish in the air?\n0. blimp\n1. balloon\n2. kite\n3. sculpture\n", "2.jpg"),
194
+ ]
195
+
196
+ for text, filename in raw_examples:
197
+ path = os.path.join(example_images_dir, filename)
198
+ # ChatInterface multimodal examples 格式: {"text": str, "files": [list]}
199
+ if os.path.exists(path):
200
+ examples_data.append({"text": text, "files": [path]})
201
+
202
+ # 定义额外输入组件 (Generation Config)
203
+ additional_inputs = [
204
+ gr.Slider(minimum=0.01, maximum=1.0, value=0.6, step=0.05, label="Temperature"),
205
+ gr.Slider(minimum=128, maximum=4096, value=1024, step=128, label="Max New Tokens"),
206
+ ]
207
+
208
+ # 创建 Interface
209
+ demo = gr.ChatInterface(
210
+ fn=respond,
211
+ type="messages", # 使用标准的 OpenAI 格式历史记录 [cite: 24]
212
+ multimodal=True, # 启用多模态上传
213
+ title="🤖 Robust-R1: Degradation-Aware Reasoning",
214
+ description="Upload an image and ask questions. The model considers image degradations during reasoning.",
215
+ additional_inputs=additional_inputs, # 添加配置滑块
216
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Generation Config", open=True), # 设置配置区域
217
+ examples=examples_data, # 添加示例
218
+ cache_examples=False, # 根据需要开启或关闭
219
+ theme=gr.themes.Soft(),
220
+ css=CUSTOM_CSS
221
+ )
222
+
223
+ # ==========================================
224
+ # 6. 启动 Launch
225
+ # ==========================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  if __name__ == "__main__":
 
 
227
  if is_spaces:
228
  print(f"🚀 Running on Hugging Face Spaces: {os.getenv('SPACE_ID')}")
 
229
  demo.launch(
230
+ allowed_paths=[PROJECT_DIR] # 允许访问本地图片资源
 
 
 
231
  )
232
  else:
233
+ print(f"🚀 Service is starting at http://localhost:7860")
234
  demo.launch(
 
 
235
  server_name="0.0.0.0",
236
  server_port=7860,
237
+ allowed_paths=[PROJECT_DIR]
 
 
238
  )
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- gradio>=6.0.0
2
  torch>=2.0.0
3
  torchvision>=0.15.0
4
  transformers>=4.37.0
 
1
+ gradio>=6.1.0
2
  torch>=2.0.0
3
  torchvision>=0.15.0
4
  transformers>=4.37.0