nomid2 commited on
Commit
dafb0bc
·
verified ·
1 Parent(s): d609f98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -39
app.py CHANGED
@@ -182,7 +182,7 @@ def decode_base64_file(data_url: str) -> tuple[str, str, str]:
182
  logger.error(f"Failed to parse data URL: {e}")
183
  return None, None, None
184
 
185
- async def upload_image_to_imgbb(session: aiohttp.ClientSession, base64_data: str) -> str:
186
  """
187
  将 base64 图片上传到 imgbb
188
  返回图片的 URL
@@ -203,21 +203,22 @@ async def upload_image_to_imgbb(session: aiohttp.ClientSession, base64_data: str
203
 
204
  logger.info(f"Uploading image to imgbb, size: {len(base64_content)} chars")
205
 
206
- # 上传到 imgbb
207
- async with session.post(IMGBB_API_URL, data=data, timeout=30) as response:
208
- if response.status == 200:
209
- result = await response.json()
210
- if result.get('success'):
211
- image_url = result['data']['url']
212
- logger.info(f"Image uploaded successfully: {image_url}")
213
- return image_url
 
 
 
 
214
  else:
215
- logger.error(f"imgbb upload failed: {result}")
 
216
  return None
217
- else:
218
- error_text = await response.text()
219
- logger.error(f"imgbb upload error: {response.status} - {error_text}")
220
- return None
221
 
222
  except asyncio.TimeoutError:
223
  logger.error("Timeout uploading image to imgbb")
@@ -226,7 +227,7 @@ async def upload_image_to_imgbb(session: aiohttp.ClientSession, base64_data: str
226
  logger.error(f"Failed to upload image to imgbb: {e}")
227
  return None
228
 
229
- async def format_image_for_model(session: aiohttp.ClientSession, base64_data: str, model_config: Dict[str, Any]) -> str:
230
  """
231
  根据模型配置格式化图片数据
232
  """
@@ -234,7 +235,7 @@ async def format_image_for_model(session: aiohttp.ClientSession, base64_data: st
234
 
235
  if image_format == "url":
236
  # 需要上传图片到 imgbb 并返回 URL
237
- image_url = await upload_image_to_imgbb(session, base64_data)
238
  if image_url:
239
  return image_url
240
  else:
@@ -379,7 +380,7 @@ def format_files_for_prompt(files: List[Dict[str, str]]) -> str:
379
 
380
  return "\n".join(file_sections)
381
 
382
- async def transform_openai_to_replicate(session: aiohttp.ClientSession, openai_request: Dict[str, Any], model_override: str = None) -> Dict[str, Any]:
383
  """将OpenAI格式的请求转换为Replicate格式"""
384
  try:
385
  messages = openai_request.get("messages", [])
@@ -449,7 +450,7 @@ async def transform_openai_to_replicate(session: aiohttp.ClientSession, openai_r
449
  formatted_image = None
450
  if has_images and primary_image:
451
  logger.info(f"Processing image for model {model} with format {model_config.get('image_format')}")
452
- formatted_image = await format_image_for_model(session, primary_image, model_config)
453
 
454
  if not formatted_image:
455
  logger.error("Failed to format image for model")
@@ -652,7 +653,7 @@ async def root():
652
  "status": "running",
653
  "replicate_token_configured": bool(REPLICATE_API_TOKEN),
654
  "imgbb_token_configured": bool(IMGBB_API_KEY),
655
- "version": "1.2.0",
656
  "supported_models": list(MODEL_CONFIGS.keys()),
657
  "vision_support": True,
658
  "file_support": True,
@@ -703,13 +704,14 @@ async def chat_completions(request: Request):
703
  logger.info(f"Client parameters: max_tokens={body.get('max_tokens', 'not set')}, temperature={body.get('temperature', 'not set')}")
704
  logger.info(f"Message count: {len(body.get('messages', []))}")
705
 
706
- async with aiohttp.ClientSession() as session:
707
- # 转换请求格式
708
- replicate_data, model = await transform_openai_to_replicate(session, body)
709
-
710
- if body.get("stream", False):
711
- # 流式响应
712
- async def generate_stream():
 
713
  try:
714
  # 创建预测
715
  prediction = await create_replicate_prediction(session, model, replicate_data)
@@ -786,6 +788,7 @@ async def chat_completions(request: Request):
786
  yield "data: [DONE]\n\n"
787
  except Exception as e:
788
  logger.error(f"Stream generation error: {e}")
 
789
  error_response = {
790
  "error": {
791
  "message": str(e),
@@ -793,20 +796,21 @@ async def chat_completions(request: Request):
793
  }
794
  }
795
  yield f"data: {json.dumps(error_response)}\n\n"
796
-
797
- return StreamingResponse(
798
- generate_stream(),
799
- media_type="text/event-stream",
800
- headers={
801
- "Cache-Control": "no-cache",
802
- "Connection": "keep-alive",
803
- "Access-Control-Allow-Origin": "*",
804
- "X-Accel-Buffering": "no",
805
- }
806
- )
807
 
808
- else:
809
- # 非流式响应
 
 
 
 
 
 
 
 
 
 
 
 
810
  # 创建预测
811
  prediction = await create_replicate_prediction(session, model, replicate_data)
812
  prediction_id = prediction.get('id')
 
182
  logger.error(f"Failed to parse data URL: {e}")
183
  return None, None, None
184
 
185
+ async def upload_image_to_imgbb(base64_data: str) -> str:
186
  """
