Spaces:
Paused
Paused
| import asyncio | |
| import base64 | |
| import os | |
| import tempfile | |
| import time | |
| import uuid | |
| import logging | |
| import threading | |
| from typing import List, Dict, Any, Optional, Union, Tuple | |
| from ..sora_integration import SoraClient | |
| from ..config import Config | |
| from ..utils import localize_image_urls | |
| logger = logging.getLogger("sora-api.image_service") | |
| # 存储生成结果的全局字典 | |
| generation_results = {} | |
| # 存储任务与API密钥的映射关系 | |
| task_to_api_key = {} | |
| # 将处理中状态消息格式化为think代码块 | |
| def format_think_block(message: str) -> str: | |
| """将消息放入```think代码块中""" | |
| return f"```think\n{message}\n```" | |
| async def process_image_task( | |
| request_id: str, | |
| sora_client: SoraClient, | |
| task_type: str, | |
| prompt: str, | |
| **kwargs | |
| ) -> None: | |
| """ | |
| 统一的图像处理任务函数 | |
| Args: | |
| request_id: 请求ID | |
| sora_client: Sora客户端实例 | |
| task_type: 任务类型 ("generation" 或 "remix") | |
| prompt: 提示词 | |
| **kwargs: 其他参数,取决于任务类型 | |
| """ | |
| try: | |
| # 保存当前任务使用的API密钥,以便后续使用同一密钥进行操作 | |
| current_api_key = sora_client.auth_token | |
| task_to_api_key[request_id] = current_api_key | |
| # 更新状态为处理中 | |
| generation_results[request_id] = { | |
| "status": "processing", | |
| "message": format_think_block("正在准备生成任务,请稍候..."), | |
| "timestamp": int(time.time()), | |
| "api_key": current_api_key # 记录使用的API密钥 | |
| } | |
| # 根据任务类型执行不同操作 | |
| if task_type == "generation": | |
| # 文本到图像生成 | |
| num_images = kwargs.get("num_images", 1) | |
| width = kwargs.get("width", 720) | |
| height = kwargs.get("height", 480) | |
| # 更新状态 | |
| generation_results[request_id] = { | |
| "status": "processing", | |
| "message": format_think_block("正在生成图像,请耐心等待..."), | |
| "timestamp": int(time.time()), | |
| "api_key": current_api_key | |
| } | |
| # 生成图像 | |
| logger.info(f"[{request_id}] 开始生成图像, 提示词: {prompt}") | |
| image_urls = await sora_client.generate_image( | |
| prompt=prompt, | |
| num_images=num_images, | |
| width=width, | |
| height=height | |
| ) | |
| elif task_type == "remix": | |
| # 图像到图像生成(Remix) | |
| image_data = kwargs.get("image_data") | |
| num_images = kwargs.get("num_images", 1) | |
| if not image_data: | |
| raise ValueError("缺少图像数据") | |
| # 更新状态 | |
| generation_results[request_id] = { | |
| "status": "processing", | |
| "message": format_think_block("正在处理上传的图片..."), | |
| "timestamp": int(time.time()), | |
| "api_key": current_api_key | |
| } | |
| # 保存base64图片到临时文件 | |
| temp_dir = tempfile.mkdtemp() | |
| temp_image_path = os.path.join(temp_dir, f"upload_{uuid.uuid4()}.png") | |
| try: | |
| # 解码并保存图片 | |
| with open(temp_image_path, "wb") as f: | |
| f.write(base64.b64decode(image_data)) | |
| # 更新状态 | |
| generation_results[request_id] = { | |
| "status": "processing", | |
| "message": format_think_block("正在上传图片到Sora服务..."), | |
| "timestamp": int(time.time()), | |
| "api_key": current_api_key | |
| } | |
| # 上传图片 - 确保使用与初始请求相同的API密钥 | |
| upload_result = await sora_client.upload_image(temp_image_path) | |
| media_id = upload_result['id'] | |
| # 更新状态 | |
| generation_results[request_id] = { | |
| "status": "processing", | |
| "message": format_think_block("正在基于图片生成新图像..."), | |
| "timestamp": int(time.time()), | |
| "api_key": current_api_key | |
| } | |
| # 执行remix生成 | |
| logger.info(f"[{request_id}] 开始生成Remix图像, 提示词: {prompt}") | |
| image_urls = await sora_client.generate_image_remix( | |
| prompt=prompt, | |
| media_id=media_id, | |
| num_images=num_images | |
| ) | |
| finally: | |
| # 清理临时文件 | |
| if os.path.exists(temp_image_path): | |
| os.remove(temp_image_path) | |
| if os.path.exists(temp_dir): | |
| os.rmdir(temp_dir) | |
| else: | |
| raise ValueError(f"未知的任务类型: {task_type}") | |
| # 验证生成结果 | |
| if isinstance(image_urls, str): | |
| logger.warning(f"[{request_id}] 图像生成失败或返回了错误信息: {image_urls}") | |
| generation_results[request_id] = { | |
| "status": "failed", | |
| "error": image_urls, | |
| "message": format_think_block(f"图像生成失败: {image_urls}"), | |
| "timestamp": int(time.time()), | |
| "api_key": current_api_key | |
| } | |
| return | |
| if not image_urls: | |
| logger.warning(f"[{request_id}] 图像生成返回了空列表") | |
| generation_results[request_id] = { | |
| "status": "failed", | |
| "error": "图像生成返回了空结果", | |
| "message": format_think_block("图像生成失败: 服务器返回了空结果"), | |
| "timestamp": int(time.time()), | |
| "api_key": current_api_key | |
| } | |
| return | |
| logger.info(f"[{request_id}] 成功生成 {len(image_urls)} 张图片") | |
| # 本地化图片URL | |
| if Config.IMAGE_LOCALIZATION: | |
| logger.info(f"[{request_id}] 准备进行图片本地化处理") | |
| try: | |
| localized_urls = await localize_image_urls(image_urls) | |
| logger.info(f"[{request_id}] 图片本地化处理完成") | |
| # 检查本地化结果 | |
| if not localized_urls: | |
| logger.warning(f"[{request_id}] 本地化处理返回了空列表,将使用原始URL") | |
| localized_urls = image_urls | |
| # 检查是否所有URL都被正确本地化 | |
| local_count = sum(1 for url in localized_urls if url.startswith("/static/") or "/static/" in url) | |
| logger.info(f"[{request_id}] 本地化结果: 总计 {len(localized_urls)} 张图片,成功本地化 {local_count} 张") | |
| if local_count == 0: | |
| logger.warning(f"[{request_id}] 警告:没有一个URL被成功本地化,将使用原始URL") | |
| localized_urls = image_urls | |
| image_urls = localized_urls | |
| except Exception as e: | |
| logger.error(f"[{request_id}] 图片本地化过程中发生错误: {str(e)}", exc_info=True) | |
| logger.info(f"[{request_id}] 由于错误,将使用原始URL") | |
| else: | |
| logger.info(f"[{request_id}] 图片本地化功能未启用,使用原始URL") | |
| # 存储结果 | |
| generation_results[request_id] = { | |
| "status": "completed", | |
| "image_urls": image_urls, | |
| "timestamp": int(time.time()), | |
| "api_key": current_api_key | |
| } | |
| # 30分钟后自动清理结果 | |
| def cleanup_task(): | |
| generation_results.pop(request_id, None) | |
| task_to_api_key.pop(request_id, None) | |
| threading.Timer(1800, cleanup_task).start() | |
| except Exception as e: | |
| error_message = f"图像生成失败: {str(e)}" | |
| generation_results[request_id] = { | |
| "status": "failed", | |
| "error": error_message, | |
| "message": format_think_block(error_message), | |
| "timestamp": int(time.time()), | |
| "api_key": sora_client.auth_token # 记录当前API密钥 | |
| } | |
| logger.error(f"图像生成失败 (ID: {request_id}): {str(e)}", exc_info=True) | |
| def get_generation_result(request_id: str) -> Dict[str, Any]: | |
| """获取生成结果""" | |
| if request_id not in generation_results: | |
| return { | |
| "status": "not_found", | |
| "error": f"找不到生成任务: {request_id}", | |
| "timestamp": int(time.time()) | |
| } | |
| return generation_results[request_id] | |
| def get_task_api_key(request_id: str) -> Optional[str]: | |
| """获取任务对应的API密钥""" | |
| return task_to_api_key.get(request_id) |