Jiaqi-hkust commited on
Commit
7f7273d
·
verified ·
1 Parent(s): b4fe2c0

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +250 -169
app.py CHANGED
@@ -3,46 +3,28 @@ import os
3
  import torch
4
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
  from qwen_vl_utils import process_vision_info
6
- import copy
7
 
8
- # ==========================================
9
- # 1. 环境与检测 Setup
10
- # ==========================================
11
- is_spaces = os.getenv("SPACE_ID") is not None
12
- spaces_available = False
13
- GPU = None
14
-
15
- if is_spaces:
16
- try:
17
- from spaces import GPU
18
- spaces_available = True
19
- except ImportError:
20
- print("⚠️ spaces module not available, GPU detection may not work")
21
-
22
- def gpu_decorator(func):
23
- if spaces_available and GPU is not None:
24
- return GPU(func)
25
- return func
26
-
27
- # ==========================================
28
- # 2. 常量与配置 Constants
29
- # ==========================================
30
- MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1-RL")
31
- PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))
32
-
33
- SYS_PROMPT = """First output the types of degradations in image briefly in <TYPE> <TYPE_END> tags,
34
  and then output what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags,
35
  then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END> tags,
36
  and then summarize the content of reasoning and the give the answer in <CONCLUSION> <CONCLUSION_END> tags,
37
  provides the user with the answer briefly in <ANSWER> <ANSWER_END>."""
38
 
39
- CUSTOM_CSS = """
40
- .gradio-container { font-family: 'Inter', sans-serif; }
41
- """
 
 
 
 
 
 
 
 
 
 
42
 
43
- # ==========================================
44
- # 3. 模型处理类 Model Handler
45
- # ==========================================
46
  class ModelHandler:
47
  def __init__(self, model_path):
48
  self.model_path = model_path
@@ -52,21 +34,24 @@ class ModelHandler:
52
 
53
  def _load_model(self):
54
  try:
55
- print(f"⏳ Loading model weights from {self.model_path}...")
 
56
  self.processor = AutoProcessor.from_pretrained(self.model_path)
57
 
58
- use_flash_attention = False
59
  if torch.cuda.is_available():
60
  device_capability = torch.cuda.get_device_capability()
61
- if device_capability[0] >= 8:
62
- use_flash_attention = True
63
- print(f"🔧 CUDA available with Ampere+, utilizing Flash Attention 2")
 
 
64
 
65
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
66
  self.model_path,
67
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
68
  device_map="auto",
69
- attn_implementation="flash_attention_2" if use_flash_attention else "sdpa",
 
70
  trust_remote_code=True
71
  )
72
  print("✅ Model loaded successfully!")
@@ -74,155 +59,251 @@ class ModelHandler:
74
  print(f"❌ Model loading failed: {e}")
75
  raise e
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  model_handler = None
78
 
79
  def get_model_handler():
 
80
  global model_handler
81
  if model_handler is None:
 
82
  model_handler = ModelHandler(MODEL_PATH)
83
  return model_handler
84
 
85
- # ==========================================
86
- # 4. 聊天生成函数 Chat Function
87
- # ==========================================
88
- @gpu_decorator
89
- def respond(message, history, temperature, max_tokens):
90
  """
91
- message: dict -> {'text': str, 'files': list}
92
- history: list of dicts -> OpenAI 风格历史记录
93
- """
94
- handler = get_model_handler()
95
-
96
- # 1. 转换当前消息
97
- current_user_content = []
98
- if message.get("files"):
99
- for file_path in message["files"]:
100
- current_user_content.append({"type": "image", "image": file_path})
101
-
102
- user_text = message.get("text", "")
103
- if user_text:
104
- current_user_content.append({"type": "text", "text": user_text})
105
 
106
- # 2. 构建完整对话 (History + Current)
107
- conversation = copy.deepcopy(history)
108
- conversation.append({"role": "user", "content": current_user_content})
109
 
