Replicate / app.py
nomid2's picture
Update app.py
fcdaffb verified
raw
history blame
16.3 kB
import os
import json
import asyncio
import aiohttp
import traceback
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from typing import Dict, Any, AsyncGenerator
import logging
# 配置更详细的日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Replicate API Proxy for LobeChat",
description="A proxy service to forward Replicate API requests in OpenAI-compatible format",
version="1.0.0"
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 从环境变量获取配置
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
if not REPLICATE_API_TOKEN:
logger.error("REPLICATE_API_TOKEN not found in environment variables")
# Replicate API配置
REPLICATE_BASE_URL = "https://api.replicate.com/v1"
DEFAULT_MODEL = "anthropic/claude-3.5-sonnet"
# 全局异常处理器
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
logger.error(f"Global exception: {str(exc)}")
logger.error(f"Traceback: {traceback.format_exc()}")
return JSONResponse(
status_code=500,
content={
"error": {
"message": f"Internal server error: {str(exc)}",
"type": "internal_error"
}
}
)
def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override: str = None) -> Dict[str, Any]:
"""将OpenAI格式的请求转换为Replicate格式"""
try:
messages = openai_request.get("messages", [])
# 提取system prompt
system_prompt = "You are a helpful assistant"
user_messages = []
for message in messages:
if message.get("role") == "system":
system_prompt = message.get("content", "You are a helpful assistant")
elif message.get("role") in ["user", "assistant"]:
user_messages.append(message)
# 构建prompt
prompt_parts = []
for msg in user_messages:
role = msg.get("role", "")
content = msg.get("content", "")
if role == "user":
prompt_parts.append(f"Human: {content}")
elif role == "assistant":
prompt_parts.append(f"Assistant: {content}")
prompt = "\n\n".join(prompt_parts)
if prompt_parts and not prompt.endswith("\n\nAssistant:"):
prompt += "\n\nAssistant:"
# 确定使用的模型
model = model_override or openai_request.get("model", DEFAULT_MODEL)
# 正确的模型名称映射
model_mapping = {
"claude-4-sonnet": "anthropic/claude-4-sonnet",
"claude-3.5-sonnet": "anthropic/claude-3.5-sonnet",
"claude-3-sonnet": "anthropic/claude-3-sonnet",
"claude-3.5-haiku": "anthropic/claude-3.5-haiku",
"claude-3-haiku": "anthropic/claude-3-haiku",
}
if model in model_mapping:
model = model_mapping[model]
elif not model.startswith("anthropic/"):
model = "anthropic/claude-3.5-sonnet"
replicate_request = {
"stream": openai_request.get("stream", False),
"input": {
"prompt": prompt,
"system_prompt": system_prompt,
"max_tokens": openai_request.get("max_tokens", 4000),
"temperature": openai_request.get("temperature", 0.7)
}
}
logger.info(f"Transformed request for model: {model}")
return replicate_request, model
except Exception as e:
logger.error(f"Error transforming request: {str(e)}")
raise HTTPException(status_code=400, detail=f"Request transformation error: {str(e)}")
async def create_replicate_prediction(session: aiohttp.ClientSession, model: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""创建Replicate预测"""
try:
url = f"{REPLICATE_BASE_URL}/models/{model}/predictions"
headers = {
"Authorization": f"Bearer {REPLICATE_API_TOKEN}",
"Content-Type": "application/json"
}
logger.info(f"Creating prediction for model: {model}")
async with session.post(url, headers=headers, json=data, timeout=30) as response:
response_text = await response.text()
logger.info(f"Replicate response status: {response.status}")
if response.status != 201:
logger.error(f"Replicate API error: {response.status} - {response_text}")
raise HTTPException(
status_code=response.status,
detail=f"Replicate API error: {response_text}"
)
return json.loads(response_text)
except asyncio.TimeoutError:
logger.error("Timeout creating Replicate prediction")
raise HTTPException(status_code=504, detail="Timeout creating prediction")
except Exception as e:
logger.error(f"Error creating prediction: {str(e)}")
raise HTTPException(status_code=500, detail=f"Prediction creation error: {str(e)}")
def transform_replicate_to_openai_stream(event_data: str, model: str) -> str:
"""将Replicate流式响应转换为OpenAI格式"""
if not event_data.startswith("data: "):
return ""
try:
data = json.loads(event_data[6:]) # 移除 "data: " 前缀
if data.get("event") == "output":
# 构建OpenAI格式的响应
openai_response = {
"id": f"chatcmpl-{data.get('id', 'unknown')}",
"object": "chat.completion.chunk",
"created": int(asyncio.get_event_loop().time()),
"model": model,
"choices": [{
"index": 0,
"delta": {
"content": data.get("data", "")
},
"finish_reason": None
}]
}
return f"data: {json.dumps(openai_response)}\n\n"
elif data.get("event") == "done":
# 发送结束标记
openai_response = {
"id": f"chatcmpl-{data.get('id', 'unknown')}",
"object": "chat.completion.chunk",
"created": int(asyncio.get_event_loop().time()),
"model": model,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
return f"data: {json.dumps(openai_response)}\n\ndata: [DONE]\n\n"
return ""
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse event data: {event_data}, error: {e}")
return ""
@app.get("/")
async def root():
"""健康检查端点"""
return {
"message": "Replicate API Proxy for LobeChat",
"status": "running",
"replicate_token_configured": bool(REPLICATE_API_TOKEN),
"version": "1.0.0"
}
@app.get("/health")
async def health():
"""详细健康检查"""
return {
"status": "healthy",
"replicate_token": "configured" if REPLICATE_API_TOKEN else "missing",
"timestamp": asyncio.get_event_loop().time()
}
@app.get("/v1/models")
async def list_models():
"""列出可用模型(兼容OpenAI API)"""
models = [
{
"id": "claude-4-sonnet",
"object": "model",
"created": 1677610602,
"owned_by": "anthropic"
},
{
"id": "claude-3.5-sonnet",
"object": "model",
"created": 1677610602,
"owned_by": "anthropic"
},
{
"id": "claude-3.5-haiku",
"object": "model",
"created": 1677610602,
"owned_by": "anthropic"
},
{
"id": "claude-3-sonnet",
"object": "model",
"created": 1677610602,
"owned_by": "anthropic"
},
{
"id": "claude-3-haiku",
"object": "model",
"created": 1677610602,
"owned_by": "anthropic"
}
]
return {"object": "list", "data": models}
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
"""处理聊天完成请求(兼容OpenAI API)"""
if not REPLICATE_API_TOKEN:
logger.error("REPLICATE_API_TOKEN not configured")
raise HTTPException(status_code=500, detail="REPLICATE_API_TOKEN not configured")
try:
body = await request.json()
logger.info(f"Received chat completion request")
# 转换请求格式
replicate_data, model = transform_openai_to_replicate(body)
if body.get("stream", False):
# 流式响应 - 修复会话管理问题
async def generate_stream():
async with aiohttp.ClientSession() as session:
try:
# 创建预测
prediction = await create_replicate_prediction(session, model, replicate_data)
prediction_id = prediction.get('id')
logger.info(f"Created prediction: {prediction_id}")
# 获取流式URL
stream_url = prediction.get("urls", {}).get("stream")
if not stream_url:
error_response = {
"error": {
"message": "Stream URL not available",
"type": "stream_error"
}
}
yield f"data: {json.dumps(error_response)}\n\n"
return
logger.info(f"Starting stream from: {stream_url}")
# 流式读取响应
headers = {
"Accept": "text/event-stream",
"Cache-Control": "no-store"
}
async with session.get(stream_url, headers=headers, timeout=300) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Stream error: {response.status} - {error_text}")
error_response = {
"error": {
"message": f"Stream error: {error_text}",
"type": "stream_error"
}
}
yield f"data: {json.dumps(error_response)}\n\n"
return
async for line in response.content:
line = line.decode('utf-8').strip()
if line:
openai_event = transform_replicate_to_openai_stream(line, model)
if openai_event:
yield openai_event
except Exception as e:
logger.error(f"Stream generation error: {e}")
error_response = {
"error": {
"message": str(e),
"type": "stream_error"
}
}
yield f"data: {json.dumps(error_response)}\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
}
)
else:
# 非流式响应
async with aiohttp.ClientSession() as session:
# 创建预测
prediction = await create_replicate_prediction(session, model, replicate_data)
prediction_id = prediction.get('id')
logger.info(f"Created prediction: {prediction_id}")
# 轮询等待结果
prediction_url = f"{REPLICATE_BASE_URL}/predictions/{prediction_id}"
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}"}
max_attempts = 60 # 最多等待60秒
attempt = 0
while attempt < max_attempts:
async with session.get(prediction_url, headers=headers) as response:
result = await response.json()
status = result.get("status")
logger.info(f"Prediction {prediction_id} status: {status}")
if status == "succeeded":
output = result.get("output", [])
content = "".join(output) if isinstance(output, list) else str(output)
openai_response = {
"id": f"chatcmpl-{prediction_id}",
"object": "chat.completion",
"created": int(asyncio.get_event_loop().time()),
"model": model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": content
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 0,
"completion_tokens": len(content.split()),
"total_tokens": len(content.split())
}
}
return openai_response
elif status == "failed":
error_msg = result.get('error', 'Unknown error')
logger.error(f"Prediction failed: {error_msg}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {error_msg}")
elif status in ["canceled", "cancelled"]:
raise HTTPException(status_code=500, detail="Prediction was canceled")
# 等待一秒后重试
await asyncio.sleep(1)
attempt += 1
raise HTTPException(status_code=504, detail="Prediction timeout")
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error processing request: {str(e)}")
logger.error(f"Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
if __name__ == "__main__":
port = int(os.getenv("PORT", 7860))
logger.info(f"Starting server on port {port}")
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")