Spaces:
Paused
Paused
| import asyncio | |
| import concurrent.futures | |
| from typing import List, Dict, Any, Optional, Union | |
| import json | |
| import os | |
| import base64 | |
| import tempfile | |
| import uuid | |
| # 导入原有的SoraImageGenerator类 | |
| from .sora_generator import SoraImageGenerator | |
| class SoraClient: | |
| def __init__(self, proxy_host=None, proxy_port=None, proxy_user=None, proxy_pass=None, auth_token=None): | |
| """初始化Sora客户端,使用cloudscraper绕过CF验证""" | |
| self.generator = SoraImageGenerator( | |
| proxy_host=proxy_host, | |
| proxy_port=proxy_port, | |
| proxy_user=proxy_user, | |
| proxy_pass=proxy_pass, | |
| auth_token=auth_token | |
| ) | |
| # 保存原始的auth_token,用于检测是否已更新 | |
| self.auth_token = auth_token | |
| # 创建线程池执行器 | |
| self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=10) | |
| async def generate_image(self, prompt: str, num_images: int = 1, | |
| width: int = 720, height: int = 480) -> List[str]: | |
| """异步包装SoraImageGenerator.generate_image方法""" | |
| loop = asyncio.get_running_loop() | |
| # 使用线程池执行同步方法(因为cloudscraper不是异步的) | |
| result = await loop.run_in_executor( | |
| self.executor, | |
| lambda: self.generator.generate_image(prompt, num_images, width, height) | |
| ) | |
| # 检查generator中的auth_token是否已经被更新(由自动切换密钥机制) | |
| if self.generator.auth_token != self.auth_token: | |
| self.auth_token = self.generator.auth_token | |
| if isinstance(result, list): | |
| return result | |
| else: | |
| raise Exception(f"图像生成失败: {result}") | |
| async def upload_image(self, image_path: str) -> Dict: | |
| """异步包装上传图片方法""" | |
| loop = asyncio.get_running_loop() | |
| result = await loop.run_in_executor( | |
| self.executor, | |
| lambda: self.generator.upload_image(image_path) | |
| ) | |
| # 检查generator中的auth_token是否已经被更新 | |
| if self.generator.auth_token != self.auth_token: | |
| self.auth_token = self.generator.auth_token | |
| if isinstance(result, dict) and 'id' in result: | |
| return result | |
| else: | |
| raise Exception(f"图片上传失败: {result}") | |
| async def generate_image_remix(self, prompt: str, media_id: str, | |
| num_images: int = 1) -> List[str]: | |
| """异步包装remix方法""" | |
| loop = asyncio.get_running_loop() | |
| # 处理可能包含API密钥信息的media_id对象 | |
| if isinstance(media_id, dict) and 'id' in media_id: | |
| # 如果上传时使用的密钥与当前不同,则先切换密钥 | |
| if 'used_auth_token' in media_id and media_id['used_auth_token'] != self.auth_token: | |
| self.auth_token = media_id['used_auth_token'] | |
| # 同步更新generator的auth_token | |
| self.generator.auth_token = self.auth_token | |
| # 提取实际的media_id | |
| media_id = media_id['id'] | |
| result = await loop.run_in_executor( | |
| self.executor, | |
| lambda: self.generator.generate_image_remix(prompt, media_id, num_images) | |
| ) | |
| # 检查generator中的auth_token是否已经被更新 | |
| if self.generator.auth_token != self.auth_token: | |
| self.auth_token = self.generator.auth_token | |
| if isinstance(result, list): | |
| return result | |
| else: | |
| raise Exception(f"Remix生成失败: {result}") | |
| async def test_connection(self) -> Dict: | |
| """测试API连接是否有效""" | |
| try: | |
| # 简单测试上传功能,这个方法会调用API但不会真正上传文件 | |
| loop = asyncio.get_running_loop() | |
| result = await loop.run_in_executor( | |
| self.executor, | |
| lambda: self.generator.test_connection() | |
| ) | |
| # 检查generator中的auth_token是否已经被更新 | |
| if self.generator.auth_token != self.auth_token: | |
| self.auth_token = self.generator.auth_token | |
| # 直接返回generator.test_connection的结果,保留所有信息 | |
| return result | |
| except Exception as e: | |
| return {"status": "error", "message": f"API连接测试失败: {str(e)}"} | |
| def close(self): | |
| """关闭线程池""" | |
| self.executor.shutdown(wait=False) |