Update app.py
Browse files
app.py
CHANGED
|
@@ -172,6 +172,39 @@ def decode_base64_file(data_url: str) -> tuple[str, str, str]:
|
|
| 172 |
logger.error(f"Failed to parse data URL: {e}")
|
| 173 |
return None, None, None
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
def extract_content_from_message(message: Dict[str, Any]) -> tuple[str, List[str], List[Dict[str, str]]]:
|
| 176 |
"""
|
| 177 |
从消息中提取文本内容、图片和文件
|
|
@@ -204,10 +237,16 @@ def extract_content_from_message(message: Dict[str, Any]) -> tuple[str, List[str
|
|
| 204 |
try:
|
| 205 |
if ";base64," in url:
|
| 206 |
base64_data = url.split(";base64,")[1]
|
| 207 |
-
|
|
|
|
|
|
|
| 208 |
logger.info(f"Found base64 image, size: {len(base64_data)} chars")
|
|
|
|
|
|
|
| 209 |
except Exception as e:
|
| 210 |
logger.error(f"Error processing image: {e}")
|
|
|
|
|
|
|
| 211 |
|
| 212 |
elif item_type == "file" or (item_type == "image_url" and not item.get("image_url", {}).get("url", "").startswith("data:image/")):
|
| 213 |
# 处理文件上传
|
|
@@ -221,7 +260,8 @@ def extract_content_from_message(message: Dict[str, Any]) -> tuple[str, List[str
|
|
| 221 |
|
| 222 |
if file_ext in SUPPORTED_IMAGE_EXTENSIONS and mime_type.startswith("image/"):
|
| 223 |
# 图片文件
|
| 224 |
-
|
|
|
|
| 225 |
logger.info(f"Found image file: {filename}")
|
| 226 |
elif file_ext in SUPPORTED_TEXT_EXTENSIONS or mime_type.startswith("text/"):
|
| 227 |
# 文本文件
|
|
@@ -278,6 +318,7 @@ def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override
|
|
| 278 |
has_images = False
|
| 279 |
has_files = False
|
| 280 |
all_files = []
|
|
|
|
| 281 |
|
| 282 |
for message in messages:
|
| 283 |
if message.get("role") == "system":
|
|
@@ -297,6 +338,10 @@ def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override
|
|
| 297 |
|
| 298 |
if image_list:
|
| 299 |
has_images = True
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
if file_list:
|
| 301 |
has_files = True
|
| 302 |
all_files.extend(file_list)
|
|
@@ -359,15 +404,10 @@ def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override
|
|
| 359 |
|
| 360 |
replicate_input["prompt"] = prompt
|
| 361 |
|
| 362 |
-
#
|
| 363 |
-
if has_images:
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
if msg["role"] == "user" and msg["images"]:
|
| 367 |
-
primary_image = msg["images"][0]
|
| 368 |
-
replicate_input["image"] = f"data:image/jpeg;base64,{primary_image}"
|
| 369 |
-
logger.info(f"Added primary image to request")
|
| 370 |
-
break
|
| 371 |
|
| 372 |
# 只在有 system_prompt 时才添加
|
| 373 |
if system_prompt:
|
|
@@ -437,7 +477,8 @@ async def create_replicate_prediction(session: aiohttp.ClientSession, model: str
|
|
| 437 |
log_data = data.copy()
|
| 438 |
if "input" in log_data:
|
| 439 |
if "image" in log_data["input"]:
|
| 440 |
-
log_data["input"]["image"]
|
|
|
|
| 441 |
if "prompt" in log_data["input"] and len(log_data["input"]["prompt"]) > 1000:
|
| 442 |
log_data["input"]["prompt"] = log_data["input"]["prompt"][:1000] + "...[TRUNCATED]"
|
| 443 |
logger.info(f"Request data: {json.dumps(log_data, indent=2)}")
|
|
@@ -522,7 +563,7 @@ async def root():
|
|
| 522 |
"message": "Replicate API Proxy for LobeChat with Vision and File Support",
|
| 523 |
"status": "running",
|
| 524 |
"replicate_token_configured": bool(REPLICATE_API_TOKEN),
|
| 525 |
-
"version": "1.1.
|
| 526 |
"supported_models": list(MODEL_CONFIGS.keys()),
|
| 527 |
"vision_support": True,
|
| 528 |
"file_support": True,
|
|
|
|
| 172 |
logger.error(f"Failed to parse data URL: {e}")
|
| 173 |
return None, None, None
|
| 174 |
|
| 175 |
+
def format_image_for_replicate(base64_data: str) -> str:
|
| 176 |
+
"""
|
| 177 |
+
将 base64 图片数据格式化为 Replicate 期望的格式
|
| 178 |
+
"""
|
| 179 |
+
# 检查 base64 数据是否已经包含 data URL 前缀
|
| 180 |
+
if base64_data.startswith("data:"):
|
| 181 |
+
return base64_data
|
| 182 |
+
|
| 183 |
+
# 如果没有前缀,添加默认的 JPEG data URL 前缀
|
| 184 |
+
# 但首先尝试检测实际的图片格式
|
| 185 |
+
try:
|
| 186 |
+
# 解码 base64 数据的前几个字节来检测格式
|
| 187 |
+
decoded_bytes = base64.b64decode(base64_data[:100])
|
| 188 |
+
|
| 189 |
+
if decoded_bytes.startswith(b'\xff\xd8\xff'):
|
| 190 |
+
# JPEG
|
| 191 |
+
return f"data:image/jpeg;base64,{base64_data}"
|
| 192 |
+
elif decoded_bytes.startswith(b'\x89PNG\r\n\x1a\n'):
|
| 193 |
+
# PNG
|
| 194 |
+
return f"data:image/png;base64,{base64_data}"
|
| 195 |
+
elif decoded_bytes.startswith(b'GIF87a') or decoded_bytes.startswith(b'GIF89a'):
|
| 196 |
+
# GIF
|
| 197 |
+
return f"data:image/gif;base64,{base64_data}"
|
| 198 |
+
elif decoded_bytes.startswith(b'RIFF') and b'WEBP' in decoded_bytes[:20]:
|
| 199 |
+
# WebP
|
| 200 |
+
return f"data:image/webp;base64,{base64_data}"
|
| 201 |
+
else:
|
| 202 |
+
# 默认使用 JPEG
|
| 203 |
+
return f"data:image/jpeg;base64,{base64_data}"
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.warning(f"Failed to detect image format: {e}, using JPEG as default")
|
| 206 |
+
return f"data:image/jpeg;base64,{base64_data}"
|
| 207 |
+
|
| 208 |
def extract_content_from_message(message: Dict[str, Any]) -> tuple[str, List[str], List[Dict[str, str]]]:
|
| 209 |
"""
|
| 210 |
从消息中提取文本内容、图片和文件
|
|
|
|
| 237 |
try:
|
| 238 |
if ";base64," in url:
|
| 239 |
base64_data = url.split(";base64,")[1]
|
| 240 |
+
# 格式化为 Replicate 期望的格式
|
| 241 |
+
formatted_image = format_image_for_replicate(base64_data)
|
| 242 |
+
images.append(formatted_image)
|
| 243 |
logger.info(f"Found base64 image, size: {len(base64_data)} chars")
|
| 244 |
+
else:
|
| 245 |
+
logger.warning(f"Image URL format not supported: {url[:100]}...")
|
| 246 |
except Exception as e:
|
| 247 |
logger.error(f"Error processing image: {e}")
|
| 248 |
+
else:
|
| 249 |
+
logger.warning(f"External image URLs not supported: {url}")
|
| 250 |
|
| 251 |
elif item_type == "file" or (item_type == "image_url" and not item.get("image_url", {}).get("url", "").startswith("data:image/")):
|
| 252 |
# 处理文件上传
|
|
|
|
| 260 |
|
| 261 |
if file_ext in SUPPORTED_IMAGE_EXTENSIONS and mime_type.startswith("image/"):
|
| 262 |
# 图片文件
|
| 263 |
+
formatted_image = format_image_for_replicate(file_content)
|
| 264 |
+
images.append(formatted_image)
|
| 265 |
logger.info(f"Found image file: {filename}")
|
| 266 |
elif file_ext in SUPPORTED_TEXT_EXTENSIONS or mime_type.startswith("text/"):
|
| 267 |
# 文本文件
|
|
|
|
| 318 |
has_images = False
|
| 319 |
has_files = False
|
| 320 |
all_files = []
|
| 321 |
+
primary_image = None
|
| 322 |
|
| 323 |
for message in messages:
|
| 324 |
if message.get("role") == "system":
|
|
|
|
| 338 |
|
| 339 |
if image_list:
|
| 340 |
has_images = True
|
| 341 |
+
# 使用最后一个用户消息中的第一张图片作为主要图片
|
| 342 |
+
if message.get("role") == "user":
|
| 343 |
+
primary_image = image_list[0]
|
| 344 |
+
|
| 345 |
if file_list:
|
| 346 |
has_files = True
|
| 347 |
all_files.extend(file_list)
|
|
|
|
| 404 |
|
| 405 |
replicate_input["prompt"] = prompt
|
| 406 |
|
| 407 |
+
# 处理图片 - 使用正确的 data URL 格式
|
| 408 |
+
if has_images and primary_image:
|
| 409 |
+
replicate_input["image"] = primary_image
|
| 410 |
+
logger.info(f"Added primary image to request: {primary_image[:100]}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
|
| 412 |
# 只在有 system_prompt 时才添加
|
| 413 |
if system_prompt:
|
|
|
|
| 477 |
log_data = data.copy()
|
| 478 |
if "input" in log_data:
|
| 479 |
if "image" in log_data["input"]:
|
| 480 |
+
image_data = log_data["input"]["image"]
|
| 481 |
+
log_data["input"]["image"] = f"[IMAGE_DATA_{len(image_data)}]"
|
| 482 |
if "prompt" in log_data["input"] and len(log_data["input"]["prompt"]) > 1000:
|
| 483 |
log_data["input"]["prompt"] = log_data["input"]["prompt"][:1000] + "...[TRUNCATED]"
|
| 484 |
logger.info(f"Request data: {json.dumps(log_data, indent=2)}")
|
|
|
|
| 563 |
"message": "Replicate API Proxy for LobeChat with Vision and File Support",
|
| 564 |
"status": "running",
|
| 565 |
"replicate_token_configured": bool(REPLICATE_API_TOKEN),
|
| 566 |
+
"version": "1.1.1",
|
| 567 |
"supported_models": list(MODEL_CONFIGS.keys()),
|
| 568 |
"vision_support": True,
|
| 569 |
"file_support": True,
|