| from __future__ import annotations |
|
|
| from fastapi import APIRouter, Header, HTTPException, Query, Request |
| from fastapi.concurrency import run_in_threadpool |
| from pydantic import BaseModel, Field |
|
|
| from api.image_inputs import parse_image_edit_request, read_image_sources |
| from api.support import require_identity, resolve_image_base_url |
| from services.content_filter import check_request |
| from services.image_task_service import image_task_service |
| from services.log_service import LoggedCall |
|
|
|
|
| class ImageGenerationTaskRequest(BaseModel): |
| client_task_id: str = Field(..., min_length=1) |
| prompt: str = Field(..., min_length=1) |
| model: str = "gpt-image-2" |
| size: str | None = None |
| account_pool_strategy: str = "own_first" |
| memory: bool = False |
| memory_conversation_id: str = "" |
| memory_reset: bool = False |
|
|
|
|
| def _parse_task_ids(value: str) -> list[str]: |
| return [item.strip() for item in value.split(",") if item.strip()] |
|
|
|
|
| def _credit_kwargs(identity: dict[str, object]) -> dict[str, object]: |
| if str(identity.get("role") or "").strip().lower() != "normal": |
| return {} |
| return {"credit_user_id": str(identity.get("id") or "").strip(), "credit_amount": 1} |
|
|
|
|
| def _effective_account_pool_strategy(identity: dict[str, object], value: object) -> str: |
| if identity.get("role") == "admin" or bool(identity.get("account_pool_enabled")): |
| return str(value or "own_first") |
| return "admin_only" |
|
|
|
|
| async def filter_or_log(call: LoggedCall, text: str) -> None: |
| try: |
| await run_in_threadpool(check_request, text) |
| except HTTPException as exc: |
| call.log("调用失败", status="failed", error=str(exc.detail)) |
| raise |
|
|
|
|
| def create_router() -> APIRouter: |
| router = APIRouter() |
|
|
| @router.get("/api/image-tasks") |
| async def list_image_tasks( |
| ids: str = Query(default=""), |
| authorization: str | None = Header(default=None), |
| ): |
| identity = require_identity(authorization) |
| return await run_in_threadpool(image_task_service.list_tasks, identity, _parse_task_ids(ids)) |
|
|
| @router.post("/api/image-tasks/generations") |
| async def create_generation_task( |
| body: ImageGenerationTaskRequest, |
| request: Request, |
| authorization: str | None = Header(default=None), |
| ): |
| identity = require_identity(authorization) |
| await filter_or_log(LoggedCall(identity, "/api/image-tasks/generations", body.model, "文生图任务", request_text=body.prompt), body.prompt) |
| try: |
| return await run_in_threadpool( |
| image_task_service.submit_generation, |
| identity, |
| client_task_id=body.client_task_id, |
| prompt=body.prompt, |
| model=body.model, |
| size=body.size, |
| base_url=resolve_image_base_url(request), |
| memory_enabled=body.memory, |
| memory_conversation_id=body.memory_conversation_id, |
| memory_reset=body.memory_reset, |
| account_pool_strategy=_effective_account_pool_strategy(identity, body.account_pool_strategy), |
| **_credit_kwargs(identity), |
| ) |
| except ValueError as exc: |
| status_code = 402 if "次数不足" in str(exc) else 400 |
| raise HTTPException(status_code=status_code, detail={"error": str(exc)}) from exc |
|
|
| @router.post("/api/image-tasks/edits") |
| async def create_edit_task( |
| request: Request, |
| authorization: str | None = Header(default=None), |
| ): |
| identity = require_identity(authorization) |
| payload, image_sources = await parse_image_edit_request(request) |
| client_task_id = str(payload.get("client_task_id") or "").strip() |
| if not client_task_id: |
| raise HTTPException(status_code=400, detail={"error": "client_task_id is required"}) |
| prompt = str(payload["prompt"]) |
| model = str(payload["model"]) |
| await filter_or_log(LoggedCall(identity, "/api/image-tasks/edits", model, "图生图任务", request_text=prompt), prompt) |
| images = await read_image_sources(image_sources) |
| try: |
| return await run_in_threadpool( |
| image_task_service.submit_edit, |
| identity, |
| client_task_id=client_task_id, |
| prompt=prompt, |
| model=model, |
| size=payload["size"], |
| base_url=resolve_image_base_url(request), |
| images=images, |
| memory_enabled=bool(payload.get("memory")), |
| memory_conversation_id=str(payload.get("memory_conversation_id") or ""), |
| memory_reset=bool(payload.get("memory_reset")), |
| account_pool_strategy=_effective_account_pool_strategy(identity, payload.get("account_pool_strategy") or "own_first"), |
| **_credit_kwargs(identity), |
| ) |
| except ValueError as exc: |
| status_code = 402 if "次数不足" in str(exc) else 400 |
| raise HTTPException(status_code=status_code, detail={"error": str(exc)}) from exc |
|
|
| return router |
|
|