File size: 9,059 Bytes
3025bb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abd08cb
3025bb3
 
abd08cb
 
 
 
3025bb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
from __future__ import annotations

import io
from abc import ABC, abstractmethod
from typing import Any, Sequence

import requests
from PIL import Image

from nodes.utils import encode_image
from services.exceptions import GenerationError
from services.image.ImageGenerationResult import ImageGenerationResult
from services.image.ImageGenerator import ImageGenerator
from services.progress import ProgressCallback, call_progress
from services.registry import register_service
from services.utils.fal_service import run_fal


class FalAIImageGenerator(ImageGenerator, ABC):
    """Abstract base class for fal.ai based image generators.

    This class implements the common pipeline:

    - select a fal model slug to call
    - build the input arguments
    - call fal.ai
    - download and decode the images

    Concrete subclasses are responsible for:

    - storing any model slugs they need
    - implementing the selection strategy
    - shaping the input arguments
    """

    service_id = "fal_ai_image"

    def __init__(
            self,
            *,
            service_model_name: str | None = None,
            api_key: str | None = None,
            extra_arguments: dict[str, Any] | None = None,
    ) -> None:
        # service_model_name is only for metadata in GenerationService
        super().__init__(model_name=service_model_name)
        self._api_key = api_key
        self._extra_arguments: dict[str, Any] = extra_arguments or {}

    def close(self) -> None:
        return

    # small helpers that do not know any model names

    def _encode_images(self, images: Sequence[Image.Image]) -> list[str]:
        return [encode_image(img) for img in images]

    def _attach_image_list_argument(
            self,
            arguments: dict[str, Any],
            images: Sequence[Image.Image],
            arg_name: str,
    ) -> None:
        encoded_images = self._encode_images(images)
        existing_value = arguments.get(arg_name)
        if isinstance(existing_value, list):
            arguments[arg_name] = existing_value + encoded_images
        else:
            arguments[arg_name] = encoded_images

    def _base_arguments(self, **kwargs: Any) -> dict[str, Any]:
        """Start from configured extra arguments and apply kwargs as overrides."""
        arguments = dict(self._extra_arguments)
        for key, value in kwargs.items():
            if value is not None:
                arguments[key] = value
        return arguments

    # abstract hooks for concrete implementations

    @abstractmethod
    def _select_model(
            self,
            *,
            prompt: str,
            images: Sequence[Image.Image] | None,
            **kwargs: Any,
    ) -> str:
        """Return the fal model slug to call for this request."""
        raise NotImplementedError

    @abstractmethod
    def _build_arguments(
            self,
            *,
            prompt: str,
            images: Sequence[Image.Image] | None,
            **kwargs: Any,
    ) -> dict[str, Any]:
        """Return the arguments payload for fal.ai."""
        raise NotImplementedError

    # shared generation pipeline

    def generate(
            self,
            prompt: str,
            images: Sequence[Image.Image] | None = None,
            *,
            progress: ProgressCallback | None = None,
            aspect_ratio: str | None = None,
            **kwargs: Any,
    ) -> ImageGenerationResult:
        # Include aspect_ratio in kwargs if provided
        if aspect_ratio is not None:
            kwargs = {**kwargs, "aspect_ratio": aspect_ratio}

        model = self._select_model(prompt=prompt, images=images, **kwargs)

        call_progress(progress, 0.1, "Encoding inputs for fal.ai image model")

        arguments = self._build_arguments(
            prompt=prompt,
            images=images,
            **kwargs,
        )

        call_progress(progress, 0.4, "Calling fal.ai image model")

        response = run_fal(
            model=model,
            arguments=arguments,
            api_key=self._api_key,
        )

        call_progress(progress, 0.7, "Downloading images from fal.ai")

        raw_images = response.get("images")
        if not isinstance(raw_images, list) or not raw_images:
            raise GenerationError(
                "fal.ai image model did not return any images in the response."
            )

        decoded_images: list[Image.Image] = []

        for item in raw_images:
            if not isinstance(item, dict):
                continue
            url = item.get("url")
            if not isinstance(url, str) or not url:
                continue

            try:
                resp = requests.get(url, timeout=30)
            except requests.RequestException as exc:
                raise GenerationError(
                    f"Failed to download image from fal.ai URL {url!r}."
                ) from exc

            if resp.status_code != 200:
                raise GenerationError(
                    f"fal.ai image URL {url!r} returned status code {resp.status_code}."
                )

            try:
                img = Image.open(io.BytesIO(resp.content)).convert("RGBA")
            except OSError as exc:
                raise GenerationError(
                    "Received invalid image data from fal.ai."
                ) from exc

            img.load()
            decoded_images.append(img)

        if not decoded_images:
            raise GenerationError(
                "fal.ai image model did not yield any decodable images."
            )

        call_progress(progress, 0.95, "Preparing fal.ai image result")

        return ImageGenerationResult(
            provider="fal.ai",
            model=model,
            images=decoded_images,
            raw_response=response,
        )


