seedream4 / api /fal_client.py
wapadil
fix(app_simple): apply critical fixes to production deployment
76ee88a
"""
FAL API Client - Simplified and clean implementation
消除复杂的异步设计,使用同步模式简化代码
通过猴子补丁注入 X-Fal-Store-IO header
"""
import os
import uuid
import tempfile
import base64
import json
import io
from typing import Dict, Any, Optional, List
from PIL import Image
# 猴子补丁:注入 X-Fal-Store-IO header 到所有 httpx 请求
import httpx
_original_request = httpx.Client.request
_original_async_request = httpx.AsyncClient.request
def _patched_request(self, *args, **kwargs):
"""注入隐私保护 header 的同步请求"""
if 'headers' not in kwargs:
kwargs['headers'] = {}
kwargs['headers']['X-Fal-Store-IO'] = '0'
return _original_request(self, *args, **kwargs)
async def _patched_async_request(self, *args, **kwargs):
"""注入隐私保护 header 的异步请求"""
if 'headers' not in kwargs:
kwargs['headers'] = {}
kwargs['headers']['X-Fal-Store-IO'] = '0'
return await _original_async_request(self, *args, **kwargs)
httpx.Client.request = _patched_request
httpx.AsyncClient.request = _patched_async_request
# 现在导入 fal_client,它会使用打过补丁的 httpx
import fal_client
class FALClient:
"""简化的FAL客户端,消除复杂的异步处理"""
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or os.environ.get('FAL_KEY')
if not self.api_key:
raise ValueError("FAL API key is required")
# 使用简单的同步客户端
# 注意: SyncClient 不支持 headers 参数(与 AsyncClient 不同)
self.client = fal_client.SyncClient(key=self.api_key)
def generate_image(self, model_endpoint: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
图像生成 - 队列模式,支持状态轮询
按照FAL最佳实践,使用队列API避免长时间阻塞
"""
try:
# 使用队列模式,不调用.get(),返回request_id供轮询
request_handle = self.client.submit(model_endpoint, arguments=arguments)
return {
'success': True,
'request_id': request_handle.request_id,
'status': 'submitted'
}
except Exception as e:
return {
'success': False,
'error': str(e),
'request_id': str(uuid.uuid4())
}
def get_status(self, model_endpoint: str, request_id: str) -> Dict[str, Any]:
"""
获取请求状态 - 正确处理FAL状态对象类型
使用isinstance分支处理Queued/InProgress/Completed对象
"""
try:
base_model = self._get_base_model(model_endpoint)
st = self.client.status(base_model, request_id, with_logs=True) # 返回类型:Queued/InProgress/Completed
resp: Dict[str, Any] = {"success": True, "status": None, "logs": []}
# Queued
if isinstance(st, fal_client.Queued):
resp["status"] = "in_queue"
# fal-client 对应属性可能叫 position / queue_position,视版本而定:
pos = getattr(st, "position", None) or getattr(st, "queue_position", None)
if pos is not None:
resp["queue_position"] = pos
# In progress
elif isinstance(st, fal_client.InProgress):
resp["status"] = "in_progress"
resp["logs"] = getattr(st, "logs", []) or []
# Completed -> 需要单独取结果
elif isinstance(st, fal_client.Completed):
resp["status"] = "completed"
resp["logs"] = getattr(st, "logs", []) or []
# 取最终结果
resp["result"] = self.client.result(base_model, request_id)
# 兜底
else:
resp["status"] = "unknown"
return resp
except fal_client.FalClientHTTPError as e:
return {"success": False, "error": f"HTTP {getattr(e, 'status_code', '?')}: {str(e)}"}
except Exception as e:
return {"success": False, "error": str(e)}
def _get_base_model(self, model_endpoint: str) -> str:
"""
去掉FAL模型端点的子路径,用于状态查询
例如: fal-ai/bytedance/seedream/v4/edit -> fal-ai/bytedance/seedream/v4
"""
# 移除常见的子路径后缀
subpaths = ['/edit', '/dev', '/image-to-image', '/text-to-image', '/api', '/realtime', '/playground']
for subpath in subpaths:
if model_endpoint.endswith(subpath):
return model_endpoint[:-len(subpath)]
return model_endpoint
def upload_file(self, file_data: str) -> Dict[str, Any]:
"""
文件上传 - 添加图片压缩(FAL合规:max 2000px, <10MB)
"""
try:
if not file_data.startswith('data:'):
return {'success': True, 'url': file_data}
# 简单的base64解码
header, base64_content = file_data.split(',', 1)
image_bytes = base64.b64decode(base64_content)
# ---- 图片压缩和调整大小(FAL compliance: max 2000px, <10MB) ----
im = Image.open(io.BytesIO(image_bytes)).convert("RGB")
original_size = (im.width, im.height)
max_side = max(im.width, im.height)
if max_side > 2000:
scale = 2000 / float(max_side)
new_width = int(im.width * scale)
new_height = int(im.height * scale)
im = im.resize((new_width, new_height), Image.LANCZOS)
# 压缩为JPEG质量92
buf = io.BytesIO()
im.save(buf, format="JPEG", quality=92, optimize=True)
compressed_size_mb = buf.tell() / (1024 * 1024)
buf.seek(0)
# --------------------------------------------------------------------
# 使用临时文件上传
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file:
tmp_file.write(buf.read())
tmp_file_path = tmp_file.name
try:
url = self.client.upload_file(tmp_file_path)
return {'success': True, 'url': url}
finally:
# 清理临时文件
try:
os.unlink(tmp_file_path)
except:
pass
except Exception as e:
return {'success': False, 'error': str(e)}
def get_api_key_from_request(request) -> Optional[str]:
"""从请求中提取API密钥"""
auth_header = request.headers.get('Authorization', '')
if auth_header.startswith('Bearer '):
return auth_header.replace('Bearer ', '')
return os.environ.get('FAL_KEY')
def validate_generation_request(data: Dict[str, Any]) -> Dict[str, Any]:
"""验证生成请求参数"""
if not data:
return {'valid': False, 'error': 'No data provided'}
if not data.get('prompt'):
return {'valid': False, 'error': 'Prompt is required'}
return {'valid': True}
def prepare_fal_arguments(data: Dict[str, Any], model_endpoint: str) -> Dict[str, Any]:
"""准备FAL API参数 - 支持图片和视频模型"""
arguments = {'prompt': data.get('prompt')}
# 判断模型类型
is_text_to_image = 'text-to-image' in model_endpoint
is_video_model = 'wan-25-preview' in model_endpoint or 'wan/v2.2' in model_endpoint
# 处理图像编辑模式(非T2I,非视频)
if not is_text_to_image and not is_video_model:
image_urls = data.get('image_urls', [])
if image_urls:
arguments['image_urls'] = image_urls[:10] # 最多10张图片
if 'max_images' in data:
arguments['max_images'] = data['max_images']
# 图片通用参数(不用于视频)
if not is_video_model:
for param in ['image_size', 'num_images']:
if param in data:
arguments[param] = data[param]
# 通用参数(图片和视频都支持)
if 'seed' in data:
arguments['seed'] = data['seed']
# 视频专用参数(WAN 2.2/2.5)
if is_video_model:
video_params = [
'image_url', # I2V: single first frame
'resolution', # e.g. "480p" | "580p" | "720p" | "1080p"
'duration', # "5" | "10" (WAN 2.5)
'frames_per_second', # (WAN 2.2)
'num_frames', # (WAN 2.2)
'negative_prompt',
'video_quality', # (WAN 2.2)
'video_write_mode', # (WAN 2.2)
'acceleration', # (WAN 2.2)
'guidance_scale', # (WAN 2.2)
'guidance_scale_2', # (WAN 2.2)
'interpolator_model', # (WAN 2.2)
'num_interpolated_frames',# (WAN 2.2)
'adjust_fps_for_interpolation', # (WAN 2.2)
'aspect_ratio', # (WAN 2.2)
'end_image_url', # (WAN 2.2 optional)
'audio_url', # (WAN 2.5 optional: audio track)
'enable_prompt_expansion' # both 2.5 and 2.2 (2.5 defaults to true)
]
for param in video_params:
if param in data:
arguments[param] = data[param]
# 其他可选参数
if 'enable_safety_checker' in data:
arguments['enable_safety_checker'] = data['enable_safety_checker']
return arguments