"""API routes - OpenAI compatible endpoints""" from fastapi import APIRouter, Depends, HTTPException,Request from fastapi.responses import StreamingResponse, JSONResponse,HTMLResponse from datetime import datetime from typing import List import json import re from ..core.auth import verify_api_key_header from ..core.models import ChatCompletionRequest from ..services.generation_handler import GenerationHandler, MODEL_CONFIG router = APIRouter() # Dependency injection will be set up in main.py generation_handler: GenerationHandler = None def set_generation_handler(handler: GenerationHandler): """Set generation handler instance""" global generation_handler generation_handler = handler def _extract_remix_id(text: str) -> str: """Extract remix ID from text Supports two formats: 1. Full URL: https://sora.chatgpt.com/p/s_68e3a06dcd888191b150971da152c1f5 2. Short ID: s_68e3a06dcd888191b150971da152c1f5 Args: text: Text to search for remix ID Returns: Remix ID (s_[a-f0-9]{32}) or empty string if not found """ if not text: return "" # Match Sora share link format: s_[a-f0-9]{32} match = re.search(r's_[a-f0-9]{32}', text) if match: return match.group(0) return "" @router.get("/v1/models") async def list_models(api_key: str = Depends(verify_api_key_header)): """List available models""" models = [] for model_id, config in MODEL_CONFIG.items(): description = f"{config['type'].capitalize()} generation" if config['type'] == 'image': description += f" - {config['width']}x{config['height']}" else: description += f" - {config['orientation']}" models.append({ "id": model_id, "object": "model", "owned_by": "sora2api", "description": description }) return { "object": "list", "data": models } @router.post("/v1/chat/completions") async def create_chat_completion( request: ChatCompletionRequest, api_key: str = Depends(verify_api_key_header) ): """Create chat completion (unified endpoint for image and video generation)""" try: # Extract prompt from messages if not request.messages: raise HTTPException(status_code=400, detail="Messages cannot be empty") last_message = request.messages[-1] content = last_message.content # Handle both string and array format (OpenAI multimodal) prompt = "" image_data = request.image # Default to request.image if provided video_data = request.video # Video parameter remix_target_id = request.remix_target_id # Remix target ID if isinstance(content, str): # Simple string format prompt = content # Extract remix_target_id from prompt if not already provided if not remix_target_id: remix_target_id = _extract_remix_id(prompt) elif isinstance(content, list): # Array format (OpenAI multimodal) for item in content: if isinstance(item, dict): if item.get("type") == "text": prompt = item.get("text", "") # Extract remix_target_id from prompt if not already provided if not remix_target_id: remix_target_id = _extract_remix_id(prompt) elif item.get("type") == "image_url": # Extract base64 image from data URI image_url = item.get("image_url", {}) url = image_url.get("url", "") if url.startswith("data:image"): # Extract base64 data from data URI if "base64," in url: image_data = url.split("base64,", 1)[1] else: image_data = url elif item.get("type") == "input_video": # Extract video from input_video video_url = item.get("videoUrl", {}) url = video_url.get("url", "") if url.startswith("data:video") or url.startswith("data:application"): # Extract base64 data from data URI if "base64," in url: video_data = url.split("base64,", 1)[1] else: video_data = url else: # It's a URL, pass it as-is (will be downloaded in generation_handler) video_data = url else: raise HTTPException(status_code=400, detail="Invalid content format") # Validate model if request.model not in MODEL_CONFIG: raise HTTPException(status_code=400, detail=f"Invalid model: {request.model}") # Check if this is a video model model_config = MODEL_CONFIG[request.model] is_video_model = model_config["type"] == "video" # For video models with video parameter, we need streaming if is_video_model and (video_data or remix_target_id): if not request.stream: # Non-streaming mode: only check availability result = None async for chunk in generation_handler.handle_generation( model=request.model, prompt=prompt, image=image_data, video=video_data, remix_target_id=remix_target_id, stream=False ): result = chunk if result: import json return JSONResponse(content=json.loads(result)) else: return JSONResponse( status_code=500, content={ "error": { "message": "Availability check failed", "type": "server_error", "param": None, "code": None } } ) # Handle streaming if request.stream: async def generate(): import json as json_module # Import inside function to avoid scope issues try: async for chunk in generation_handler.handle_generation( model=request.model, prompt=prompt, image=image_data, video=video_data, remix_target_id=remix_target_id, stream=True ): yield chunk except Exception as e: # Return OpenAI-compatible error format error_response = { "error": { "message": str(e), "type": "server_error", "param": None, "code": None } } error_chunk = f'data: {json_module.dumps(error_response)}\n\n' yield error_chunk yield 'data: [DONE]\n\n' return StreamingResponse( generate(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" } ) else: # Non-streaming response (availability check only) result = None async for chunk in generation_handler.handle_generation( model=request.model, prompt=prompt, image=image_data, video=video_data, remix_target_id=remix_target_id, stream=False ): result = chunk if result: import json return JSONResponse(content=json.loads(result)) else: # Return OpenAI-compatible error format return JSONResponse( status_code=500, content={ "error": { "message": "Availability check failed", "type": "server_error", "param": None, "code": None } } ) except Exception as e: return JSONResponse( status_code=500, content={ "error": { "message": str(e), "type": "server_error", "param": None, "code": None } } ) @router.post("/v1/tasks") async def submit_task( request: ChatCompletionRequest, api_key: str = Depends(verify_api_key_header) ): """Submit an asynchronous generation task""" try: # Extract prompt from messages (reuse logic from create_chat_completion) if not request.messages: raise HTTPException(status_code=400, detail="Messages cannot be empty") last_message = request.messages[-1] content = last_message.content prompt = "" image_data = request.image video_data = request.video remix_target_id = request.remix_target_id if isinstance(content, str): prompt = content if not remix_target_id: remix_target_id = _extract_remix_id(prompt) elif isinstance(content, list): for item in content: if isinstance(item, dict): if item.get("type") == "text": prompt = item.get("text", "") if not remix_target_id: remix_target_id = _extract_remix_id(prompt) elif item.get("type") == "image_url": image_url = item.get("image_url", {}) url = image_url.get("url", "") if url.startswith("data:image"): if "base64," in url: image_data = url.split("base64,", 1)[1] else: image_data = url elif item.get("type") == "input_video": video_url = item.get("videoUrl", {}) url = video_url.get("url", "") if url.startswith("data:video") or url.startswith("data:application"): if "base64," in url: video_data = url.split("base64,", 1)[1] else: video_data = url else: video_data = url task_id = await generation_handler.submit_generation_task( model=request.model, prompt=prompt, image=image_data, video=video_data, remix_target_id=remix_target_id ) return { "id": task_id, "object": "task", "created": int(datetime.now().timestamp()), "status": "processing" } except Exception as e: return JSONResponse( status_code=500, content={ "error": { "message": str(e), "type": "server_error", "param": None, "code": None } } ) @router.get("/v1/tasks/{task_id}") async def get_task_status( task_id: str, api_key: str = Depends(verify_api_key_header) ): """Query task status""" try: task = await generation_handler.db.get_task(task_id) if not task: raise HTTPException(status_code=404, detail=f"Task {task_id} not found") response = { "id": task.task_id, "object": "task", "status": task.status, "created": int(task.created_at.timestamp()) if task.created_at else 0, "model": task.model, "progress": f"{task.progress:.0f}%" } if task.status == "completed": response["result"] = { "url": json.loads(task.result_urls)[0] if task.result_urls else None } elif task.status == "failed": response["error"] = { "message": task.error_message } return response except HTTPException: raise except Exception as e: return JSONResponse( status_code=500, content={ "error": { "message": str(e), "type": "server_error", "param": None, "code": None } } ) @router.post("/v1beta/models/gemini-3-pro-image-preview:generateContent") async def proxy_gemini_vision(request: Request, key: str): """ Direct proxy for gemini-3-pro-image-preview:generateContent """ try: body = await request.json() target_url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-3-pro-image-preview:generateContent?key={key}" headers = { "Content-Type": "application/json" } # Use httpx for async request import httpx async with httpx.AsyncClient() as client: response = await client.post(target_url, json=body, headers=headers, timeout=60) # Forward the status code and content if response.status_code != 200: return JSONResponse(status_code=response.status_code, content=response.json()) return response.json() except Exception as e: # logger.error(f"Proxy error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/", response_class=HTMLResponse) async def root(): html_content = f"""
服务运行中
Sora API Endpoints available at /v1/...