110
- # 3. 注入 System Prompt
111
- last_content = conversation[-1]["content"]
112
- sys_prompt_fmt = "\n" + " ".join(SYS_PROMPT.split())
113
-
114
- text_injected = False
115
- for item in last_content:
116
- if item.get("type") == "text":
117
- item["text"] += sys_prompt_fmt
118
- text_injected = True
119
- break
120
- if not text_injected:
121
- last_content.append({"type": "text", "text": sys_prompt_fmt})
122
-
123
- # 4. 推理
124
- text_prompt = handler.processor.apply_chat_template(
125
- conversation, tokenize=False, add_generation_prompt=True
126
- )
127
- image_inputs, video_inputs = process_vision_info(conversation)
128
-
129
- inputs = handler.processor(
130
- text=[text_prompt],
131
- images=image_inputs,
132
- videos=video_inputs,
133
- padding=True,
134
- return_tensors="pt"
135
- )
136
- inputs = inputs.to(handler.model.device)
137
-
138
- generation_kwargs = dict(
139
- **inputs,
140
- max_new_tokens=max_tokens,
141
- temperature=temperature,
142
- do_sample=True if temperature > 0 else False,
143
- )
144
-
145
- try:
146
- with torch.no_grad():
147
- generated_ids = handler.model.generate(**generation_kwargs)
148
 
149
- input_length = inputs['input_ids'].shape[1]
150
- generated_ids = generated_ids[0][input_length:]
151
- generated_text = handler.processor.tokenizer.decode(
152
- generated_ids,
153
- skip_special_tokens=True
154
- )
155
 
156
- yield generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
- except Exception as e:
159
- import traceback
160
- traceback.print_exc()
161
- yield f"❌ Error: {str(e)}"
162
-
163
- # ==========================================
164
- # 5. 构建 UI
165
- # ==========================================
166
-
167
- # 【关键修复点】:Examples 格式必须包含 Additional Inputs 的值
168
- example_images_dir = os.path.join(PROJECT_DIR, "assets")
169
- examples_data = []
170
-
171
- if os.path.exists(example_images_dir):
172
- raw_examples = [
173
- ("What type of vehicles are the people riding?\n0. trucks\n1. wagons\n2. jeeps\n3. cars\n", "1.jpg"),
174
- ("What is the giant fish in the air?\n0. blimp\n1. balloon\n2. kite\n3. sculpture\n", "2.jpg"),
175
- ]
176
- for text, filename in raw_examples:
177
- path = os.path.join(example_images_dir, filename)
178
- if os.path.exists(path):
179
- # 格式必须是: [MessageDict, TemperatureValue, MaxTokensValue]
180
- examples_data.append([
181
- {"text": text, "files": [path]}, # 1. 消息对象
182
- 0.6, # 2. Temperature (对应 additional_inputs[0])
183
- 1024 # 3. Max Tokens (对应 additional_inputs[1])
184
- ])
185
-
186
- # 定义额外输入
187
- additional_inputs = [
188
- gr.Slider(minimum=0.01, maximum=1.0, value=0.6, step=0.05, label="Temperature"),
189
- gr.Slider(minimum=128, maximum=4096, value=1024, step=128, label="Max New Tokens"),
190
- ]
191
-
192
- # 自定义 Chatbot 组件
193
- chatbot_component = gr.Chatbot(
194
- label="Robust-R1 Chat",
195
- avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Qwen"),
196
- height=650
197
- )
198
-
199
- # ChatInterface
200
- demo = gr.ChatInterface(
201
- fn=respond,
202
- chatbot=chatbot_component,
203
- multimodal=True,
204
- title="🤖 Robust-R1: Degradation-Aware Reasoning",
205
- description="Upload an image and ask questions.",
206
- additional_inputs=additional_inputs,
207
- additional_inputs_accordion=gr.Accordion(label="⚙️ Generation Config", open=True),
208
- examples=examples_data, # 现在这里的格式是 [[msg, 0.6, 1024], ...]
209
- cache_examples=False
210
- )
211
 
212
  if __name__ == "__main__":
