Jiaqi-hkust commited on
Commit
edc87ef
·
verified ·
1 Parent(s): 29ab1d7

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +4 -54
app.py CHANGED
@@ -5,7 +5,6 @@ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
  from qwen_vl_utils import process_vision_info
6
  import html
7
 
8
- # 导入 spaces 模块用于 GPU 检测
9
  is_spaces = os.getenv("SPACE_ID") is not None
10
  spaces_available = False
11
  GPU = None
@@ -15,11 +14,9 @@ if is_spaces:
15
  from spaces import GPU
16
  spaces_available = True
17
  except ImportError:
18
- print("⚠️ spaces module not available, GPU detection may not work")
19
 
20
- # 创建条件装饰器
21
  def gpu_decorator(func):
22
- """条件应用 GPU 装饰器"""
23
  if spaces_available and GPU is not None:
24
  return GPU(func)
25
  return func
@@ -40,42 +37,27 @@ if not is_spaces:
40
 
41
  MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1-RL")
42
 
43
- print(f"==========================================")
44
- print(f"Initializing application...")
45
- print(f"==========================================")
46
-
47
  class ModelHandler:
48
  def __init__(self, model_path):
49
  self.model_path = model_path
50
  self.model = None
51
  self.processor = None
52
- # 不在 __init__ 中加载模型,延迟到实际使用时
53
 
54
  def _load_model(self):
55
- """延迟加载模型,在 GPU 装饰器函数内部调用"""
56
  if self.model is not None:
57
- return # 已经加载过了
58
 
59
  try:
60
- print(f"⏳ Loading model weights, this may take a few minutes...")
61
-
62
  self.processor = AutoProcessor.from_pretrained(self.model_path)
63
 
64
- # 在 ZeroGPU 环境中,避免过早检查 CUDA
65
- # 让 device_map="auto" 自动处理设备分配
66
  try:
67
  cuda_available = torch.cuda.is_available()
68
  if cuda_available:
69
- device_capability = torch.cuda.get_device_capability()
70
- print(f"🔧 CUDA available, device capability: {device_capability}")
71
  torch_dtype = torch.bfloat16
72
  else:
73
- print(f"🔧 Using CPU or non-CUDA device")
74
  torch_dtype = torch.float32
75
  except RuntimeError:
76
- # ZeroGPU 环境中可能暂时无法检查 CUDA
77
- print(f"🔧 CUDA check skipped (ZeroGPU environment)")
78
- torch_dtype = torch.bfloat16 # 假设有 GPU,让 device_map 处理
79
 
80
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
81
  self.model_path,
@@ -84,13 +66,10 @@ class ModelHandler:
84
  attn_implementation="sdpa",
85
  trust_remote_code=True
86
  )
87
- print("✅ Model loaded successfully!")
88
  except Exception as e:
89
- print(f"❌ Model loading failed: {e}")
90
  raise e
91
 
92
  def predict(self, message_dict, history, temperature, max_tokens):
93
- # 确保模型已加载
94
  if self.model is None:
95
  self._load_model()
96
 
@@ -100,7 +79,6 @@ class ModelHandler:
100
  messages = []
101
 
102
  if history:
103
- print(f"Processing {len(history)} previous messages from history")
104
  for msg in history:
105
  role = msg.get("role", "")
106
  content = msg.get("content", "")
@@ -140,9 +118,6 @@ class ModelHandler:
140
 
141
  if current_content:
142
  messages.append({"role": "user", "content": current_content})
143
-
144
- print(f"Total messages for model: {len(messages)}")
145
- print(f"Message roles: {[m['role'] for m in messages]}")
146
 
147
  text_prompt = self.processor.apply_chat_template(
148
  messages, tokenize=False, add_generation_prompt=True
@@ -168,44 +143,32 @@ class ModelHandler:
168
  )
169
 
170
  try:
171
- print("Starting model generation...")
172
  with torch.no_grad():
173
  generated_ids = self.model.generate(**generation_kwargs)
174
 
175
  input_length = inputs['input_ids'].shape[1]
176
  generated_ids = generated_ids[0][input_length:]
