CHATSAM / api /image_tasks.py
ldl
feat: support image edit urls
ec77a9f
Raw
History Blame Contribute Delete
3.62 kB
from __future__ import annotations
from fastapi import APIRouter, Header, HTTPException, Query, Request
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel, Field
from api.image_inputs import parse_image_edit_request, read_image_sources
from api.support import require_identity, resolve_image_base_url
from services.content_filter import check_request
from services.image_task_service import image_task_service
from services.log_service import LoggedCall
class ImageGenerationTaskRequest(BaseModel):
client_task_id: str = Field(..., min_length=1)
prompt: str = Field(..., min_length=1)
model: str = "gpt-image-2"
size: str | None = None
def _parse_task_ids(value: str) -> list[str]:
return [item.strip() for item in value.split(",") if item.strip()]
async def filter_or_log(call: LoggedCall, text: str) -> None:
try:
await run_in_threadpool(check_request, text)
except HTTPException as exc:
call.log("调用失败", status="failed", error=str(exc.detail))
raise
def create_router() -> APIRouter:
router = APIRouter()
@router.get("/api/image-tasks")
async def list_image_tasks(
ids: str = Query(default=""),
authorization: str | None = Header(default=None),
):
identity = require_identity(authorization)
return await run_in_threadpool(image_task_service.list_tasks, identity, _parse_task_ids(ids))
@router.post("/api/image-tasks/generations")
async def create_generation_task(
body: ImageGenerationTaskRequest,
request: Request,
authorization: str | None = Header(default=None),
):
identity = require_identity(authorization)
await filter_or_log(LoggedCall(identity, "/api/image-tasks/generations", body.model, "文生图任务", request_text=body.prompt), body.prompt)
try:
return await run_in_threadpool(
image_task_service.submit_generation,
identity,
client_task_id=body.client_task_id,
prompt=body.prompt,
model=body.model,
size=body.size,
base_url=resolve_image_base_url(request),
)
except ValueError as exc:
raise HTTPException(status_code=400, detail={"error": str(exc)}) from exc
@router.post("/api/image-tasks/edits")
async def create_edit_task(
request: Request,
authorization: str | None = Header(default=None),
):
identity = require_identity(authorization)
payload, image_sources = await parse_image_edit_request(request)
client_task_id = str(payload.get("client_task_id") or "").strip()
if not client_task_id:
raise HTTPException(status_code=400, detail={"error": "client_task_id is required"})
prompt = str(payload["prompt"])
model = str(payload["model"])
await filter_or_log(LoggedCall(identity, "/api/image-tasks/edits", model, "图生图任务", request_text=prompt), prompt)
images = await read_image_sources(image_sources)
try:
return await run_in_threadpool(
image_task_service.submit_edit,
identity,
client_task_id=client_task_id,
prompt=prompt,
model=model,
size=payload["size"],
base_url=resolve_image_base_url(request),
images=images,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail={"error": str(exc)}) from exc
return router