nomid2 commited on
Commit
c425d75
·
verified ·
1 Parent(s): f2cbdd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -13
app.py CHANGED
@@ -172,6 +172,39 @@ def decode_base64_file(data_url: str) -> tuple[str, str, str]:
172
  logger.error(f"Failed to parse data URL: {e}")
173
  return None, None, None
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  def extract_content_from_message(message: Dict[str, Any]) -> tuple[str, List[str], List[Dict[str, str]]]:
176
  """
177
  从消息中提取文本内容、图片和文件
@@ -204,10 +237,16 @@ def extract_content_from_message(message: Dict[str, Any]) -> tuple[str, List[str
204
  try:
205
  if ";base64," in url:
206
  base64_data = url.split(";base64,")[1]
207
- images.append(base64_data)
 
 
208
  logger.info(f"Found base64 image, size: {len(base64_data)} chars")
 
 
209
  except Exception as e:
210
  logger.error(f"Error processing image: {e}")
 
 
211
 
212
  elif item_type == "file" or (item_type == "image_url" and not item.get("image_url", {}).get("url", "").startswith("data:image/")):
213
  # 处理文件上传
@@ -221,7 +260,8 @@ def extract_content_from_message(message: Dict[str, Any]) -> tuple[str, List[str
221
 
222
  if file_ext in SUPPORTED_IMAGE_EXTENSIONS and mime_type.startswith("image/"):
223
  # 图片文件
224
- images.append(file_content)
 
225
  logger.info(f"Found image file: {filename}")
226
  elif file_ext in SUPPORTED_TEXT_EXTENSIONS or mime_type.startswith("text/"):
227
  # 文本文件
@@ -278,6 +318,7 @@ def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override
278
  has_images = False
279
  has_files = False
280
  all_files = []
 
281
 
282
  for message in messages:
283
  if message.get("role") == "system":
@@ -297,6 +338,10 @@ def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override
297
 
298
  if image_list:
299
  has_images = True
 
 
 
 
300
  if file_list:
301
  has_files = True
302
  all_files.extend(file_list)
@@ -359,15 +404,10 @@ def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override
359
 
360
  replicate_input["prompt"] = prompt
361
 
362
- # 处理图片(只使用第一张图片)
363
- if has_images:
364
- # 找到最后一个包含图片的用户消息
365
- for msg in reversed(user_messages):
366
- if msg["role"] == "user" and msg["images"]:
367
- primary_image = msg["images"][0]
368
- replicate_input["image"] = f"data:image/jpeg;base64,{primary_image}"
369
- logger.info(f"Added primary image to request")
370
- break
371
 
372
  # 只在有 system_prompt 时才添加
373
  if system_prompt:
@@ -437,7 +477,8 @@ async def create_replicate_prediction(session: aiohttp.ClientSession, model: str
437
  log_data = data.copy()
438
  if "input" in log_data:
439
  if "image" in log_data["input"]:
440
- log_data["input"]["image"] = f"[IMAGE_DATA_{len(log_data['input']['image'])}]"
 
441
  if "prompt" in log_data["input"] and len(log_data["input"]["prompt"]) > 1000:
442
  log_data["input"]["prompt"] = log_data["input"]["prompt"][:1000] + "...[TRUNCATED]"
443
  logger.info(f"Request data: {json.dumps(log_data, indent=2)}")
@@ -522,7 +563,7 @@ async def root():
522
  "message": "Replicate API Proxy for LobeChat with Vision and File Support",
523
  "status": "running",
524
  "replicate_token_configured": bool(REPLICATE_API_TOKEN),
525
- "version": "1.1.0",
526
  "supported_models": list(MODEL_CONFIGS.keys()),
527
  "vision_support": True,
528
  "file_support": True,
 
172
  logger.error(f"Failed to parse data URL: {e}")
173
  return None, None, None
174
 
175
+ def format_image_for_replicate(base64_data: str) -> str:
176
+ """
177
+ 将 base64 图片数据格式化为 Replicate 期望的格式
178
+ """
179
+ # 检查 base64 数据是否已经包含 data URL 前缀
180
+ if base64_data.startswith("data:"):
181
+ return base64_data
182
+
183
+ # 如果没有前缀,添加默认的 JPEG data URL 前缀
184
+ # 但首先尝试检测实际的图片格式
185
+ try:
186
+ # 解码 base64 数据的前几个字节来检测格式
187
+ decoded_bytes = base64.b64decode(base64_data[:100])
188
+
189
+ if decoded_bytes.startswith(b'\xff\xd8\xff'):
190
+ # JPEG
191
+ return f"data:image/jpeg;base64,{base64_data}"
192
+ elif decoded_bytes.startswith(b'\x89PNG\r\n\x1a\n'):
193
+ # PNG
194
+ return f"data:image/png;base64,{base64_data}"
195
+ elif decoded_bytes.startswith(b'GIF87a') or decoded_bytes.startswith(b'GIF89a'):
196
+ # GIF
197
+ return f"data:image/gif;base64,{base64_data}"
198
+ elif decoded_bytes.startswith(b'RIFF') and b'WEBP' in decoded_bytes[:20]:
199
+ # WebP
200
+ return f"data:image/webp;base64,{base64_data}"
201
+ else:
202
+ # 默认使用 JPEG
203
+ return f"data:image/jpeg;base64,{base64_data}"
204
+ except Exception as e:
205
+ logger.warning(f"Failed to detect image format: {e}, using JPEG as default")
206
+ return f"data:image/jpeg;base64,{base64_data}"
207
+
208
  def extract_content_from_message(message: Dict[str, Any]) -> tuple[str, List[str], List[Dict[str, str]]]:
209
  """
