Replicate / app.py
nomid2's picture
Update app.py
e4c9bed verified
raw
history blame
15.8 kB
import os
import json
import asyncio
import aiohttp
import traceback
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from typing import Dict, Any, AsyncGenerator
import logging
# 配置更详细的日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Replicate API Proxy for LobeChat",
description="A proxy service to forward Replicate API requests in OpenAI-compatible format",
version="1.0.0"
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 从环境变量获取配置
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
if not REPLICATE_API_TOKEN:
logger.error("REPLICATE_API_TOKEN not found in environment variables")
# Replicate API配置
REPLICATE_BASE_URL = "https://api.replicate.com/v1"
DEFAULT_MODEL = "anthropic/claude-3.5-sonnet" # 使用实际存在的模型
# 全局异常处理器
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
logger.error(f"Global exception: {str(exc)}")
logger.error(f"Traceback: {traceback.format_exc()}")
return JSONResponse(
status_code=500,
content={
"error": {
"message": f"Internal server error: {str(exc)}",
"type": "internal_error"
}
}
)
def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override: str = None) -> Dict[str, Any]:
"""将OpenAI格式的请求转换为Replicate格式"""
try:
messages = openai_request.get("messages", [])
# 提取system prompt
system_prompt = "You are a helpful assistant"
user_messages = []
for message in messages:
if message.get("role") == "system":
system_prompt = message.get("content", "You are a helpful assistant")
elif message.get("role") in ["user", "assistant"]:
user_messages.append(message)
# 构建prompt
prompt_parts = []
for msg in user_messages:
role = msg.get("role", "")
content = msg.get("content", "")
if role == "user":
prompt_parts.append(f"Human: {content}")
elif role == "assistant":
prompt_parts.append(f"Assistant: {content}")
prompt = "\n\n".join(prompt_parts)
if prompt_parts and not prompt.endswith("\n\nAssistant:"):
prompt += "\n\nAssistant:"
# 确定使用的模型 - 使用正确的 Replicate 模型名称
model = model_override or openai_request.get("model", DEFAULT_MODEL)
# 正确的模型名称映射(基于搜索结果)
model_mapping = {
"claude-4-sonnet": "anthropic/claude-4-sonnet", # 最新的 Claude 4
"claude-3.5-sonnet": "anthropic/claude-3.5-sonnet", # Claude 3.5 Sonnet
"claude-3-sonnet": "anthropic/claude-3-sonnet", # Claude 3 Sonnet
"claude-3.5-haiku": "anthropic/claude-3.5-haiku", # Claude 3.5 Haiku
"claude-3-haiku": "anthropic/claude-3-haiku", # Claude 3 Haiku
}
if model in model_mapping:
model = model_mapping[model]
elif not model.startswith("anthropic/"):
# 默认使用 claude-3.5-sonnet
model = "anthropic/claude-3.5-sonnet"
replicate_request = {
"stream": openai_request.get("stream", False),
"input": {
"prompt": prompt,
"system_prompt": system_prompt,
"max_tokens": openai_request.get("max_tokens", 4000),
"temperature": openai_request.get("temperature", 0.7)
}
}
logger.info(f"Transformed request for model: {model}")
return replicate_request, model
except Exception as e:
logger.error(f"Error transforming request: {str(e)}")
raise HTTPException(status_code=400, detail=f"Request transformation error: {str(e)}")
async def create_replicate_prediction(session: aiohttp.ClientSession, model: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""创建Replicate预测"""
try:
url = f"{REPLICATE_BASE_URL}/models/{model}/predictions"
headers = {
"Authorization": f"Bearer {REPLICATE_API_TOKEN}",
"Content-Type": "application/json"
}
logger.info(f"Creating prediction for model: {model}")
logger.info(f"Request URL: {url}")
async with session.post(url, headers=headers, json=data, timeout=30) as response:
response_text = await response.text()
logger.info(f"Replicate response status: {response.status}")
logger.info(f"Replicate response: {response_text}")
if response.status != 201:
logger.error(f"Replicate API error: {response.status} - {response_text}")
raise HTTPException(
status_code=response.status,
detail=f"Replicate API error: {response_text}"
)
return json.loads(response_text)
except asyncio.TimeoutError:
logger.error("Timeout creating Replicate prediction")
raise HTTPException(status_code=504, detail="Timeout creating prediction")
except Exception as e:
logger.error(f"Error creating prediction: {str(e)}")
raise HTTPException(status_code=500, detail=f"Prediction creation error: {str(e)}")
async def stream_replicate_response(session: aiohttp.ClientSession, stream_url: str) -> AsyncGenerator[str, None]:
"""流式读取Replicate响应"""
try:
headers = {
"Accept": "text/event-stream",
"Cache-Control": "no-store"
}
logger.info(f"Starting stream from: {stream_url}")
async with session.get(stream_url, headers=headers, timeout=300) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Stream error: {response.status} - {error_text}")
raise HTTPException(status_code=response.status, detail=f"Stream error: {error_text}")
async for line in response.content:
line = line.decode('utf-8').strip()
if line:
yield line
except Exception as e:
logger.error(f"Stream error: {str(e)}")
raise
def transform_replicate_to_openai_stream(event_data: str, model: str) -> str:
"""将Replicate流式响应转换为OpenAI格式"""
if not event_data.startswith("data: "):
return ""
try:
data = json.loads(event_data[6:]) # 移除 "data: " 前缀
if data.get("event") == "output":
# 构建OpenAI格式的响应
openai_response = {
"id": f"chatcmpl-{data.get('id', 'unknown')}",
"object": "chat.completion.chunk",
"created": int(asyncio.get_event_loop().time()),
"model": model,
"choices": [{
"index": 0,
"delta": {
"content": data.get("data", "")
},
"finish_reason": None
}]
}
return f"data: {json.dumps(openai_response)}\n\n"
elif data.get("event") == "done":
# 发送结束标记
openai_response = {
"id": f"chatcmpl-{data.get('id', 'unknown')}",
"object": "chat.completion.chunk",
"created": int(asyncio.get_event_loop().time()),
"model": model,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
}
return f"data: {json.dumps(openai_response)}\n\ndata: [DONE]\n\n"
return ""
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse event data: {event_data}, error: {e}")
return ""
@app.get("/")
async def root():
"""健康检查端点"""
return {
"message": "Replicate API Proxy for LobeChat",
"status": "running",
"replicate_token_configured": bool(REPLICATE_API_TOKEN),
"version": "1.0.0"
}
@app.get("/health")
async def health():
"""详细健康检查"""
return {
"status": "healthy",
"replicate_token": "configured" if REPLICATE_API_TOKEN else "missing",
"timestamp": asyncio.get_event_loop().time()
}
@app.get("/v1/models")
async def list_models():
"""列出可用模型(兼容OpenAI API)"""
models = [
{
"id": "claude-4-sonnet",
"object": "model",
"created": 1677610602,
"owned_by": "anthropic"
},
{
"id": "claude-3.5-sonnet",
"object": "model",
"created": 1677610602,
"owned_by": "anthropic"
},
{
"id": "claude-3.5-haiku",
"object": "model",
"created": 1677610602,
"owned_by": "anthropic"
},
{
"id": "claude-3-sonnet",
"object": "model",
"created": 1677610602,
"owned_by": "anthropic"
},
{
"id": "claude-3-haiku",
"object": "model",
"created": 1677610602,
"owned_by": "anthropic"
}
]
return {"object": "list", "data": models}
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
"""处理聊天完成请求(兼容OpenAI API)"""
if not REPLICATE_API_TOKEN:
logger.error("REPLICATE_API_TOKEN not configured")
raise HTTPException(status_code=500, detail="REPLICATE_API_TOKEN not configured")
try:
body = await request.json()
logger.info(f"Received chat completion request")
logger.info(f"Request body: {json.dumps(body, indent=2)}")
# 转换请求格式
replicate_data, model = transform_openai_to_replicate(body)
async with aiohttp.ClientSession() as session:
# 创建预测
prediction = await create_replicate_prediction(session, model, replicate_data)
prediction_id = prediction.get('id')
logger.info(f"Created prediction: {prediction_id}")
if body.get("stream", False):
# 流式响应
stream_url = prediction.get("urls", {}).get("stream")
if not stream_url:
raise HTTPException(status_code=500, detail="Stream URL not available")
async def generate_stream():
try:
async for event in stream_replicate_response(session, stream_url):
openai_event = transform_replicate_to_openai_stream(event, model)
if openai_event:
yield openai_event
except Exception as e:
logger.error(f"Stream generation error: {e}")
# 发送错误响应
error_response = {
"error": {
"message": str(e),
"type": "stream_error"
}
}
yield f"data: {json.dumps(error_response)}\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
}
)
else:
# 非流式响应 - 等待预测完成
prediction_url = f"{REPLICATE_BASE_URL}/predictions/{prediction_id}"
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}"}
# 轮询等待结果
max_attempts = 60 # 最多等待60秒
attempt = 0
while attempt < max_attempts:
async with session.get(prediction_url, headers=headers) as response:
result = await response.json()
status = result.get("status")
logger.info(f"Prediction {prediction_id} status: {status}")
if status == "succeeded":
output = result.get("output", [])
content = "".join(output) if isinstance(output, list) else str(output)
openai_response = {
"id": f"chatcmpl-{prediction_id}",
"object": "chat.completion",
"created": int(asyncio.get_event_loop().time()),
"model": model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": content
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 0,
"completion_tokens": len(content.split()),
"total_tokens": len(content.split())
}
}
return openai_response
elif status == "failed":
error_msg = result.get('error', 'Unknown error')
logger.error(f"Prediction failed: {error_msg}")
raise HTTPException(status_code=500, detail=f"Prediction failed: {error_msg}")
elif status in ["canceled", "cancelled"]:
raise HTTPException(status_code=500, detail="Prediction was canceled")
# 等待一秒后重试
await asyncio.sleep(1)
attempt += 1
raise HTTPException(status_code=504, detail="Prediction timeout")
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error processing request: {str(e)}")
logger.error(f"Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
if __name__ == "__main__":
port = int(os.getenv("PORT", 7860))
logger.info(f"Starting server on port {port}")
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")