177
 
178
- print(f"Input length: {input_length}, Generated token count: {len(generated_ids)}")
179
-
180
  generated_text = self.processor.tokenizer.decode(
181
  generated_ids,
182
  skip_special_tokens=True
183
  )
184
 
185
- print(f"Generation completed. Output length: {len(generated_text)}, Content preview: {repr(generated_text[:200])}")
186
-
187
  if generated_text and generated_text.strip():
188
- print(f"Yielding generated text: {generated_text[:100]}...")
189
  yield generated_text
190
  else:
191
  warning_msg = "⚠️ No output generated. The model may not have produced any response."
192
- print(warning_msg)
193
  yield warning_msg
194
 
195
  except Exception as e:
196
- import traceback
197
- error_details = traceback.format_exc()
198
- print(f"Error in model.generate: {error_details}")
199
  yield f"❌ Generation error: {str(e)}"
200
  return
201
 
202
  model_handler = None
203
 
204
  def get_model_handler():
205
- """Get model handler with lazy loading"""
206
  global model_handler
207
  if model_handler is None:
208
- print("🔄 Initializing model handler...")
209
  model_handler = ModelHandler(MODEL_PATH)
210
  return model_handler
211
 
@@ -216,30 +179,21 @@ custom_css = """
216
 
217
  @gpu_decorator
218
  def respond(user_msg, history, temp, tokens):
219
- print("user_msg:")
220
- print(user_msg)
221
  text = user_msg.get("text", "").strip()
222
  files = user_msg.get("files", [])
223
- # 按照参考代码的格式:文件使用 {"path": x},文本直接使用字符串
224
  user_message = {"role": "user", "content": []}
225
 
226
- # 添加图像文件,使用 {"path": file_path} 格式
227
  for file_path in files:
228
  if file_path:
229
- # 确保使用绝对路径
230
  abs_path = os.path.abspath(file_path) if not os.path.isabs(file_path) else file_path
231
  user_message["content"].append({"path": abs_path})
232
 
233
- # 添加文本,直接使用字符串
234
  if text:
235
  user_message["content"].append(text)
236
 
237
- # 如果只有文本没有文件,content 保持为列表;如果都没有,content 为空列表
238
- # 如果只有文本,也可以直接使用字符串(参考 Gradio 的常见用法)
239
  if not files and text:
240
  user_message["content"] = text
241
- print("user_message:")
242
- print(user_message)
243
  history.append(user_message)
244
  yield history, gr.MultimodalTextbox(value=None, interactive=False)
245
 
@@ -259,8 +213,6 @@ def respond(user_msg, history, temp, tokens):
259
  yield history, gr.MultimodalTextbox(interactive=False)
260
 
261
  except Exception as e:
262
- import traceback
263
- traceback.print_exc()
264
  history[-1]["content"] = f"❌ Inference error: {str(e)}"
265
  yield history, gr.MultimodalTextbox(interactive=True)
266
 
@@ -345,7 +297,6 @@ if __name__ == "__main__":
345
  demo = create_chat_ui()
346
 
347
  if is_spaces:
348
- print(f"🚀 Running on Hugging Face Spaces: {os.getenv('SPACE_ID')}")
349
  allowed_paths = [project_dir] if project_dir else None
350
  demo.launch(
351
  theme=gr.themes.Soft(),
@@ -354,7 +305,6 @@ if __name__ == "__main__":
354
  allowed_paths=allowed_paths
355
  )
356
  else:
357
- print(f"🚀 Service is starting, please visit: http://localhost:7860")
358
  demo.launch(
359
  theme=gr.themes.Soft(),
360
  css=custom_css,
 
5
  from qwen_vl_utils import process_vision_info
6
  import html
7
 
 
8
  is_spaces = os.getenv("SPACE_ID") is not None
9
  spaces_available = False
10
  GPU = None
 
14
  from spaces import GPU
15
  spaces_available = True
16
  except ImportError:
17
+ pass
18
 
 
19
  def gpu_decorator(func):
 
20
  if spaces_available and GPU is not None:
21
  return GPU(func)
22
  return func
 
37
 
38
  MODEL_PATH = os.getenv("MODEL_PATH", "Jiaqi-hkust/Robust-R1-RL")
39
 
 
 
 
 
40
  class ModelHandler:
41
  def __init__(self, model_path):
42
  self.model_path = model_path
43
  self.model = None
44
  self.processor = None
 
45
 
46
  def _load_model(self):
 
47
  if self.model is not None:
48
+ return
49
 
50
  try:
 
 
51
  self.processor = AutoProcessor.from_pretrained(self.model_path)
52
 
 
 
53
  try:
54
  cuda_available = torch.cuda.is_available()
55
  if cuda_available:
 
 
56
  torch_dtype = torch.bfloat16
57
  else:
 
58
  torch_dtype = torch.float32
59
  except RuntimeError:
60
+ torch_dtype = torch.bfloat16
 
 
61
 
62
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
63
  self.model_path,
 
66
  attn_implementation="sdpa",
67
  trust_remote_code=True
68
  )
 
69
  except Exception as e:
 
70
  raise e
71
 
72
  def predict(self, message_dict, history, temperature, max_tokens):
 
73
  if self.model is None:
74
  self._load_model()
75
 
 
79
  messages = []
80
 
81
  if history:
 
82
  for msg in history:
83
  role = msg.get("role", "")
84
  content = msg.get("content", "")
 
118
 
119
  if current_content:
120
  messages.append({"role": "user", "content": current_content})
 
 
 
121
 
122
  text_prompt = self.processor.apply_chat_template(
123
  messages, tokenize=False, add_generation_prompt=True
 
143
  )
144
 
145
  try:
 
146
  with torch.no_grad():
147
  generated_ids = self.model.generate(**generation_kwargs)
148
 
149
  input_length = inputs['input_ids'].shape[1]
150
  generated_ids = generated_ids[0][input_length:]
151
 
 
 
152
  generated_text = self.processor.tokenizer.decode(
153
  generated_ids,
154
  skip_special_tokens=True
155
  )
156
 
 
 
157
  if generated_text and generated_text.strip():
 
158
  yield generated_text
159
  else:
160
  warning_msg = "⚠️ No output generated. The model may not have produced any response."
 
161
  yield warning_msg
162
 
163
  except Exception as e:
 
 
 
164
  yield f"❌ Generation error: {str(e)}"
165
  return
166
 
167
  model_handler = None
168
 
169
  def get_model_handler():
 
170
  global model_handler
171
  if model_handler is None:
 
172
  model_handler = ModelHandler(MODEL_PATH)
173
  return model_handler
174
 
 
179
 
180
  @gpu_decorator
181
  def respond(user_msg, history, temp, tokens):
 
 
182
  text = user_msg.get("text", "").strip()
183
  files = user_msg.get("files", [])
 
184
  user_message = {"role": "user", "content": []}
185
 
 
186
  for file_path in files:
187
  if file_path:
 
188
  abs_path = os.path.abspath(file_path) if not os.path.isabs(file_path) else file_path
189
  user_message["content"].append({"path": abs_path})
190
 
 
191
  if text:
192
  user_message["content"].append(text)
193
 
 
 
194
  if not files and text:
195
  user_message["content"] = text
196
+
 
197
  history.append(user_message)
198
  yield history, gr.MultimodalTextbox(value=None, interactive=False)
199
 
 
213
  yield history, gr.MultimodalTextbox(interactive=False)
214
 
215
  except Exception as e:
 
 
216
  history[-1]["content"] = f"❌ Inference error: {str(e)}"
217
  yield history, gr.MultimodalTextbox(interactive=True)
218
 
 
297
  demo = create_chat_ui()
298
 
299
  if is_spaces:
 
300
  allowed_paths = [project_dir] if project_dir else None
301
  demo.launch(
302
  theme=gr.themes.Soft(),
 
305
  allowed_paths=allowed_paths
306
  )
307
  else:
 
308
  demo.launch(
309
  theme=gr.themes.Soft(),
310
  css=custom_css,