RedInk / backend /generators /image_api.py
m19921414377's picture
Upload folder using huggingface_hub
6db48b4 verified
"""Image API 图片生成器"""
import logging
import time
import random
import base64
import requests
from typing import Dict, Any, Optional, List, Union
from .base import ImageGeneratorBase
from ..utils.image_compressor import compress_image
logger = logging.getLogger(__name__)
def retry_on_error(max_retries: int = 3, base_delay: float = 2):
"""错误重试装饰器"""
def decorator(func):
def wrapper(*args, **kwargs):
last_error = None
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except Exception as e:
last_error = e
if attempt < max_retries - 1:
delay = base_delay * (2 ** attempt) + random.uniform(0, 1)
logger.warning(f"请求失败,{delay:.1f}秒后重试 (尝试 {attempt + 2}/{max_retries}): {str(e)[:100]}")
time.sleep(delay)
raise last_error
return wrapper
return decorator
class ImageApiGenerator(ImageGeneratorBase):
"""Image API 生成器"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
logger.debug("初始化 ImageApiGenerator...")
self.base_url = config.get('base_url', 'https://api.example.com').rstrip('/').rstrip('/v1')
self.model = config.get('model', 'default-model')
self.default_aspect_ratio = config.get('default_aspect_ratio', '3:4')
self.image_size = config.get('image_size', '4K')
# 支持自定义端点路径
endpoint_type = config.get('endpoint_type', '/v1/images/generations')
# 兼容旧的简写格式
if endpoint_type == 'images':
endpoint_type = '/v1/images/generations'
elif endpoint_type == 'chat':
endpoint_type = '/v1/chat/completions'
# 确保以 / 开头
if not endpoint_type.startswith('/'):
endpoint_type = '/' + endpoint_type
self.endpoint_type = endpoint_type
logger.info(f"ImageApiGenerator 初始化完成: base_url={self.base_url}, model={self.model}, endpoint={self.endpoint_type}")
def validate_config(self) -> bool:
"""验证配置是否有效"""
if not self.api_key:
logger.error("Image API Key 未配置")
raise ValueError(
"Image API Key 未配置。\n"
"解决方案:在系统设置页面编辑该服务商,填写 API Key"
)
return True
def get_supported_sizes(self) -> List[str]:
"""获取支持的图片尺寸"""
return ["1K", "2K", "4K"]
def get_supported_aspect_ratios(self) -> List[str]:
"""获取支持的宽高比"""
return ["1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"]
@retry_on_error(max_retries=3, base_delay=2)
def generate_image(
self,
prompt: str,
aspect_ratio: str = None,
temperature: float = 1.0,
model: str = None,
reference_image: Optional[bytes] = None,
reference_images: Optional[List[bytes]] = None,
**kwargs
) -> bytes:
"""
生成图片
Args:
prompt: 图片描述
aspect_ratio: 宽高比
temperature: 创意度(未使用,保留接口兼容)
model: 模型名称
reference_image: 单张参考图片数据(向后兼容)
reference_images: 多张参考图片数据列表
Returns:
生成的图片二进制数据
"""
self.validate_config()
if aspect_ratio is None:
aspect_ratio = self.default_aspect_ratio
if model is None:
model = self.model
logger.info(f"Image API 生成图片: model={model}, aspect_ratio={aspect_ratio}, endpoint={self.endpoint_type}")
# 根据端点类型选择不同的生成方式
if 'chat' in self.endpoint_type or 'completions' in self.endpoint_type:
return self._generate_via_chat_api(prompt, aspect_ratio, model, reference_image, reference_images)
else:
return self._generate_via_images_api(prompt, aspect_ratio, model, reference_image, reference_images)
def _generate_via_images_api(
self,
prompt: str,
aspect_ratio: str,
model: str,
reference_image: Optional[bytes] = None,
reference_images: Optional[List[bytes]] = None
) -> bytes:
"""通过 /v1/images/generations 端点生成图片"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model,
"prompt": prompt,
"response_format": "b64_json",
"aspect_ratio": aspect_ratio,
"image_size": self.image_size
}
# 收集所有参考图片
all_reference_images = []
if reference_images and len(reference_images) > 0:
all_reference_images.extend(reference_images)
if reference_image and reference_image not in all_reference_images:
all_reference_images.append(reference_image)
# 如果有参考图片,添加到 image 数组
if all_reference_images:
logger.debug(f" 添加 {len(all_reference_images)} 张参考图片")
image_uris = []
for idx, img_data in enumerate(all_reference_images):
compressed_img = compress_image(img_data, max_size_kb=200)
logger.debug(f" 参考图 {idx}: {len(img_data)} -> {len(compressed_img)} bytes")
base64_image = base64.b64encode(compressed_img).decode('utf-8')
data_uri = f"data:image/png;base64,{base64_image}"
image_uris.append(data_uri)
payload["image"] = image_uris
ref_count = len(all_reference_images)
enhanced_prompt = f"""参考提供的 {ref_count} 张图片的风格(色彩、光影、构图、氛围),生成一张新图片。
新图片内容:{prompt}
要求:
1. 保持相似的色调和氛围
2. 使用相似的光影处理
3. 保持一致的画面质感
4. 如果参考图中有人物或产品,可以适当融入"""
payload["prompt"] = enhanced_prompt
api_url = f"{self.base_url}{self.endpoint_type}"
logger.debug(f" 发送请求到: {api_url}")
response = requests.post(api_url, headers=headers, json=payload, timeout=300)
if response.status_code != 200:
error_detail = response.text[:500]
logger.error(f"Image API 请求失败: status={response.status_code}, error={error_detail}")
raise Exception(
f"Image API 请求失败 (状态码: {response.status_code})\n"
f"错误详情: {error_detail}\n"
f"请求地址: {api_url}\n"
"可能原因:\n"
"1. API密钥无效或已过期\n"
"2. 请求参数不符合API要求\n"
"3. API服务端错误\n"
"4. Base URL配置错误\n"
"建议:检查API密钥和base_url配置"
)
result = response.json()
logger.debug(f" API 响应: data 长度={len(result.get('data', []))}")
if "data" in result and len(result["data"]) > 0:
item = result["data"][0]
if "b64_json" in item:
b64_data_uri = item["b64_json"]
if b64_data_uri.startswith('data:'):
b64_string = b64_data_uri.split(',', 1)[1]
else:
b64_string = b64_data_uri
image_data = base64.b64decode(b64_string)
logger.info(f"✅ Image API 图片生成成功: {len(image_data)} bytes")
return image_data
logger.error(f"无法从响应中提取图片数据: {str(result)[:200]}")
raise Exception(
f"图片数据提取失败:未找到 b64_json 数据。\n"
f"API响应片段: {str(result)[:500]}\n"
"可能原因:\n"
"1. API返回格式与预期不符\n"
"2. response_format 参数未生效\n"
"3. 该模型不支持 b64_json 格式\n"
"建议:检查API文档确认返回格式要求"
)
def _generate_via_chat_api(
self,
prompt: str,
aspect_ratio: str,
model: str,
reference_image: Optional[bytes] = None,
reference_images: Optional[List[bytes]] = None
) -> bytes:
"""通过 /v1/chat/completions 端点生成图片(如即梦 API)"""
import re
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建用户消息内容
user_content: Any = prompt
# 收集所有参考图片
all_reference_images = []
if reference_images and len(reference_images) > 0:
all_reference_images.extend(reference_images)
if reference_image and reference_image not in all_reference_images:
all_reference_images.append(reference_image)
# 如果有参考图片,构建多模态消息
if all_reference_images:
logger.debug(f" 添加 {len(all_reference_images)} 张参考图片到 chat 消息")
content_parts = [{"type": "text", "text": prompt}]
for idx, img_data in enumerate(all_reference_images):
compressed_img = compress_image(img_data, max_size_kb=200)
logger.debug(f" 参考图 {idx}: {len(img_data)} -> {len(compressed_img)} bytes")
base64_image = base64.b64encode(compressed_img).decode('utf-8')
content_parts.append({
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{base64_image}"}
})
user_content = content_parts
payload = {
"model": model,
"messages": [{"role": "user", "content": user_content}],
"max_tokens": 4096,
"temperature": 1.0
}
api_url = f"{self.base_url}{self.endpoint_type}"
logger.info(f"Chat API 生成图片: {api_url}, model={model}")
response = requests.post(api_url, headers=headers, json=payload, timeout=300)
if response.status_code != 200:
error_detail = response.text[:500]
status_code = response.status_code
if status_code == 401:
raise Exception(
"❌ API Key 认证失败\n\n"
"【可能原因】\n"
"1. API Key 无效或已过期\n"
"2. API Key 格式错误\n\n"
"【解决方案】\n"
"在系统设置页面检查 API Key 是否正确"
)
elif status_code == 429:
raise Exception(
"⏳ API 配额或速率限制\n\n"
"【解决方案】\n"
"1. 稍后再试\n"
"2. 检查 API 配额使用情况"
)
else:
raise Exception(
f"❌ Chat API 请求失败 (状态码: {status_code})\n\n"
f"【错误详情】\n{error_detail[:300]}\n\n"
f"【请求地址】{api_url}\n"
f"【模型】{model}"
)
result = response.json()
logger.debug(f"Chat API 响应: {str(result)[:500]}")
# 解析响应
if "choices" in result and len(result["choices"]) > 0:
choice = result["choices"][0]
if "message" in choice and "content" in choice["message"]:
content = choice["message"]["content"]
if isinstance(content, str):
# Markdown 图片链接: ![xxx](url)
pattern = r'!\[.*?\]\((https?://[^\s\)]+)\)'
urls = re.findall(pattern, content)
if urls:
logger.info(f"从 Markdown 提取到 {len(urls)} 张图片,下载第一张...")
return self._download_image(urls[0])
# Markdown 图片 Base64: ![xxx](data:image/...)
base64_pattern = r'!\[.*?\]\((data:image\/[^;]+;base64,[^\s\)]+)\)'
base64_urls = re.findall(base64_pattern, content)
if base64_urls:
logger.info("从 Markdown 提取到 Base64 图片数据")
base64_data = base64_urls[0].split(",")[1]
return base64.b64decode(base64_data)
# 纯 Base64 data URL
if content.startswith("data:image"):
logger.info("检测到 Base64 图片数据")
base64_data = content.split(",")[1]
return base64.b64decode(base64_data)
# 纯 URL
if content.startswith("http://") or content.startswith("https://"):
logger.info("检测到图片 URL")
return self._download_image(content.strip())
raise Exception(
"❌ 无法从 Chat API 响应中提取图片数据\n\n"
f"【响应内容】\n{str(result)[:500]}\n\n"
"【可能原因】\n"
"1. 该模型不支持图片生成\n"
"2. 响应格式与预期不符\n"
"3. 提示词被安全过滤\n\n"
"【解决方案】\n"
"1. 确认模型名称正确\n"
"2. 修改提示词后重试"
)
def _download_image(self, url: str) -> bytes:
"""下载图片并返回二进制数据"""
logger.info(f"下载图片: {url[:100]}...")
try:
response = requests.get(url, timeout=60)
if response.status_code == 200:
logger.info(f"✅ 图片下载成功: {len(response.content)} bytes")
return response.content
else:
raise Exception(f"下载图片失败: HTTP {response.status_code}")
except requests.exceptions.Timeout:
raise Exception("❌ 下载图片超时,请重试")
except Exception as e:
raise Exception(f"❌ 下载图片失败: {str(e)}")