Jiaqi-hkust commited on
Commit
79138c6
·
verified ·
1 Parent(s): b44b983

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +142 -95
app.py CHANGED
@@ -81,26 +81,56 @@ class ModelHandler:
81
  print(f"❌ Model loading failed: {e}")
82
  raise e
83
 
84
- def predict(self, messages, temperature, max_tokens):
85
- # 这里的 messages 已经是处理好的标准 OpenAI 格式列表
 
 
 
86
 
87
- # 将 sys_prompt 注入到最后一条用户消息中
88
- if messages and messages[-1]["role"] == "user":
89
- content = messages[-1]["content"]
90
- sys_prompt_fmt = "\n" + " ".join(sys_prompt.split())
91
-
92
- if isinstance(content, str):
93
- messages[-1]["content"] += sys_prompt_fmt
94
- elif isinstance(content, list):
95
- # 查找文本部分并追加,如果没有则添加
96
- text_found = False
97
- for item in content:
98
- if item.get("type") == "text":
99
- item["text"] += sys_prompt_fmt
100
- text_found = True
101
- break
102
- if not text_found:
103
- content.append({"type": "text", "text": sys_prompt_fmt})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  text_prompt = self.processor.apply_chat_template(
106
  messages, tokenize=False, add_generation_prompt=True
@@ -125,17 +155,37 @@ class ModelHandler:
125
  do_sample=True if temperature > 0 else False,
126
  )
127
 
128
- with torch.no_grad():
129
- generated_ids = self.model.generate(**generation_kwargs)
130
-
131
- input_length = inputs['input_ids'].shape[1]
132
- generated_ids = generated_ids[0][input_length:]
133
-
134
- generated_text = self.processor.tokenizer.decode(
135
- generated_ids, skip_special_tokens=True
136
- )
137
-
138
- yield generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  model_handler = None
141
 
@@ -147,116 +197,113 @@ def get_model_handler():
147
  model_handler = ModelHandler(MODEL_PATH)
148
  return model_handler
149
 
150
- def _history_to_messages(history):
151
- """
152
- 将 Gradio 的 Tuple 历史 [[user, bot], ...] 转换为 OpenAI 格式的消息列表。
153
- 以便发送给模型。
154
- """
155
  messages = []
156
-
157
  for pair in history:
158
- user_msg, bot_msg = pair
159
-
160
- # --- 处理用户消息 ---
161
- if user_msg:
162
- # 检查是否是文件路径(图片)
163
- # Gradio 中图片通常是临时路径,或者 http 链接
164
- is_image = False
165
  if isinstance(user_msg, str):
166
- if os.path.exists(user_msg) or user_msg.startswith("http"):
167
- # 简单的判断:如果是现有路径或者是URL,且看起来像图片
168
- lower_msg = user_msg.lower()
169
- if any(lower_msg.endswith(ext) for ext in ['.jpg', '.png', '.jpeg', '.webp', '.bmp']):
170
- is_image = True
 
 
 
 
 
 
 
171
 
172
- # 构建 User Content
173
- if is_image:
174
- # 这是一个独立的图片消息
175
- # 注意:为了模型效果,最好将图片和紧接着的文本合并。
176
- # 但为了代码简单,我们先作为独立消息,大多数 VLM 也能处理。
177
- messages.append({
178
- "role": "user",
179
- "content": [{"type": "image", "image": user_msg}]
180
- })
181
- else:
182
- # 这是一个文本消息
183
- # 如果上一条也是 user 且是 image,尝试合并(可选,这里简单起见直接 append)
184
- messages.append({
185
- "role": "user",
186
- "content": [{"type": "text", "text": str(user_msg)}]
187
- })
188
-
189
- # --- 处理机器人消息 ---
190
- if bot_msg:
191
- messages.append({
192
- "role": "assistant",
193
- "content": [{"type": "text", "text": str(bot_msg)}]
194
- })
195
 
 
 
 
 
 
 
196
  return messages
197
 
198
- # ==========================================
 
 
 
 
 
 
 
 
 
 
 
199
  @gpu_decorator
200
- def respond(user_input, history, temp, tokens):
201
  """
202
- user_input: MultimodalTextbox 返回的字典 {'text': '...', 'files': [...]}
203
- history: 这里的 history Gradio 自动维护的 [{'role': 'user', ...}, ...] 列表
204
  """
205
 
206
- # 1. 构建当前用户的消息对象
207
  user_content = []
208
 
209
- # 处理图片
210
- files = user_input.get("files", [])
211
  for file_path in files:
212
  user_content.append({"type": "image", "image": file_path})
213
 
214
  # 处理文本
215
- text = user_input.get("text", "")
216
  if text:
217
  user_content.append({"type": "text", "text": text})
218
 
219
- # 如果没有内容,直接返回
220
  if not user_content:
221
  yield history, gr.MultimodalTextbox(interactive=True)
222
  return
223
 
