File size: 9,617 Bytes
ad2cc89 4af6f42 ad2cc89 76ee88a ad2cc89 76ee88a 4af6f42 ad2cc89 49e37f6 ad2cc89 fc3d048 ad2cc89 fc3d048 ad2cc89 fc3d048 ad2cc89 fc3d048 6224b7f fc3d048 6224b7f fc3d048 6224b7f fc3d048 ad2cc89 76ee88a ad2cc89 76ee88a ad2cc89 76ee88a ad2cc89 76ee88a ad2cc89 76ee88a ad2cc89 76ee88a ad2cc89 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 | """
FAL API Client - Simplified and clean implementation
消除复杂的异步设计,使用同步模式简化代码
通过猴子补丁注入 X-Fal-Store-IO header
"""
import os
import uuid
import tempfile
import base64
import json
import io
from typing import Dict, Any, Optional, List
from PIL import Image
# 猴子补丁:注入 X-Fal-Store-IO header 到所有 httpx 请求
import httpx
_original_request = httpx.Client.request
_original_async_request = httpx.AsyncClient.request
def _patched_request(self, *args, **kwargs):
"""注入隐私保护 header 的同步请求"""
if 'headers' not in kwargs:
kwargs['headers'] = {}
kwargs['headers']['X-Fal-Store-IO'] = '0'
return _original_request(self, *args, **kwargs)
async def _patched_async_request(self, *args, **kwargs):
"""注入隐私保护 header 的异步请求"""
if 'headers' not in kwargs:
kwargs['headers'] = {}
kwargs['headers']['X-Fal-Store-IO'] = '0'
return await _original_async_request(self, *args, **kwargs)
httpx.Client.request = _patched_request
httpx.AsyncClient.request = _patched_async_request
# 现在导入 fal_client,它会使用打过补丁的 httpx
import fal_client
class FALClient:
"""简化的FAL客户端,消除复杂的异步处理"""
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or os.environ.get('FAL_KEY')
if not self.api_key:
raise ValueError("FAL API key is required")
# 使用简单的同步客户端
# 注意: SyncClient 不支持 headers 参数(与 AsyncClient 不同)
self.client = fal_client.SyncClient(key=self.api_key)
def generate_image(self, model_endpoint: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
图像生成 - 队列模式,支持状态轮询
按照FAL最佳实践,使用队列API避免长时间阻塞
"""
try:
# 使用队列模式,不调用.get(),返回request_id供轮询
request_handle = self.client.submit(model_endpoint, arguments=arguments)
return {
'success': True,
'request_id': request_handle.request_id,
'status': 'submitted'
}
except Exception as e:
return {
'success': False,
'error': str(e),
'request_id': str(uuid.uuid4())
}
def get_status(self, model_endpoint: str, request_id: str) -> Dict[str, Any]:
"""
获取请求状态 - 正确处理FAL状态对象类型
使用isinstance分支处理Queued/InProgress/Completed对象
"""
try:
base_model = self._get_base_model(model_endpoint)
st = self.client.status(base_model, request_id, with_logs=True) # 返回类型:Queued/InProgress/Completed
resp: Dict[str, Any] = {"success": True, "status": None, "logs": []}
# Queued
if isinstance(st, fal_client.Queued):
resp["status"] = "in_queue"
# fal-client 对应属性可能叫 position / queue_position,视版本而定:
pos = getattr(st, "position", None) or getattr(st, "queue_position", None)
if pos is not None:
resp["queue_position"] = pos
# In progress
elif isinstance(st, fal_client.InProgress):
resp["status"] = "in_progress"
resp["logs"] = getattr(st, "logs", []) or []
# Completed -> 需要单独取结果
elif isinstance(st, fal_client.Completed):
resp["status"] = "completed"
resp["logs"] = getattr(st, "logs", []) or []
# 取最终结果
resp["result"] = self.client.result(base_model, request_id)
# 兜底
else:
resp["status"] = "unknown"
return resp
except fal_client.FalClientHTTPError as e:
return {"success": False, "error": f"HTTP {getattr(e, 'status_code', '?')}: {str(e)}"}
except Exception as e:
return {"success": False, "error": str(e)}
def _get_base_model(self, model_endpoint: str) -> str:
"""
去掉FAL模型端点的子路径,用于状态查询
例如: fal-ai/bytedance/seedream/v4/edit -> fal-ai/bytedance/seedream/v4
"""
# 移除常见的子路径后缀
subpaths = ['/edit', '/dev', '/image-to-image', '/text-to-image', '/api', '/realtime', '/playground']
for subpath in subpaths:
if model_endpoint.endswith(subpath):
return model_endpoint[:-len(subpath)]
return model_endpoint
def upload_file(self, file_data: str) -> Dict[str, Any]:
"""
文件上传 - 添加图片压缩(FAL合规:max 2000px, <10MB)
"""
try:
if not file_data.startswith('data:'):
return {'success': True, 'url': file_data}
# 简单的base64解码
header, base64_content = file_data.split(',', 1)
image_bytes = base64.b64decode(base64_content)
# ---- 图片压缩和调整大小(FAL compliance: max 2000px, <10MB) ----
im = Image.open(io.BytesIO(image_bytes)).convert("RGB")
original_size = (im.width, im.height)
max_side = max(im.width, im.height)
if max_side > 2000:
scale = 2000 / float(max_side)
new_width = int(im.width * scale)
new_height = int(im.height * scale)
im = im.resize((new_width, new_height), Image.LANCZOS)
# 压缩为JPEG质量92
buf = io.BytesIO()
im.save(buf, format="JPEG", quality=92, optimize=True)
compressed_size_mb = buf.tell() / (1024 * 1024)
buf.seek(0)
# --------------------------------------------------------------------
# 使用临时文件上传
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file:
tmp_file.write(buf.read())
tmp_file_path = tmp_file.name
try:
url = self.client.upload_file(tmp_file_path)
return {'success': True, 'url': url}
finally:
# 清理临时文件
try:
os.unlink(tmp_file_path)
except:
pass
except Exception as e:
return {'success': False, 'error': str(e)}
def get_api_key_from_request(request) -> Optional[str]:
"""从请求中提取API密钥"""
auth_header = request.headers.get('Authorization', '')
if auth_header.startswith('Bearer '):
return auth_header.replace('Bearer ', '')
return os.environ.get('FAL_KEY')
def validate_generation_request(data: Dict[str, Any]) -> Dict[str, Any]:
"""验证生成请求参数"""
if not data:
return {'valid': False, 'error': 'No data provided'}
if not data.get('prompt'):
return {'valid': False, 'error': 'Prompt is required'}
return {'valid': True}
def prepare_fal_arguments(data: Dict[str, Any], model_endpoint: str) -> Dict[str, Any]:
"""准备FAL API参数 - 支持图片和视频模型"""
arguments = {'prompt': data.get('prompt')}
# 判断模型类型
is_text_to_image = 'text-to-image' in model_endpoint
is_video_model = 'wan-25-preview' in model_endpoint or 'wan/v2.2' in model_endpoint
# 处理图像编辑模式(非T2I,非视频)
if not is_text_to_image and not is_video_model:
image_urls = data.get('image_urls', [])
if image_urls:
arguments['image_urls'] = image_urls[:10] # 最多10张图片
if 'max_images' in data:
arguments['max_images'] = data['max_images']
# 图片通用参数(不用于视频)
if not is_video_model:
for param in ['image_size', 'num_images']:
if param in data:
arguments[param] = data[param]
# 通用参数(图片和视频都支持)
if 'seed' in data:
arguments['seed'] = data['seed']
# 视频专用参数(WAN 2.2/2.5)
if is_video_model:
video_params = [
'image_url', # I2V: single first frame
'resolution', # e.g. "480p" | "580p" | "720p" | "1080p"
'duration', # "5" | "10" (WAN 2.5)
'frames_per_second', # (WAN 2.2)
'num_frames', # (WAN 2.2)
'negative_prompt',
'video_quality', # (WAN 2.2)
'video_write_mode', # (WAN 2.2)
'acceleration', # (WAN 2.2)
'guidance_scale', # (WAN 2.2)
'guidance_scale_2', # (WAN 2.2)
'interpolator_model', # (WAN 2.2)
'num_interpolated_frames',# (WAN 2.2)
'adjust_fps_for_interpolation', # (WAN 2.2)
'aspect_ratio', # (WAN 2.2)
'end_image_url', # (WAN 2.2 optional)
'audio_url', # (WAN 2.5 optional: audio track)
'enable_prompt_expansion' # both 2.5 and 2.2 (2.5 defaults to true)
]
for param in video_params:
if param in data:
arguments[param] = data[param]
# 其他可选参数
if 'enable_safety_checker' in data:
arguments['enable_safety_checker'] = data['enable_safety_checker']
return arguments |