from __future__ import annotations import io from abc import ABC, abstractmethod from typing import Any, Sequence import requests from PIL import Image from nodes.utils import encode_image from services.exceptions import GenerationError from services.image.ImageGenerationResult import ImageGenerationResult from services.image.ImageGenerator import ImageGenerator from services.progress import ProgressCallback, call_progress from services.registry import register_service from services.utils.fal_service import run_fal class FalAIImageGenerator(ImageGenerator, ABC): """Abstract base class for fal.ai based image generators. This class implements the common pipeline: - select a fal model slug to call - build the input arguments - call fal.ai - download and decode the images Concrete subclasses are responsible for: - storing any model slugs they need - implementing the selection strategy - shaping the input arguments """ service_id = "fal_ai_image" def __init__( self, *, service_model_name: str | None = None, api_key: str | None = None, extra_arguments: dict[str, Any] | None = None, ) -> None: # service_model_name is only for metadata in GenerationService super().__init__(model_name=service_model_name) self._api_key = api_key self._extra_arguments: dict[str, Any] = extra_arguments or {} def close(self) -> None: return # small helpers that do not know any model names def _encode_images(self, images: Sequence[Image.Image]) -> list[str]: return [encode_image(img) for img in images] def _attach_image_list_argument( self, arguments: dict[str, Any], images: Sequence[Image.Image], arg_name: str, ) -> None: encoded_images = self._encode_images(images) existing_value = arguments.get(arg_name) if isinstance(existing_value, list): arguments[arg_name] = existing_value + encoded_images else: arguments[arg_name] = encoded_images def _base_arguments(self, **kwargs: Any) -> dict[str, Any]: """Start from configured extra arguments and apply kwargs as overrides.""" arguments = dict(self._extra_arguments) for key, value in kwargs.items(): if value is not None: arguments[key] = value return arguments # abstract hooks for concrete implementations @abstractmethod def _select_model( self, *, prompt: str, images: Sequence[Image.Image] | None, **kwargs: Any, ) -> str: """Return the fal model slug to call for this request.""" raise NotImplementedError @abstractmethod def _build_arguments( self, *, prompt: str, images: Sequence[Image.Image] | None, **kwargs: Any, ) -> dict[str, Any]: """Return the arguments payload for fal.ai.""" raise NotImplementedError # shared generation pipeline def generate( self, prompt: str, images: Sequence[Image.Image] | None = None, *, progress: ProgressCallback | None = None, aspect_ratio: str | None = None, **kwargs: Any, ) -> ImageGenerationResult: # Include aspect_ratio in kwargs if provided if aspect_ratio is not None: kwargs = {**kwargs, "aspect_ratio": aspect_ratio} model = self._select_model(prompt=prompt, images=images, **kwargs) call_progress(progress, 0.1, "Encoding inputs for fal.ai image model") arguments = self._build_arguments( prompt=prompt, images=images, **kwargs, ) call_progress(progress, 0.4, "Calling fal.ai image model") response = run_fal( model=model, arguments=arguments, api_key=self._api_key, ) call_progress(progress, 0.7, "Downloading images from fal.ai") raw_images = response.get("images") if not isinstance(raw_images, list) or not raw_images: raise GenerationError( "fal.ai image model did not return any images in the response." ) decoded_images: list[Image.Image] = [] for item in raw_images: if not isinstance(item, dict): continue url = item.get("url") if not isinstance(url, str) or not url: continue try: resp = requests.get(url, timeout=30) except requests.RequestException as exc: raise GenerationError( f"Failed to download image from fal.ai URL {url!r}." ) from exc if resp.status_code != 200: raise GenerationError( f"fal.ai image URL {url!r} returned status code {resp.status_code}." ) try: img = Image.open(io.BytesIO(resp.content)).convert("RGBA") except OSError as exc: raise GenerationError( "Received invalid image data from fal.ai." ) from exc img.load() decoded_images.append(img) if not decoded_images: raise GenerationError( "fal.ai image model did not yield any decodable images." ) call_progress(progress, 0.95, "Preparing fal.ai image result") return ImageGenerationResult( provider="fal.ai", model=model, images=decoded_images, raw_response=response, ) # concrete model combinations @register_service class FalAINanoBananaGenerator(FalAIImageGenerator): """fal-ai/nano-banana text and edit combination.""" service_id = "fal_ai_image_nano_banana" def __init__(self, api_key: str | None = None) -> None: super().__init__( service_model_name="fal-ai/nano-banana", api_key=api_key, extra_arguments={}, ) self._text_model: str = "fal-ai/nano-banana" self._edit_model: str = "fal-ai/nano-banana/edit" # nano banana edit expects images under "image_urls" self._image_argument: str = "image_urls" def _select_model( self, *, prompt: str, images: Sequence[Image.Image] | None, **kwargs: Any, ) -> str: if not images: return self._text_model return self._edit_model def _build_arguments( self, *, prompt: str, images: Sequence[Image.Image] | None, **kwargs: Any, ) -> dict[str, Any]: arguments = self._base_arguments(**kwargs) arguments["prompt"] = prompt if not images: return arguments self._attach_image_list_argument(arguments, images, self._image_argument) return arguments @classmethod def default_model_name(cls) -> str: return "fal-ai/nano-banana" @register_service class FalAIReveGenerator(FalAIImageGenerator): """fal-ai/reve combination: - text to image - fast edit for one image - fast remix for multiple images """ service_id = "fal_ai_image_reve" def __init__(self, api_key: str | None = None) -> None: super().__init__( service_model_name="fal-ai/reve", api_key=api_key, extra_arguments={}, ) self._text_model: str = "fal-ai/reve/text-to-image" self._edit_model: str = "fal-ai/reve/fast/edit" self._remix_model: str = "fal-ai/reve/fast/remix" def _select_model( self, *, prompt: str, images: Sequence[Image.Image] | None, **kwargs: Any, ) -> str: count = len(images) if images is not None else 0 if count == 0: return self._text_model if count == 1: return self._edit_model return self._remix_model def _build_arguments( self, *, prompt: str, images: Sequence[Image.Image] | None, **kwargs: Any, ) -> dict[str, Any]: arguments = self._base_arguments(**kwargs) arguments["prompt"] = prompt if not images: # text to image ignores image inputs return arguments count = len(images) if count == 1: # fast edit expects a single "image_url" arguments["image_url"] = encode_image(images[0]) return arguments # fast remix expects "image_urls" as a list self._attach_image_list_argument(arguments, images, "image_urls") return arguments @classmethod def default_model_name(cls) -> str: return "fal-ai/reve"