| """ |
| FAL API Client - Simplified and clean implementation |
| 消除复杂的异步设计,使用同步模式简化代码 |
| 通过猴子补丁注入 X-Fal-Store-IO header |
| """ |
| import os |
| import uuid |
| import tempfile |
| import base64 |
| import json |
| import io |
| from typing import Dict, Any, Optional, List |
| from PIL import Image |
|
|
| |
| import httpx |
|
|
| _original_request = httpx.Client.request |
| _original_async_request = httpx.AsyncClient.request |
|
|
| def _patched_request(self, *args, **kwargs): |
| """注入隐私保护 header 的同步请求""" |
| if 'headers' not in kwargs: |
| kwargs['headers'] = {} |
| kwargs['headers']['X-Fal-Store-IO'] = '0' |
| return _original_request(self, *args, **kwargs) |
|
|
| async def _patched_async_request(self, *args, **kwargs): |
| """注入隐私保护 header 的异步请求""" |
| if 'headers' not in kwargs: |
| kwargs['headers'] = {} |
| kwargs['headers']['X-Fal-Store-IO'] = '0' |
| return await _original_async_request(self, *args, **kwargs) |
|
|
| httpx.Client.request = _patched_request |
| httpx.AsyncClient.request = _patched_async_request |
|
|
| |
| import fal_client |
|
|
|
|
| class FALClient: |
| """简化的FAL客户端,消除复杂的异步处理""" |
|
|
| def __init__(self, api_key: Optional[str] = None): |
| self.api_key = api_key or os.environ.get('FAL_KEY') |
| if not self.api_key: |
| raise ValueError("FAL API key is required") |
|
|
| |
| |
| self.client = fal_client.SyncClient(key=self.api_key) |
|
|
| def generate_image(self, model_endpoint: str, arguments: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| 图像生成 - 队列模式,支持状态轮询 |
| 按照FAL最佳实践,使用队列API避免长时间阻塞 |
| """ |
| try: |
| |
| request_handle = self.client.submit(model_endpoint, arguments=arguments) |
| return { |
| 'success': True, |
| 'request_id': request_handle.request_id, |
| 'status': 'submitted' |
| } |
| except Exception as e: |
| return { |
| 'success': False, |
| 'error': str(e), |
| 'request_id': str(uuid.uuid4()) |
| } |
|
|
| def get_status(self, model_endpoint: str, request_id: str) -> Dict[str, Any]: |
| """ |
| 获取请求状态 - 正确处理FAL状态对象类型 |
| 使用isinstance分支处理Queued/InProgress/Completed对象 |
| """ |
| try: |
| base_model = self._get_base_model(model_endpoint) |
|
|
| st = self.client.status(base_model, request_id, with_logs=True) |
|
|
| resp: Dict[str, Any] = {"success": True, "status": None, "logs": []} |
|
|
| |
| if isinstance(st, fal_client.Queued): |
| resp["status"] = "in_queue" |
| |
| pos = getattr(st, "position", None) or getattr(st, "queue_position", None) |
| if pos is not None: |
| resp["queue_position"] = pos |
|
|
| |
| elif isinstance(st, fal_client.InProgress): |
| resp["status"] = "in_progress" |
| resp["logs"] = getattr(st, "logs", []) or [] |
|
|
| |
| elif isinstance(st, fal_client.Completed): |
| resp["status"] = "completed" |
| resp["logs"] = getattr(st, "logs", []) or [] |
| |
| resp["result"] = self.client.result(base_model, request_id) |
|
|
| |
| else: |
| resp["status"] = "unknown" |
|
|
| return resp |
|
|
| except fal_client.FalClientHTTPError as e: |
| return {"success": False, "error": f"HTTP {getattr(e, 'status_code', '?')}: {str(e)}"} |
| except Exception as e: |
| return {"success": False, "error": str(e)} |
|
|
| def _get_base_model(self, model_endpoint: str) -> str: |
| """ |
| 去掉FAL模型端点的子路径,用于状态查询 |
| 例如: fal-ai/bytedance/seedream/v4/edit -> fal-ai/bytedance/seedream/v4 |
| """ |
| |
| subpaths = ['/edit', '/dev', '/image-to-image', '/text-to-image', '/api', '/realtime', '/playground'] |
| for subpath in subpaths: |
| if model_endpoint.endswith(subpath): |
| return model_endpoint[:-len(subpath)] |
| return model_endpoint |
|
|
| def upload_file(self, file_data: str) -> Dict[str, Any]: |
| """ |
| 文件上传 - 添加图片压缩(FAL合规:max 2000px, <10MB) |
| """ |
| try: |
| if not file_data.startswith('data:'): |
| return {'success': True, 'url': file_data} |
|
|
| |
| header, base64_content = file_data.split(',', 1) |
| image_bytes = base64.b64decode(base64_content) |
|
|
| |
| im = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| original_size = (im.width, im.height) |
| max_side = max(im.width, im.height) |
|
|
| if max_side > 2000: |
| scale = 2000 / float(max_side) |
| new_width = int(im.width * scale) |
| new_height = int(im.height * scale) |
| im = im.resize((new_width, new_height), Image.LANCZOS) |
|
|
| |
| buf = io.BytesIO() |
| im.save(buf, format="JPEG", quality=92, optimize=True) |
| compressed_size_mb = buf.tell() / (1024 * 1024) |
| buf.seek(0) |
| |
|
|
| |
| with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file: |
| tmp_file.write(buf.read()) |
| tmp_file_path = tmp_file.name |
|
|
| try: |
| url = self.client.upload_file(tmp_file_path) |
| return {'success': True, 'url': url} |
| finally: |
| |
| try: |
| os.unlink(tmp_file_path) |
| except: |
| pass |
|
|
| except Exception as e: |
| return {'success': False, 'error': str(e)} |
|
|
|
|
| def get_api_key_from_request(request) -> Optional[str]: |
| """从请求中提取API密钥""" |
| auth_header = request.headers.get('Authorization', '') |
| if auth_header.startswith('Bearer '): |
| return auth_header.replace('Bearer ', '') |
| return os.environ.get('FAL_KEY') |
|
|
|
|
| def validate_generation_request(data: Dict[str, Any]) -> Dict[str, Any]: |
| """验证生成请求参数""" |
| if not data: |
| return {'valid': False, 'error': 'No data provided'} |
|
|
| if not data.get('prompt'): |
| return {'valid': False, 'error': 'Prompt is required'} |
|
|
| return {'valid': True} |
|
|
|
|
| def prepare_fal_arguments(data: Dict[str, Any], model_endpoint: str) -> Dict[str, Any]: |
| """准备FAL API参数 - 支持图片和视频模型""" |
| arguments = {'prompt': data.get('prompt')} |
|
|
| |
| is_text_to_image = 'text-to-image' in model_endpoint |
| is_video_model = 'wan-25-preview' in model_endpoint or 'wan/v2.2' in model_endpoint |
|
|
| |
| if not is_text_to_image and not is_video_model: |
| image_urls = data.get('image_urls', []) |
| if image_urls: |
| arguments['image_urls'] = image_urls[:10] |
|
|
| if 'max_images' in data: |
| arguments['max_images'] = data['max_images'] |
|
|
| |
| if not is_video_model: |
| for param in ['image_size', 'num_images']: |
| if param in data: |
| arguments[param] = data[param] |
|
|
| |
| if 'seed' in data: |
| arguments['seed'] = data['seed'] |
|
|
| |
| if is_video_model: |
| video_params = [ |
| 'image_url', |
| 'resolution', |
| 'duration', |
| 'frames_per_second', |
| 'num_frames', |
| 'negative_prompt', |
| 'video_quality', |
| 'video_write_mode', |
| 'acceleration', |
| 'guidance_scale', |
| 'guidance_scale_2', |
| 'interpolator_model', |
| 'num_interpolated_frames', |
| 'adjust_fps_for_interpolation', |
| 'aspect_ratio', |
| 'end_image_url', |
| 'audio_url', |
| 'enable_prompt_expansion' |
| ] |
| for param in video_params: |
| if param in data: |
| arguments[param] = data[param] |
|
|
| |
| if 'enable_safety_checker' in data: |
| arguments['enable_safety_checker'] = data['enable_safety_checker'] |
|
|
| return arguments |