m19921414377's picture
Upload folder using huggingface_hub
6db48b4 verified
"""图片生成服务"""
import logging
import os
import uuid
import time
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, Any, Generator, List, Optional, Tuple
from backend.config import Config
from backend.generators.factory import ImageGeneratorFactory
from backend.utils.image_compressor import compress_image
logger = logging.getLogger(__name__)
class ImageService:
"""图片生成服务类"""
# 并发配置
MAX_CONCURRENT = 15 # 最大并发数
AUTO_RETRY_COUNT = 3 # 自动重试次数
def __init__(self, provider_name: str = None):
"""
初始化图片生成服务
Args:
provider_name: 服务商名称,如果为None则使用配置文件中的激活服务商
"""
logger.debug("初始化 ImageService...")
# 获取服务商配置
if provider_name is None:
provider_name = Config.get_active_image_provider()
logger.info(f"使用图片服务商: {provider_name}")
provider_config = Config.get_image_provider_config(provider_name)
# 创建生成器实例
provider_type = provider_config.get('type', provider_name)
logger.debug(f"创建生成器: type={provider_type}")
self.generator = ImageGeneratorFactory.create(provider_type, provider_config)
# 保存配置信息
self.provider_name = provider_name
self.provider_config = provider_config
# 检查是否启用短 prompt 模式
self.use_short_prompt = provider_config.get('short_prompt', False)
# 加载提示词模板
self.prompt_template = self._load_prompt_template()
self.prompt_template_short = self._load_prompt_template(short=True)
# 历史记录根目录
self.history_root_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"history"
)
os.makedirs(self.history_root_dir, exist_ok=True)
# 当前任务的输出目录(每个任务一个子文件夹)
self.current_task_dir = None
# 存储任务状态(用于重试)
self._task_states: Dict[str, Dict] = {}
logger.info(f"ImageService 初始化完成: provider={provider_name}, type={provider_type}")
def _load_prompt_template(self, short: bool = False) -> str:
"""加载 Prompt 模板"""
filename = "image_prompt_short.txt" if short else "image_prompt.txt"
prompt_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"prompts",
filename
)
if not os.path.exists(prompt_path):
# 如果短模板不存在,返回空字符串
return ""
with open(prompt_path, "r", encoding="utf-8") as f:
return f.read()
def _save_image(self, image_data: bytes, filename: str, task_dir: str = None) -> str:
"""
保存图片到本地,同时生成缩略图
Args:
image_data: 图片二进制数据
filename: 文件名
task_dir: 任务目录(如果为None则使用当前任务目录)
Returns:
保存的文件路径
"""
if task_dir is None:
task_dir = self.current_task_dir
if task_dir is None:
raise ValueError("任务目录未设置")
# 保存原图
filepath = os.path.join(task_dir, filename)
with open(filepath, "wb") as f:
f.write(image_data)
# 生成缩略图(50KB左右)
thumbnail_data = compress_image(image_data, max_size_kb=50)
thumbnail_filename = f"thumb_{filename}"
thumbnail_path = os.path.join(task_dir, thumbnail_filename)
with open(thumbnail_path, "wb") as f:
f.write(thumbnail_data)
return filepath
def _generate_single_image(
self,
page: Dict,
task_id: str,
reference_image: Optional[bytes] = None,
retry_count: int = 0,
full_outline: str = "",
user_images: Optional[List[bytes]] = None,
user_topic: str = ""
) -> Tuple[int, bool, Optional[str], Optional[str]]:
"""
生成单张图片(带自动重试)
Args:
page: 页面数据
task_id: 任务ID
reference_image: 参考图片(封面图)
retry_count: 当前重试次数
full_outline: 完整的大纲文本
user_images: 用户上传的参考图片列表
user_topic: 用户原始输入
Returns:
(index, success, filename, error_message)
"""
index = page["index"]
page_type = page["type"]
page_content = page["content"]
max_retries = self.AUTO_RETRY_COUNT
for attempt in range(max_retries):
try:
logger.debug(f"生成图片 [{index}]: type={page_type}, attempt={attempt + 1}/{max_retries}")
# 根据配置选择模板(短 prompt 或完整 prompt)
if self.use_short_prompt and self.prompt_template_short:
# 短 prompt 模式:只包含页面类型和内容
prompt = self.prompt_template_short.format(
page_content=page_content,
page_type=page_type
)
logger.debug(f" 使用短 prompt 模式 ({len(prompt)} 字符)")
else:
# 完整 prompt 模式:包含大纲和用户需求
prompt = self.prompt_template.format(
page_content=page_content,
page_type=page_type,
full_outline=full_outline,
user_topic=user_topic if user_topic else "未提供"
)
# 调用生成器生成图片
if self.provider_config.get('type') == 'google_genai':
logger.debug(f" 使用 Google GenAI 生成器")
image_data = self.generator.generate_image(
prompt=prompt,
aspect_ratio=self.provider_config.get('default_aspect_ratio', '3:4'),
temperature=self.provider_config.get('temperature', 1.0),
model=self.provider_config.get('model', 'gemini-3-pro-image-preview'),
reference_image=reference_image,
)
elif self.provider_config.get('type') == 'image_api':
logger.debug(f" 使用 Image API 生成器")
# Image API 支持多张参考图片
# 组合参考图片:用户上传的图片 + 封面图
reference_images = []
if user_images:
reference_images.extend(user_images)
if reference_image:
reference_images.append(reference_image)
image_data = self.generator.generate_image(
prompt=prompt,
aspect_ratio=self.provider_config.get('default_aspect_ratio', '3:4'),
temperature=self.provider_config.get('temperature', 1.0),
model=self.provider_config.get('model', 'nano-banana-2'),
reference_images=reference_images if reference_images else None,
)
else:
logger.debug(f" 使用 OpenAI 兼容生成器")
image_data = self.generator.generate_image(
prompt=prompt,
size=self.provider_config.get('default_size', '1024x1024'),
model=self.provider_config.get('model'),
quality=self.provider_config.get('quality', 'standard'),
)
# 保存图片(使用当前任务目录)
filename = f"{index}.png"
self._save_image(image_data, filename, self.current_task_dir)
logger.info(f"✅ 图片 [{index}] 生成成功: {filename}")
return (index, True, filename, None)
except Exception as e:
error_msg = str(e)
logger.warning(f"图片 [{index}] 生成失败 (尝试 {attempt + 1}/{max_retries}): {error_msg[:200]}")
if attempt < max_retries - 1:
# 等待后重试
wait_time = 2 ** attempt
logger.debug(f" 等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
continue
logger.error(f"❌ 图片 [{index}] 生成失败,已达最大重试次数")
return (index, False, None, error_msg)
return (index, False, None, "超过最大重试次数")
def generate_images(
self,
pages: list,
task_id: str = None,
full_outline: str = "",
user_images: Optional[List[bytes]] = None,
user_topic: str = ""
) -> Generator[Dict[str, Any], None, None]:
"""
生成图片(生成器,支持 SSE 流式返回)
优化版本:先生成封面,然后并发生成其他页面
Args:
pages: 页面列表
task_id: 任务 ID(可选)
full_outline: 完整的大纲文本(用于保持风格一致)
user_images: 用户上传的参考图片列表(可选)
user_topic: 用户原始输入(用于保持意图一致)
Yields:
进度事件字典
"""
if task_id is None:
task_id = f"task_{uuid.uuid4().hex[:8]}"
logger.info(f"开始图片生成任务: task_id={task_id}, pages={len(pages)}")
# 创建任务专属目录
self.current_task_dir = os.path.join(self.history_root_dir, task_id)
os.makedirs(self.current_task_dir, exist_ok=True)
logger.debug(f"任务目录: {self.current_task_dir}")
total = len(pages)
generated_images = []
failed_pages = []
cover_image_data = None
# 压缩用户上传的参考图到200KB以内(减少内存和传输开销)
compressed_user_images = None
if user_images:
compressed_user_images = [compress_image(img, max_size_kb=200) for img in user_images]
# 初始化任务状态
self._task_states[task_id] = {
"pages": pages,
"generated": {},
"failed": {},
"cover_image": None,
"full_outline": full_outline,
"user_images": compressed_user_images,
"user_topic": user_topic
}
# ==================== 第一阶段:生成封面 ====================
cover_page = None
other_pages = []
for page in pages:
if page["type"] == "cover":
cover_page = page
else:
other_pages.append(page)
# 如果没有封面,使用第一页作为封面
if cover_page is None and len(pages) > 0:
cover_page = pages[0]
other_pages = pages[1:]
if cover_page:
# 发送封面生成进度
yield {
"event": "progress",
"data": {
"index": cover_page["index"],
"status": "generating",
"message": "正在生成封面...",
"current": 1,
"total": total,
"phase": "cover"
}
}
# 生成封面(使用用户上传的图片作为参考)
index, success, filename, error = self._generate_single_image(
cover_page, task_id, reference_image=None, full_outline=full_outline,
user_images=compressed_user_images, user_topic=user_topic
)
if success:
generated_images.append(filename)
self._task_states[task_id]["generated"][index] = filename
# 读取封面图片作为参考,并立即压缩到200KB以内
cover_path = os.path.join(self.current_task_dir, filename)
with open(cover_path, "rb") as f:
cover_image_data = f.read()
# 压缩封面图(减少内存占用和后续传输开销)
cover_image_data = compress_image(cover_image_data, max_size_kb=200)
self._task_states[task_id]["cover_image"] = cover_image_data
yield {
"event": "complete",
"data": {
"index": index,
"status": "done",
"image_url": f"/api/images/{task_id}/{filename}",
"phase": "cover"
}
}
else:
failed_pages.append(cover_page)
self._task_states[task_id]["failed"][index] = error
yield {
"event": "error",
"data": {
"index": index,
"status": "error",
"message": error,
"retryable": True,
"phase": "cover"
}
}
# ==================== 第二阶段:生成其他页面 ====================
if other_pages:
# 检查是否启用高并发模式
high_concurrency = self.provider_config.get('high_concurrency', False)
if high_concurrency:
# 高并发模式:并行生成
yield {
"event": "progress",
"data": {
"status": "batch_start",
"message": f"开始并发生成 {len(other_pages)} 页内容...",
"current": len(generated_images),
"total": total,
"phase": "content"
}
}
# 使用线程池并发生成
with ThreadPoolExecutor(max_workers=self.MAX_CONCURRENT) as executor:
# 提交所有任务
future_to_page = {
executor.submit(
self._generate_single_image,
page,
task_id,
cover_image_data, # 使用封面作为参考
0, # retry_count
full_outline, # 传入完整大纲
compressed_user_images, # 用户上传的参考图片(已压缩)
user_topic # 用户原始输入
): page
for page in other_pages
}
# 发送每个页面的进度
for page in other_pages:
yield {
"event": "progress",
"data": {
"index": page["index"],
"status": "generating",
"current": len(generated_images) + 1,
"total": total,
"phase": "content"
}
}
# 收集结果
for future in as_completed(future_to_page):
page = future_to_page[future]
try:
index, success, filename, error = future.result()
if success:
generated_images.append(filename)
self._task_states[task_id]["generated"][index] = filename
yield {
"event": "complete",
"data": {
"index": index,
"status": "done",
"image_url": f"/api/images/{task_id}/{filename}",
"phase": "content"
}
}
else:
failed_pages.append(page)
self._task_states[task_id]["failed"][index] = error
yield {
"event": "error",
"data": {
"index": index,
"status": "error",
"message": error,
"retryable": True,
"phase": "content"
}
}
except Exception as e:
failed_pages.append(page)
error_msg = str(e)
self._task_states[task_id]["failed"][page["index"]] = error_msg
yield {
"event": "error",
"data": {
"index": page["index"],
"status": "error",
"message": error_msg,
"retryable": True,
"phase": "content"
}
}
else:
# 顺序模式:逐个生成
yield {
"event": "progress",
"data": {
"status": "batch_start",
"message": f"开始顺序生成 {len(other_pages)} 页内容...",
"current": len(generated_images),
"total": total,
"phase": "content"
}
}
for page in other_pages:
# 发送生成进度
yield {
"event": "progress",
"data": {
"index": page["index"],
"status": "generating",
"current": len(generated_images) + 1,
"total": total,
"phase": "content"
}
}
# 生成单张图片
index, success, filename, error = self._generate_single_image(
page,
task_id,
cover_image_data,
0,
full_outline,
compressed_user_images,
user_topic
)
if success:
generated_images.append(filename)
self._task_states[task_id]["generated"][index] = filename
yield {
"event": "complete",
"data": {
"index": index,
"status": "done",
"image_url": f"/api/images/{task_id}/{filename}",
"phase": "content"
}
}
else:
failed_pages.append(page)
self._task_states[task_id]["failed"][index] = error
yield {
"event": "error",
"data": {
"index": index,
"status": "error",
"message": error,
"retryable": True,
"phase": "content"
}
}
# ==================== 完成 ====================
yield {
"event": "finish",
"data": {
"success": len(failed_pages) == 0,
"task_id": task_id,
"images": generated_images,
"total": total,
"completed": len(generated_images),
"failed": len(failed_pages),
"failed_indices": [p["index"] for p in failed_pages]
}
}
def retry_single_image(
self,
task_id: str,
page: Dict,
use_reference: bool = True,
full_outline: str = "",
user_topic: str = ""
) -> Dict[str, Any]:
"""
重试生成单张图片
Args:
task_id: 任务ID
page: 页面数据
use_reference: 是否使用封面作为参考
full_outline: 完整大纲文本(从前端传入)
user_topic: 用户原始输入(从前端传入)
Returns:
生成结果
"""
self.current_task_dir = os.path.join(self.history_root_dir, task_id)
os.makedirs(self.current_task_dir, exist_ok=True)
reference_image = None
user_images = None
# 首先尝试从任务状态中获取上下文
if task_id in self._task_states:
task_state = self._task_states[task_id]
if use_reference:
reference_image = task_state.get("cover_image")
# 如果没有传入上下文,则使用任务状态中的
if not full_outline:
full_outline = task_state.get("full_outline", "")
if not user_topic:
user_topic = task_state.get("user_topic", "")
user_images = task_state.get("user_images")
# 如果任务状态中没有封面图,尝试从文件系统加载
if use_reference and reference_image is None:
cover_path = os.path.join(self.current_task_dir, "0.png")
if os.path.exists(cover_path):
with open(cover_path, "rb") as f:
cover_data = f.read()
# 压缩封面图到 200KB
reference_image = compress_image(cover_data, max_size_kb=200)
index, success, filename, error = self._generate_single_image(
page,
task_id,
reference_image,
0,
full_outline,
user_images,
user_topic
)
if success:
if task_id in self._task_states:
self._task_states[task_id]["generated"][index] = filename
if index in self._task_states[task_id]["failed"]:
del self._task_states[task_id]["failed"][index]
return {
"success": True,
"index": index,
"image_url": f"/api/images/{task_id}/{filename}"
}
else:
return {
"success": False,
"index": index,
"error": error,
"retryable": True
}
def retry_failed_images(
self,
task_id: str,
pages: List[Dict]
) -> Generator[Dict[str, Any], None, None]:
"""
批量重试失败的图片
Args:
task_id: 任务ID
pages: 需要重试的页面列表
Yields:
进度事件
"""
# 获取参考图
reference_image = None
if task_id in self._task_states:
reference_image = self._task_states[task_id].get("cover_image")
total = len(pages)
success_count = 0
failed_count = 0
yield {
"event": "retry_start",
"data": {
"total": total,
"message": f"开始重试 {total} 张失败的图片"
}
}
# 并发重试
# 从任务状态中获取完整大纲
full_outline = ""
if task_id in self._task_states:
full_outline = self._task_states[task_id].get("full_outline", "")
with ThreadPoolExecutor(max_workers=self.MAX_CONCURRENT) as executor:
future_to_page = {
executor.submit(
self._generate_single_image,
page,
task_id,
reference_image,
0, # retry_count
full_outline # 传入完整大纲
): page
for page in pages
}
for future in as_completed(future_to_page):
page = future_to_page[future]
try:
index, success, filename, error = future.result()
if success:
success_count += 1
if task_id in self._task_states:
self._task_states[task_id]["generated"][index] = filename
if index in self._task_states[task_id]["failed"]:
del self._task_states[task_id]["failed"][index]
yield {
"event": "complete",
"data": {
"index": index,
"status": "done",
"image_url": f"/api/images/{task_id}/{filename}"
}
}
else:
failed_count += 1
yield {
"event": "error",
"data": {
"index": index,
"status": "error",
"message": error,
"retryable": True
}
}
except Exception as e:
failed_count += 1
yield {
"event": "error",
"data": {
"index": page["index"],
"status": "error",
"message": str(e),
"retryable": True
}
}
yield {
"event": "retry_finish",
"data": {
"success": failed_count == 0,
"total": total,
"completed": success_count,
"failed": failed_count
}
}
def regenerate_image(
self,
task_id: str,
page: Dict,
use_reference: bool = True,
full_outline: str = "",
user_topic: str = ""
) -> Dict[str, Any]:
"""
重新生成图片(用户手动触发,即使成功的也可以重新生成)
Args:
task_id: 任务ID
page: 页面数据
use_reference: 是否使用封面作为参考
full_outline: 完整大纲文本
user_topic: 用户原始输入
Returns:
生成结果
"""
return self.retry_single_image(
task_id, page, use_reference,
full_outline=full_outline,
user_topic=user_topic
)
def get_image_path(self, task_id: str, filename: str) -> str:
"""
获取图片完整路径
Args:
task_id: 任务ID
filename: 文件名
Returns:
完整路径
"""
task_dir = os.path.join(self.history_root_dir, task_id)
return os.path.join(task_dir, filename)
def get_task_state(self, task_id: str) -> Optional[Dict]:
"""获取任务状态"""
return self._task_states.get(task_id)
def cleanup_task(self, task_id: str):
"""清理任务状态(释放内存)"""
if task_id in self._task_states:
del self._task_states[task_id]
# 全局服务实例
_service_instance = None
def get_image_service() -> ImageService:
"""获取全局图片生成服务实例"""
global _service_instance
if _service_instance is None:
_service_instance = ImageService()
return _service_instance
def reset_image_service():
"""重置全局服务实例(配置更新后调用)"""
global _service_instance
_service_instance = None