Jiaqi-hkust commited on
Commit
b4fe2c0
·
verified ·
1 Parent(s): 610318e

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +35 -53
app.py CHANGED
@@ -20,7 +20,6 @@ if is_spaces:
20
  print("⚠️ spaces module not available, GPU detection may not work")
21
 
22
  def gpu_decorator(func):
23
- """条件应用 GPU 装饰器"""
24
  if spaces_available and GPU is not None:
25
  return GPU(func)
26
  return func
@@ -31,14 +30,12 @@ def gpu_decorator(func):
31
  MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1-RL")
32
  PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))
33
 
34
- # 系统提示词
35
  SYS_PROMPT = """First output the types of degradations in image briefly in <TYPE> <TYPE_END> tags,
36
  and then output what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags,
37
  then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END> tags,
38
  and then summarize the content of reasoning and the give the answer in <CONCLUSION> <CONCLUSION_END> tags,
39
  provides the user with the answer briefly in <ANSWER> <ANSWER_END>."""
40
 
41
- # CSS 样式
42
  CUSTOM_CSS = """
43
  .gradio-container { font-family: 'Inter', sans-serif; }
44
  """
@@ -58,7 +55,6 @@ class ModelHandler:
58
  print(f"⏳ Loading model weights from {self.model_path}...")
59
  self.processor = AutoProcessor.from_pretrained(self.model_path)
60
 
61
- # 智能判断 Flash Attention
62
  use_flash_attention = False
63
  if torch.cuda.is_available():
64
  device_capability = torch.cuda.get_device_capability()
@@ -81,7 +77,6 @@ class ModelHandler:
81
  model_handler = None
82
 
83
  def get_model_handler():
84
- """懒加载模型句柄"""
85
  global model_handler
86
  if model_handler is None:
87
  model_handler = ModelHandler(MODEL_PATH)
@@ -93,33 +88,26 @@ def get_model_handler():
93
  @gpu_decorator
94
  def respond(message, history, temperature, max_tokens):
95
  """
96
- 符合 gr.ChatInterface 标准的生成函数
97
- message: dict (multimodal=True时) -> {'text': str, 'files': list} [cite: 140]
98
  history: list of dicts -> OpenAI 风格历史记录
99
  """
100
  handler = get_model_handler()
101
 
102
- # 1. 构建当前用户消息 (转换为 OpenAI/Qwen 格式)
103
- # message['files'] 包含文件路径列表
104
  current_user_content = []
105
-
106
- # 处理图片
107
  if message.get("files"):
108
  for file_path in message["files"]:
109
  current_user_content.append({"type": "image", "image": file_path})
110
 
111
- # 处理文本
112
  user_text = message.get("text", "")
113
  if user_text:
114
  current_user_content.append({"type": "text", "text": user_text})
115
 
116
- # 2. 构建完整的对话列表 (History + Current Message)
117
- # 注意:ChatInterface 的 history 包含之前的内容,不包含当前这一条
118
  conversation = copy.deepcopy(history)
119
  conversation.append({"role": "user", "content": current_user_content})
120
 
121
- # 3. 注入 System Prompt (添加到最后一条用户消息的文本中)
122
- # 保持您原有的逻辑:将 prompt 拼接到最后一条消息
123
  last_content = conversation[-1]["content"]
124
  sys_prompt_fmt = "\n" + " ".join(SYS_PROMPT.split())
125
 
@@ -132,7 +120,7 @@ def respond(message, history, temperature, max_tokens):
132
  if not text_injected:
133
  last_content.append({"type": "text", "text": sys_prompt_fmt})
134
 
135
- # 4. 预处理输入
136
  text_prompt = handler.processor.apply_chat_template(
137
  conversation, tokenize=False, add_generation_prompt=True
138
  )
@@ -147,7 +135,6 @@ def respond(message, history, temperature, max_tokens):
147
  )
148
  inputs = inputs.to(handler.model.device)
149
 
150
- # 5. 生成参数
151
  generation_kwargs = dict(
152
  **inputs,
153
  max_new_tokens=max_tokens,
@@ -155,80 +142,75 @@ def respond(message, history, temperature, max_tokens):
155
  do_sample=True if temperature > 0 else False,
156
  )
157
 
158
- # 6. 流式生成 (Yielding response) [cite: 85]
159
  try:
160
- input_length = inputs['input_ids'].shape[1]
161
- # 注意:这里为了简化演示,使用非流式的 generate,然后模拟流式输出
162
- # 如果需要真正的 token 级流式,需要使用 TextIteratorStreamer
163
- # 但为了保持您原有逻辑的稳定性,我们先获取结果再 yield
164
-
165
  with torch.no_grad():