213
- launch_kwargs = {
214
- "theme": gr.themes.Soft(),
215
- "css": CUSTOM_CSS,
216
- "allowed_paths": [PROJECT_DIR]
217
- }
218
 
219
  if is_spaces:
220
- print(f"🚀 Running on Hugging Face Spaces")
221
- demo.launch(**launch_kwargs)
 
 
 
222
  else:
223
- print(f"🚀 Running locally")
224
  demo.launch(
225
  server_name="0.0.0.0",
226
  server_port=7860,
227
- **launch_kwargs
228
- )
 
 
 
3
  import torch
4
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
  from qwen_vl_utils import process_vision_info
6
+ import html
7
 
8
+ sys_prompt = """First output the the types of degradations in image briefly in <TYPE> <TYPE_END> tags,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  and then output what effects do these degradation have on the image in <INFLUENCE> <INFLUENCE_END> tags,
10
  then based on the strength of degradation, output an APPROPRIATE length for the reasoning process in <REASONING> <REASONING_END> tags,
11
  and then summarize the content of reasoning and the give the answer in <CONCLUSION> <CONCLUSION_END> tags,
12
  provides the user with the answer briefly in <ANSWER> <ANSWER_END>."""
13
 
14
+ project_dir = os.path.dirname(os.path.abspath(__file__))
15
+
16
+ is_spaces = os.getenv("SPACE_ID") is not None
17
+ if not is_spaces:
18
+ temp_dir = os.path.join(project_dir, ".gradio_temp")
19
+ os.makedirs(temp_dir, exist_ok=True)
20
+ os.environ["GRADIO_TEMP_DIR"] = temp_dir
21
+
22
+ MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1-RL")
23
+
24
+ print(f"==========================================")
25
+ print(f"Initializing application...")
26
+ print(f"==========================================")
27
 
 
 
 
28
  class ModelHandler:
29
  def __init__(self, model_path):
30
  self.model_path = model_path
 
34
 
35
  def _load_model(self):
36
  try:
37
+ print(f"⏳ Loading model weights, this may take a few minutes...")
38
+
39
  self.processor = AutoProcessor.from_pretrained(self.model_path)
40
 
 
41
  if torch.cuda.is_available():
42
  device_capability = torch.cuda.get_device_capability()
43
+ use_flash_attention = device_capability[0] >= 8
44
+ print(f"🔧 CUDA available, device capability: {device_capability}")
45
+ else:
46
+ use_flash_attention = False
47
+ print(f"🔧 Using CPU or non-CUDA device")
48
 
49
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
50
  self.model_path,
51
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
52
  device_map="auto",
53
+ # attn_implementation="flash_attention_2" if use_flash_attention else "eager",
54
+ attn_implementation="sdpa",
55
  trust_remote_code=True
56
  )
57
  print("✅ Model loaded successfully!")
 
59
  print(f"❌ Model loading failed: {e}")
60
  raise e
61
 
