| from __future__ import annotations |
|
|
| import io |
| from abc import ABC, abstractmethod |
| from typing import Any, Sequence |
|
|
| import requests |
| from PIL import Image |
|
|
| from nodes.utils import encode_image |
| from services.exceptions import GenerationError |
| from services.image.ImageGenerationResult import ImageGenerationResult |
| from services.image.ImageGenerator import ImageGenerator |
| from services.progress import ProgressCallback, call_progress |
| from services.registry import register_service |
| from services.utils.fal_service import run_fal |
|
|
|
|
| class FalAIImageGenerator(ImageGenerator, ABC): |
| """Abstract base class for fal.ai based image generators. |
| |
| This class implements the common pipeline: |
| |
| - select a fal model slug to call |
| - build the input arguments |
| - call fal.ai |
| - download and decode the images |
| |
| Concrete subclasses are responsible for: |
| |
| - storing any model slugs they need |
| - implementing the selection strategy |
| - shaping the input arguments |
| """ |
|
|
| service_id = "fal_ai_image" |
|
|
| def __init__( |
| self, |
| *, |
| service_model_name: str | None = None, |
| api_key: str | None = None, |
| extra_arguments: dict[str, Any] | None = None, |
| ) -> None: |
| |
| super().__init__(model_name=service_model_name) |
| self._api_key = api_key |
| self._extra_arguments: dict[str, Any] = extra_arguments or {} |
|
|
| def close(self) -> None: |
| return |
|
|
| |
|
|
| def _encode_images(self, images: Sequence[Image.Image]) -> list[str]: |
| return [encode_image(img) for img in images] |
|
|
| def _attach_image_list_argument( |
| self, |
| arguments: dict[str, Any], |
| images: Sequence[Image.Image], |
| arg_name: str, |
| ) -> None: |
| encoded_images = self._encode_images(images) |
| existing_value = arguments.get(arg_name) |
| if isinstance(existing_value, list): |
| arguments[arg_name] = existing_value + encoded_images |
| else: |
| arguments[arg_name] = encoded_images |
|
|
| def _base_arguments(self, **kwargs: Any) -> dict[str, Any]: |
| """Start from configured extra arguments and apply kwargs as overrides.""" |
| arguments = dict(self._extra_arguments) |
| for key, value in kwargs.items(): |
| if value is not None: |
| arguments[key] = value |
| return arguments |
|
|
| |
|
|
| @abstractmethod |
| def _select_model( |
| self, |
| *, |
| prompt: str, |
| images: Sequence[Image.Image] | None, |
| **kwargs: Any, |
| ) -> str: |
| """Return the fal model slug to call for this request.""" |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def _build_arguments( |
| self, |
| *, |
| prompt: str, |
| images: Sequence[Image.Image] | None, |
| **kwargs: Any, |
| ) -> dict[str, Any]: |
| """Return the arguments payload for fal.ai.""" |
| raise NotImplementedError |
|
|
| |
|
|
| def generate( |
| self, |
| prompt: str, |
| images: Sequence[Image.Image] | None = None, |
| *, |
| progress: ProgressCallback | None = None, |
| aspect_ratio: str | None = None, |
| **kwargs: Any, |
| ) -> ImageGenerationResult: |
| |
| if aspect_ratio is not None: |
| kwargs = {**kwargs, "aspect_ratio": aspect_ratio} |
|
|
| model = self._select_model(prompt=prompt, images=images, **kwargs) |
|
|
| call_progress(progress, 0.1, "Encoding inputs for fal.ai image model") |
|
|
| arguments = self._build_arguments( |
| prompt=prompt, |
| images=images, |
| **kwargs, |
| ) |
|
|
| call_progress(progress, 0.4, "Calling fal.ai image model") |
|
|
| response = run_fal( |
| model=model, |
| arguments=arguments, |
| api_key=self._api_key, |
| ) |
|
|
| call_progress(progress, 0.7, "Downloading images from fal.ai") |
|
|
| raw_images = response.get("images") |
| if not isinstance(raw_images, list) or not raw_images: |
| raise GenerationError( |
| "fal.ai image model did not return any images in the response." |
| ) |
|
|
| decoded_images: list[Image.Image] = [] |
|
|
| for item in raw_images: |
| if not isinstance(item, dict): |
| continue |
| url = item.get("url") |
| if not isinstance(url, str) or not url: |
| continue |
|
|
| try: |
| resp = requests.get(url, timeout=30) |
| except requests.RequestException as exc: |
| raise GenerationError( |
| f"Failed to download image from fal.ai URL {url!r}." |
| ) from exc |
|
|
| if resp.status_code != 200: |
| raise GenerationError( |
| f"fal.ai image URL {url!r} returned status code {resp.status_code}." |
| ) |
|
|
| try: |
| img = Image.open(io.BytesIO(resp.content)).convert("RGBA") |
| except OSError as exc: |
| raise GenerationError( |
| "Received invalid image data from fal.ai." |
| ) from exc |
|
|
| img.load() |
| decoded_images.append(img) |
|
|
| if not decoded_images: |
| raise GenerationError( |
| "fal.ai image model did not yield any decodable images." |
| ) |
|
|
| call_progress(progress, 0.95, "Preparing fal.ai image result") |
|
|
| return ImageGenerationResult( |
| provider="fal.ai", |
| model=model, |
| images=decoded_images, |
| raw_response=response, |
| ) |
|
|
|
|
| |
|
|
|
|
| @register_service |
| class FalAINanoBananaGenerator(FalAIImageGenerator): |
| """fal-ai/nano-banana text and edit combination.""" |
|
|
| service_id = "fal_ai_image_nano_banana" |
|
|
| def __init__(self, api_key: str | None = None) -> None: |
| super().__init__( |
| service_model_name="fal-ai/nano-banana", |
| api_key=api_key, |
| extra_arguments={}, |
| ) |
| self._text_model: str = "fal-ai/nano-banana" |
| self._edit_model: str = "fal-ai/nano-banana/edit" |
| |
| self._image_argument: str = "image_urls" |
|
|
| def _select_model( |
| self, |
| *, |
| prompt: str, |
| images: Sequence[Image.Image] | None, |
| **kwargs: Any, |
| ) -> str: |
| if not images: |
| return self._text_model |
| return self._edit_model |
|
|
| def _build_arguments( |
| self, |
| *, |
| prompt: str, |
| images: Sequence[Image.Image] | None, |
| **kwargs: Any, |
| ) -> dict[str, Any]: |
| arguments = self._base_arguments(**kwargs) |
| arguments["prompt"] = prompt |
|
|
| if not images: |
| return arguments |
|
|
| self._attach_image_list_argument(arguments, images, self._image_argument) |
| return arguments |
|
|
| @classmethod |
| def default_model_name(cls) -> str: |
| return "fal-ai/nano-banana" |
|
|
|
|
| @register_service |
| class FalAIReveGenerator(FalAIImageGenerator): |
| """fal-ai/reve combination: |
| - text to image |
| - fast edit for one image |
| - fast remix for multiple images |
| """ |
|
|
| service_id = "fal_ai_image_reve" |
|
|
| def __init__(self, api_key: str | None = None) -> None: |
| super().__init__( |
| service_model_name="fal-ai/reve", |
| api_key=api_key, |
| extra_arguments={}, |
| ) |
| self._text_model: str = "fal-ai/reve/text-to-image" |
| self._edit_model: str = "fal-ai/reve/fast/edit" |
| self._remix_model: str = "fal-ai/reve/fast/remix" |
|
|
| def _select_model( |
| self, |
| *, |
| prompt: str, |
| images: Sequence[Image.Image] | None, |
| **kwargs: Any, |
| ) -> str: |
| count = len(images) if images is not None else 0 |
|
|
| if count == 0: |
| return self._text_model |
| if count == 1: |
| return self._edit_model |
| return self._remix_model |
|
|
| def _build_arguments( |
| self, |
| *, |
| prompt: str, |
| images: Sequence[Image.Image] | None, |
| **kwargs: Any, |
| ) -> dict[str, Any]: |
| arguments = self._base_arguments(**kwargs) |
| arguments["prompt"] = prompt |
|
|
| if not images: |
| |
| return arguments |
|
|
| count = len(images) |
|
|
| if count == 1: |
| |
| arguments["image_url"] = encode_image(images[0]) |
| return arguments |
|
|
| |
| self._attach_image_list_argument(arguments, images, "image_urls") |
| return arguments |
|
|
| @classmethod |
| def default_model_name(cls) -> str: |
| return "fal-ai/reve" |
|
|