from __future__ import annotations import base64 import binascii import json import mimetypes import re from pathlib import PurePosixPath from typing import Any, TypeGuard from urllib.parse import unquote, unquote_to_bytes, urlparse from curl_cffi import requests from fastapi import HTTPException, Request from fastapi.concurrency import run_in_threadpool from starlette.datastructures import UploadFile from services.proxy_service import proxy_settings ImageInput = tuple[bytes, str, str] ImageSource = str | UploadFile | ImageInput MAX_IMAGE_REFERENCE_BYTES = 50 * 1024 * 1024 IMAGE_REFERENCE_FIELDS = {"image", "image[]", "images", "images[]", "image_url", "image_url[]"} def _clean(value: object, default: str = "") -> str: """清理字符串:转换为字符串并去掉首尾空白。""" text = str(value if value is not None else default).strip() return text or default def _is_upload(value: object) -> TypeGuard[UploadFile]: """识别上传文件:兼容 Starlette 表单返回的 UploadFile。""" return isinstance(value, UploadFile) def _parse_bool(value: object) -> bool | None: """解析布尔字段:兼容 JSON 布尔值和表单字符串。""" if value is None or value == "": return None if isinstance(value, bool): return value text = _clean(value).lower() if text in {"true", "1", "yes", "y", "on"}: return True if text in {"false", "0", "no", "n", "off"}: return False raise HTTPException(status_code=400, detail={"error": "stream must be a boolean"}) def _parse_count(value: object) -> int: """解析生成数量:保持图片接口的 1 到 4 限制。""" try: count = int(value or 1) except (TypeError, ValueError) as exc: raise HTTPException(status_code=400, detail={"error": "n must be an integer"}) from exc if count < 1 or count > 4: raise HTTPException(status_code=400, detail={"error": "n must be between 1 and 4"}) return count def _payload_from_fields(fields: dict[str, Any]) -> dict[str, Any]: """构造图片编辑载荷:从表单或 JSON 字段提取通用参数。""" prompt = _clean(fields.get("prompt")) if not prompt: raise HTTPException(status_code=400, detail={"error": "prompt is required"}) payload = { "prompt": prompt, "model": _clean(fields.get("model"), "gpt-image-2"), "n": _parse_count(fields.get("n")), "size": _clean(fields.get("size")) or None, "response_format": _clean(fields.get("response_format"), "b64_json"), "stream": _parse_bool(fields.get("stream")), } if "client_task_id" in fields: payload["client_task_id"] = _clean(fields.get("client_task_id")) return payload def _json_reference_value(value: object) -> object: """解析表单图片引用:支持把 images 字段写成 JSON 字符串。""" if not isinstance(value, str): return value text = value.strip() if not text or text[0] not in "[{": return value try: return json.loads(text) except json.JSONDecodeError: return value def _decode_base64_image(value: object, filename: str, mime_type: str) -> ImageInput: try: data = base64.b64decode(str(value).strip(), validate=True) except (binascii.Error, ValueError) as exc: raise HTTPException(status_code=400, detail={"error": "invalid base64 image data"}) from exc if not data: raise HTTPException(status_code=400, detail={"error": "image file is empty"}) if len(data) > MAX_IMAGE_REFERENCE_BYTES: raise HTTPException(status_code=400, detail={"error": "image URL exceeds 50MB limit"}) return data, filename, mime_type def _source_from_object(value: dict[str, Any]) -> list[ImageSource]: """提取图片引用对象:支持 image_url 或 url,明确拒绝 file_id。""" has_url = "image_url" in value or "url" in value if value.get("file_id"): raise HTTPException( status_code=400, detail={"error": "file_id image references are not supported; use image_url instead"}, ) inline = value.get("b64_json") or value.get("base64") if inline: filename = _clean(value.get("filename") or value.get("file_name"), "image.png") mime_type = _clean(value.get("mime_type") or value.get("mimeType"), "image/png") return [_decode_base64_image(inline, filename, mime_type)] if not has_url: raise HTTPException(status_code=400, detail={"error": "image reference must include image_url"}) image_url = value.get("image_url", value.get("url")) if isinstance(image_url, dict): image_url = image_url.get("url") return _sources_from_value(image_url) def _sources_from_value(value: object) -> list[ImageSource]: """展开图片引用:把字符串、数组和对象统一成图片来源列表。""" value = _json_reference_value(value) if _is_upload(value): return [value] if isinstance(value, str): text = value.strip() if not text: return [] if text.lower().startswith(("data:", "http://", "https://")): return [text] return [_decode_base64_image(text, "image.png", "image/png")] if isinstance(value, list): sources: list[ImageSource] = [] for item in value: sources.extend(_sources_from_value(item)) return sources if isinstance(value, dict): return _source_from_object(value) if value is None: return [] raise HTTPException(status_code=400, detail={"error": "invalid image reference"}) def _json_image_sources(body: dict[str, Any]) -> list[ImageSource]: """读取 JSON 图片引用:优先支持官方 images 数组字段。""" sources: list[ImageSource] = [] for key in ("images", "image", "image_url"): if key in body: sources.extend(_sources_from_value(body.get(key))) return sources async def parse_image_edit_request(request: Request) -> tuple[dict[str, Any], list[ImageSource]]: """解析图片编辑请求:同时支持 multipart 上传和官方 JSON 图片 URL。""" content_type = request.headers.get("content-type", "").split(";", 1)[0].strip().lower() if content_type == "application/json": try: body = await request.json() except json.JSONDecodeError as exc: raise HTTPException(status_code=400, detail={"error": "invalid JSON body"}) from exc if not isinstance(body, dict): raise HTTPException(status_code=400, detail={"error": "JSON body must be an object"}) return _payload_from_fields(body), _json_image_sources(body) form = await request.form() fields: dict[str, Any] = {} for key in ("client_task_id", "prompt", "model", "n", "size", "response_format", "stream"): value = form.get(key) if isinstance(value, str): fields[key] = value sources: list[ImageSource] = [] for key, value in form.multi_items(): if key in IMAGE_REFERENCE_FIELDS: sources.extend(_sources_from_value(value)) return _payload_from_fields(fields), sources def _extension_from_mime(mime_type: str) -> str: """推导图片扩展名:把 MIME 类型转换为常见文件后缀。""" subtype = mime_type.split("/", 1)[1].split("+", 1)[0] if "/" in mime_type else "png" if subtype == "jpeg": return "jpg" return re.sub(r"[^a-z0-9]+", "", subtype.lower()) or "png" def _safe_filename(name: str, mime_type: str, fallback: str) -> str: """生成安全文件名:清理 URL 文件名并补齐扩展名。""" cleaned = re.sub(r"[^A-Za-z0-9._-]+", "_", name).strip("._") if not cleaned: cleaned = fallback if "." not in cleaned: cleaned = f"{cleaned}.{_extension_from_mime(mime_type)}" return cleaned def _decode_data_url(url: str) -> ImageInput: """解码 data URL:把内联图片转成标准图片输入元组。""" header, separator, payload = url.partition(",") if not separator: raise HTTPException(status_code=400, detail={"error": "invalid data image URL"}) mime_type = header.split(";", 1)[0].removeprefix("data:") or "image/png" if not mime_type.startswith("image/"): raise HTTPException(status_code=400, detail={"error": "image_url must point to an image"}) try: data = base64.b64decode(payload, validate=True) if ";base64" in header else unquote_to_bytes(payload) except (binascii.Error, ValueError) as exc: raise HTTPException(status_code=400, detail={"error": "invalid data image URL"}) from exc if not data: raise HTTPException(status_code=400, detail={"error": "image URL is empty"}) if len(data) > MAX_IMAGE_REFERENCE_BYTES: raise HTTPException(status_code=400, detail={"error": "image URL exceeds 50MB limit"}) return data, f"image_url.{_extension_from_mime(mime_type)}", mime_type def _response_mime_type(response: requests.Response, parsed_path: str) -> str: """识别下载图片类型:优先响应头,必要时按 URL 后缀推断。""" header_type = str(response.headers.get("content-type") or "").split(";", 1)[0].strip().lower() guessed_type = mimetypes.guess_type(parsed_path)[0] or "" if header_type.startswith("image/"): return header_type if header_type and header_type not in {"application/octet-stream", "binary/octet-stream"}: raise HTTPException(status_code=400, detail={"error": "image_url must point to an image"}) if guessed_type.startswith("image/"): return guessed_type if not header_type or header_type in {"application/octet-stream", "binary/octet-stream"}: return "image/png" raise HTTPException(status_code=400, detail={"error": "image_url must point to an image"}) def _filename_from_url(parsed_path: str, mime_type: str) -> str: """生成 URL 图片文件名:从链接路径提取名称并做安全化。""" raw_name = PurePosixPath(unquote(parsed_path)).name return _safe_filename(raw_name, mime_type, "image_url") def _download_image_url(url: str) -> ImageInput: """下载远程图片:把 http/https 图片链接转成标准图片输入元组。""" source = _clean(url) if source.startswith("data:"): return _decode_data_url(source) parsed = urlparse(source) if parsed.scheme not in {"http", "https"} or not parsed.netloc: raise HTTPException(status_code=400, detail={"error": "image_url must be an http or https URL"}) try: response = requests.get( source, headers={"Accept": "image/*,*/*;q=0.8", "User-Agent": "chatgpt2api image fetcher"}, timeout=60, allow_redirects=True, **proxy_settings.build_session_kwargs(), ) except Exception as exc: raise HTTPException(status_code=400, detail={"error": f"image_url fetch failed: {exc}"}) from exc if not 200 <= response.status_code < 300: raise HTTPException(status_code=400, detail={"error": f"image_url fetch failed: HTTP {response.status_code}"}) content_length = _clean(response.headers.get("content-length")) if content_length and content_length.isdigit() and int(content_length) > MAX_IMAGE_REFERENCE_BYTES: raise HTTPException(status_code=400, detail={"error": "image_url exceeds 50MB limit"}) data = response.content if not data: raise HTTPException(status_code=400, detail={"error": "image_url returned empty content"}) if len(data) > MAX_IMAGE_REFERENCE_BYTES: raise HTTPException(status_code=400, detail={"error": "image_url exceeds 50MB limit"}) mime_type = _response_mime_type(response, parsed.path) return data, _filename_from_url(parsed.path, mime_type), mime_type async def read_image_sources(sources: list[ImageSource]) -> list[ImageInput]: """读取图片来源:上传文件直接读取,URL 下载后统一返回图片元组。""" images: list[ImageInput] = [] for source in sources: if isinstance(source, tuple): images.append(source) continue if _is_upload(source): try: image_data = await source.read() finally: await source.close() if not image_data: raise HTTPException(status_code=400, detail={"error": "image file is empty"}) images.append((image_data, source.filename or "image.png", source.content_type or "image/png")) continue images.append(await run_in_threadpool(_download_image_url, source)) if not images: raise HTTPException(status_code=400, detail={"error": "image file or image_url is required"}) return images