|
|
"""API routes - OpenAI compatible endpoints""" |
|
|
from fastapi import APIRouter, Depends, HTTPException,Request |
|
|
from fastapi.responses import StreamingResponse, JSONResponse,HTMLResponse |
|
|
from datetime import datetime |
|
|
from typing import List |
|
|
import json |
|
|
import re |
|
|
from ..core.auth import verify_api_key_header |
|
|
from ..core.models import ChatCompletionRequest |
|
|
from ..services.generation_handler import GenerationHandler, MODEL_CONFIG |
|
|
|
|
|
router = APIRouter() |
|
|
|
|
|
|
|
|
generation_handler: GenerationHandler = None |
|
|
|
|
|
def set_generation_handler(handler: GenerationHandler): |
|
|
"""Set generation handler instance""" |
|
|
global generation_handler |
|
|
generation_handler = handler |
|
|
|
|
|
def _extract_remix_id(text: str) -> str: |
|
|
"""Extract remix ID from text |
|
|
|
|
|
Supports two formats: |
|
|
1. Full URL: https://sora.chatgpt.com/p/s_68e3a06dcd888191b150971da152c1f5 |
|
|
2. Short ID: s_68e3a06dcd888191b150971da152c1f5 |
|
|
|
|
|
Args: |
|
|
text: Text to search for remix ID |
|
|
|
|
|
Returns: |
|
|
Remix ID (s_[a-f0-9]{32}) or empty string if not found |
|
|
""" |
|
|
if not text: |
|
|
return "" |
|
|
|
|
|
|
|
|
match = re.search(r's_[a-f0-9]{32}', text) |
|
|
if match: |
|
|
return match.group(0) |
|
|
|
|
|
return "" |
|
|
|
|
|
@router.get("/v1/models") |
|
|
async def list_models(api_key: str = Depends(verify_api_key_header)): |
|
|
"""List available models""" |
|
|
models = [] |
|
|
|
|
|
for model_id, config in MODEL_CONFIG.items(): |
|
|
description = f"{config['type'].capitalize()} generation" |
|
|
if config['type'] == 'image': |
|
|
description += f" - {config['width']}x{config['height']}" |
|
|
else: |
|
|
description += f" - {config['orientation']}" |
|
|
|
|
|
models.append({ |
|
|
"id": model_id, |
|
|
"object": "model", |
|
|
"owned_by": "sora2api", |
|
|
"description": description |
|
|
}) |
|
|
|
|
|
return { |
|
|
"object": "list", |
|
|
"data": models |
|
|
} |
|
|
|
|
|
@router.post("/v1/chat/completions") |
|
|
async def create_chat_completion( |
|
|
request: ChatCompletionRequest, |
|
|
api_key: str = Depends(verify_api_key_header) |
|
|
): |
|
|
"""Create chat completion (unified endpoint for image and video generation)""" |
|
|
try: |
|
|
|
|
|
if not request.messages: |
|
|
raise HTTPException(status_code=400, detail="Messages cannot be empty") |
|
|
|
|
|
last_message = request.messages[-1] |
|
|
content = last_message.content |
|
|
|
|
|
|
|
|
prompt = "" |
|
|
image_data = request.image |
|
|
video_data = request.video |
|
|
remix_target_id = request.remix_target_id |
|
|
|
|
|
if isinstance(content, str): |
|
|
|
|
|
prompt = content |
|
|
|
|
|
if not remix_target_id: |
|
|
remix_target_id = _extract_remix_id(prompt) |
|
|
elif isinstance(content, list): |
|
|
|
|
|
for item in content: |
|
|
if isinstance(item, dict): |
|
|
if item.get("type") == "text": |
|
|
prompt = item.get("text", "") |
|
|
|
|
|
if not remix_target_id: |
|
|
remix_target_id = _extract_remix_id(prompt) |
|
|
elif item.get("type") == "image_url": |
|
|
|
|
|
image_url = item.get("image_url", {}) |
|
|
url = image_url.get("url", "") |
|
|
if url.startswith("data:image"): |
|
|
|
|
|
if "base64," in url: |
|
|
image_data = url.split("base64,", 1)[1] |
|
|
else: |
|
|
image_data = url |
|
|
elif item.get("type") == "input_video": |
|
|
|
|
|
video_url = item.get("videoUrl", {}) |
|
|
url = video_url.get("url", "") |
|
|
if url.startswith("data:video") or url.startswith("data:application"): |
|
|
|
|
|
if "base64," in url: |
|
|
video_data = url.split("base64,", 1)[1] |
|
|
else: |
|
|
video_data = url |
|
|
else: |
|
|
|
|
|
video_data = url |
|
|
else: |
|
|
raise HTTPException(status_code=400, detail="Invalid content format") |
|
|
|
|
|
|
|
|
if request.model not in MODEL_CONFIG: |
|
|
raise HTTPException(status_code=400, detail=f"Invalid model: {request.model}") |
|
|
|
|
|
|
|
|
model_config = MODEL_CONFIG[request.model] |
|
|
is_video_model = model_config["type"] == "video" |
|
|
|
|
|
|
|
|
if is_video_model and (video_data or remix_target_id): |
|
|
if not request.stream: |
|
|
|
|
|
result = None |
|
|
async for chunk in generation_handler.handle_generation( |
|
|
model=request.model, |
|
|
prompt=prompt, |
|
|
image=image_data, |
|
|
video=video_data, |
|
|
remix_target_id=remix_target_id, |
|
|
stream=False |
|
|
): |
|
|
result = chunk |
|
|
|
|
|
if result: |
|
|
import json |
|
|
return JSONResponse(content=json.loads(result)) |
|
|
else: |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={ |
|
|
"error": { |
|
|
"message": "Availability check failed", |
|
|
"type": "server_error", |
|
|
"param": None, |
|
|
"code": None |
|
|
} |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
if request.stream: |
|
|
async def generate(): |
|
|
import json as json_module |
|
|
try: |
|
|
async for chunk in generation_handler.handle_generation( |
|
|
model=request.model, |
|
|
prompt=prompt, |
|
|
image=image_data, |
|
|
video=video_data, |
|
|
remix_target_id=remix_target_id, |
|
|
stream=True |
|
|
): |
|
|
yield chunk |
|
|
except Exception as e: |
|
|
|
|
|
error_response = { |
|
|
"error": { |
|
|
"message": str(e), |
|
|
"type": "server_error", |
|
|
"param": None, |
|
|
"code": None |
|
|
} |
|
|
} |
|
|
error_chunk = f'data: {json_module.dumps(error_response)}\n\n' |
|
|
yield error_chunk |
|
|
yield 'data: [DONE]\n\n' |
|
|
|
|
|
return StreamingResponse( |
|
|
generate(), |
|
|
media_type="text/event-stream", |
|
|
headers={ |
|
|
"Cache-Control": "no-cache", |
|
|
"Connection": "keep-alive", |
|
|
"X-Accel-Buffering": "no" |
|
|
} |
|
|
) |
|
|
else: |
|
|
|
|
|
result = None |
|
|
async for chunk in generation_handler.handle_generation( |
|
|
model=request.model, |
|
|
prompt=prompt, |
|
|
image=image_data, |
|
|
video=video_data, |
|
|
remix_target_id=remix_target_id, |
|
|
stream=False |
|
|
): |
|
|
result = chunk |
|
|
|
|
|
if result: |
|
|
import json |
|
|
return JSONResponse(content=json.loads(result)) |
|
|
else: |
|
|
|
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={ |
|
|
"error": { |
|
|
"message": "Availability check failed", |
|
|
"type": "server_error", |
|
|
"param": None, |
|
|
"code": None |
|
|
} |
|
|
} |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={ |
|
|
"error": { |
|
|
"message": str(e), |
|
|
"type": "server_error", |
|
|
"param": None, |
|
|
"code": None |
|
|
} |
|
|
} |
|
|
) |
|
|
|
|
|
@router.post("/v1/tasks") |
|
|
async def submit_task( |
|
|
request: ChatCompletionRequest, |
|
|
api_key: str = Depends(verify_api_key_header) |
|
|
): |
|
|
"""Submit an asynchronous generation task""" |
|
|
try: |
|
|
|
|
|
if not request.messages: |
|
|
raise HTTPException(status_code=400, detail="Messages cannot be empty") |
|
|
|
|
|
last_message = request.messages[-1] |
|
|
content = last_message.content |
|
|
|
|
|
prompt = "" |
|
|
image_data = request.image |
|
|
video_data = request.video |
|
|
remix_target_id = request.remix_target_id |
|
|
|
|
|
if isinstance(content, str): |
|
|
prompt = content |
|
|
if not remix_target_id: |
|
|
remix_target_id = _extract_remix_id(prompt) |
|
|
elif isinstance(content, list): |
|
|
for item in content: |
|
|
if isinstance(item, dict): |
|
|
if item.get("type") == "text": |
|
|
prompt = item.get("text", "") |
|
|
if not remix_target_id: |
|
|
remix_target_id = _extract_remix_id(prompt) |
|
|
elif item.get("type") == "image_url": |
|
|
image_url = item.get("image_url", {}) |
|
|
url = image_url.get("url", "") |
|
|
if url.startswith("data:image"): |
|
|
if "base64," in url: |
|
|
image_data = url.split("base64,", 1)[1] |
|
|
else: |
|
|
image_data = url |
|
|
elif item.get("type") == "input_video": |
|
|
video_url = item.get("videoUrl", {}) |
|
|
url = video_url.get("url", "") |
|
|
if url.startswith("data:video") or url.startswith("data:application"): |
|
|
if "base64," in url: |
|
|
video_data = url.split("base64,", 1)[1] |
|
|
else: |
|
|
video_data = url |
|
|
else: |
|
|
video_data = url |
|
|
|
|
|
task_id = await generation_handler.submit_generation_task( |
|
|
model=request.model, |
|
|
prompt=prompt, |
|
|
image=image_data, |
|
|
video=video_data, |
|
|
remix_target_id=remix_target_id |
|
|
) |
|
|
|
|
|
return { |
|
|
"id": task_id, |
|
|
"object": "task", |
|
|
"created": int(datetime.now().timestamp()), |
|
|
"status": "processing" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={ |
|
|
"error": { |
|
|
"message": str(e), |
|
|
"type": "server_error", |
|
|
"param": None, |
|
|
"code": None |
|
|
} |
|
|
} |
|
|
) |
|
|
|
|
|
@router.get("/v1/tasks/{task_id}") |
|
|
async def get_task_status( |
|
|
task_id: str, |
|
|
api_key: str = Depends(verify_api_key_header) |
|
|
): |
|
|
"""Query task status""" |
|
|
try: |
|
|
task = await generation_handler.db.get_task(task_id) |
|
|
if not task: |
|
|
raise HTTPException(status_code=404, detail=f"Task {task_id} not found") |
|
|
|
|
|
response = { |
|
|
"id": task.task_id, |
|
|
"object": "task", |
|
|
"status": task.status, |
|
|
"created": int(task.created_at.timestamp()) if task.created_at else 0, |
|
|
"model": task.model, |
|
|
"progress": f"{task.progress:.0f}%" |
|
|
} |
|
|
|
|
|
if task.status == "completed": |
|
|
response["result"] = { |
|
|
"url": json.loads(task.result_urls)[0] if task.result_urls else None |
|
|
} |
|
|
elif task.status == "failed": |
|
|
response["error"] = { |
|
|
"message": task.error_message |
|
|
} |
|
|
|
|
|
return response |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={ |
|
|
"error": { |
|
|
"message": str(e), |
|
|
"type": "server_error", |
|
|
"param": None, |
|
|
"code": None |
|
|
} |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
@router.post("/v1beta/models/gemini-3-pro-image-preview:generateContent") |
|
|
async def proxy_gemini_vision(request: Request, key: str): |
|
|
""" |
|
|
Direct proxy for gemini-3-pro-image-preview:generateContent |
|
|
""" |
|
|
try: |
|
|
body = await request.json() |
|
|
target_url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-3-pro-image-preview:generateContent?key={key}" |
|
|
|
|
|
headers = { |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
|
|
|
|
|
|
import httpx |
|
|
async with httpx.AsyncClient() as client: |
|
|
response = await client.post(target_url, json=body, headers=headers, timeout=60) |
|
|
|
|
|
|
|
|
if response.status_code != 200: |
|
|
return JSONResponse(status_code=response.status_code, content=response.json()) |
|
|
|
|
|
return response.json() |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
@router.get("/", response_class=HTMLResponse) |
|
|
async def root(): |
|
|
html_content = f""" |
|
|
<!DOCTYPE html> |
|
|
<html> |
|
|
<head> |
|
|
<title>Sora API 服务</title> |
|
|
<style> |
|
|
body {{ |
|
|
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; |
|
|
max-width: 800px; |
|
|
margin: 0 auto; |
|
|
padding: 20px; |
|
|
line-height: 1.6; |
|
|
}} |
|
|
h1 {{ |
|
|
color: #333; |
|
|
text-align: center; |
|
|
margin-bottom: 30px; |
|
|
}} |
|
|
.info-box {{ |
|
|
background-color: #f8f9fa; |
|
|
border: 1px solid #dee2e6; |
|
|
border-radius: 4px; |
|
|
padding: 20px; |
|
|
margin-bottom: 20px; |
|
|
}} |
|
|
.status {{ |
|
|
color: #28a745; |
|
|
font-weight: bold; |
|
|
}} |
|
|
</style> |
|
|
</head> |
|
|
<body> |
|
|
<h1>🤖 Sora API 服务</h1> |
|
|
|
|
|
<div class="info-box"> |
|
|
<h2>🟢 运行状态</h2> |
|
|
<p class="status">服务运行中</p> |
|
|
<!-- <p>可用API密钥数量: {{len(key_manager.api_keys)}}</p> --> |
|
|
<!-- <p>可用模型数量: {{len(GeminiClient.AVAILABLE_MODELS)}}</p> --> |
|
|
</div> |
|
|
|
|
|
<div class="info-box"> |
|
|
<h2>⚙️ 环境配置</h2> |
|
|
<!-- <p>每分钟请求限制: {{MAX_REQUESTS_PER_MINUTE}}</p> --> |
|
|
<!-- <p>每IP每日请求限制: {{MAX_REQUESTS_PER_DAY_PER_IP}}</p> --> |
|
|
<!-- <p>最大重试次数: {{len(key_manager.api_keys)}}</p> --> |
|
|
<p>Sora API Endpoints available at /v1/...</p> |
|
|
</div> |
|
|
</body> |
|
|
</html> |
|
|
""" |
|
|
return html_content |
|
|
|