224
- # 2. 将用户消息加入历史
225
- # 注意:history 必须是 [{'role': 'user', 'content': [...]}, ...] 格式
226
- history.append({"role": "user", "content": user_content})
227
 
228
- # 立即更新 UI,清空输入框
229
  yield history, gr.MultimodalTextbox(value=None, interactive=False)
230
 
 
231
  try:
232
  handler = get_model_handler()
233
 
234
- # 3. 预先加入一个空的助手消息,用于流式显示
235
  history.append({"role": "assistant", "content": ""})
236
 
237
- # 4. 调用模型
238
- # 我们传入 history 的副本(不包含刚才加的空 assistant),以免污染
239
- input_messages = copy.deepcopy(history[:-1])
 
 
240
 
241
- full_response = ""
242
- for chunk in handler.predict(input_messages, temp, tokens):
243
- full_response += chunk
244
  # 实时更新最后一条消息的内容
245
- history[-1]["content"] = full_response
246
  yield history, gr.MultimodalTextbox(interactive=False)
247
 
248
  except Exception as e:
249
  import traceback
250
  traceback.print_exc()
251
- # 出错时显示错误信息
252
  if history and history[-1]["role"] == "assistant":
253
  history[-1]["content"] += f"\n❌ Error: {str(e)}"
254
  else:
255
  history.append({"role": "assistant", "content": f"❌ Error: {str(e)}"})
256
-
257
  yield history, gr.MultimodalTextbox(interactive=True)
258
 
259
- # 最后恢复输入框
260
  yield history, gr.MultimodalTextbox(interactive=True)
261
 
262
  def create_chat_ui():
 
81
  print(f"❌ Model loading failed: {e}")
82
  raise e
83
 
84
+ def predict(self, message_dict, history, temperature, max_tokens):
85
+ text = message_dict.get("text", "")
86
+ files = message_dict.get("files", [])
87
+
88
+ messages = []
89
 
