Spaces:
Sleeping
Sleeping
File size: 14,574 Bytes
6db48b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 |
"""Image API 图片生成器"""
import logging
import time
import random
import base64
import requests
from typing import Dict, Any, Optional, List, Union
from .base import ImageGeneratorBase
from ..utils.image_compressor import compress_image
logger = logging.getLogger(__name__)
def retry_on_error(max_retries: int = 3, base_delay: float = 2):
"""错误重试装饰器"""
def decorator(func):
def wrapper(*args, **kwargs):
last_error = None
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except Exception as e:
last_error = e
if attempt < max_retries - 1:
delay = base_delay * (2 ** attempt) + random.uniform(0, 1)
logger.warning(f"请求失败,{delay:.1f}秒后重试 (尝试 {attempt + 2}/{max_retries}): {str(e)[:100]}")
time.sleep(delay)
raise last_error
return wrapper
return decorator
class ImageApiGenerator(ImageGeneratorBase):
"""Image API 生成器"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
logger.debug("初始化 ImageApiGenerator...")
self.base_url = config.get('base_url', 'https://api.example.com').rstrip('/').rstrip('/v1')
self.model = config.get('model', 'default-model')
self.default_aspect_ratio = config.get('default_aspect_ratio', '3:4')
self.image_size = config.get('image_size', '4K')
# 支持自定义端点路径
endpoint_type = config.get('endpoint_type', '/v1/images/generations')
# 兼容旧的简写格式
if endpoint_type == 'images':
endpoint_type = '/v1/images/generations'
elif endpoint_type == 'chat':
endpoint_type = '/v1/chat/completions'
# 确保以 / 开头
if not endpoint_type.startswith('/'):
endpoint_type = '/' + endpoint_type
self.endpoint_type = endpoint_type
logger.info(f"ImageApiGenerator 初始化完成: base_url={self.base_url}, model={self.model}, endpoint={self.endpoint_type}")
def validate_config(self) -> bool:
"""验证配置是否有效"""
if not self.api_key:
logger.error("Image API Key 未配置")
raise ValueError(
"Image API Key 未配置。\n"
"解决方案:在系统设置页面编辑该服务商,填写 API Key"
)
return True
def get_supported_sizes(self) -> List[str]:
"""获取支持的图片尺寸"""
return ["1K", "2K", "4K"]
def get_supported_aspect_ratios(self) -> List[str]:
"""获取支持的宽高比"""
return ["1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"]
@retry_on_error(max_retries=3, base_delay=2)
def generate_image(
self,
prompt: str,
aspect_ratio: str = None,
temperature: float = 1.0,
model: str = None,
reference_image: Optional[bytes] = None,
reference_images: Optional[List[bytes]] = None,
**kwargs
) -> bytes:
"""
生成图片
Args:
prompt: 图片描述
aspect_ratio: 宽高比
temperature: 创意度(未使用,保留接口兼容)
model: 模型名称
reference_image: 单张参考图片数据(向后兼容)
reference_images: 多张参考图片数据列表
Returns:
生成的图片二进制数据
"""
self.validate_config()
if aspect_ratio is None:
aspect_ratio = self.default_aspect_ratio
if model is None:
model = self.model
logger.info(f"Image API 生成图片: model={model}, aspect_ratio={aspect_ratio}, endpoint={self.endpoint_type}")
# 根据端点类型选择不同的生成方式
if 'chat' in self.endpoint_type or 'completions' in self.endpoint_type:
return self._generate_via_chat_api(prompt, aspect_ratio, model, reference_image, reference_images)
else:
return self._generate_via_images_api(prompt, aspect_ratio, model, reference_image, reference_images)
def _generate_via_images_api(
self,
prompt: str,
aspect_ratio: str,
model: str,
reference_image: Optional[bytes] = None,
reference_images: Optional[List[bytes]] = None
) -> bytes:
"""通过 /v1/images/generations 端点生成图片"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model,
"prompt": prompt,
"response_format": "b64_json",
"aspect_ratio": aspect_ratio,
"image_size": self.image_size
}
# 收集所有参考图片
all_reference_images = []
if reference_images and len(reference_images) > 0:
all_reference_images.extend(reference_images)
if reference_image and reference_image not in all_reference_images:
all_reference_images.append(reference_image)
# 如果有参考图片,添加到 image 数组
if all_reference_images:
logger.debug(f" 添加 {len(all_reference_images)} 张参考图片")
image_uris = []
for idx, img_data in enumerate(all_reference_images):
compressed_img = compress_image(img_data, max_size_kb=200)
logger.debug(f" 参考图 {idx}: {len(img_data)} -> {len(compressed_img)} bytes")
base64_image = base64.b64encode(compressed_img).decode('utf-8')
data_uri = f"data:image/png;base64,{base64_image}"
image_uris.append(data_uri)
payload["image"] = image_uris
ref_count = len(all_reference_images)
enhanced_prompt = f"""参考提供的 {ref_count} 张图片的风格(色彩、光影、构图、氛围),生成一张新图片。
新图片内容:{prompt}
要求:
1. 保持相似的色调和氛围
2. 使用相似的光影处理
3. 保持一致的画面质感
4. 如果参考图中有人物或产品,可以适当融入"""
payload["prompt"] = enhanced_prompt
api_url = f"{self.base_url}{self.endpoint_type}"
logger.debug(f" 发送请求到: {api_url}")
response = requests.post(api_url, headers=headers, json=payload, timeout=300)
if response.status_code != 200:
error_detail = response.text[:500]
logger.error(f"Image API 请求失败: status={response.status_code}, error={error_detail}")
raise Exception(
f"Image API 请求失败 (状态码: {response.status_code})\n"
f"错误详情: {error_detail}\n"
f"请求地址: {api_url}\n"
"可能原因:\n"
"1. API密钥无效或已过期\n"
"2. 请求参数不符合API要求\n"
"3. API服务端错误\n"
"4. Base URL配置错误\n"
"建议:检查API密钥和base_url配置"
)
result = response.json()
logger.debug(f" API 响应: data 长度={len(result.get('data', []))}")
if "data" in result and len(result["data"]) > 0:
item = result["data"][0]
if "b64_json" in item:
b64_data_uri = item["b64_json"]
if b64_data_uri.startswith('data:'):
b64_string = b64_data_uri.split(',', 1)[1]
else:
b64_string = b64_data_uri
image_data = base64.b64decode(b64_string)
logger.info(f"✅ Image API 图片生成成功: {len(image_data)} bytes")
return image_data
logger.error(f"无法从响应中提取图片数据: {str(result)[:200]}")
raise Exception(
f"图片数据提取失败:未找到 b64_json 数据。\n"
f"API响应片段: {str(result)[:500]}\n"
"可能原因:\n"
"1. API返回格式与预期不符\n"
"2. response_format 参数未生效\n"
"3. 该模型不支持 b64_json 格式\n"
"建议:检查API文档确认返回格式要求"
)
def _generate_via_chat_api(
self,
prompt: str,
aspect_ratio: str,
model: str,
reference_image: Optional[bytes] = None,
reference_images: Optional[List[bytes]] = None
) -> bytes:
"""通过 /v1/chat/completions 端点生成图片(如即梦 API)"""
import re
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建用户消息内容
user_content: Any = prompt
# 收集所有参考图片
all_reference_images = []
if reference_images and len(reference_images) > 0:
all_reference_images.extend(reference_images)
if reference_image and reference_image not in all_reference_images:
all_reference_images.append(reference_image)
# 如果有参考图片,构建多模态消息
if all_reference_images:
logger.debug(f" 添加 {len(all_reference_images)} 张参考图片到 chat 消息")
content_parts = [{"type": "text", "text": prompt}]
for idx, img_data in enumerate(all_reference_images):
compressed_img = compress_image(img_data, max_size_kb=200)
logger.debug(f" 参考图 {idx}: {len(img_data)} -> {len(compressed_img)} bytes")
base64_image = base64.b64encode(compressed_img).decode('utf-8')
content_parts.append({
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{base64_image}"}
})
user_content = content_parts
payload = {
"model": model,
"messages": [{"role": "user", "content": user_content}],
"max_tokens": 4096,
"temperature": 1.0
}
api_url = f"{self.base_url}{self.endpoint_type}"
logger.info(f"Chat API 生成图片: {api_url}, model={model}")
response = requests.post(api_url, headers=headers, json=payload, timeout=300)
if response.status_code != 200:
error_detail = response.text[:500]
status_code = response.status_code
if status_code == 401:
raise Exception(
"❌ API Key 认证失败\n\n"
"【可能原因】\n"
"1. API Key 无效或已过期\n"
"2. API Key 格式错误\n\n"
"【解决方案】\n"
"在系统设置页面检查 API Key 是否正确"
)
elif status_code == 429:
raise Exception(
"⏳ API 配额或速率限制\n\n"
"【解决方案】\n"
"1. 稍后再试\n"
"2. 检查 API 配额使用情况"
)
else:
raise Exception(
f"❌ Chat API 请求失败 (状态码: {status_code})\n\n"
f"【错误详情】\n{error_detail[:300]}\n\n"
f"【请求地址】{api_url}\n"
f"【模型】{model}"
)
result = response.json()
logger.debug(f"Chat API 响应: {str(result)[:500]}")
# 解析响应
if "choices" in result and len(result["choices"]) > 0:
choice = result["choices"][0]
if "message" in choice and "content" in choice["message"]:
content = choice["message"]["content"]
if isinstance(content, str):
# Markdown 图片链接: 
pattern = r'!\[.*?\]\((https?://[^\s\)]+)\)'
urls = re.findall(pattern, content)
if urls:
logger.info(f"从 Markdown 提取到 {len(urls)} 张图片,下载第一张...")
return self._download_image(urls[0])
# Markdown 图片 Base64: 
base64_pattern = r'!\[.*?\]\((data:image\/[^;]+;base64,[^\s\)]+)\)'
base64_urls = re.findall(base64_pattern, content)
if base64_urls:
logger.info("从 Markdown 提取到 Base64 图片数据")
base64_data = base64_urls[0].split(",")[1]
return base64.b64decode(base64_data)
# 纯 Base64 data URL
if content.startswith("data:image"):
logger.info("检测到 Base64 图片数据")
base64_data = content.split(",")[1]
return base64.b64decode(base64_data)
# 纯 URL
if content.startswith("http://") or content.startswith("https://"):
logger.info("检测到图片 URL")
return self._download_image(content.strip())
raise Exception(
"❌ 无法从 Chat API 响应中提取图片数据\n\n"
f"【响应内容】\n{str(result)[:500]}\n\n"
"【可能原因】\n"
"1. 该模型不支持图片生成\n"
"2. 响应格式与预期不符\n"
"3. 提示词被安全过滤\n\n"
"【解决方案】\n"
"1. 确认模型名称正确\n"
"2. 修改提示词后重试"
)
def _download_image(self, url: str) -> bytes:
"""下载图片并返回二进制数据"""
logger.info(f"下载图片: {url[:100]}...")
try:
response = requests.get(url, timeout=60)
if response.status_code == 200:
logger.info(f"✅ 图片下载成功: {len(response.content)} bytes")
return response.content
else:
raise Exception(f"下载图片失败: HTTP {response.status_code}")
except requests.exceptions.Timeout:
raise Exception("❌ 下载图片超时,请重试")
except Exception as e:
raise Exception(f"❌ 下载图片失败: {str(e)}")
|