CHATSAM-Public / api /image_inputs.py
chanfasf's picture
init: CHATSAM public version with data persistence and keep-alive
bd4057b
Raw
History Blame Contribute Delete
12.8 kB
from __future__ import annotations
import base64
import binascii
import json
import mimetypes
import re
from pathlib import PurePosixPath
from typing import Any, TypeGuard
from urllib.parse import unquote, unquote_to_bytes, urlparse
from curl_cffi import requests
from fastapi import HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from starlette.datastructures import UploadFile
from services.proxy_service import proxy_settings
ImageInput = tuple[bytes, str, str]
ImageSource = str | UploadFile | ImageInput
MAX_IMAGE_REFERENCE_BYTES = 50 * 1024 * 1024
IMAGE_REFERENCE_FIELDS = {"image", "image[]", "images", "images[]", "image_url", "image_url[]"}
def _clean(value: object, default: str = "") -> str:
"""清理字符串:转换为字符串并去掉首尾空白。"""
text = str(value if value is not None else default).strip()
return text or default
def _is_upload(value: object) -> TypeGuard[UploadFile]:
"""识别上传文件:兼容 Starlette 表单返回的 UploadFile。"""
return isinstance(value, UploadFile)
def _parse_bool(value: object) -> bool | None:
"""解析布尔字段:兼容 JSON 布尔值和表单字符串。"""
if value is None or value == "":
return None
if isinstance(value, bool):
return value
text = _clean(value).lower()
if text in {"true", "1", "yes", "y", "on"}:
return True
if text in {"false", "0", "no", "n", "off"}:
return False
raise HTTPException(status_code=400, detail={"error": "stream must be a boolean"})
def _parse_count(value: object) -> int:
"""解析生成数量:保持图片接口的 1 到 4 限制。"""
try:
count = int(value or 1)
except (TypeError, ValueError) as exc:
raise HTTPException(status_code=400, detail={"error": "n must be an integer"}) from exc
if count < 1 or count > 4:
raise HTTPException(status_code=400, detail={"error": "n must be between 1 and 4"})
return count
def _payload_from_fields(fields: dict[str, Any]) -> dict[str, Any]:
"""构造图片编辑载荷:从表单或 JSON 字段提取通用参数。"""
prompt = _clean(fields.get("prompt"))
if not prompt:
raise HTTPException(status_code=400, detail={"error": "prompt is required"})
payload = {
"prompt": prompt,
"model": _clean(fields.get("model"), "gpt-image-2"),
"n": _parse_count(fields.get("n")),
"size": _clean(fields.get("size")) or None,
"response_format": _clean(fields.get("response_format"), "b64_json"),
"stream": _parse_bool(fields.get("stream")),
}
if "client_task_id" in fields:
payload["client_task_id"] = _clean(fields.get("client_task_id"))
return payload
def _json_reference_value(value: object) -> object:
"""解析表单图片引用:支持把 images 字段写成 JSON 字符串。"""
if not isinstance(value, str):
return value
text = value.strip()
if not text or text[0] not in "[{":
return value
try:
return json.loads(text)
except json.JSONDecodeError:
return value
def _decode_base64_image(value: object, filename: str, mime_type: str) -> ImageInput:
try:
data = base64.b64decode(str(value).strip(), validate=True)
except (binascii.Error, ValueError) as exc:
raise HTTPException(status_code=400, detail={"error": "invalid base64 image data"}) from exc
if not data:
raise HTTPException(status_code=400, detail={"error": "image file is empty"})
if len(data) > MAX_IMAGE_REFERENCE_BYTES:
raise HTTPException(status_code=400, detail={"error": "image URL exceeds 50MB limit"})
return data, filename, mime_type
def _source_from_object(value: dict[str, Any]) -> list[ImageSource]:
"""提取图片引用对象:支持 image_url 或 url,明确拒绝 file_id。"""
has_url = "image_url" in value or "url" in value
if value.get("file_id"):
raise HTTPException(
status_code=400,
detail={"error": "file_id image references are not supported; use image_url instead"},
)
inline = value.get("b64_json") or value.get("base64")
if inline:
filename = _clean(value.get("filename") or value.get("file_name"), "image.png")
mime_type = _clean(value.get("mime_type") or value.get("mimeType"), "image/png")
return [_decode_base64_image(inline, filename, mime_type)]
if not has_url:
raise HTTPException(status_code=400, detail={"error": "image reference must include image_url"})
image_url = value.get("image_url", value.get("url"))
if isinstance(image_url, dict):
image_url = image_url.get("url")
return _sources_from_value(image_url)
def _sources_from_value(value: object) -> list[ImageSource]:
"""展开图片引用:把字符串、数组和对象统一成图片来源列表。"""
value = _json_reference_value(value)
if _is_upload(value):
return [value]
if isinstance(value, str):
text = value.strip()
if not text:
return []
if text.lower().startswith(("data:", "http://", "https://")):
return [text]
return [_decode_base64_image(text, "image.png", "image/png")]
if isinstance(value, list):
sources: list[ImageSource] = []
for item in value:
sources.extend(_sources_from_value(item))
return sources
if isinstance(value, dict):
return _source_from_object(value)
if value is None:
return []
raise HTTPException(status_code=400, detail={"error": "invalid image reference"})
def _json_image_sources(body: dict[str, Any]) -> list[ImageSource]:
"""读取 JSON 图片引用:优先支持官方 images 数组字段。"""
sources: list[ImageSource] = []
for key in ("images", "image", "image_url"):
if key in body:
sources.extend(_sources_from_value(body.get(key)))
return sources
async def parse_image_edit_request(request: Request) -> tuple[dict[str, Any], list[ImageSource]]:
"""解析图片编辑请求:同时支持 multipart 上传和官方 JSON 图片 URL。"""
content_type = request.headers.get("content-type", "").split(";", 1)[0].strip().lower()
if content_type == "application/json":
try:
body = await request.json()
except json.JSONDecodeError as exc:
raise HTTPException(status_code=400, detail={"error": "invalid JSON body"}) from exc
if not isinstance(body, dict):
raise HTTPException(status_code=400, detail={"error": "JSON body must be an object"})
return _payload_from_fields(body), _json_image_sources(body)
form = await request.form()
fields: dict[str, Any] = {}
for key in ("client_task_id", "prompt", "model", "n", "size", "response_format", "stream"):
value = form.get(key)
if isinstance(value, str):
fields[key] = value
sources: list[ImageSource] = []
for key, value in form.multi_items():
if key in IMAGE_REFERENCE_FIELDS:
sources.extend(_sources_from_value(value))
return _payload_from_fields(fields), sources
def _extension_from_mime(mime_type: str) -> str:
"""推导图片扩展名:把 MIME 类型转换为常见文件后缀。"""
subtype = mime_type.split("/", 1)[1].split("+", 1)[0] if "/" in mime_type else "png"
if subtype == "jpeg":
return "jpg"
return re.sub(r"[^a-z0-9]+", "", subtype.lower()) or "png"
def _safe_filename(name: str, mime_type: str, fallback: str) -> str:
"""生成安全文件名:清理 URL 文件名并补齐扩展名。"""
cleaned = re.sub(r"[^A-Za-z0-9._-]+", "_", name).strip("._")
if not cleaned:
cleaned = fallback
if "." not in cleaned:
cleaned = f"{cleaned}.{_extension_from_mime(mime_type)}"
return cleaned
def _decode_data_url(url: str) -> ImageInput:
"""解码 data URL:把内联图片转成标准图片输入元组。"""
header, separator, payload = url.partition(",")
if not separator:
raise HTTPException(status_code=400, detail={"error": "invalid data image URL"})
mime_type = header.split(";", 1)[0].removeprefix("data:") or "image/png"
if not mime_type.startswith("image/"):
raise HTTPException(status_code=400, detail={"error": "image_url must point to an image"})
try:
data = base64.b64decode(payload, validate=True) if ";base64" in header else unquote_to_bytes(payload)
except (binascii.Error, ValueError) as exc:
raise HTTPException(status_code=400, detail={"error": "invalid data image URL"}) from exc
if not data:
raise HTTPException(status_code=400, detail={"error": "image URL is empty"})
if len(data) > MAX_IMAGE_REFERENCE_BYTES:
raise HTTPException(status_code=400, detail={"error": "image URL exceeds 50MB limit"})
return data, f"image_url.{_extension_from_mime(mime_type)}", mime_type
def _response_mime_type(response: requests.Response, parsed_path: str) -> str:
"""识别下载图片类型:优先响应头,必要时按 URL 后缀推断。"""
header_type = str(response.headers.get("content-type") or "").split(";", 1)[0].strip().lower()
guessed_type = mimetypes.guess_type(parsed_path)[0] or ""
if header_type.startswith("image/"):
return header_type
if header_type and header_type not in {"application/octet-stream", "binary/octet-stream"}:
raise HTTPException(status_code=400, detail={"error": "image_url must point to an image"})
if guessed_type.startswith("image/"):
return guessed_type
if not header_type or header_type in {"application/octet-stream", "binary/octet-stream"}:
return "image/png"
raise HTTPException(status_code=400, detail={"error": "image_url must point to an image"})
def _filename_from_url(parsed_path: str, mime_type: str) -> str:
"""生成 URL 图片文件名:从链接路径提取名称并做安全化。"""
raw_name = PurePosixPath(unquote(parsed_path)).name
return _safe_filename(raw_name, mime_type, "image_url")
def _download_image_url(url: str) -> ImageInput:
"""下载远程图片:把 http/https 图片链接转成标准图片输入元组。"""
source = _clean(url)
if source.startswith("data:"):
return _decode_data_url(source)
parsed = urlparse(source)
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
raise HTTPException(status_code=400, detail={"error": "image_url must be an http or https URL"})
try:
response = requests.get(
source,
headers={"Accept": "image/*,*/*;q=0.8", "User-Agent": "chatgpt2api image fetcher"},
timeout=60,
allow_redirects=True,
**proxy_settings.build_session_kwargs(),
)
except Exception as exc:
raise HTTPException(status_code=400, detail={"error": f"image_url fetch failed: {exc}"}) from exc
if not 200 <= response.status_code < 300:
raise HTTPException(status_code=400, detail={"error": f"image_url fetch failed: HTTP {response.status_code}"})
content_length = _clean(response.headers.get("content-length"))
if content_length and content_length.isdigit() and int(content_length) > MAX_IMAGE_REFERENCE_BYTES:
raise HTTPException(status_code=400, detail={"error": "image_url exceeds 50MB limit"})
data = response.content
if not data:
raise HTTPException(status_code=400, detail={"error": "image_url returned empty content"})
if len(data) > MAX_IMAGE_REFERENCE_BYTES:
raise HTTPException(status_code=400, detail={"error": "image_url exceeds 50MB limit"})
mime_type = _response_mime_type(response, parsed.path)
return data, _filename_from_url(parsed.path, mime_type), mime_type
async def read_image_sources(sources: list[ImageSource]) -> list[ImageInput]:
"""读取图片来源:上传文件直接读取,URL 下载后统一返回图片元组。"""
images: list[ImageInput] = []
for source in sources:
if isinstance(source, tuple):
images.append(source)
continue
if _is_upload(source):
try:
image_data = await source.read()
finally:
await source.close()
if not image_data:
raise HTTPException(status_code=400, detail={"error": "image file is empty"})
images.append((image_data, source.filename or "image.png", source.content_type or "image/png"))
continue
images.append(await run_in_threadpool(_download_image_url, source))
if not images:
raise HTTPException(status_code=400, detail={"error": "image file or image_url is required"})
return images