|
|
import os |
|
|
import json |
|
|
import asyncio |
|
|
import aiohttp |
|
|
import traceback |
|
|
from fastapi import FastAPI, Request, HTTPException |
|
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
import uvicorn |
|
|
from typing import Dict, Any, AsyncGenerator, List, Union |
|
|
import logging |
|
|
import base64 |
|
|
import mimetypes |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI( |
|
|
title="Replicate API Proxy for LobeChat", |
|
|
description="A proxy service to forward Replicate API requests in OpenAI-compatible format", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN") |
|
|
if not REPLICATE_API_TOKEN: |
|
|
logger.error("REPLICATE_API_TOKEN not found in environment variables") |
|
|
|
|
|
|
|
|
IMGBB_API_KEY = "78f0c4360135e80c46b24b44e1e20a20" |
|
|
IMGBB_API_URL = "https://api.imgbb.com/1/upload" |
|
|
|
|
|
|
|
|
REPLICATE_BASE_URL = "https://api.replicate.com/v1" |
|
|
DEFAULT_MODEL = "anthropic/claude-3.5-sonnet" |
|
|
|
|
|
|
|
|
SUPPORTED_TEXT_EXTENSIONS = { |
|
|
'.txt', '.md', '.py', '.js', '.ts', '.html', '.htm', '.css', '.json', |
|
|
'.xml', '.yaml', '.yml', '.sh', '.bash', '.zsh', '.fish', '.ps1', |
|
|
'.java', '.c', '.cpp', '.cc', '.cxx', '.h', '.hpp', '.cs', '.php', |
|
|
'.rb', '.go', '.rs', '.swift', '.kt', '.scala', '.r', '.sql', |
|
|
'.dockerfile', '.gitignore', '.gitattributes', '.env', '.ini', '.conf', |
|
|
'.log', '.csv', '.tsv', '.properties', '.toml', '.lock' |
|
|
} |
|
|
|
|
|
SUPPORTED_IMAGE_EXTENSIONS = { |
|
|
'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp', '.svg' |
|
|
} |
|
|
|
|
|
|
|
|
MODEL_CONFIGS = { |
|
|
"anthropic/claude-4-sonnet": { |
|
|
"min_max_tokens": 1024, |
|
|
"default_max_tokens": 8192, |
|
|
"has_max_tokens_limit": True, |
|
|
"supports_vision": True, |
|
|
"supports_files": True, |
|
|
"image_format": "url" |
|
|
}, |
|
|
"anthropic/claude-3.5-sonnet": { |
|
|
"min_max_tokens": 1, |
|
|
"default_max_tokens": 8192, |
|
|
"has_max_tokens_limit": False, |
|
|
"supports_vision": True, |
|
|
"supports_files": True, |
|
|
"image_format": "data_url" |
|
|
}, |
|
|
"anthropic/claude-3-sonnet": { |
|
|
"min_max_tokens": 1, |
|
|
"default_max_tokens": 4096, |
|
|
"has_max_tokens_limit": False, |
|
|
"supports_vision": True, |
|
|
"supports_files": True, |
|
|
"image_format": "data_url" |
|
|
}, |
|
|
"anthropic/claude-3.5-haiku": { |
|
|
"min_max_tokens": 1, |
|
|
"default_max_tokens": 4096, |
|
|
"has_max_tokens_limit": False, |
|
|
"supports_vision": True, |
|
|
"supports_files": True, |
|
|
"image_format": "data_url" |
|
|
}, |
|
|
"anthropic/claude-3-haiku": { |
|
|
"min_max_tokens": 1, |
|
|
"default_max_tokens": 4096, |
|
|
"has_max_tokens_limit": False, |
|
|
"supports_vision": True, |
|
|
"supports_files": True, |
|
|
"image_format": "data_url" |
|
|
}, |
|
|
"google/gemini-2.5-pro": { |
|
|
"min_max_tokens": 1, |
|
|
"default_max_tokens": 8192, |
|
|
"has_max_tokens_limit": False, |
|
|
"supports_vision": True, |
|
|
"supports_files": True, |
|
|
"image_format": "data_url" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@app.exception_handler(Exception) |
|
|
async def global_exception_handler(request: Request, exc: Exception): |
|
|
logger.error(f"Global exception: {str(exc)}") |
|
|
logger.error(f"Traceback: {traceback.format_exc()}") |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={ |
|
|
"error": { |
|
|
"message": f"Internal server error: {str(exc)}", |
|
|
"type": "internal_error" |
|
|
} |
|
|
} |
|
|
) |
|
|
|
|
|
def get_file_extension(filename: str) -> str: |
|
|
"""获取文件扩展名""" |
|
|
return os.path.splitext(filename.lower())[1] |
|
|
|
|
|
def decode_base64_file(data_url: str) -> tuple[str, str, str]: |
|
|
""" |
|
|
解码 base64 文件数据 |
|
|
返回: (mime_type, filename, content) |
|
|
""" |
|
|
try: |
|
|
if not data_url.startswith("data:"): |
|
|
return None, None, None |
|
|
|
|
|
|
|
|
header, base64_content = data_url.split(",", 1) |
|
|
header_parts = header.split(";") |
|
|
|
|
|
mime_type = header_parts[0].replace("data:", "") |
|
|
filename = "unknown_file" |
|
|
|
|
|
|
|
|
for part in header_parts: |
|
|
if part.startswith("name="): |
|
|
filename = part.replace("name=", "") |
|
|
break |
|
|
|
|
|
|
|
|
try: |
|
|
decoded_bytes = base64.b64decode(base64_content) |
|
|
|
|
|
|
|
|
if mime_type.startswith("text/") or any(filename.lower().endswith(ext) for ext in SUPPORTED_TEXT_EXTENSIONS): |
|
|
try: |
|
|
content = decoded_bytes.decode('utf-8') |
|
|
return mime_type, filename, content |
|
|
except UnicodeDecodeError: |
|
|
try: |
|
|
content = decoded_bytes.decode('latin-1') |
|
|
return mime_type, filename, content |
|
|
except UnicodeDecodeError: |
|
|
logger.warning(f"Failed to decode text file {filename}") |
|
|
return mime_type, filename, None |
|
|
else: |
|
|
|
|
|
return mime_type, filename, base64_content |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to decode base64 content: {e}") |
|
|
return mime_type, filename, None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to parse data URL: {e}") |
|
|
return None, None, None |
|
|
|
|
|
async def download_image_from_url(url: str) -> str: |
|
|
""" |
|
|
从URL下载图片并转换为base64 |
|
|
返回base64编码的图片数据 |
|
|
""" |
|
|
try: |
|
|
logger.info(f"Downloading image from URL: {url}") |
|
|
|
|
|
async with aiohttp.ClientSession() as session: |
|
|
async with session.get(url, timeout=30) as response: |
|
|
if response.status == 200: |
|
|
image_bytes = await response.read() |
|
|
|
|
|
|
|
|
content_type = response.headers.get('content-type', '') |
|
|
if not content_type.startswith('image/'): |
|
|
|
|
|
if url.lower().endswith(('.jpg', '.jpeg')): |
|
|
content_type = 'image/jpeg' |
|
|
elif url.lower().endswith('.png'): |
|
|
content_type = 'image/png' |
|
|
elif url.lower().endswith('.gif'): |
|
|
content_type = 'image/gif' |
|
|
elif url.lower().endswith('.webp'): |
|
|
content_type = 'image/webp' |
|
|
else: |
|
|
content_type = 'image/jpeg' |
|
|
|
|
|
|
|
|
base64_data = base64.b64encode(image_bytes).decode('utf-8') |
|
|
data_url = f"data:{content_type};base64,{base64_data}" |
|
|
|
|
|
logger.info(f"Successfully downloaded image, size: {len(image_bytes)} bytes, base64 size: {len(base64_data)} chars") |
|
|
return data_url |
|
|
else: |
|
|
logger.error(f"Failed to download image: HTTP {response.status}") |
|
|
return None |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
logger.error(f"Timeout downloading image from {url}") |
|
|
return None |
|
|
except Exception as e: |
|
|
logger.error(f"Error downloading image from {url}: {e}") |
|
|
return None |
|
|
|
|
|
async def upload_image_to_imgbb(base64_data: str) -> str: |
|
|
""" |
|
|
将 base64 图片上传到 imgbb |
|
|
返回图片的 URL |
|
|
""" |
|
|
try: |
|
|
|
|
|
if base64_data.startswith("data:"): |
|
|
base64_content = base64_data.split(",")[1] |
|
|
else: |
|
|
base64_content = base64_data |
|
|
|
|
|
|
|
|
data = { |
|
|
'key': IMGBB_API_KEY, |
|
|
'image': base64_content, |
|
|
'expiration': 300 |
|
|
} |
|
|
|
|
|
logger.info(f"Uploading image to imgbb, size: {len(base64_content)} chars") |
|
|
|
|
|
|
|
|
async with aiohttp.ClientSession() as session: |
|
|
async with session.post(IMGBB_API_URL, data=data, timeout=30) as response: |
|
|
if response.status == 200: |
|
|
result = await response.json() |
|
|
if result.get('success'): |
|
|
image_url = result['data']['url'] |
|
|
logger.info(f"Image uploaded successfully: {image_url}") |
|
|
return image_url |
|
|
else: |
|
|
logger.error(f"imgbb upload failed: {result}") |
|
|
return None |
|
|
else: |
|
|
error_text = await response.text() |
|
|
logger.error(f"imgbb upload error: {response.status} - {error_text}") |
|
|
return None |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
logger.error("Timeout uploading image to imgbb") |
|
|
return None |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to upload image to imgbb: {e}") |
|
|
return None |
|
|
|
|
|
async def format_image_for_model(base64_data: str, model_config: Dict[str, Any]) -> str: |
|
|
""" |
|
|
根据模型配置格式化图片数据 |
|
|
""" |
|
|
image_format = model_config.get("image_format", "data_url") |
|
|
|
|
|
if image_format == "url": |
|
|
|
|
|
image_url = await upload_image_to_imgbb(base64_data) |
|
|
if image_url: |
|
|
return image_url |
|
|
else: |
|
|
logger.error("Failed to upload image, falling back to data URL") |
|
|
|
|
|
return format_image_as_data_url(base64_data) |
|
|
|
|
|
elif image_format == "data_url": |
|
|
return format_image_as_data_url(base64_data) |
|
|
|
|
|
return base64_data |
|
|
|
|
|
def format_image_as_data_url(base64_data: str) -> str: |
|
|
""" |
|
|
将 base64 数据格式化为 data URL |
|
|
""" |
|
|
|
|
|
if base64_data.startswith("data:"): |
|
|
return base64_data |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
decoded_bytes = base64.b64decode(base64_data[:100]) |
|
|
|
|
|
if decoded_bytes.startswith(b'\xff\xd8\xff'): |
|
|
|
|
|
return f"data:image/jpeg;base64,{base64_data}" |
|
|
elif decoded_bytes.startswith(b'\x89PNG\r\n\x1a\n'): |
|
|
|
|
|
return f"data:image/png;base64,{base64_data}" |
|
|
elif decoded_bytes.startswith(b'GIF87a') or decoded_bytes.startswith(b'GIF89a'): |
|
|
|
|
|
return f"data:image/gif;base64,{base64_data}" |
|
|
elif decoded_bytes.startswith(b'RIFF') and b'WEBP' in decoded_bytes[:20]: |
|
|
|
|
|
return f"data:image/webp;base64,{base64_data}" |
|
|
else: |
|
|
|
|
|
return f"data:image/jpeg;base64,{base64_data}" |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to detect image format: {e}, using JPEG as default") |
|
|
return f"data:image/jpeg;base64,{base64_data}" |
|
|
|
|
|
def extract_images_from_context(content: str) -> List[str]: |
|
|
""" |
|
|
从系统上下文中提取图片URL |
|
|
""" |
|
|
images = [] |
|
|
try: |
|
|
|
|
|
import re |
|
|
pattern = r'<image[^>]+url="([^"]+)"[^>]*></image>' |
|
|
matches = re.findall(pattern, content) |
|
|
for url in matches: |
|
|
if url.startswith('http'): |
|
|
images.append(url) |
|
|
logger.info(f"Found image URL in context: {url}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error extracting images from context: {e}") |
|
|
|
|
|
return images |
|
|
|
|
|
def extract_content_from_message(message: Dict[str, Any]) -> tuple[str, List[str], List[Dict[str, str]]]: |
|
|
""" |
|
|
从消息中提取文本内容、图片和文件 |
|
|
返回: (text_content, image_data_list, file_data_list) |
|
|
""" |
|
|
content = message.get("content", "") |
|
|
images = [] |
|
|
files = [] |
|
|
|
|
|
if isinstance(content, str): |
|
|
|
|
|
context_images = extract_images_from_context(content) |
|
|
if context_images: |
|
|
images.extend(context_images) |
|
|
return content, images, files |
|
|
elif isinstance(content, list): |
|
|
|
|
|
text_parts = [] |
|
|
|
|
|
for item in content: |
|
|
if isinstance(item, dict): |
|
|
item_type = item.get("type", "") |
|
|
|
|
|
if item_type == "text": |
|
|
text_content = item.get("text", "") |
|
|
text_parts.append(text_content) |
|
|
|
|
|
context_images = extract_images_from_context(text_content) |
|
|
if context_images: |
|
|
images.extend(context_images) |
|
|
|
|
|
elif item_type == "image_url": |
|
|
image_url = item.get("image_url", {}) |
|
|
url = image_url.get("url", "") |
|
|
|
|
|
if url.startswith("data:image/"): |
|
|
|
|
|
try: |
|
|
if ";base64," in url: |
|
|
base64_data = url.split(";base64,")[1] |
|
|
|
|
|
images.append(url) |
|
|
logger.info(f"Found base64 image, size: {len(base64_data)} chars") |
|
|
else: |
|
|
logger.warning(f"Image URL format not supported: {url[:100]}...") |
|
|
except Exception as e: |
|
|
logger.error(f"Error processing image: {e}") |
|
|
elif url.startswith("http"): |
|
|
|
|
|
images.append(url) |
|
|
logger.info(f"Found external image URL: {url}") |
|
|
else: |
|
|
logger.warning(f"Unsupported image URL format: {url}") |
|
|
|
|
|
elif item_type == "file" or (item_type == "image_url" and not item.get("image_url", {}).get("url", "").startswith("data:image/")): |
|
|
|
|
|
file_url = item.get("image_url", {}).get("url", "") if item_type == "image_url" else item.get("file_url", {}).get("url", "") |
|
|
|
|
|
if file_url.startswith("data:"): |
|
|
mime_type, filename, file_content = decode_base64_file(file_url) |
|
|
|
|
|
if file_content is not None: |
|
|
file_ext = get_file_extension(filename) |
|
|
|
|
|
if file_ext in SUPPORTED_IMAGE_EXTENSIONS and mime_type.startswith("image/"): |
|
|
|
|
|
images.append(file_url) |
|
|
logger.info(f"Found image file: {filename}") |
|
|
elif file_ext in SUPPORTED_TEXT_EXTENSIONS or mime_type.startswith("text/"): |
|
|
|
|
|
files.append({ |
|
|
"filename": filename, |
|
|
"content": file_content, |
|
|
"mime_type": mime_type |
|
|
}) |
|
|
logger.info(f"Found text file: {filename}, size: {len(file_content)} chars") |
|
|
else: |
|
|
logger.warning(f"Unsupported file type: {filename} ({mime_type})") |
|
|
|
|
|
elif isinstance(item, str): |
|
|
text_parts.append(item) |
|
|
|
|
|
context_images = extract_images_from_context(item) |
|
|
if context_images: |
|
|
images.extend(context_images) |
|
|
|
|
|
return " ".join(text_parts), images, files |
|
|
|
|
|
return str(content), images, files |
|
|
|
|
|
def format_files_for_prompt(files: List[Dict[str, str]]) -> str: |
|
|
"""将文件内容格式化为提示文本""" |
|
|
if not files: |
|
|
return "" |
|
|
|
|
|
file_sections = [] |
|
|
for file_data in files: |
|
|
filename = file_data["filename"] |
|
|
content = file_data["content"] |
|
|
mime_type = file_data.get("mime_type", "text/plain") |
|
|
|
|
|
|
|
|
max_length = 10000 |
|
|
if len(content) > max_length: |
|
|
content = content[:max_length] + "\n\n[文件内容已截断,显示前 10000 字符]" |
|
|
|
|
|
file_section = f""" |
|
|
|
|
|
--- 文件: {filename} ({mime_type}) --- |
|
|
{content} |
|
|
--- 文件结束 --- |
|
|
""" |
|
|
file_sections.append(file_section) |
|
|
|
|
|
return "\n".join(file_sections) |
|
|
|
|
|
async def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override: str = None) -> Dict[str, Any]: |
|
|
"""将OpenAI格式的请求转换为Replicate格式""" |
|
|
try: |
|
|
messages = openai_request.get("messages", []) |
|
|
|
|
|
|
|
|
system_prompt = None |
|
|
user_messages = [] |
|
|
has_images = False |
|
|
has_files = False |
|
|
all_files = [] |
|
|
primary_image = None |
|
|
|
|
|
for message in messages: |
|
|
if message.get("role") == "system": |
|
|
system_prompt = message.get("content", "") |
|
|
elif message.get("role") in ["user", "assistant"]: |
|
|
|
|
|
text_content, image_list, file_list = extract_content_from_message(message) |
|
|
|
|
|
|
|
|
msg_data = { |
|
|
"role": message.get("role"), |
|
|
"content": text_content, |
|
|
"images": image_list, |
|
|
"files": file_list |
|
|
} |
|
|
user_messages.append(msg_data) |
|
|
|
|
|
if image_list: |
|
|
has_images = True |
|
|
|
|
|
if message.get("role") == "user": |
|
|
primary_image = image_list[0] |
|
|
|
|
|
if file_list: |
|
|
has_files = True |
|
|
all_files.extend(file_list) |
|
|
|
|
|
|
|
|
model = model_override or openai_request.get("model", DEFAULT_MODEL) |
|
|
|
|
|
|
|
|
model_mapping = { |
|
|
"claude-4-sonnet": "anthropic/claude-4-sonnet", |
|
|
"claude-3.5-sonnet": "anthropic/claude-3.5-sonnet", |
|
|
"claude-3-sonnet": "anthropic/claude-3-sonnet", |
|
|
"claude-3.5-haiku": "anthropic/claude-3.5-haiku", |
|
|
"claude-3-haiku": "anthropic/claude-3-haiku", |
|
|
"gemini-2.5-pro": "google/gemini-2.5-pro", |
|
|
} |
|
|
|
|
|
if model in model_mapping: |
|
|
model = model_mapping[model] |
|
|
elif not model.startswith(("anthropic/", "google/")): |
|
|
model = "anthropic/claude-3.5-sonnet" |
|
|
|
|
|
|
|
|
model_config = MODEL_CONFIGS.get(model, MODEL_CONFIGS["anthropic/claude-3.5-sonnet"]) |
|
|
|
|
|
|
|
|
if has_images and not model_config.get("supports_vision", False): |
|
|
logger.warning(f"Model {model} may not support vision") |
|
|
if has_files and not model_config.get("supports_files", False): |
|
|
logger.warning(f"Model {model} may not support file processing") |
|
|
|
|
|
|
|
|
formatted_image = None |
|
|
if has_images and primary_image: |
|
|
logger.info(f"Processing image for model {model} with format {model_config.get('image_format')}") |
|
|
|
|
|
|
|
|
if primary_image.startswith("http"): |
|
|
logger.info(f"Downloading external image: {primary_image}") |
|
|
downloaded_image = await download_image_from_url(primary_image) |
|
|
if downloaded_image: |
|
|
primary_image = downloaded_image |
|
|
logger.info("External image downloaded and converted to base64") |
|
|
else: |
|
|
logger.error("Failed to download external image") |
|
|
primary_image = None |
|
|
|
|
|
if primary_image: |
|
|
formatted_image = await format_image_for_model(primary_image, model_config) |
|
|
|
|
|
if not formatted_image: |
|
|
logger.error("Failed to format image for model") |
|
|
raise HTTPException(status_code=500, detail="Failed to process image") |
|
|
|
|
|
|
|
|
replicate_input = {} |
|
|
|
|
|
|
|
|
prompt_parts = [] |
|
|
|
|
|
|
|
|
if has_files: |
|
|
files_section = format_files_for_prompt(all_files) |
|
|
if files_section: |
|
|
prompt_parts.append("以下是用户上传的文件内容:") |
|
|
prompt_parts.append(files_section) |
|
|
prompt_parts.append("请根据上述文件内容回答用户的问题。") |
|
|
|
|
|
|
|
|
for i, msg in enumerate(user_messages): |
|
|
role = msg["role"] |
|
|
content = msg["content"] |
|
|
|
|
|
if role == "user": |
|
|
prompt_parts.append(f"Human: {content}") |
|
|
elif role == "assistant": |
|
|
prompt_parts.append(f"Assistant: {content}") |
|
|
|
|
|
|
|
|
prompt = "\n\n".join(prompt_parts) |
|
|
if not prompt.endswith("\n\nAssistant:"): |
|
|
prompt += "\n\nAssistant:" |
|
|
|
|
|
replicate_input["prompt"] = prompt |
|
|
|
|
|
|
|
|
if formatted_image: |
|
|
replicate_input["image"] = formatted_image |
|
|
if formatted_image.startswith("http"): |
|
|
logger.info(f"Added image URL to request for model {model}: {formatted_image}") |
|
|
else: |
|
|
logger.info(f"Added image data to request for model {model}: {formatted_image[:100]}...") |
|
|
|
|
|
|
|
|
if system_prompt: |
|
|
replicate_input["system_prompt"] = system_prompt |
|
|
|
|
|
|
|
|
client_max_tokens = openai_request.get("max_tokens") |
|
|
|
|
|
if client_max_tokens is not None: |
|
|
max_tokens = client_max_tokens |
|
|
if max_tokens < model_config["min_max_tokens"]: |
|
|
logger.info(f"Adjusting max_tokens from {max_tokens} to {model_config['min_max_tokens']} (model minimum)") |
|
|
max_tokens = model_config["min_max_tokens"] |
|
|
else: |
|
|
if model_config["has_max_tokens_limit"]: |
|
|
max_tokens = model_config["default_max_tokens"] |
|
|
logger.info(f"Using default max_tokens {max_tokens} for model {model}") |
|
|
else: |
|
|
max_tokens = None |
|
|
logger.info(f"No max_tokens limit for model {model}, allowing unlimited") |
|
|
|
|
|
if max_tokens is not None: |
|
|
replicate_input["max_tokens"] = max_tokens |
|
|
|
|
|
|
|
|
if "temperature" in openai_request: |
|
|
replicate_input["temperature"] = openai_request["temperature"] |
|
|
|
|
|
if "top_p" in openai_request: |
|
|
replicate_input["top_p"] = openai_request["top_p"] |
|
|
|
|
|
if "frequency_penalty" in openai_request: |
|
|
replicate_input["frequency_penalty"] = openai_request["frequency_penalty"] |
|
|
|
|
|
if "presence_penalty" in openai_request: |
|
|
replicate_input["presence_penalty"] = openai_request["presence_penalty"] |
|
|
|
|
|
replicate_request = { |
|
|
"stream": openai_request.get("stream", False), |
|
|
"input": replicate_input |
|
|
} |
|
|
|
|
|
logger.info(f"Transformed request for model: {model}") |
|
|
logger.info(f"Message count: {len(messages)} (system: {1 if system_prompt else 0}, user/assistant: {len(user_messages)})") |
|
|
logger.info(f"Has images: {has_images}, Has files: {has_files}") |
|
|
if has_files: |
|
|
logger.info(f"Files: {[f['filename'] for f in all_files]}") |
|
|
logger.info(f"Parameters: max_tokens={max_tokens}, temperature={replicate_input.get('temperature', 'not set')}") |
|
|
|
|
|
return replicate_request, model |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error transforming request: {str(e)}") |
|
|
raise HTTPException(status_code=400, detail=f"Request transformation error: {str(e)}") |
|
|
|
|
|
def create_log_safe_data(data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""创建用于日志记录的安全数据副本,不修改原始数据""" |
|
|
log_data = json.loads(json.dumps(data)) |
|
|
|
|
|
if "input" in log_data: |
|
|
if "image" in log_data["input"]: |
|
|
image_data = log_data["input"]["image"] |
|
|
if image_data.startswith("http"): |
|
|
log_data["input"]["image"] = f"[IMAGE_URL: {image_data}]" |
|
|
else: |
|
|
log_data["input"]["image"] = f"[IMAGE_DATA_{len(image_data)}]" |
|
|
if "prompt" in log_data["input"] and len(log_data["input"]["prompt"]) > 1000: |
|
|
log_data["input"]["prompt"] = log_data["input"]["prompt"][:1000] + "...[TRUNCATED]" |
|
|
|
|
|
return log_data |
|
|
|
|
|
async def create_replicate_prediction(session: aiohttp.ClientSession, model: str, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""创建Replicate预测""" |
|
|
try: |
|
|
url = f"{REPLICATE_BASE_URL}/models/{model}/predictions" |
|
|
headers = { |
|
|
"Authorization": f"Bearer {REPLICATE_API_TOKEN}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
|
|
|
logger.info(f"Creating prediction for model: {model}") |
|
|
|
|
|
|
|
|
log_data = create_log_safe_data(data) |
|
|
logger.info(f"Request data: {json.dumps(log_data, indent=2)}") |
|
|
|
|
|
async with session.post(url, headers=headers, json=data, timeout=30) as response: |
|
|
response_text = await response.text() |
|
|
logger.info(f"Replicate response status: {response.status}") |
|
|
|
|
|
if response.status != 201: |
|
|
logger.error(f"Replicate API error: {response.status} - {response_text}") |
|
|
raise HTTPException( |
|
|
status_code=response.status, |
|
|
detail=f"Replicate API error: {response_text}" |
|
|
) |
|
|
|
|
|
return json.loads(response_text) |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
logger.error("Timeout creating Replicate prediction") |
|
|
raise HTTPException(status_code=504, detail="Timeout creating prediction") |
|
|
except Exception as e: |
|
|
logger.error(f"Error creating prediction: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Prediction creation error: {str(e)}") |
|
|
|
|
|
class SSEParser: |
|
|
"""Server-Sent Events 解析器""" |
|
|
def __init__(self): |
|
|
self.event_type = None |
|
|
self.event_id = None |
|
|
self.data_buffer = [] |
|
|
|
|
|
def parse_line(self, line: str): |
|
|
"""解析 SSE 格式的一行""" |
|
|
if line.startswith('event: '): |
|
|
self.event_type = line[7:].strip() |
|
|
elif line.startswith('id: '): |
|
|
self.event_id = line[4:].strip() |
|
|
elif line.startswith('data: '): |
|
|
self.data_buffer.append(line[6:]) |
|
|
elif line.startswith(': '): |
|
|
|
|
|
pass |
|
|
elif line == '': |
|
|
|
|
|
if self.data_buffer or self.event_type: |
|
|
data = '\n'.join(self.data_buffer) |
|
|
event = { |
|
|
'event': self.event_type, |
|
|
'id': self.event_id, |
|
|
'data': data |
|
|
} |
|
|
|
|
|
self.event_type = None |
|
|
self.event_id = None |
|
|
self.data_buffer = [] |
|
|
return event |
|
|
return None |
|
|
|
|
|
def create_openai_chunk(content: str, model: str, prediction_id: str, finish_reason=None): |
|
|
"""创建 OpenAI 格式的流式响应块""" |
|
|
chunk = { |
|
|
"id": f"chatcmpl-{prediction_id}", |
|
|
"object": "chat.completion.chunk", |
|
|
"created": int(asyncio.get_event_loop().time()), |
|
|
"model": model, |
|
|
"choices": [{ |
|
|
"index": 0, |
|
|
"delta": {}, |
|
|
"finish_reason": finish_reason |
|
|
}] |
|
|
} |
|
|
|
|
|
if content and not finish_reason: |
|
|
chunk["choices"][0]["delta"]["content"] = content |
|
|
|
|
|
return f"data: {json.dumps(chunk)}\n\n" |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""健康检查端点""" |
|
|
return { |
|
|
"message": "Replicate API Proxy for LobeChat with Vision and File Support", |
|
|
"status": "running", |
|
|
"replicate_token_configured": bool(REPLICATE_API_TOKEN), |
|
|
"imgbb_token_configured": bool(IMGBB_API_KEY), |
|
|
"version": "1.3.0", |
|
|
"supported_models": list(MODEL_CONFIGS.keys()), |
|
|
"vision_support": True, |
|
|
"file_support": True, |
|
|
"external_image_support": True, |
|
|
"supported_text_files": list(SUPPORTED_TEXT_EXTENSIONS), |
|
|
"supported_image_files": list(SUPPORTED_IMAGE_EXTENSIONS), |
|
|
"claude4_vision_support": "Full support via imgbb image hosting" |
|
|
} |
|
|
|
|
|
@app.get("/health") |
|
|
async def health(): |
|
|
"""详细健康检查""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"replicate_token": "configured" if REPLICATE_API_TOKEN else "missing", |
|
|
"imgbb_token": "configured" if IMGBB_API_KEY else "missing", |
|
|
"timestamp": asyncio.get_event_loop().time(), |
|
|
"model_configs": MODEL_CONFIGS, |
|
|
"supported_file_types": { |
|
|
"text": list(SUPPORTED_TEXT_EXTENSIONS), |
|
|
"image": list(SUPPORTED_IMAGE_EXTENSIONS) |
|
|
} |
|
|
} |
|
|
|
|
|
@app.get("/v1/models") |
|
|
async def list_models(): |
|
|
"""列出可用模型(兼容OpenAI API)""" |
|
|
models = [] |
|
|
for model_id in ["claude-4-sonnet", "claude-3.5-sonnet", "claude-3.5-haiku", "claude-3-sonnet", "claude-3-haiku"]: |
|
|
models.append({ |
|
|
"id": model_id, |
|
|
"object": "model", |
|
|
"created": 1677610602, |
|
|
"owned_by": "anthropic" |
|
|
}) |
|
|
|
|
|
return {"object": "list", "data": models} |
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
|
async def chat_completions(request: Request): |
|
|
"""处理聊天完成请求(兼容OpenAI API)""" |
|
|
if not REPLICATE_API_TOKEN: |
|
|
logger.error("REPLICATE_API_TOKEN not configured") |
|
|
raise HTTPException(status_code=500, detail="REPLICATE_API_TOKEN not configured") |
|
|
|
|
|
try: |
|
|
body = await request.json() |
|
|
logger.info(f"Received chat completion request") |
|
|
logger.info(f"Client parameters: max_tokens={body.get('max_tokens', 'not set')}, temperature={body.get('temperature', 'not set')}") |
|
|
logger.info(f"Message count: {len(body.get('messages', []))}") |
|
|
|
|
|
|
|
|
replicate_data, model = await transform_openai_to_replicate(body) |
|
|
|
|
|
if body.get("stream", False): |
|
|
|
|
|
async def generate_stream(): |
|
|
|
|
|
async with aiohttp.ClientSession() as session: |
|
|
try: |
|
|
|
|
|
prediction = await create_replicate_prediction(session, model, replicate_data) |
|
|
prediction_id = prediction.get('id') |
|
|
logger.info(f"Created prediction: {prediction_id}") |
|
|
|
|
|
|
|
|
stream_url = prediction.get("urls", {}).get("stream") |
|
|
if not stream_url: |
|
|
error_response = { |
|
|
"error": { |
|
|
"message": "Stream URL not available", |
|
|
"type": "stream_error" |
|
|
} |
|
|
} |
|
|
yield f"data: {json.dumps(error_response)}\n\n" |
|
|
return |
|
|
|
|
|
logger.info(f"Starting stream from: {stream_url}") |
|
|
|
|
|
|
|
|
headers = { |
|
|
"Accept": "text/event-stream", |
|
|
"Cache-Control": "no-store" |
|
|
} |
|
|
|
|
|
sse_parser = SSEParser() |
|
|
|
|
|
async with session.get(stream_url, headers=headers, timeout=120) as response: |
|
|
if response.status != 200: |
|
|
error_text = await response.text() |
|
|
logger.error(f"Stream error: {response.status} - {error_text}") |
|
|
error_response = { |
|
|
"error": { |
|
|
"message": f"Stream error: {error_text}", |
|
|
"type": "stream_error" |
|
|
} |
|
|
} |
|
|
yield f"data: {json.dumps(error_response)}\n\n" |
|
|
return |
|
|
|
|
|
async for line in response.content: |
|
|
line = line.decode('utf-8').rstrip('\r\n') |
|
|
|
|
|
|
|
|
if '408' in line or 'timeout' in line.lower(): |
|
|
logger.info(f"Ignoring timeout message: {line}") |
|
|
continue |
|
|
|
|
|
|
|
|
event = sse_parser.parse_line(line) |
|
|
if event: |
|
|
event_type = event.get('event') |
|
|
data = event.get('data', '') |
|
|
|
|
|
if event_type == 'output' and data.strip(): |
|
|
|
|
|
yield create_openai_chunk(data, model, prediction_id) |
|
|
elif event_type == 'done': |
|
|
|
|
|
logger.info("Stream completed with done event") |
|
|
yield create_openai_chunk("", model, prediction_id, "stop") |
|
|
yield "data: [DONE]\n\n" |
|
|
return |
|
|
|
|
|
|
|
|
logger.info("Stream ended without done event, sending manual completion") |
|
|
yield create_openai_chunk("", model, prediction_id, "stop") |
|
|
yield "data: [DONE]\n\n" |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
logger.error("Stream timeout") |
|
|
yield create_openai_chunk("", model, prediction_id or "unknown", "stop") |
|
|
yield "data: [DONE]\n\n" |
|
|
except Exception as e: |
|
|
logger.error(f"Stream generation error: {e}") |
|
|
logger.error(f"Traceback: {traceback.format_exc()}") |
|
|
error_response = { |
|
|
"error": { |
|
|
"message": str(e), |
|
|
"type": "stream_error" |
|
|
} |
|
|
} |
|
|
yield f"data: {json.dumps(error_response)}\n\n" |
|
|
|
|
|
return StreamingResponse( |
|
|
generate_stream(), |
|
|
media_type="text/event-stream", |
|
|
headers={ |
|
|
"Cache-Control": "no-cache", |
|
|
"Connection": "keep-alive", |
|
|
"Access-Control-Allow-Origin": "*", |
|
|
"X-Accel-Buffering": "no", |
|
|
} |
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
async with aiohttp.ClientSession() as session: |
|
|
|
|
|
prediction = await create_replicate_prediction(session, model, replicate_data) |
|
|
prediction_id = prediction.get('id') |
|
|
logger.info(f"Created prediction: {prediction_id}") |
|
|
|
|
|
|
|
|
prediction_url = f"{REPLICATE_BASE_URL}/predictions/{prediction_id}" |
|
|
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}"} |
|
|
|
|
|
max_attempts = 60 |
|
|
attempt = 0 |
|
|
|
|
|
while attempt < max_attempts: |
|
|
async with session.get(prediction_url, headers=headers) as response: |
|
|
result = await response.json() |
|
|
status = result.get("status") |
|
|
|
|
|
logger.info(f"Prediction {prediction_id} status: {status}") |
|
|
|
|
|
if status == "succeeded": |
|
|
output = result.get("output", []) |
|
|
content = "".join(output) if isinstance(output, list) else str(output) |
|
|
|
|
|
openai_response = { |
|
|
"id": f"chatcmpl-{prediction_id}", |
|
|
"object": "chat.completion", |
|
|
"created": int(asyncio.get_event_loop().time()), |
|
|
"model": model, |
|
|
"choices": [{ |
|
|
"index": 0, |
|
|
"message": { |
|
|
"role": "assistant", |
|
|
"content": content |
|
|
}, |
|
|
"finish_reason": "stop" |
|
|
}], |
|
|
"usage": { |
|
|
"prompt_tokens": 0, |
|
|
"completion_tokens": len(content.split()), |
|
|
"total_tokens": len(content.split()) |
|
|
} |
|
|
} |
|
|
return openai_response |
|
|
|
|
|
elif status == "failed": |
|
|
error_msg = result.get('error', 'Unknown error') |
|
|
logger.error(f"Prediction failed: {error_msg}") |
|
|
raise HTTPException(status_code=500, detail=f"Prediction failed: {error_msg}") |
|
|
|
|
|
elif status in ["canceled", "cancelled"]: |
|
|
raise HTTPException(status_code=500, detail="Prediction was canceled") |
|
|
|
|
|
|
|
|
await asyncio.sleep(1) |
|
|
attempt += 1 |
|
|
|
|
|
raise HTTPException(status_code=504, detail="Prediction timeout") |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected error processing request: {str(e)}") |
|
|
logger.error(f"Traceback: {traceback.format_exc()}") |
|
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.getenv("PORT", 7860)) |
|
|
logger.info(f"Starting server on port {port}") |
|
|
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info") |