166
  generated_ids = handler.model.generate(**generation_kwargs)
167
 
 
168
  generated_ids = generated_ids[0][input_length:]
169
  generated_text = handler.processor.tokenizer.decode(
170
  generated_ids,
171
  skip_special_tokens=True
172
  )
173
 
174
- # 简单模拟流式效果(或直接返回)
175
  yield generated_text
176
 
177
  except Exception as e:
178
  import traceback
179
  traceback.print_exc()
180
- yield f"❌ Generation error: {str(e)}"
181
 
182
  # ==========================================
183
- # 5. 构建 UI (ChatInterface)
184
  # ==========================================
185
 
186
- # 准备 Examples 数据
187
  example_images_dir = os.path.join(PROJECT_DIR, "assets")
188
  examples_data = []
189
 
190
- # 定义示例数据源
191
- raw_examples = [
192
- ("What type of vehicles are the people riding?\n0. trucks\n1. wagons\n2. jeeps\n3. cars\n", "1.jpg"),
193
- ("What is the giant fish in the air?\n0. blimp\n1. balloon\n2. kite\n3. sculpture\n", "2.jpg"),
194
- ]
195
-
196
- for text, filename in raw_examples:
197
- path = os.path.join(example_images_dir, filename)
198
- # ChatInterface multimodal examples 格式: {"text": str, "files": [list]}
199
- if os.path.exists(path):
200
- examples_data.append({"text": text, "files": [path]})
201
-
202
- # 定义额外输入组件 (Generation Config)
 
 
 
203
  additional_inputs = [
204
  gr.Slider(minimum=0.01, maximum=1.0, value=0.6, step=0.05, label="Temperature"),
205
  gr.Slider(minimum=128, maximum=4096, value=1024, step=128, label="Max New Tokens"),
206
  ]
207
 
 
208
  chatbot_component = gr.Chatbot(
209
  label="Robust-R1 Chat",
210
  avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Qwen"),
211
  height=650
212
  )
213
 
214
- # 创建 Interface
215
  demo = gr.ChatInterface(
216
  fn=respond,
217
- chatbot=chatbot_component, # <--- 这里传入自定义的 chatbot
218
- multimodal=True, # 启用多模态上传
219
  title="🤖 Robust-R1: Degradation-Aware Reasoning",
220
- description="Upload an image and ask questions. The model considers image degradations during reasoning.",
221
- additional_inputs=additional_inputs, # 添加配置滑块
222
- additional_inputs_accordion=gr.Accordion(label="⚙️ Generation Config", open=True), # 设置配置区域
223
- examples=examples_data, # 添加示例
224
- cache_examples=False, # 根据需要开启或关闭
225
  )
226
 
227
- # ==========================================
228
- # 6. 启动 Launch
229
- # ==========================================
230
  if __name__ == "__main__":
231
- aunch_kwargs = {
232
  "theme": gr.themes.Soft(),
233
  "css": CUSTOM_CSS,
234
  "allowed_paths": [PROJECT_DIR]
 
20
  print("⚠️ spaces module not available, GPU detection may not work")
21
 
22
  def gpu_decorator(func):
 
23
  if spaces_available and GPU is not None:
24
  return GPU(func)
25
  return func
 
30
  MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1-RL")
31
  PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))
32
 
 
33
  SYS_PROMPT = """First output the types of degradations in image briefly in <TYPE> <TYPE_END> tags,
34
  and then output what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags,
35
  then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END> tags,
36
  and then summarize the content of reasoning and the give the answer in <CONCLUSION> <CONCLUSION_END> tags,
37
  provides the user with the answer briefly in <ANSWER> <ANSWER_END>."""
38
 
 
39
  CUSTOM_CSS = """
40
  .gradio-container { font-family: 'Inter', sans-serif; }
41
  """
 
55
  print(f"⏳ Loading model weights from {self.model_path}...")
56
  self.processor = AutoProcessor.from_pretrained(self.model_path)
57
 
 
58
  use_flash_attention = False
59
  if torch.cuda.is_available():
60
  device_capability = torch.cuda.get_device_capability()
 
77
  model_handler = None
78
 
79
  def get_model_handler():
 
80
  global model_handler
81
  if model_handler is None:
82
  model_handler = ModelHandler(MODEL_PATH)
 
88
  @gpu_decorator
89
  def respond(message, history, temperature, max_tokens):
