"""图片输入解析与下载。""" from __future__ import annotations import asyncio import base64 import imghdr import mimetypes import urllib.parse import urllib.request from dataclasses import dataclass SUPPORTED_IMAGE_MIME_TYPES = { "image/png", "image/jpeg", "image/webp", "image/gif", } MAX_IMAGE_BYTES = 10 * 1024 * 1024 MAX_IMAGE_COUNT = 5 @dataclass class PreparedImage: filename: str mime_type: str data: bytes def _validate_image_bytes(data: bytes, mime_type: str) -> None: if mime_type not in SUPPORTED_IMAGE_MIME_TYPES: raise ValueError(f"暂不支持的图片类型: {mime_type}") if len(data) > MAX_IMAGE_BYTES: raise ValueError("单张图片不能超过 10MB") def _default_filename(mime_type: str, *, prefix: str = "image") -> str: ext = mimetypes.guess_extension(mime_type) or ".bin" if ext == ".jpe": ext = ".jpg" return f"{prefix}{ext}" def parse_data_url(url: str, *, prefix: str = "image") -> PreparedImage: if not url.startswith("data:") or ";base64," not in url: raise ValueError("仅支持 data:image/...;base64,... 格式") header, payload = url.split(",", 1) mime_type = header[5:].split(";", 1)[0].strip().lower() data = base64.b64decode(payload, validate=True) _validate_image_bytes(data, mime_type) return PreparedImage( filename=_default_filename(mime_type, prefix=prefix), mime_type=mime_type, data=data, ) def parse_base64_image( data_b64: str, mime_type: str, *, prefix: str = "image", ) -> PreparedImage: mime = mime_type.strip().lower() data = base64.b64decode(data_b64, validate=True) _validate_image_bytes(data, mime) return PreparedImage( filename=_default_filename(mime, prefix=prefix), mime_type=mime, data=data, ) def _sniff_mime_type(data: bytes, url: str) -> str: kind = imghdr.what(None, data) if kind == "jpeg": return "image/jpeg" if kind in {"png", "gif", "webp"}: return f"image/{kind}" guessed, _ = mimetypes.guess_type(url) return (guessed or "application/octet-stream").lower() def _download_remote_image_sync(url: str, *, prefix: str = "image") -> PreparedImage: parsed = urllib.parse.urlparse(url) if parsed.scheme not in {"http", "https"}: raise ValueError("image_url 仅支持 http/https 或 data URL") req = urllib.request.Request( url, headers={"User-Agent": "web2api/1.0", "Accept": "image/*"}, ) with urllib.request.urlopen(req, timeout=20) as resp: data = resp.read(MAX_IMAGE_BYTES + 1) mime_type = str(resp.headers.get_content_type() or "").lower() if not mime_type or mime_type == "application/octet-stream": mime_type = _sniff_mime_type(data, url) _validate_image_bytes(data, mime_type) filename = urllib.parse.unquote( parsed.path.rsplit("/", 1)[-1] ) or _default_filename(mime_type, prefix=prefix) if "." not in filename: filename = _default_filename(mime_type, prefix=prefix) return PreparedImage(filename=filename, mime_type=mime_type, data=data) async def download_remote_image(url: str, *, prefix: str = "image") -> PreparedImage: return await asyncio.to_thread(_download_remote_image_sync, url, prefix=prefix)