""" FAL API Client - Simplified and clean implementation 消除复杂的异步设计,使用同步模式简化代码 通过猴子补丁注入 X-Fal-Store-IO header """ import os import uuid import tempfile import base64 import json import io from typing import Dict, Any, Optional, List from PIL import Image # 猴子补丁:注入 X-Fal-Store-IO header 到所有 httpx 请求 import httpx _original_request = httpx.Client.request _original_async_request = httpx.AsyncClient.request def _patched_request(self, *args, **kwargs): """注入隐私保护 header 的同步请求""" if 'headers' not in kwargs: kwargs['headers'] = {} kwargs['headers']['X-Fal-Store-IO'] = '0' return _original_request(self, *args, **kwargs) async def _patched_async_request(self, *args, **kwargs): """注入隐私保护 header 的异步请求""" if 'headers' not in kwargs: kwargs['headers'] = {} kwargs['headers']['X-Fal-Store-IO'] = '0' return await _original_async_request(self, *args, **kwargs) httpx.Client.request = _patched_request httpx.AsyncClient.request = _patched_async_request # 现在导入 fal_client,它会使用打过补丁的 httpx import fal_client class FALClient: """简化的FAL客户端,消除复杂的异步处理""" def __init__(self, api_key: Optional[str] = None): self.api_key = api_key or os.environ.get('FAL_KEY') if not self.api_key: raise ValueError("FAL API key is required") # 使用简单的同步客户端 # 注意: SyncClient 不支持 headers 参数(与 AsyncClient 不同) self.client = fal_client.SyncClient(key=self.api_key) def generate_image(self, model_endpoint: str, arguments: Dict[str, Any]) -> Dict[str, Any]: """ 图像生成 - 队列模式,支持状态轮询 按照FAL最佳实践,使用队列API避免长时间阻塞 """ try: # 使用队列模式,不调用.get(),返回request_id供轮询 request_handle = self.client.submit(model_endpoint, arguments=arguments) return { 'success': True, 'request_id': request_handle.request_id, 'status': 'submitted' } except Exception as e: return { 'success': False, 'error': str(e), 'request_id': str(uuid.uuid4()) } def get_status(self, model_endpoint: str, request_id: str) -> Dict[str, Any]: """ 获取请求状态 - 正确处理FAL状态对象类型 使用isinstance分支处理Queued/InProgress/Completed对象 """ try: base_model = self._get_base_model(model_endpoint) st = self.client.status(base_model, request_id, with_logs=True) # 返回类型:Queued/InProgress/Completed resp: Dict[str, Any] = {"success": True, "status": None, "logs": []} # Queued if isinstance(st, fal_client.Queued): resp["status"] = "in_queue" # fal-client 对应属性可能叫 position / queue_position,视版本而定: pos = getattr(st, "position", None) or getattr(st, "queue_position", None) if pos is not None: resp["queue_position"] = pos # In progress elif isinstance(st, fal_client.InProgress): resp["status"] = "in_progress" resp["logs"] = getattr(st, "logs", []) or [] # Completed -> 需要单独取结果 elif isinstance(st, fal_client.Completed): resp["status"] = "completed" resp["logs"] = getattr(st, "logs", []) or [] # 取最终结果 resp["result"] = self.client.result(base_model, request_id) # 兜底 else: resp["status"] = "unknown" return resp except fal_client.FalClientHTTPError as e: return {"success": False, "error": f"HTTP {getattr(e, 'status_code', '?')}: {str(e)}"} except Exception as e: return {"success": False, "error": str(e)} def _get_base_model(self, model_endpoint: str) -> str: """ 去掉FAL模型端点的子路径,用于状态查询 例如: fal-ai/bytedance/seedream/v4/edit -> fal-ai/bytedance/seedream/v4 """ # 移除常见的子路径后缀 subpaths = ['/edit', '/dev', '/image-to-image', '/text-to-image', '/api', '/realtime', '/playground'] for subpath in subpaths: if model_endpoint.endswith(subpath): return model_endpoint[:-len(subpath)] return model_endpoint def upload_file(self, file_data: str) -> Dict[str, Any]: """ 文件上传 - 添加图片压缩(FAL合规:max 2000px, <10MB) """ try: if not file_data.startswith('data:'): return {'success': True, 'url': file_data} # 简单的base64解码 header, base64_content = file_data.split(',', 1) image_bytes = base64.b64decode(base64_content) # ---- 图片压缩和调整大小(FAL compliance: max 2000px, <10MB) ---- im = Image.open(io.BytesIO(image_bytes)).convert("RGB") original_size = (im.width, im.height) max_side = max(im.width, im.height) if max_side > 2000: scale = 2000 / float(max_side) new_width = int(im.width * scale) new_height = int(im.height * scale) im = im.resize((new_width, new_height), Image.LANCZOS) # 压缩为JPEG质量92 buf = io.BytesIO() im.save(buf, format="JPEG", quality=92, optimize=True) compressed_size_mb = buf.tell() / (1024 * 1024) buf.seek(0) # -------------------------------------------------------------------- # 使用临时文件上传 with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file: tmp_file.write(buf.read()) tmp_file_path = tmp_file.name try: url = self.client.upload_file(tmp_file_path) return {'success': True, 'url': url} finally: # 清理临时文件 try: os.unlink(tmp_file_path) except: pass except Exception as e: return {'success': False, 'error': str(e)} def get_api_key_from_request(request) -> Optional[str]: """从请求中提取API密钥""" auth_header = request.headers.get('Authorization', '') if auth_header.startswith('Bearer '): return auth_header.replace('Bearer ', '') return os.environ.get('FAL_KEY') def validate_generation_request(data: Dict[str, Any]) -> Dict[str, Any]: """验证生成请求参数""" if not data: return {'valid': False, 'error': 'No data provided'} if not data.get('prompt'): return {'valid': False, 'error': 'Prompt is required'} return {'valid': True} def prepare_fal_arguments(data: Dict[str, Any], model_endpoint: str) -> Dict[str, Any]: """准备FAL API参数 - 支持图片和视频模型""" arguments = {'prompt': data.get('prompt')} # 判断模型类型 is_text_to_image = 'text-to-image' in model_endpoint is_video_model = 'wan-25-preview' in model_endpoint or 'wan/v2.2' in model_endpoint # 处理图像编辑模式(非T2I,非视频) if not is_text_to_image and not is_video_model: image_urls = data.get('image_urls', []) if image_urls: arguments['image_urls'] = image_urls[:10] # 最多10张图片 if 'max_images' in data: arguments['max_images'] = data['max_images'] # 图片通用参数(不用于视频) if not is_video_model: for param in ['image_size', 'num_images']: if param in data: arguments[param] = data[param] # 通用参数(图片和视频都支持) if 'seed' in data: arguments['seed'] = data['seed'] # 视频专用参数(WAN 2.2/2.5) if is_video_model: video_params = [ 'image_url', # I2V: single first frame 'resolution', # e.g. "480p" | "580p" | "720p" | "1080p" 'duration', # "5" | "10" (WAN 2.5) 'frames_per_second', # (WAN 2.2) 'num_frames', # (WAN 2.2) 'negative_prompt', 'video_quality', # (WAN 2.2) 'video_write_mode', # (WAN 2.2) 'acceleration', # (WAN 2.2) 'guidance_scale', # (WAN 2.2) 'guidance_scale_2', # (WAN 2.2) 'interpolator_model', # (WAN 2.2) 'num_interpolated_frames',# (WAN 2.2) 'adjust_fps_for_interpolation', # (WAN 2.2) 'aspect_ratio', # (WAN 2.2) 'end_image_url', # (WAN 2.2 optional) 'audio_url', # (WAN 2.5 optional: audio track) 'enable_prompt_expansion' # both 2.5 and 2.2 (2.5 defaults to true) ] for param in video_params: if param in data: arguments[param] = data[param] # 其他可选参数 if 'enable_safety_checker' in data: arguments['enable_safety_checker'] = data['enable_safety_checker'] return arguments