90
+ if history:
91
+ print(f"Processing {len(history)} previous messages from history")
92
+ for msg in history:
93
+ role = msg.get("role", "")
94
+ content = msg.get("content", "")
95
+
96
+ if role == "user":
97
+ user_content = []
98
+
99
+ if isinstance(content, list):
100
+ for item in content:
101
+ if isinstance(item, str):
102
+ if os.path.exists(item) or any(item.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']):
103
+ user_content.append({"type": "image", "image": item})
104
+ else:
105
+ user_content.append({"type": "text", "text": item})
106
+ elif isinstance(item, dict):
107
+ user_content.append(item)
108
+ elif isinstance(content, str):
109
+ if content:
110
+ user_content.append({"type": "text", "text": content})
111
+
112
+ if user_content:
113
+ messages.append({"role": "user", "content": user_content})
114
+
115
+ elif role == "assistant":
116
+ if isinstance(content, str) and content:
117
+ messages.append({"role": "assistant", "content": content})
118
+
119
+ current_content = []
120
+ if files:
121
+ for file_path in files:
122
+ current_content.append({"type": "image", "image": file_path})
123
+
124
+ if text:
125
+ sys_prompt_formatted = " ".join(sys_prompt.split())
126
+ full_text = f"{text}\n{sys_prompt_formatted}"
127
+ current_content.append({"type": "text", "text": full_text})
128
+
129
+ if current_content:
130
+ messages.append({"role": "user", "content": current_content})
131
+
132
+ print(f"Total messages for model: {len(messages)}")
133
+ print(f"Message roles: {[m['role'] for m in messages]}")
134
 
135
  text_prompt = self.processor.apply_chat_template(
136
  messages, tokenize=False, add_generation_prompt=True
 
155
  do_sample=True if temperature > 0 else False,
156
  )
157
 
158
+ try:
159
+ print("Starting model generation...")
160
+ with torch.no_grad():
161
+ generated_ids = self.model.generate(**generation_kwargs)
162
+
163
+ input_length = inputs['input_ids'].shape[1]
164
+ generated_ids = generated_ids[0][input_length:]
165
+
166
+ print(f"Input length: {input_length}, Generated token count: {len(generated_ids)}")
167
+
168
+ generated_text = self.processor.tokenizer.decode(
169
+ generated_ids,
170
+ skip_special_tokens=True
171
+ )
172
+
173
+ print(f"Generation completed. Output length: {len(generated_text)}, Content preview: {repr(generated_text[:200])}")
174
+
175
+ if generated_text and generated_text.strip():
176
+ print(f"Yielding generated text: {generated_text[:100]}...")
177
+ yield generated_text
178
+ else:
179
+ warning_msg = "⚠️ No output generated. The model may not have produced any response."
180
+ print(warning_msg)
181
+ yield warning_msg
182
+
183
+ except Exception as e:
184
+ import traceback
185
+ error_details = traceback.format_exc()
186
+ print(f"Error in model.generate: {error_details}")
187
+ yield f"❌ Generation error: {str(e)}"
188
+ return
189
 
190
  model_handler = None
191
 
 
197
  model_handler = ModelHandler(MODEL_PATH)
198
  return model_handler
199
 
200
+ def _convert_history_to_messages_format(history):
201
+ """将旧格式的 Chatbot history 转换为新格式的 messages"""
 
 
 
202
  messages = []
 
203
  for pair in history:
204
+ if isinstance(pair, list) and len(pair) >= 2:
205
+ user_msg = pair[0]
206
+ assistant_msg = pair[1] if len(pair) > 1 else ""
207
+
208
+ # 处理用户消息
209
+ user_content = []
 
210
  if isinstance(user_msg, str):
211
+ user_content.append({"type": "text", "text": user_msg})
212
+ elif isinstance(user_msg, tuple):
213
+ # 旧格式可能是 (text, image) 或 (image, text)
214
+ for item in user_msg:
215
+ if isinstance(item, str):
216
+ if os.path.exists(item) or any(item.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']):
217
+ user_content.append({"type": "image", "image": item})
218
+ else:
219
+ user_content.append({"type": "text", "text": item})
220
+ elif isinstance(user_msg, list):
221
+ # 可能是新格式的内容列表
222
+ user_content = user_msg
223
 
224
+ if user_content:
225
+ messages.append({"role": "user", "content": user_content})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
+ # 处理助手消息
228
+ if assistant_msg and isinstance(assistant_msg, str):
229
+ messages.append({"role": "assistant", "content": assistant_msg})
230
+ elif isinstance(pair, dict):
231
+ # 如果已经是新格式,直接使用
232
+ messages.append(pair)
233
  return messages
234
 
235
+ def _format_user_input_for_chatbot(text, files):
236
+ """格式化用户输入为 Chatbot 可显示的格式"""
237
+ if files and text:
238
+ # 有图片和文本,返回元组格式
239
+ return (text, *files)
240
+ elif files:
241
+ # 只有图片
242
+ return files[0] if len(files) == 1 else tuple(files)
243
+ else:
244
+ # 只有文本
245
+ return text
246
+
247
  @gpu_decorator
248
+ def respond(user_msg, history, temp, tokens):
249
  """
250
+ user_msg: Gradio MultimodalTextbox 返回的字典 {'text': '...', 'files': ['...']}
251
+ history: Gradio Chatbot (type='messages') 维护的列表 [{'role': 'user', 'content': ...}, ...]
252
  """
253
 
254
+ # 1. 解析用户输入,构建符合 OpenAI 格式的 User Message
255
  user_content = []
256
 
257
+ # 处理图片文件
258
+ files = user_msg.get("files", [])
259
  for file_path in files:
260
  user_content.append({"type": "image", "image": file_path})
261
 
262
  # 处理文本
263
+ text = user_msg.get("text", "")
264
  if text:
265
  user_content.append({"type": "text", "text": text})
266
 
 
267
  if not user_content:
268
  yield history, gr.MultimodalTextbox(interactive=True)
269
  return
270
 
271
+ # 2. 将当前用户消息加入历史记录
272
+ new_message = {"role": "user", "content": user_content}
273
+ history.append(new_message)
274
 
275
+ # 3. 立即 yield 更新 UI(显示用户消息),同时清空输入框
276
  yield history, gr.MultimodalTextbox(value=None, interactive=False)
277
 
278
+ # 4. 准备调用模型
279
  try:
280
  handler = get_model_handler()
281
 
282
+ # 预先加入一个空的 Assistant 消息用于流式填充
283
  history.append({"role": "assistant", "content": ""})
284
 
285
+ # 调用 predict (注意:这里直接传 history 即可,因为 history 已经是完整的上下文)
286
+ # 我们传入 history 的深拷贝以防修改影响 UI,或者直接传引用如果 predict 内部做了处理
287
+ # 这里为了安全,我们只把 history 传进去,predict 内部会处理 sys_prompt 的追加逻辑
288
+ import copy
289
+ messages_payload = copy.deepcopy(history[:-1]) # 去掉刚才加的空 assistant
290
 
291
+ for chunk in handler.predict(messages_payload, temp, tokens):
 
 
292
  # 实时更新最后一条消息的内容
293
+ history[-1]["content"] = chunk
294
  yield history, gr.MultimodalTextbox(interactive=False)
295
 
296
  except Exception as e:
297
  import traceback
298
  traceback.print_exc()
299
+ # 发生错误时,追加错误信息
300
  if history and history[-1]["role"] == "assistant":
301
  history[-1]["content"] += f"\n❌ Error: {str(e)}"
302
  else:
303
  history.append({"role": "assistant", "content": f"❌ Error: {str(e)}"})
 
304
  yield history, gr.MultimodalTextbox(interactive=True)
305
 
306
+ # 最后恢复输入框可交互
307
  yield history, gr.MultimodalTextbox(interactive=True)
308
 
309
  def create_chat_ui():