Update app.py
Browse files
app.py
CHANGED
|
@@ -182,7 +182,7 @@ def decode_base64_file(data_url: str) -> tuple[str, str, str]:
|
|
| 182 |
logger.error(f"Failed to parse data URL: {e}")
|
| 183 |
return None, None, None
|
| 184 |
|
| 185 |
-
async def upload_image_to_imgbb(
|
| 186 |
"""
|
| 187 |
将 base64 图片上传到 imgbb
|
| 188 |
返回图片的 URL
|
|
@@ -203,21 +203,22 @@ async def upload_image_to_imgbb(session: aiohttp.ClientSession, base64_data: str
|
|
| 203 |
|
| 204 |
logger.info(f"Uploading image to imgbb, size: {len(base64_content)} chars")
|
| 205 |
|
| 206 |
-
# 上传到 imgbb
|
| 207 |
-
async with
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
else:
|
| 215 |
-
|
|
|
|
| 216 |
return None
|
| 217 |
-
else:
|
| 218 |
-
error_text = await response.text()
|
| 219 |
-
logger.error(f"imgbb upload error: {response.status} - {error_text}")
|
| 220 |
-
return None
|
| 221 |
|
| 222 |
except asyncio.TimeoutError:
|
| 223 |
logger.error("Timeout uploading image to imgbb")
|
|
@@ -226,7 +227,7 @@ async def upload_image_to_imgbb(session: aiohttp.ClientSession, base64_data: str
|
|
| 226 |
logger.error(f"Failed to upload image to imgbb: {e}")
|
| 227 |
return None
|
| 228 |
|
| 229 |
-
async def format_image_for_model(
|
| 230 |
"""
|
| 231 |
根据模型配置格式化图片数据
|
| 232 |
"""
|
|
@@ -234,7 +235,7 @@ async def format_image_for_model(session: aiohttp.ClientSession, base64_data: st
|
|
| 234 |
|
| 235 |
if image_format == "url":
|
| 236 |
# 需要上传图片到 imgbb 并返回 URL
|
| 237 |
-
image_url = await upload_image_to_imgbb(
|
| 238 |
if image_url:
|
| 239 |
return image_url
|
| 240 |
else:
|
|
@@ -379,7 +380,7 @@ def format_files_for_prompt(files: List[Dict[str, str]]) -> str:
|
|
| 379 |
|
| 380 |
return "\n".join(file_sections)
|
| 381 |
|
| 382 |
-
async def transform_openai_to_replicate(
|
| 383 |
"""将OpenAI格式的请求转换为Replicate格式"""
|
| 384 |
try:
|
| 385 |
messages = openai_request.get("messages", [])
|
|
@@ -449,7 +450,7 @@ async def transform_openai_to_replicate(session: aiohttp.ClientSession, openai_r
|
|
| 449 |
formatted_image = None
|
| 450 |
if has_images and primary_image:
|
| 451 |
logger.info(f"Processing image for model {model} with format {model_config.get('image_format')}")
|
| 452 |
-
formatted_image = await format_image_for_model(
|
| 453 |
|
| 454 |
if not formatted_image:
|
| 455 |
logger.error("Failed to format image for model")
|
|
@@ -652,7 +653,7 @@ async def root():
|
|
| 652 |
"status": "running",
|
| 653 |
"replicate_token_configured": bool(REPLICATE_API_TOKEN),
|
| 654 |
"imgbb_token_configured": bool(IMGBB_API_KEY),
|
| 655 |
-
"version": "1.2.
|
| 656 |
"supported_models": list(MODEL_CONFIGS.keys()),
|
| 657 |
"vision_support": True,
|
| 658 |
"file_support": True,
|
|
@@ -703,13 +704,14 @@ async def chat_completions(request: Request):
|
|
| 703 |
logger.info(f"Client parameters: max_tokens={body.get('max_tokens', 'not set')}, temperature={body.get('temperature', 'not set')}")
|
| 704 |
logger.info(f"Message count: {len(body.get('messages', []))}")
|
| 705 |
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
|
|
|
| 713 |
try:
|
| 714 |
# 创建预测
|
| 715 |
prediction = await create_replicate_prediction(session, model, replicate_data)
|
|
@@ -786,6 +788,7 @@ async def chat_completions(request: Request):
|
|
| 786 |
yield "data: [DONE]\n\n"
|
| 787 |
except Exception as e:
|
| 788 |
logger.error(f"Stream generation error: {e}")
|
|
|
|
| 789 |
error_response = {
|
| 790 |
"error": {
|
| 791 |
"message": str(e),
|
|
@@ -793,20 +796,21 @@ async def chat_completions(request: Request):
|
|
| 793 |
}
|
| 794 |
}
|
| 795 |
yield f"data: {json.dumps(error_response)}\n\n"
|
| 796 |
-
|
| 797 |
-
return StreamingResponse(
|
| 798 |
-
generate_stream(),
|
| 799 |
-
media_type="text/event-stream",
|
| 800 |
-
headers={
|
| 801 |
-
"Cache-Control": "no-cache",
|
| 802 |
-
"Connection": "keep-alive",
|
| 803 |
-
"Access-Control-Allow-Origin": "*",
|
| 804 |
-
"X-Accel-Buffering": "no",
|
| 805 |
-
}
|
| 806 |
-
)
|
| 807 |
|
| 808 |
-
|
| 809 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 810 |
# 创建预测
|
| 811 |
prediction = await create_replicate_prediction(session, model, replicate_data)
|
| 812 |
prediction_id = prediction.get('id')
|
|
|
|
| 182 |
logger.error(f"Failed to parse data URL: {e}")
|
| 183 |
return None, None, None
|
| 184 |
|
| 185 |
+
async def upload_image_to_imgbb(base64_data: str) -> str:
|
| 186 |
"""
|
| 187 |
将 base64 图片上传到 imgbb
|
| 188 |
返回图片的 URL
|
|
|
|
| 203 |
|
| 204 |
logger.info(f"Uploading image to imgbb, size: {len(base64_content)} chars")
|
| 205 |
|
| 206 |
+
# 使用独立的 session 上传到 imgbb
|
| 207 |
+
async with aiohttp.ClientSession() as session:
|
| 208 |
+
async with session.post(IMGBB_API_URL, data=data, timeout=30) as response:
|
| 209 |
+
if response.status == 200:
|
| 210 |
+
result = await response.json()
|
| 211 |
+
if result.get('success'):
|
| 212 |
+
image_url = result['data']['url']
|
| 213 |
+
logger.info(f"Image uploaded successfully: {image_url}")
|
| 214 |
+
return image_url
|
| 215 |
+
else:
|
| 216 |
+
logger.error(f"imgbb upload failed: {result}")
|
| 217 |
+
return None
|
| 218 |
else:
|
| 219 |
+
error_text = await response.text()
|
| 220 |
+
logger.error(f"imgbb upload error: {response.status} - {error_text}")
|
| 221 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
except asyncio.TimeoutError:
|
| 224 |
logger.error("Timeout uploading image to imgbb")
|
|
|
|
| 227 |
logger.error(f"Failed to upload image to imgbb: {e}")
|
| 228 |
return None
|
| 229 |
|
| 230 |
+
async def format_image_for_model(base64_data: str, model_config: Dict[str, Any]) -> str:
|
| 231 |
"""
|
| 232 |
根据模型配置格式化图片数据
|
| 233 |
"""
|
|
|
|
| 235 |
|
| 236 |
if image_format == "url":
|
| 237 |
# 需要上传图片到 imgbb 并返回 URL
|
| 238 |
+
image_url = await upload_image_to_imgbb(base64_data)
|
| 239 |
if image_url:
|
| 240 |
return image_url
|
| 241 |
else:
|
|
|
|
| 380 |
|
| 381 |
return "\n".join(file_sections)
|
| 382 |
|
| 383 |
+
async def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override: str = None) -> Dict[str, Any]:
|
| 384 |
"""将OpenAI格式的请求转换为Replicate格式"""
|
| 385 |
try:
|
| 386 |
messages = openai_request.get("messages", [])
|
|
|
|
| 450 |
formatted_image = None
|
| 451 |
if has_images and primary_image:
|
| 452 |
logger.info(f"Processing image for model {model} with format {model_config.get('image_format')}")
|
| 453 |
+
formatted_image = await format_image_for_model(primary_image, model_config)
|
| 454 |
|
| 455 |
if not formatted_image:
|
| 456 |
logger.error("Failed to format image for model")
|
|
|
|
| 653 |
"status": "running",
|
| 654 |
"replicate_token_configured": bool(REPLICATE_API_TOKEN),
|
| 655 |
"imgbb_token_configured": bool(IMGBB_API_KEY),
|
| 656 |
+
"version": "1.2.1",
|
| 657 |
"supported_models": list(MODEL_CONFIGS.keys()),
|
| 658 |
"vision_support": True,
|
| 659 |
"file_support": True,
|
|
|
|
| 704 |
logger.info(f"Client parameters: max_tokens={body.get('max_tokens', 'not set')}, temperature={body.get('temperature', 'not set')}")
|
| 705 |
logger.info(f"Message count: {len(body.get('messages', []))}")
|
| 706 |
|
| 707 |
+
# 转换请求格式(不依赖 session)
|
| 708 |
+
replicate_data, model = await transform_openai_to_replicate(body)
|
| 709 |
+
|
| 710 |
+
if body.get("stream", False):
|
| 711 |
+
# 流式响应
|
| 712 |
+
async def generate_stream():
|
| 713 |
+
# 在生成器内部创建独立的 session
|
| 714 |
+
async with aiohttp.ClientSession() as session:
|
| 715 |
try:
|
| 716 |
# 创建预测
|
| 717 |
prediction = await create_replicate_prediction(session, model, replicate_data)
|
|
|
|
| 788 |
yield "data: [DONE]\n\n"
|
| 789 |
except Exception as e:
|
| 790 |
logger.error(f"Stream generation error: {e}")
|
| 791 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
| 792 |
error_response = {
|
| 793 |
"error": {
|
| 794 |
"message": str(e),
|
|
|
|
| 796 |
}
|
| 797 |
}
|
| 798 |
yield f"data: {json.dumps(error_response)}\n\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 799 |
|
| 800 |
+
return StreamingResponse(
|
| 801 |
+
generate_stream(),
|
| 802 |
+
media_type="text/event-stream",
|
| 803 |
+
headers={
|
| 804 |
+
"Cache-Control": "no-cache",
|
| 805 |
+
"Connection": "keep-alive",
|
| 806 |
+
"Access-Control-Allow-Origin": "*",
|
| 807 |
+
"X-Accel-Buffering": "no",
|
| 808 |
+
}
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
else:
|
| 812 |
+
# 非流式响应
|
| 813 |
+
async with aiohttp.ClientSession() as session:
|
| 814 |
# 创建预测
|
| 815 |
prediction = await create_replicate_prediction(session, model, replicate_data)
|
| 816 |
prediction_id = prediction.get('id')
|