li / api /image_tasks.py
xiaolongmr
Add shared account pool controls
07dd5b7
Raw
History Blame Contribute Delete
5.11 kB
from __future__ import annotations
from fastapi import APIRouter, Header, HTTPException, Query, Request
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel, Field
from api.image_inputs import parse_image_edit_request, read_image_sources
from api.support import require_identity, resolve_image_base_url
from services.content_filter import check_request
from services.image_task_service import image_task_service
from services.log_service import LoggedCall
class ImageGenerationTaskRequest(BaseModel):
client_task_id: str = Field(..., min_length=1)
prompt: str = Field(..., min_length=1)
model: str = "gpt-image-2"
size: str | None = None
account_pool_strategy: str = "own_first"
memory: bool = False
memory_conversation_id: str = ""
memory_reset: bool = False
def _parse_task_ids(value: str) -> list[str]:
return [item.strip() for item in value.split(",") if item.strip()]
def _credit_kwargs(identity: dict[str, object]) -> dict[str, object]:
if str(identity.get("role") or "").strip().lower() != "normal":
return {}
return {"credit_user_id": str(identity.get("id") or "").strip(), "credit_amount": 1}
def _effective_account_pool_strategy(identity: dict[str, object], value: object) -> str:
if identity.get("role") == "admin" or bool(identity.get("account_pool_enabled")):
return str(value or "own_first")
return "admin_only"
async def filter_or_log(call: LoggedCall, text: str) -> None:
try:
await run_in_threadpool(check_request, text)
except HTTPException as exc:
call.log("调用失败", status="failed", error=str(exc.detail))
raise
def create_router() -> APIRouter:
router = APIRouter()
@router.get("/api/image-tasks")
async def list_image_tasks(
ids: str = Query(default=""),
authorization: str | None = Header(default=None),
):
identity = require_identity(authorization)
return await run_in_threadpool(image_task_service.list_tasks, identity, _parse_task_ids(ids))
@router.post("/api/image-tasks/generations")
async def create_generation_task(
body: ImageGenerationTaskRequest,
request: Request,
authorization: str | None = Header(default=None),
):
identity = require_identity(authorization)
await filter_or_log(LoggedCall(identity, "/api/image-tasks/generations", body.model, "文生图任务", request_text=body.prompt), body.prompt)
try:
return await run_in_threadpool(
image_task_service.submit_generation,
identity,
client_task_id=body.client_task_id,
prompt=body.prompt,
model=body.model,
size=body.size,
base_url=resolve_image_base_url(request),
memory_enabled=body.memory,
memory_conversation_id=body.memory_conversation_id,
memory_reset=body.memory_reset,
account_pool_strategy=_effective_account_pool_strategy(identity, body.account_pool_strategy),
**_credit_kwargs(identity),
)
except ValueError as exc:
status_code = 402 if "次数不足" in str(exc) else 400
raise HTTPException(status_code=status_code, detail={"error": str(exc)}) from exc
@router.post("/api/image-tasks/edits")
async def create_edit_task(
request: Request,
authorization: str | None = Header(default=None),
):
identity = require_identity(authorization)
payload, image_sources = await parse_image_edit_request(request)
client_task_id = str(payload.get("client_task_id") or "").strip()
if not client_task_id:
raise HTTPException(status_code=400, detail={"error": "client_task_id is required"})
prompt = str(payload["prompt"])
model = str(payload["model"])
await filter_or_log(LoggedCall(identity, "/api/image-tasks/edits", model, "图生图任务", request_text=prompt), prompt)
images = await read_image_sources(image_sources)
try:
return await run_in_threadpool(
image_task_service.submit_edit,
identity,
client_task_id=client_task_id,
prompt=prompt,
model=model,
size=payload["size"],
base_url=resolve_image_base_url(request),
images=images,
memory_enabled=bool(payload.get("memory")),
memory_conversation_id=str(payload.get("memory_conversation_id") or ""),
memory_reset=bool(payload.get("memory_reset")),
account_pool_strategy=_effective_account_pool_strategy(identity, payload.get("account_pool_strategy") or "own_first"),
**_credit_kwargs(identity),
)
except ValueError as exc:
status_code = 402 if "次数不足" in str(exc) else 400
raise HTTPException(status_code=status_code, detail={"error": str(exc)}) from exc
return router