62
+ def predict(self, message_dict, history, temperature, max_tokens):
63
+ text = message_dict.get("text", "")
64
+ files = message_dict.get("files", [])
65
+
66
+ messages = []
67
+
68
+ if history:
69
+ print(f"Processing {len(history)} previous messages from history")
70
+ for msg in history:
71
+ role = msg.get("role", "")
72
+ content = msg.get("content", "")
73
+
74
+ if role == "user":
75
+ user_content = []
76
+
77
+ if isinstance(content, list):
78
+ for item in content:
79
+ if isinstance(item, str):
80
+ if os.path.exists(item) or any(item.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']):
81
+ user_content.append({"type": "image", "image": item})
82
+ else:
83
+ user_content.append({"type": "text", "text": item})
84
+ elif isinstance(item, dict):
85
+ user_content.append(item)
86
+ elif isinstance(content, str):
87
+ if content:
88
+ user_content.append({"type": "text", "text": content})
89
+
90
+ if user_content:
91
+ messages.append({"role": "user", "content": user_content})
92
+
93
+ elif role == "assistant":
94
+ if isinstance(content, str) and content:
95
+ messages.append({"role": "assistant", "content": content})
96
+
97
+ current_content = []
98
+ if files:
99
+ for file_path in files:
100
+ current_content.append({"type": "image", "image": file_path})
101
+
102
+ if text:
103
+ sys_prompt_formatted = " ".join(sys_prompt.split())
104
+ full_text = f"{text}\n{sys_prompt_formatted}"
105
+ current_content.append({"type": "text", "text": full_text})
106
+
107
+ if current_content:
108
+ messages.append({"role": "user", "content": current_content})
109
+
110
+ print(f"Total messages for model: {len(messages)}")
111
+ print(f"Message roles: {[m['role'] for m in messages]}")
112
+
113
+ text_prompt = self.processor.apply_chat_template(
114
+ messages, tokenize=False, add_generation_prompt=True
115
+ )
116
+
117
+ image_inputs, video_inputs = process_vision_info(messages)
118
+
119
+ inputs = self.processor(
120
+ text=[text_prompt],
121
+ images=image_inputs,
122
+ videos=video_inputs,
123
+ padding=True,
124
+ return_tensors="pt"
125
+ )
126
+
127
+ inputs = inputs.to(self.model.device)
128
+
129
+ generation_kwargs = dict(
130
+ **inputs,
131
+ max_new_tokens=max_tokens,
132
+ temperature=temperature,
133
+ do_sample=True if temperature > 0 else False,
134
+ )
135
+
136
+ try:
137
+ print("Starting model generation...")
138
+ with torch.no_grad():
139
+ generated_ids = self.model.generate(**generation_kwargs)
140
+
141
+ input_length = inputs['input_ids'].shape[1]
142
+ generated_ids = generated_ids[0][input_length:]
143
+
144
+ print(f"Input length: {input_length}, Generated token count: {len(generated_ids)}")
145
+
146
+ generated_text = self.processor.tokenizer.decode(
147
+ generated_ids,
148
+ skip_special_tokens=True
149
+ )
150
+
151
+ print(f"Generation completed. Output length: {len(generated_text)}, Content preview: {repr(generated_text[:200])}")
152
+
153
+ if generated_text and generated_text.strip():
154
+ print(f"Yielding generated text: {generated_text[:100]}...")
155
+ yield generated_text
156
+ else:
157
+ warning_msg = "⚠️ No output generated. The model may not have produced any response."
158
+ print(warning_msg)
159
+ yield warning_msg
160
+
161
+ except Exception as e:
162
+ import traceback
163
+ error_details = traceback.format_exc()
164
+ print(f"Error in model.generate: {error_details}")
165
+ yield f"❌ Generation error: {str(e)}"
166
+ return
167
+
168
  model_handler = None
169
 
170
  def get_model_handler():
171
+ """Get model handler with lazy loading"""
172
  global model_handler
173
  if model_handler is None:
174
+ print("🔄 Initializing model handler...")
175
  model_handler = ModelHandler(MODEL_PATH)
176
  return model_handler
177
 
178
+ def create_chat_ui():
179
+ custom_css = """
180
+ .gradio-container { font-family: 'Inter', sans-serif; }
181
+ #chatbot { height: 650px !important; overflow-y: auto; }
 
182
  """
183
+
184
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="Robust-R1") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ with gr.Row():
187
+ gr.Markdown("# 🤖Robust-R1:Degradation-Aware Reasoning for Robust Visual Understanding")
 
188
 
189
+ with gr.Row():
190
+ with gr.Column(scale=4):
191
+ chatbot = gr.Chatbot(
192
+ elem_id="chatbot",
193
+ label="Chat",
194
+ type="messages",
195
+ avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=Qwen"),
196
+ height=650
197
+ )
198
+
199
+ chat_input = gr.MultimodalTextbox(
200
+ interactive=True,
201
+ file_types=["image"],
202
+ placeholder="Enter your question or upload an image...",
203
+ show_label=False
204
+ )
205
+
206
+ with gr.Column(scale=1):
207
+ with gr.Group():
208
+ gr.Markdown("### ⚙️ Generation Config")
209
+ temperature = gr.Slider(
210
+ minimum=0.01, maximum=1.0, value=0.6, step=0.05,
211
+ label="Temperature"
212
+ )
213
+ max_tokens = gr.Slider(
214
+ minimum=128, maximum=4096, value=1024, step=128,
215
+ label="Max New Tokens"
216
+ )
217
+
218
+ clear_btn = gr.Button("🗑️ Clear Context", variant="stop")
219
+
220
+ gr.Markdown("---")
221
+ gr.Markdown("### 📚 Examples")
222
+ gr.Markdown("Click the examples below to quickly fill the input box and start a conversation")
 
 
 
 
223
 