187
  将 base64 图片上传到 imgbb
188
  返回图片的 URL
 
203
 
204
  logger.info(f"Uploading image to imgbb, size: {len(base64_content)} chars")
205
 
206
+ # 使用独立的 session 上传到 imgbb
207
+ async with aiohttp.ClientSession() as session:
208
+ async with session.post(IMGBB_API_URL, data=data, timeout=30) as response:
209
+ if response.status == 200:
210
+ result = await response.json()
211
+ if result.get('success'):
212
+ image_url = result['data']['url']
213
+ logger.info(f"Image uploaded successfully: {image_url}")
214
+ return image_url
215
+ else:
216
+ logger.error(f"imgbb upload failed: {result}")
217
+ return None
218
  else:
219
+ error_text = await response.text()
220
+ logger.error(f"imgbb upload error: {response.status} - {error_text}")
221
  return None
 
 
 
 
222
 
223
  except asyncio.TimeoutError:
224
  logger.error("Timeout uploading image to imgbb")
 
227
  logger.error(f"Failed to upload image to imgbb: {e}")
228
  return None
229
 
230
+ async def format_image_for_model(base64_data: str, model_config: Dict[str, Any]) -> str:
231
  """
232
  根据模型配置格式化图片数据
233
  """
 
235
 
236
  if image_format == "url":
237
  # 需要上传图片到 imgbb 并返回 URL
238
+ image_url = await upload_image_to_imgbb(base64_data)
239
  if image_url:
240
  return image_url
241
  else:
 
380
 
381
  return "\n".join(file_sections)
382
 
383
+ async def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override: str = None) -> Dict[str, Any]:
384
  """将OpenAI格式的请求转换为Replicate格式"""
385
  try:
386
  messages = openai_request.get("messages", [])
 
450
  formatted_image = None
451
  if has_images and primary_image:
452
  logger.info(f"Processing image for model {model} with format {model_config.get('image_format')}")
453
+ formatted_image = await format_image_for_model(primary_image, model_config)
454
 
455
  if not formatted_image:
456
  logger.error("Failed to format image for model")
 
653
  "status": "running",
654
  "replicate_token_configured": bool(REPLICATE_API_TOKEN),
655
  "imgbb_token_configured": bool(IMGBB_API_KEY),
656
+ "version": "1.2.1",
657
  "supported_models": list(MODEL_CONFIGS.keys()),
658
  "vision_support": True,
659
  "file_support": True,
 
704
  logger.info(f"Client parameters: max_tokens={body.get('max_tokens', 'not set')}, temperature={body.get('temperature', 'not set')}")
705
  logger.info(f"Message count: {len(body.get('messages', []))}")
706
 
707
+ # 转换请求格式(不依赖 session
708
+ replicate_data, model = await transform_openai_to_replicate(body)
709
+
710
+ if body.get("stream", False):
711
+ # 流式响应
712
+ async def generate_stream():
713
+ # 在生成器内部创建独立的 session
714
+ async with aiohttp.ClientSession() as session:
715
  try:
716
  # 创建预测
717
  prediction = await create_replicate_prediction(session, model, replicate_data)
 
788
  yield "data: [DONE]\n\n"
789
  except Exception as e:
790
  logger.error(f"Stream generation error: {e}")
791
+ logger.error(f"Traceback: {traceback.format_exc()}")
792
  error_response = {
793
  "error": {
794
  "message": str(e),
 
796
  }
797
  }
798
  yield f"data: {json.dumps(error_response)}\n\n"
 
 
 
 
 
 
 
 
 
 
 
799
 
800
+ return StreamingResponse(
801
+ generate_stream(),
802
+ media_type="text/event-stream",
803
+ headers={
804
+ "Cache-Control": "no-cache",
805
+ "Connection": "keep-alive",
806
+ "Access-Control-Allow-Origin": "*",
807
+ "X-Accel-Buffering": "no",
808
+ }
809
+ )
810
+
811
+ else:
812
+ # 非流式响应
813
+ async with aiohttp.ClientSession() as session:
814
  # 创建预测
815
  prediction = await create_replicate_prediction(session, model, replicate_data)
816
  prediction_id = prediction.get('id')