velai / services /image /FalAIImageGenerator.py
cansik's picture
Upload folder via script
abd08cb verified
from __future__ import annotations
import io
from abc import ABC, abstractmethod
from typing import Any, Sequence
import requests
from PIL import Image
from nodes.utils import encode_image
from services.exceptions import GenerationError
from services.image.ImageGenerationResult import ImageGenerationResult
from services.image.ImageGenerator import ImageGenerator
from services.progress import ProgressCallback, call_progress
from services.registry import register_service
from services.utils.fal_service import run_fal
class FalAIImageGenerator(ImageGenerator, ABC):
"""Abstract base class for fal.ai based image generators.
This class implements the common pipeline:
- select a fal model slug to call
- build the input arguments
- call fal.ai
- download and decode the images
Concrete subclasses are responsible for:
- storing any model slugs they need
- implementing the selection strategy
- shaping the input arguments
"""
service_id = "fal_ai_image"
def __init__(
self,
*,
service_model_name: str | None = None,
api_key: str | None = None,
extra_arguments: dict[str, Any] | None = None,
) -> None:
# service_model_name is only for metadata in GenerationService
super().__init__(model_name=service_model_name)
self._api_key = api_key
self._extra_arguments: dict[str, Any] = extra_arguments or {}
def close(self) -> None:
return
# small helpers that do not know any model names
def _encode_images(self, images: Sequence[Image.Image]) -> list[str]:
return [encode_image(img) for img in images]
def _attach_image_list_argument(
self,
arguments: dict[str, Any],
images: Sequence[Image.Image],
arg_name: str,
) -> None:
encoded_images = self._encode_images(images)
existing_value = arguments.get(arg_name)
if isinstance(existing_value, list):
arguments[arg_name] = existing_value + encoded_images
else:
arguments[arg_name] = encoded_images
def _base_arguments(self, **kwargs: Any) -> dict[str, Any]:
"""Start from configured extra arguments and apply kwargs as overrides."""
arguments = dict(self._extra_arguments)
for key, value in kwargs.items():
if value is not None:
arguments[key] = value
return arguments
# abstract hooks for concrete implementations
@abstractmethod
def _select_model(
self,
*,
prompt: str,
images: Sequence[Image.Image] | None,
**kwargs: Any,
) -> str:
"""Return the fal model slug to call for this request."""
raise NotImplementedError
@abstractmethod
def _build_arguments(
self,
*,
prompt: str,
images: Sequence[Image.Image] | None,
**kwargs: Any,
) -> dict[str, Any]:
"""Return the arguments payload for fal.ai."""
raise NotImplementedError
# shared generation pipeline
def generate(
self,
prompt: str,
images: Sequence[Image.Image] | None = None,
*,
progress: ProgressCallback | None = None,
aspect_ratio: str | None = None,
**kwargs: Any,
) -> ImageGenerationResult:
# Include aspect_ratio in kwargs if provided
if aspect_ratio is not None:
kwargs = {**kwargs, "aspect_ratio": aspect_ratio}
model = self._select_model(prompt=prompt, images=images, **kwargs)
call_progress(progress, 0.1, "Encoding inputs for fal.ai image model")
arguments = self._build_arguments(
prompt=prompt,
images=images,
**kwargs,
)
call_progress(progress, 0.4, "Calling fal.ai image model")
response = run_fal(
model=model,
arguments=arguments,
api_key=self._api_key,
)
call_progress(progress, 0.7, "Downloading images from fal.ai")
raw_images = response.get("images")
if not isinstance(raw_images, list) or not raw_images:
raise GenerationError(
"fal.ai image model did not return any images in the response."
)
decoded_images: list[Image.Image] = []
for item in raw_images:
if not isinstance(item, dict):
continue
url = item.get("url")
if not isinstance(url, str) or not url:
continue
try:
resp = requests.get(url, timeout=30)
except requests.RequestException as exc:
raise GenerationError(
f"Failed to download image from fal.ai URL {url!r}."
) from exc
if resp.status_code != 200:
raise GenerationError(
f"fal.ai image URL {url!r} returned status code {resp.status_code}."
)
try:
img = Image.open(io.BytesIO(resp.content)).convert("RGBA")
except OSError as exc:
raise GenerationError(
"Received invalid image data from fal.ai."
) from exc
img.load()
decoded_images.append(img)
if not decoded_images:
raise GenerationError(
"fal.ai image model did not yield any decodable images."
)
call_progress(progress, 0.95, "Preparing fal.ai image result")
return ImageGenerationResult(
provider="fal.ai",
model=model,
images=decoded_images,
raw_response=response,
)
# concrete model combinations
@register_service
class FalAINanoBananaGenerator(FalAIImageGenerator):
"""fal-ai/nano-banana text and edit combination."""
service_id = "fal_ai_image_nano_banana"
def __init__(self, api_key: str | None = None) -> None:
super().__init__(
service_model_name="fal-ai/nano-banana",
api_key=api_key,
extra_arguments={},
)
self._text_model: str = "fal-ai/nano-banana"
self._edit_model: str = "fal-ai/nano-banana/edit"
# nano banana edit expects images under "image_urls"
self._image_argument: str = "image_urls"
def _select_model(
self,
*,
prompt: str,
images: Sequence[Image.Image] | None,
**kwargs: Any,
) -> str:
if not images:
return self._text_model
return self._edit_model
def _build_arguments(
self,
*,
prompt: str,
images: Sequence[Image.Image] | None,
**kwargs: Any,
) -> dict[str, Any]:
arguments = self._base_arguments(**kwargs)
arguments["prompt"] = prompt
if not images:
return arguments
self._attach_image_list_argument(arguments, images, self._image_argument)
return arguments
@classmethod
def default_model_name(cls) -> str:
return "fal-ai/nano-banana"
@register_service
class FalAIReveGenerator(FalAIImageGenerator):
"""fal-ai/reve combination:
- text to image
- fast edit for one image
- fast remix for multiple images
"""
service_id = "fal_ai_image_reve"
def __init__(self, api_key: str | None = None) -> None:
super().__init__(
service_model_name="fal-ai/reve",
api_key=api_key,
extra_arguments={},
)
self._text_model: str = "fal-ai/reve/text-to-image"
self._edit_model: str = "fal-ai/reve/fast/edit"
self._remix_model: str = "fal-ai/reve/fast/remix"
def _select_model(
self,
*,
prompt: str,
images: Sequence[Image.Image] | None,
**kwargs: Any,
) -> str:
count = len(images) if images is not None else 0
if count == 0:
return self._text_model
if count == 1:
return self._edit_model
return self._remix_model
def _build_arguments(
self,
*,
prompt: str,
images: Sequence[Image.Image] | None,
**kwargs: Any,
) -> dict[str, Any]:
arguments = self._base_arguments(**kwargs)
arguments["prompt"] = prompt
if not images:
# text to image ignores image inputs
return arguments
count = len(images)
if count == 1:
# fast edit expects a single "image_url"
arguments["image_url"] = encode_image(images[0])
return arguments
# fast remix expects "image_urls" as a list
self._attach_image_list_argument(arguments, images, "image_urls")
return arguments
@classmethod
def default_model_name(cls) -> str:
return "fal-ai/reve"