# concrete model combinations


@register_service
class FalAINanoBananaGenerator(FalAIImageGenerator):
    """fal-ai/nano-banana text and edit combination."""

    service_id = "fal_ai_image_nano_banana"

    def __init__(self, api_key: str | None = None) -> None:
        super().__init__(
            service_model_name="fal-ai/nano-banana",
            api_key=api_key,
            extra_arguments={},
        )
        self._text_model: str = "fal-ai/nano-banana"
        self._edit_model: str = "fal-ai/nano-banana/edit"
        # nano banana edit expects images under "image_urls"
        self._image_argument: str = "image_urls"

    def _select_model(
            self,
            *,
            prompt: str,
            images: Sequence[Image.Image] | None,
            **kwargs: Any,
    ) -> str:
        if not images:
            return self._text_model
        return self._edit_model

    def _build_arguments(
            self,
            *,
            prompt: str,
            images: Sequence[Image.Image] | None,
            **kwargs: Any,
    ) -> dict[str, Any]:
        arguments = self._base_arguments(**kwargs)
        arguments["prompt"] = prompt

        if not images:
            return arguments

        self._attach_image_list_argument(arguments, images, self._image_argument)
        return arguments

    @classmethod
    def default_model_name(cls) -> str:
        return "fal-ai/nano-banana"


@register_service
class FalAIReveGenerator(FalAIImageGenerator):
    """fal-ai/reve combination:
    - text to image
    - fast edit for one image
    - fast remix for multiple images
    """

    service_id = "fal_ai_image_reve"

    def __init__(self, api_key: str | None = None) -> None:
        super().__init__(
            service_model_name="fal-ai/reve",
            api_key=api_key,
            extra_arguments={},
        )
        self._text_model: str = "fal-ai/reve/text-to-image"
        self._edit_model: str = "fal-ai/reve/fast/edit"
        self._remix_model: str = "fal-ai/reve/fast/remix"

    def _select_model(
            self,
            *,
            prompt: str,
            images: Sequence[Image.Image] | None,
            **kwargs: Any,
    ) -> str:
        count = len(images) if images is not None else 0

        if count == 0:
            return self._text_model
        if count == 1:
            return self._edit_model
        return self._remix_model

    def _build_arguments(
            self,
            *,
            prompt: str,
            images: Sequence[Image.Image] | None,
            **kwargs: Any,
    ) -> dict[str, Any]:
        arguments = self._base_arguments(**kwargs)
        arguments["prompt"] = prompt

        if not images:
            # text to image ignores image inputs
            return arguments

        count = len(images)

        if count == 1:
            # fast edit expects a single "image_url"
            arguments["image_url"] = encode_image(images[0])
            return arguments

        # fast remix expects "image_urls" as a list
        self._attach_image_list_argument(arguments, images, "image_urls")
        return arguments

    @classmethod
    def default_model_name(cls) -> str:
        return "fal-ai/reve"