File size: 9,617 Bytes
ad2cc89
 
 
4af6f42
ad2cc89
 
 
 
 
 
76ee88a
ad2cc89
76ee88a
4af6f42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad2cc89
 
 
 
 
 
 
 
 
 
 
49e37f6
 
 
ad2cc89
 
 
fc3d048
 
ad2cc89
 
fc3d048
 
ad2cc89
 
fc3d048
 
ad2cc89
 
 
 
 
 
 
 
fc3d048
 
6224b7f
 
fc3d048
 
 
6224b7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc3d048
6224b7f
fc3d048
 
 
 
 
 
 
 
 
 
 
 
 
ad2cc89
 
76ee88a
ad2cc89
 
 
 
 
 
 
 
 
76ee88a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad2cc89
76ee88a
 
ad2cc89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76ee88a
ad2cc89
 
76ee88a
 
 
 
 
 
ad2cc89
 
 
 
 
 
 
76ee88a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad2cc89
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""
FAL API Client - Simplified and clean implementation
消除复杂的异步设计,使用同步模式简化代码
通过猴子补丁注入 X-Fal-Store-IO header
"""
import os
import uuid
import tempfile
import base64
import json
import io
from typing import Dict, Any, Optional, List
from PIL import Image

# 猴子补丁:注入 X-Fal-Store-IO header 到所有 httpx 请求
import httpx

_original_request = httpx.Client.request
_original_async_request = httpx.AsyncClient.request

def _patched_request(self, *args, **kwargs):
    """注入隐私保护 header 的同步请求"""
    if 'headers' not in kwargs:
        kwargs['headers'] = {}
    kwargs['headers']['X-Fal-Store-IO'] = '0'
    return _original_request(self, *args, **kwargs)

async def _patched_async_request(self, *args, **kwargs):
    """注入隐私保护 header 的异步请求"""
    if 'headers' not in kwargs:
        kwargs['headers'] = {}
    kwargs['headers']['X-Fal-Store-IO'] = '0'
    return await _original_async_request(self, *args, **kwargs)

httpx.Client.request = _patched_request
httpx.AsyncClient.request = _patched_async_request

# 现在导入 fal_client,它会使用打过补丁的 httpx
import fal_client


class FALClient:
    """简化的FAL客户端,消除复杂的异步处理"""

    def __init__(self, api_key: Optional[str] = None):
        self.api_key = api_key or os.environ.get('FAL_KEY')
        if not self.api_key:
            raise ValueError("FAL API key is required")

        # 使用简单的同步客户端
        # 注意: SyncClient 不支持 headers 参数(与 AsyncClient 不同)
        self.client = fal_client.SyncClient(key=self.api_key)

    def generate_image(self, model_endpoint: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
        """
        图像生成 - 队列模式,支持状态轮询
        按照FAL最佳实践,使用队列API避免长时间阻塞
        """
        try:
            # 使用队列模式,不调用.get(),返回request_id供轮询
            request_handle = self.client.submit(model_endpoint, arguments=arguments)
            return {
                'success': True,
                'request_id': request_handle.request_id,
                'status': 'submitted'
            }
        except Exception as e:
            return {
                'success': False,
                'error': str(e),
                'request_id': str(uuid.uuid4())
            }

    def get_status(self, model_endpoint: str, request_id: str) -> Dict[str, Any]:
        """
        获取请求状态 - 正确处理FAL状态对象类型
        使用isinstance分支处理Queued/InProgress/Completed对象
        """
        try:
            base_model = self._get_base_model(model_endpoint)

            st = self.client.status(base_model, request_id, with_logs=True)  # 返回类型:Queued/InProgress/Completed

            resp: Dict[str, Any] = {"success": True, "status": None, "logs": []}

            # Queued
            if isinstance(st, fal_client.Queued):
                resp["status"] = "in_queue"
                # fal-client 对应属性可能叫 position / queue_position,视版本而定:
                pos = getattr(st, "position", None) or getattr(st, "queue_position", None)
                if pos is not None:
                    resp["queue_position"] = pos

            # In progress
            elif isinstance(st, fal_client.InProgress):
                resp["status"] = "in_progress"
                resp["logs"] = getattr(st, "logs", []) or []

            # Completed -> 需要单独取结果
            elif isinstance(st, fal_client.Completed):
                resp["status"] = "completed"
                resp["logs"] = getattr(st, "logs", []) or []
                # 取最终结果
                resp["result"] = self.client.result(base_model, request_id)

            # 兜底
            else:
                resp["status"] = "unknown"

            return resp

        except fal_client.FalClientHTTPError as e:
            return {"success": False, "error": f"HTTP {getattr(e, 'status_code', '?')}: {str(e)}"}
        except Exception as e:
            return {"success": False, "error": str(e)}

    def _get_base_model(self, model_endpoint: str) -> str:
        """
        去掉FAL模型端点的子路径,用于状态查询
        例如: fal-ai/bytedance/seedream/v4/edit -> fal-ai/bytedance/seedream/v4
        """
        # 移除常见的子路径后缀
        subpaths = ['/edit', '/dev', '/image-to-image', '/text-to-image', '/api', '/realtime', '/playground']
        for subpath in subpaths:
            if model_endpoint.endswith(subpath):
                return model_endpoint[:-len(subpath)]
        return model_endpoint

    def upload_file(self, file_data: str) -> Dict[str, Any]:
        """
        文件上传 - 添加图片压缩(FAL合规:max 2000px, <10MB)
        """
        try:
            if not file_data.startswith('data:'):
                return {'success': True, 'url': file_data}

            # 简单的base64解码
            header, base64_content = file_data.split(',', 1)
            image_bytes = base64.b64decode(base64_content)

            # ---- 图片压缩和调整大小(FAL compliance: max 2000px, <10MB) ----
            im = Image.open(io.BytesIO(image_bytes)).convert("RGB")
            original_size = (im.width, im.height)
            max_side = max(im.width, im.height)

            if max_side > 2000:
                scale = 2000 / float(max_side)
                new_width = int(im.width * scale)
                new_height = int(im.height * scale)
                im = im.resize((new_width, new_height), Image.LANCZOS)

            # 压缩为JPEG质量92
            buf = io.BytesIO()
            im.save(buf, format="JPEG", quality=92, optimize=True)
            compressed_size_mb = buf.tell() / (1024 * 1024)
            buf.seek(0)
            # --------------------------------------------------------------------

            # 使用临时文件上传
            with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file:
                tmp_file.write(buf.read())
                tmp_file_path = tmp_file.name

            try:
                url = self.client.upload_file(tmp_file_path)
                return {'success': True, 'url': url}
            finally:
                # 清理临时文件
                try:
                    os.unlink(tmp_file_path)
                except:
                    pass

        except Exception as e:
            return {'success': False, 'error': str(e)}


def get_api_key_from_request(request) -> Optional[str]:
    """从请求中提取API密钥"""
    auth_header = request.headers.get('Authorization', '')
    if auth_header.startswith('Bearer '):
        return auth_header.replace('Bearer ', '')
    return os.environ.get('FAL_KEY')


def validate_generation_request(data: Dict[str, Any]) -> Dict[str, Any]:
    """验证生成请求参数"""
    if not data:
        return {'valid': False, 'error': 'No data provided'}

    if not data.get('prompt'):
        return {'valid': False, 'error': 'Prompt is required'}

    return {'valid': True}


def prepare_fal_arguments(data: Dict[str, Any], model_endpoint: str) -> Dict[str, Any]:
    """准备FAL API参数 - 支持图片和视频模型"""
    arguments = {'prompt': data.get('prompt')}

    # 判断模型类型
    is_text_to_image = 'text-to-image' in model_endpoint
    is_video_model = 'wan-25-preview' in model_endpoint or 'wan/v2.2' in model_endpoint

    # 处理图像编辑模式(非T2I,非视频)
    if not is_text_to_image and not is_video_model:
        image_urls = data.get('image_urls', [])
        if image_urls:
            arguments['image_urls'] = image_urls[:10]  # 最多10张图片

        if 'max_images' in data:
            arguments['max_images'] = data['max_images']

    # 图片通用参数(不用于视频)
    if not is_video_model:
        for param in ['image_size', 'num_images']:
            if param in data:
                arguments[param] = data[param]

    # 通用参数(图片和视频都支持)
    if 'seed' in data:
        arguments['seed'] = data['seed']

    # 视频专用参数(WAN 2.2/2.5)
    if is_video_model:
        video_params = [
            'image_url',              # I2V: single first frame
            'resolution',             # e.g. "480p" | "580p" | "720p" | "1080p"
            'duration',               # "5" | "10" (WAN 2.5)
            'frames_per_second',      # (WAN 2.2)
            'num_frames',             # (WAN 2.2)
            'negative_prompt',
            'video_quality',          # (WAN 2.2)
            'video_write_mode',       # (WAN 2.2)
            'acceleration',           # (WAN 2.2)
            'guidance_scale',         # (WAN 2.2)
            'guidance_scale_2',       # (WAN 2.2)
            'interpolator_model',     # (WAN 2.2)
            'num_interpolated_frames',# (WAN 2.2)
            'adjust_fps_for_interpolation',  # (WAN 2.2)
            'aspect_ratio',           # (WAN 2.2)
            'end_image_url',          # (WAN 2.2 optional)
            'audio_url',              # (WAN 2.5 optional: audio track)
            'enable_prompt_expansion' # both 2.5 and 2.2 (2.5 defaults to true)
        ]
        for param in video_params:
            if param in data:
                arguments[param] = data[param]

    # 其他可选参数
    if 'enable_safety_checker' in data:
        arguments['enable_safety_checker'] = data['enable_safety_checker']

    return arguments