90
  """
91
+ message: dict -> {'text': str, 'files': list}
 
92
  history: list of dicts -> OpenAI 风格历史记录
93
  """
94
  handler = get_model_handler()
95
 
96
+ # 1. 转换当前消息
 
97
  current_user_content = []
 
 
98
  if message.get("files"):
99
  for file_path in message["files"]:
100
  current_user_content.append({"type": "image", "image": file_path})
101
 
 
102
  user_text = message.get("text", "")
103
  if user_text:
104
  current_user_content.append({"type": "text", "text": user_text})
105
 
106
+ # 2. 构建完整对话 (History + Current)
 
107
  conversation = copy.deepcopy(history)
108
  conversation.append({"role": "user", "content": current_user_content})
109
 
110
+ # 3. 注入 System Prompt
 
111
  last_content = conversation[-1]["content"]
112
  sys_prompt_fmt = "\n" + " ".join(SYS_PROMPT.split())
113
 
 
120
  if not text_injected:
121
  last_content.append({"type": "text", "text": sys_prompt_fmt})
122
 
123
+ # 4. 推理
124
  text_prompt = handler.processor.apply_chat_template(
125
  conversation, tokenize=False, add_generation_prompt=True
126
  )
 
135
  )
136
  inputs = inputs.to(handler.model.device)
137
 
 
138
  generation_kwargs = dict(
139
  **inputs,
140
  max_new_tokens=max_tokens,
 
142
  do_sample=True if temperature > 0 else False,
143
  )
144
 
 
145
  try:
 
 
 
 
 
146
  with torch.no_grad():
147
  generated_ids = handler.model.generate(**generation_kwargs)
148
 
149
+ input_length = inputs['input_ids'].shape[1]
150
  generated_ids = generated_ids[0][input_length:]
151
  generated_text = handler.processor.tokenizer.decode(
152
  generated_ids,
153
  skip_special_tokens=True
154
  )
155
 
 
156
  yield generated_text
157
 
158
  except Exception as e:
159
  import traceback
160
  traceback.print_exc()
161
+ yield f"❌ Error: {str(e)}"
162
 
163
  # ==========================================
164
+ # 5. 构建 UI
165
  # ==========================================
166
 
167
+ # 【关键修复点】:Examples 格式必须包含 Additional Inputs 的值
168
  example_images_dir = os.path.join(PROJECT_DIR, "assets")
169
  examples_data = []
170
 
171
+ if os.path.exists(example_images_dir):
172
+ raw_examples = [
173
+ ("What type of vehicles are the people riding?\n0. trucks\n1. wagons\n2. jeeps\n3. cars\n", "1.jpg"),
174
+ ("What is the giant fish in the air?\n0. blimp\n1. balloon\n2. kite\n3. sculpture\n", "2.jpg"),
175
+ ]
176
+ for text, filename in raw_examples:
177
+ path = os.path.join(example_images_dir, filename)
178
+ if os.path.exists(path):
179
+ # 格式必须是: [MessageDict, TemperatureValue, MaxTokensValue]
180
+ examples_data.append([
181
+ {"text": text, "files": [path]}, # 1. 消息对象
182
+ 0.6, # 2. Temperature (对应 additional_inputs[0])
183
+ 1024 # 3. Max Tokens (对应 additional_inputs[1])
184
+ ])
185
+
186
+ # 定义额外输入
187
  additional_inputs = [
188
  gr.Slider(minimum=0.01, maximum=1.0, value=0.6, step=0.05, label="Temperature"),
189
  gr.Slider(minimum=128, maximum=4096, value=1024, step=128, label="Max New Tokens"),
190
  ]
191
 
192
+ # 自定义 Chatbot 组件
193
  chatbot_component = gr.Chatbot(
194
  label="Robust-R1 Chat",
195
  avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Qwen"),
196
  height=650
197
  )
198
 
199
+ # ChatInterface
200
  demo = gr.ChatInterface(
201
  fn=respond,
202
+ chatbot=chatbot_component,
203
+ multimodal=True,
204
  title="🤖 Robust-R1: Degradation-Aware Reasoning",
205
+ description="Upload an image and ask questions.",
206
+ additional_inputs=additional_inputs,
207
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Generation Config", open=True),
208
+ examples=examples_data, # 现在这里的格式是 [[msg, 0.6, 1024], ...]
209
+ cache_examples=False
210
  )
211
 
 
 
 
212
  if __name__ == "__main__":
213
+ launch_kwargs = {
214
  "theme": gr.themes.Soft(),
215
  "css": CUSTOM_CSS,
216
  "allowed_paths": [PROJECT_DIR]