File size: 12,799 Bytes
ec77a9f 7c7c7ba ec77a9f 7c7c7ba ec77a9f 7c7c7ba ec77a9f 7c7c7ba ec77a9f 7c7c7ba ec77a9f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 | from __future__ import annotations
import base64
import binascii
import json
import mimetypes
import re
from pathlib import PurePosixPath
from typing import Any, TypeGuard
from urllib.parse import unquote, unquote_to_bytes, urlparse
from curl_cffi import requests
from fastapi import HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from starlette.datastructures import UploadFile
from services.proxy_service import proxy_settings
ImageInput = tuple[bytes, str, str]
ImageSource = str | UploadFile | ImageInput
MAX_IMAGE_REFERENCE_BYTES = 50 * 1024 * 1024
IMAGE_REFERENCE_FIELDS = {"image", "image[]", "images", "images[]", "image_url", "image_url[]"}
def _clean(value: object, default: str = "") -> str:
"""清理字符串:转换为字符串并去掉首尾空白。"""
text = str(value if value is not None else default).strip()
return text or default
def _is_upload(value: object) -> TypeGuard[UploadFile]:
"""识别上传文件:兼容 Starlette 表单返回的 UploadFile。"""
return isinstance(value, UploadFile)
def _parse_bool(value: object) -> bool | None:
"""解析布尔字段:兼容 JSON 布尔值和表单字符串。"""
if value is None or value == "":
return None
if isinstance(value, bool):
return value
text = _clean(value).lower()
if text in {"true", "1", "yes", "y", "on"}:
return True
if text in {"false", "0", "no", "n", "off"}:
return False
raise HTTPException(status_code=400, detail={"error": "stream must be a boolean"})
def _parse_count(value: object) -> int:
"""解析生成数量:保持图片接口的 1 到 4 限制。"""
try:
count = int(value or 1)
except (TypeError, ValueError) as exc:
raise HTTPException(status_code=400, detail={"error": "n must be an integer"}) from exc
if count < 1 or count > 4:
raise HTTPException(status_code=400, detail={"error": "n must be between 1 and 4"})
return count
def _payload_from_fields(fields: dict[str, Any]) -> dict[str, Any]:
"""构造图片编辑载荷:从表单或 JSON 字段提取通用参数。"""
prompt = _clean(fields.get("prompt"))
if not prompt:
raise HTTPException(status_code=400, detail={"error": "prompt is required"})
payload = {
"prompt": prompt,
"model": _clean(fields.get("model"), "gpt-image-2"),
"n": _parse_count(fields.get("n")),
"size": _clean(fields.get("size")) or None,
"response_format": _clean(fields.get("response_format"), "b64_json"),
"stream": _parse_bool(fields.get("stream")),
}
if "client_task_id" in fields:
payload["client_task_id"] = _clean(fields.get("client_task_id"))
return payload
def _json_reference_value(value: object) -> object:
"""解析表单图片引用:支持把 images 字段写成 JSON 字符串。"""
if not isinstance(value, str):
return value
text = value.strip()
if not text or text[0] not in "[{":
return value
try:
return json.loads(text)
except json.JSONDecodeError:
return value
def _decode_base64_image(value: object, filename: str, mime_type: str) -> ImageInput:
try:
data = base64.b64decode(str(value).strip(), validate=True)
except (binascii.Error, ValueError) as exc:
raise HTTPException(status_code=400, detail={"error": "invalid base64 image data"}) from exc
if not data:
raise HTTPException(status_code=400, detail={"error": "image file is empty"})
if len(data) > MAX_IMAGE_REFERENCE_BYTES:
raise HTTPException(status_code=400, detail={"error": "image URL exceeds 50MB limit"})
return data, filename, mime_type
def _source_from_object(value: dict[str, Any]) -> list[ImageSource]:
"""提取图片引用对象:支持 image_url 或 url,明确拒绝 file_id。"""
has_url = "image_url" in value or "url" in value
if value.get("file_id"):
raise HTTPException(
status_code=400,
detail={"error": "file_id image references are not supported; use image_url instead"},
)
inline = value.get("b64_json") or value.get("base64")
if inline:
filename = _clean(value.get("filename") or value.get("file_name"), "image.png")
mime_type = _clean(value.get("mime_type") or value.get("mimeType"), "image/png")
return [_decode_base64_image(inline, filename, mime_type)]
if not has_url:
raise HTTPException(status_code=400, detail={"error": "image reference must include image_url"})
image_url = value.get("image_url", value.get("url"))
if isinstance(image_url, dict):
image_url = image_url.get("url")
return _sources_from_value(image_url)
def _sources_from_value(value: object) -> list[ImageSource]:
"""展开图片引用:把字符串、数组和对象统一成图片来源列表。"""
value = _json_reference_value(value)
if _is_upload(value):
return [value]
if isinstance(value, str):
text = value.strip()
if not text:
return []
if text.lower().startswith(("data:", "http://", "https://")):
return [text]
return [_decode_base64_image(text, "image.png", "image/png")]
if isinstance(value, list):
sources: list[ImageSource] = []
for item in value:
sources.extend(_sources_from_value(item))
return sources
if isinstance(value, dict):
return _source_from_object(value)
if value is None:
return []
raise HTTPException(status_code=400, detail={"error": "invalid image reference"})
def _json_image_sources(body: dict[str, Any]) -> list[ImageSource]:
"""读取 JSON 图片引用:优先支持官方 images 数组字段。"""
sources: list[ImageSource] = []
for key in ("images", "image", "image_url"):
if key in body:
sources.extend(_sources_from_value(body.get(key)))
return sources
async def parse_image_edit_request(request: Request) -> tuple[dict[str, Any], list[ImageSource]]:
"""解析图片编辑请求:同时支持 multipart 上传和官方 JSON 图片 URL。"""
content_type = request.headers.get("content-type", "").split(";", 1)[0].strip().lower()
if content_type == "application/json":
try:
body = await request.json()
except json.JSONDecodeError as exc:
raise HTTPException(status_code=400, detail={"error": "invalid JSON body"}) from exc
if not isinstance(body, dict):
raise HTTPException(status_code=400, detail={"error": "JSON body must be an object"})
return _payload_from_fields(body), _json_image_sources(body)
form = await request.form()
fields: dict[str, Any] = {}
for key in ("client_task_id", "prompt", "model", "n", "size", "response_format", "stream"):
value = form.get(key)
if isinstance(value, str):
fields[key] = value
sources: list[ImageSource] = []
for key, value in form.multi_items():
if key in IMAGE_REFERENCE_FIELDS:
sources.extend(_sources_from_value(value))
return _payload_from_fields(fields), sources
def _extension_from_mime(mime_type: str) -> str:
"""推导图片扩展名:把 MIME 类型转换为常见文件后缀。"""
subtype = mime_type.split("/", 1)[1].split("+", 1)[0] if "/" in mime_type else "png"
if subtype == "jpeg":
return "jpg"
return re.sub(r"[^a-z0-9]+", "", subtype.lower()) or "png"
def _safe_filename(name: str, mime_type: str, fallback: str) -> str:
"""生成安全文件名:清理 URL 文件名并补齐扩展名。"""
cleaned = re.sub(r"[^A-Za-z0-9._-]+", "_", name).strip("._")
if not cleaned:
cleaned = fallback
if "." not in cleaned:
cleaned = f"{cleaned}.{_extension_from_mime(mime_type)}"
return cleaned
def _decode_data_url(url: str) -> ImageInput:
"""解码 data URL:把内联图片转成标准图片输入元组。"""
header, separator, payload = url.partition(",")
if not separator:
raise HTTPException(status_code=400, detail={"error": "invalid data image URL"})
mime_type = header.split(";", 1)[0].removeprefix("data:") or "image/png"
if not mime_type.startswith("image/"):
raise HTTPException(status_code=400, detail={"error": "image_url must point to an image"})
try:
data = base64.b64decode(payload, validate=True) if ";base64" in header else unquote_to_bytes(payload)
except (binascii.Error, ValueError) as exc:
raise HTTPException(status_code=400, detail={"error": "invalid data image URL"}) from exc
if not data:
raise HTTPException(status_code=400, detail={"error": "image URL is empty"})
if len(data) > MAX_IMAGE_REFERENCE_BYTES:
raise HTTPException(status_code=400, detail={"error": "image URL exceeds 50MB limit"})
return data, f"image_url.{_extension_from_mime(mime_type)}", mime_type
def _response_mime_type(response: requests.Response, parsed_path: str) -> str:
"""识别下载图片类型:优先响应头,必要时按 URL 后缀推断。"""
header_type = str(response.headers.get("content-type") or "").split(";", 1)[0].strip().lower()
guessed_type = mimetypes.guess_type(parsed_path)[0] or ""
if header_type.startswith("image/"):
return header_type
if header_type and header_type not in {"application/octet-stream", "binary/octet-stream"}:
raise HTTPException(status_code=400, detail={"error": "image_url must point to an image"})
if guessed_type.startswith("image/"):
return guessed_type
if not header_type or header_type in {"application/octet-stream", "binary/octet-stream"}:
return "image/png"
raise HTTPException(status_code=400, detail={"error": "image_url must point to an image"})
def _filename_from_url(parsed_path: str, mime_type: str) -> str:
"""生成 URL 图片文件名:从链接路径提取名称并做安全化。"""
raw_name = PurePosixPath(unquote(parsed_path)).name
return _safe_filename(raw_name, mime_type, "image_url")
def _download_image_url(url: str) -> ImageInput:
"""下载远程图片:把 http/https 图片链接转成标准图片输入元组。"""
source = _clean(url)
if source.startswith("data:"):
return _decode_data_url(source)
parsed = urlparse(source)
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
raise HTTPException(status_code=400, detail={"error": "image_url must be an http or https URL"})
try:
response = requests.get(
source,
headers={"Accept": "image/*,*/*;q=0.8", "User-Agent": "chatgpt2api image fetcher"},
timeout=60,
allow_redirects=True,
**proxy_settings.build_session_kwargs(),
)
except Exception as exc:
raise HTTPException(status_code=400, detail={"error": f"image_url fetch failed: {exc}"}) from exc
if not 200 <= response.status_code < 300:
raise HTTPException(status_code=400, detail={"error": f"image_url fetch failed: HTTP {response.status_code}"})
content_length = _clean(response.headers.get("content-length"))
if content_length and content_length.isdigit() and int(content_length) > MAX_IMAGE_REFERENCE_BYTES:
raise HTTPException(status_code=400, detail={"error": "image_url exceeds 50MB limit"})
data = response.content
if not data:
raise HTTPException(status_code=400, detail={"error": "image_url returned empty content"})
if len(data) > MAX_IMAGE_REFERENCE_BYTES:
raise HTTPException(status_code=400, detail={"error": "image_url exceeds 50MB limit"})
mime_type = _response_mime_type(response, parsed.path)
return data, _filename_from_url(parsed.path, mime_type), mime_type
async def read_image_sources(sources: list[ImageSource]) -> list[ImageInput]:
"""读取图片来源:上传文件直接读取,URL 下载后统一返回图片元组。"""
images: list[ImageInput] = []
for source in sources:
if isinstance(source, tuple):
images.append(source)
continue
if _is_upload(source):
try:
image_data = await source.read()
finally:
await source.close()
if not image_data:
raise HTTPException(status_code=400, detail={"error": "image file is empty"})
images.append((image_data, source.filename or "image.png", source.content_type or "image/png"))
continue
images.append(await run_in_threadpool(_download_image_url, source))
if not images:
raise HTTPException(status_code=400, detail={"error": "image file or image_url is required"})
return images
|