210
  从消息中提取文本内容、图片和文件
 
237
  try:
238
  if ";base64," in url:
239
  base64_data = url.split(";base64,")[1]
240
+ # 格式化为 Replicate 期望的格式
241
+ formatted_image = format_image_for_replicate(base64_data)
242
+ images.append(formatted_image)
243
  logger.info(f"Found base64 image, size: {len(base64_data)} chars")
244
+ else:
245
+ logger.warning(f"Image URL format not supported: {url[:100]}...")
246
  except Exception as e:
247
  logger.error(f"Error processing image: {e}")
248
+ else:
249
+ logger.warning(f"External image URLs not supported: {url}")
250
 
251
  elif item_type == "file" or (item_type == "image_url" and not item.get("image_url", {}).get("url", "").startswith("data:image/")):
252
  # 处理文件上传
 
260
 
261
  if file_ext in SUPPORTED_IMAGE_EXTENSIONS and mime_type.startswith("image/"):
262
  # 图片文件
263
+ formatted_image = format_image_for_replicate(file_content)
264
+ images.append(formatted_image)
265
  logger.info(f"Found image file: {filename}")
266
  elif file_ext in SUPPORTED_TEXT_EXTENSIONS or mime_type.startswith("text/"):
267
  # 文本文件
 
318
  has_images = False
319
  has_files = False
320
  all_files = []
321
+ primary_image = None
322
 
323
  for message in messages:
324
  if message.get("role") == "system":
 
338
 
339
  if image_list:
340
  has_images = True
341
+ # 使用最后一个用户消息中的第一张图片作为主要图片
342
+ if message.get("role") == "user":
343
+ primary_image = image_list[0]
344
+
345
  if file_list:
346
  has_files = True
347
  all_files.extend(file_list)
 
404
 
405
  replicate_input["prompt"] = prompt
406
 
407
+ # 处理图片 - 使用正确的 data URL 格式
408
+ if has_images and primary_image:
409
+ replicate_input["image"] = primary_image
410
+ logger.info(f"Added primary image to request: {primary_image[:100]}...")
 
 
 
 
 
411
 
412
  # 只在有 system_prompt 时才添加
413
  if system_prompt:
 
477
  log_data = data.copy()
478
  if "input" in log_data:
479
  if "image" in log_data["input"]:
480
+ image_data = log_data["input"]["image"]
481
+ log_data["input"]["image"] = f"[IMAGE_DATA_{len(image_data)}]"
482
  if "prompt" in log_data["input"] and len(log_data["input"]["prompt"]) > 1000:
483
  log_data["input"]["prompt"] = log_data["input"]["prompt"][:1000] + "...[TRUNCATED]"
484
  logger.info(f"Request data: {json.dumps(log_data, indent=2)}")
 
563
  "message": "Replicate API Proxy for LobeChat with Vision and File Support",
564
  "status": "running",
565
  "replicate_token_configured": bool(REPLICATE_API_TOKEN),
566
+ "version": "1.1.1",
567
  "supported_models": list(MODEL_CONFIGS.keys()),
568
  "vision_support": True,
569
  "file_support": True,