| import os |
| import json |
| import asyncio |
| import aiohttp |
| import traceback |
| from fastapi import FastAPI, Request, HTTPException |
| from fastapi.responses import StreamingResponse, JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| import uvicorn |
| from typing import Dict, Any, AsyncGenerator |
| import logging |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| app = FastAPI( |
| title="Replicate API Proxy for LobeChat", |
| description="A proxy service to forward Replicate API requests in OpenAI-compatible format", |
| version="1.0.0" |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN") |
| if not REPLICATE_API_TOKEN: |
| logger.error("REPLICATE_API_TOKEN not found in environment variables") |
|
|
| |
| REPLICATE_BASE_URL = "https://api.replicate.com/v1" |
| DEFAULT_MODEL = "anthropic/claude-3.5-sonnet" |
|
|
| |
| @app.exception_handler(Exception) |
| async def global_exception_handler(request: Request, exc: Exception): |
| logger.error(f"Global exception: {str(exc)}") |
| logger.error(f"Traceback: {traceback.format_exc()}") |
| return JSONResponse( |
| status_code=500, |
| content={ |
| "error": { |
| "message": f"Internal server error: {str(exc)}", |
| "type": "internal_error" |
| } |
| } |
| ) |
|
|
| def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override: str = None) -> Dict[str, Any]: |
| """将OpenAI格式的请求转换为Replicate格式""" |
| try: |
| messages = openai_request.get("messages", []) |
| |
| |
| system_prompt = "You are a helpful assistant" |
| user_messages = [] |
| |
| for message in messages: |
| if message.get("role") == "system": |
| system_prompt = message.get("content", "You are a helpful assistant") |
| elif message.get("role") in ["user", "assistant"]: |
| user_messages.append(message) |
| |
| |
| prompt_parts = [] |
| for msg in user_messages: |
| role = msg.get("role", "") |
| content = msg.get("content", "") |
| if role == "user": |
| prompt_parts.append(f"Human: {content}") |
| elif role == "assistant": |
| prompt_parts.append(f"Assistant: {content}") |
| |
| prompt = "\n\n".join(prompt_parts) |
| if prompt_parts and not prompt.endswith("\n\nAssistant:"): |
| prompt += "\n\nAssistant:" |
| |
| |
| model = model_override or openai_request.get("model", DEFAULT_MODEL) |
| |
| |
| model_mapping = { |
| "claude-4-sonnet": "anthropic/claude-4-sonnet", |
| "claude-3.5-sonnet": "anthropic/claude-3.5-sonnet", |
| "claude-3-sonnet": "anthropic/claude-3-sonnet", |
| "claude-3.5-haiku": "anthropic/claude-3.5-haiku", |
| "claude-3-haiku": "anthropic/claude-3-haiku", |
| } |
| |
| if model in model_mapping: |
| model = model_mapping[model] |
| elif not model.startswith("anthropic/"): |
| |
| model = "anthropic/claude-3.5-sonnet" |
| |
| replicate_request = { |
| "stream": openai_request.get("stream", False), |
| "input": { |
| "prompt": prompt, |
| "system_prompt": system_prompt, |
| "max_tokens": openai_request.get("max_tokens", 4000), |
| "temperature": openai_request.get("temperature", 0.7) |
| } |
| } |
| |
| logger.info(f"Transformed request for model: {model}") |
| return replicate_request, model |
| |
| except Exception as e: |
| logger.error(f"Error transforming request: {str(e)}") |
| raise HTTPException(status_code=400, detail=f"Request transformation error: {str(e)}") |
|
|
| async def create_replicate_prediction(session: aiohttp.ClientSession, model: str, data: Dict[str, Any]) -> Dict[str, Any]: |
| """创建Replicate预测""" |
| try: |
| url = f"{REPLICATE_BASE_URL}/models/{model}/predictions" |
| headers = { |
| "Authorization": f"Bearer {REPLICATE_API_TOKEN}", |
| "Content-Type": "application/json" |
| } |
| |
| logger.info(f"Creating prediction for model: {model}") |
| logger.info(f"Request URL: {url}") |
| |
| async with session.post(url, headers=headers, json=data, timeout=30) as response: |
| response_text = await response.text() |
| logger.info(f"Replicate response status: {response.status}") |
| logger.info(f"Replicate response: {response_text}") |
| |
| if response.status != 201: |
| logger.error(f"Replicate API error: {response.status} - {response_text}") |
| raise HTTPException( |
| status_code=response.status, |
| detail=f"Replicate API error: {response_text}" |
| ) |
| |
| return json.loads(response_text) |
| |
| except asyncio.TimeoutError: |
| logger.error("Timeout creating Replicate prediction") |
| raise HTTPException(status_code=504, detail="Timeout creating prediction") |
| except Exception as e: |
| logger.error(f"Error creating prediction: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Prediction creation error: {str(e)}") |
|
|
| async def stream_replicate_response(session: aiohttp.ClientSession, stream_url: str) -> AsyncGenerator[str, None]: |
| """流式读取Replicate响应""" |
| try: |
| headers = { |
| "Accept": "text/event-stream", |
| "Cache-Control": "no-store" |
| } |
| |
| logger.info(f"Starting stream from: {stream_url}") |
| |
| async with session.get(stream_url, headers=headers, timeout=300) as response: |
| if response.status != 200: |
| error_text = await response.text() |
| logger.error(f"Stream error: {response.status} - {error_text}") |
| raise HTTPException(status_code=response.status, detail=f"Stream error: {error_text}") |
| |
| async for line in response.content: |
| line = line.decode('utf-8').strip() |
| if line: |
| yield line |
| |
| except Exception as e: |
| logger.error(f"Stream error: {str(e)}") |
| raise |
|
|
| def transform_replicate_to_openai_stream(event_data: str, model: str) -> str: |
| """将Replicate流式响应转换为OpenAI格式""" |
| if not event_data.startswith("data: "): |
| return "" |
| |
| try: |
| data = json.loads(event_data[6:]) |
| |
| if data.get("event") == "output": |
| |
| openai_response = { |
| "id": f"chatcmpl-{data.get('id', 'unknown')}", |
| "object": "chat.completion.chunk", |
| "created": int(asyncio.get_event_loop().time()), |
| "model": model, |
| "choices": [{ |
| "index": 0, |
| "delta": { |
| "content": data.get("data", "") |
| }, |
| "finish_reason": None |
| }] |
| } |
| return f"data: {json.dumps(openai_response)}\n\n" |
| |
| elif data.get("event") == "done": |
| |
| openai_response = { |
| "id": f"chatcmpl-{data.get('id', 'unknown')}", |
| "object": "chat.completion.chunk", |
| "created": int(asyncio.get_event_loop().time()), |
| "model": model, |
| "choices": [{ |
| "index": 0, |
| "delta": {}, |
| "finish_reason": "stop" |
| }] |
| } |
| return f"data: {json.dumps(openai_response)}\n\ndata: [DONE]\n\n" |
| |
| return "" |
| |
| except json.JSONDecodeError as e: |
| logger.warning(f"Failed to parse event data: {event_data}, error: {e}") |
| return "" |
|
|
| @app.get("/") |
| async def root(): |
| """健康检查端点""" |
| return { |
| "message": "Replicate API Proxy for LobeChat", |
| "status": "running", |
| "replicate_token_configured": bool(REPLICATE_API_TOKEN), |
| "version": "1.0.0" |
| } |
|
|
| @app.get("/health") |
| async def health(): |
| """详细健康检查""" |
| return { |
| "status": "healthy", |
| "replicate_token": "configured" if REPLICATE_API_TOKEN else "missing", |
| "timestamp": asyncio.get_event_loop().time() |
| } |
|
|
| @app.get("/v1/models") |
| async def list_models(): |
| """列出可用模型(兼容OpenAI API)""" |
| models = [ |
| { |
| "id": "claude-4-sonnet", |
| "object": "model", |
| "created": 1677610602, |
| "owned_by": "anthropic" |
| }, |
| { |
| "id": "claude-3.5-sonnet", |
| "object": "model", |
| "created": 1677610602, |
| "owned_by": "anthropic" |
| }, |
| { |
| "id": "claude-3.5-haiku", |
| "object": "model", |
| "created": 1677610602, |
| "owned_by": "anthropic" |
| }, |
| { |
| "id": "claude-3-sonnet", |
| "object": "model", |
| "created": 1677610602, |
| "owned_by": "anthropic" |
| }, |
| { |
| "id": "claude-3-haiku", |
| "object": "model", |
| "created": 1677610602, |
| "owned_by": "anthropic" |
| } |
| ] |
| return {"object": "list", "data": models} |
|
|
| @app.post("/v1/chat/completions") |
| async def chat_completions(request: Request): |
| """处理聊天完成请求(兼容OpenAI API)""" |
| if not REPLICATE_API_TOKEN: |
| logger.error("REPLICATE_API_TOKEN not configured") |
| raise HTTPException(status_code=500, detail="REPLICATE_API_TOKEN not configured") |
| |
| try: |
| body = await request.json() |
| logger.info(f"Received chat completion request") |
| logger.info(f"Request body: {json.dumps(body, indent=2)}") |
| |
| |
| replicate_data, model = transform_openai_to_replicate(body) |
| |
| async with aiohttp.ClientSession() as session: |
| |
| prediction = await create_replicate_prediction(session, model, replicate_data) |
| prediction_id = prediction.get('id') |
| logger.info(f"Created prediction: {prediction_id}") |
| |
| if body.get("stream", False): |
| |
| stream_url = prediction.get("urls", {}).get("stream") |
| if not stream_url: |
| raise HTTPException(status_code=500, detail="Stream URL not available") |
| |
| async def generate_stream(): |
| try: |
| async for event in stream_replicate_response(session, stream_url): |
| openai_event = transform_replicate_to_openai_stream(event, model) |
| if openai_event: |
| yield openai_event |
| except Exception as e: |
| logger.error(f"Stream generation error: {e}") |
| |
| error_response = { |
| "error": { |
| "message": str(e), |
| "type": "stream_error" |
| } |
| } |
| yield f"data: {json.dumps(error_response)}\n\n" |
| |
| return StreamingResponse( |
| generate_stream(), |
| media_type="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "Access-Control-Allow-Origin": "*", |
| } |
| ) |
| |
| else: |
| |
| prediction_url = f"{REPLICATE_BASE_URL}/predictions/{prediction_id}" |
| headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}"} |
| |
| |
| max_attempts = 60 |
| attempt = 0 |
| |
| while attempt < max_attempts: |
| async with session.get(prediction_url, headers=headers) as response: |
| result = await response.json() |
| status = result.get("status") |
| |
| logger.info(f"Prediction {prediction_id} status: {status}") |
| |
| if status == "succeeded": |
| output = result.get("output", []) |
| content = "".join(output) if isinstance(output, list) else str(output) |
| |
| openai_response = { |
| "id": f"chatcmpl-{prediction_id}", |
| "object": "chat.completion", |
| "created": int(asyncio.get_event_loop().time()), |
| "model": model, |
| "choices": [{ |
| "index": 0, |
| "message": { |
| "role": "assistant", |
| "content": content |
| }, |
| "finish_reason": "stop" |
| }], |
| "usage": { |
| "prompt_tokens": 0, |
| "completion_tokens": len(content.split()), |
| "total_tokens": len(content.split()) |
| } |
| } |
| return openai_response |
| |
| elif status == "failed": |
| error_msg = result.get('error', 'Unknown error') |
| logger.error(f"Prediction failed: {error_msg}") |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {error_msg}") |
| |
| elif status in ["canceled", "cancelled"]: |
| raise HTTPException(status_code=500, detail="Prediction was canceled") |
| |
| |
| await asyncio.sleep(1) |
| attempt += 1 |
| |
| raise HTTPException(status_code=504, detail="Prediction timeout") |
| |
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"Unexpected error processing request: {str(e)}") |
| logger.error(f"Traceback: {traceback.format_exc()}") |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
| if __name__ == "__main__": |
| port = int(os.getenv("PORT", 7860)) |
| logger.info(f"Starting server on port {port}") |
| uvicorn.run(app, host="0.0.0.0", port=port, log_level="info") |