224
+ example_images_dir = os.path.join(project_dir, "assets")
225
+
226
+ examples_config = [
227
+ ("What type of vehicles are the people riding?\n0. trucks\n1. wagons\n2. jeeps\n3. cars\n", os.path.join(example_images_dir, "1.jpg")),
228
+ ("What is the giant fish in the air?\n0. blimp\n1. balloon\n2. kite\n3. sculpture\n", os.path.join(example_images_dir, "2.jpg")),
229
+ ]
230
 
231
+ example_data = []
232
+ for text, img_path in examples_config:
233
+ if os.path.exists(img_path):
234
+ example_data.append({"text": text, "files": [img_path]})
235
+
236
+ if example_data:
237
+ gr.Examples(
238
+ examples=example_data,
239
+ inputs=chat_input,
240
+ label="",
241
+ examples_per_page=3
242
+ )
243
+ else:
244
+ gr.Markdown("*No example images available, please manually upload images for testing*")
245
+
246
+ async def respond(user_msg, history, temp, tokens):
247
+ text = user_msg.get("text", "").strip()
248
+ files = user_msg.get("files", [])
249
+ user_content = list(files)
250
+ if text: user_content.append(text)
251
+
252
+ if not files and text: user_message = {"role": "user", "content": text}
253
+ else: user_message = {"role": "user", "content": user_content}
254
+
255
+ history.append(user_message)
256
+ yield history, gr.MultimodalTextbox(value=None, interactive=False)
257
+
258
+ history.append({"role": "assistant", "content": ""})
259
+
260
+ try:
261
+ previous_history = history[:-2] if len(history) >= 2 else []
262
+
263
+ handler = get_model_handler()
264
+ generated_text = ""
265
+ for chunk in handler.predict(user_msg, previous_history, temp, tokens):
266
+ generated_text = chunk
267
+
268
+ safe_text = generated_text.replace("<", "&lt;").replace(">", "&gt;")
269
+
270
+ history[-1]["content"] = safe_text
271
+ yield history, gr.MultimodalTextbox(interactive=False)
272
+
273
+ except Exception as e:
274
+ import traceback
275
+ traceback.print_exc()
276
+ history[-1]["content"] = f"❌ Inference error: {str(e)}"
277
+ yield history, gr.MultimodalTextbox(interactive=True)
278
 
279
+ yield history, gr.MultimodalTextbox(value=None, interactive=True)
280
+
281
+ chat_input.submit(
282
+ respond,
283
+ inputs=[chat_input, chatbot, temperature, max_tokens],
284
+ outputs=[chatbot, chat_input]
285
+ )
286
+
287
+ def clear_history(): return [], None
288
+ clear_btn.click(clear_history, outputs=[chatbot, chat_input])
289
+
290
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  if __name__ == "__main__":
293
+ demo = create_chat_ui()
 
 
 
 
294
 
295
  if is_spaces:
296
+ print(f"🚀 Running on Hugging Face Spaces: {os.getenv('SPACE_ID')}")
297
+ demo.launch(
298
+ show_error=True,
299
+ allowed_paths=[project_dir] if project_dir else None
300
+ )
301
  else:
302
+ print(f"🚀 Service is starting, please visit: http://localhost:7860")
303
  demo.launch(
304
  server_name="0.0.0.0",
305
  server_port=7860,
306
+ share=False,
307
+ show_error=True,
308
+ allowed_paths=[project_dir]
309
+ )