File size: 9,059 Bytes
3025bb3 abd08cb 3025bb3 abd08cb 3025bb3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 | from __future__ import annotations
import io
from abc import ABC, abstractmethod
from typing import Any, Sequence
import requests
from PIL import Image
from nodes.utils import encode_image
from services.exceptions import GenerationError
from services.image.ImageGenerationResult import ImageGenerationResult
from services.image.ImageGenerator import ImageGenerator
from services.progress import ProgressCallback, call_progress
from services.registry import register_service
from services.utils.fal_service import run_fal
class FalAIImageGenerator(ImageGenerator, ABC):
"""Abstract base class for fal.ai based image generators.
This class implements the common pipeline:
- select a fal model slug to call
- build the input arguments
- call fal.ai
- download and decode the images
Concrete subclasses are responsible for:
- storing any model slugs they need
- implementing the selection strategy
- shaping the input arguments
"""
service_id = "fal_ai_image"
def __init__(
self,
*,
service_model_name: str | None = None,
api_key: str | None = None,
extra_arguments: dict[str, Any] | None = None,
) -> None:
# service_model_name is only for metadata in GenerationService
super().__init__(model_name=service_model_name)
self._api_key = api_key
self._extra_arguments: dict[str, Any] = extra_arguments or {}
def close(self) -> None:
return
# small helpers that do not know any model names
def _encode_images(self, images: Sequence[Image.Image]) -> list[str]:
return [encode_image(img) for img in images]
def _attach_image_list_argument(
self,
arguments: dict[str, Any],
images: Sequence[Image.Image],
arg_name: str,
) -> None:
encoded_images = self._encode_images(images)
existing_value = arguments.get(arg_name)
if isinstance(existing_value, list):
arguments[arg_name] = existing_value + encoded_images
else:
arguments[arg_name] = encoded_images
def _base_arguments(self, **kwargs: Any) -> dict[str, Any]:
"""Start from configured extra arguments and apply kwargs as overrides."""
arguments = dict(self._extra_arguments)
for key, value in kwargs.items():
if value is not None:
arguments[key] = value
return arguments
# abstract hooks for concrete implementations
@abstractmethod
def _select_model(
self,
*,
prompt: str,
images: Sequence[Image.Image] | None,
**kwargs: Any,
) -> str:
"""Return the fal model slug to call for this request."""
raise NotImplementedError
@abstractmethod
def _build_arguments(
self,
*,
prompt: str,
images: Sequence[Image.Image] | None,
**kwargs: Any,
) -> dict[str, Any]:
"""Return the arguments payload for fal.ai."""
raise NotImplementedError
# shared generation pipeline
def generate(
self,
prompt: str,
images: Sequence[Image.Image] | None = None,
*,
progress: ProgressCallback | None = None,
aspect_ratio: str | None = None,
**kwargs: Any,
) -> ImageGenerationResult:
# Include aspect_ratio in kwargs if provided
if aspect_ratio is not None:
kwargs = {**kwargs, "aspect_ratio": aspect_ratio}
model = self._select_model(prompt=prompt, images=images, **kwargs)
call_progress(progress, 0.1, "Encoding inputs for fal.ai image model")
arguments = self._build_arguments(
prompt=prompt,
images=images,
**kwargs,
)
call_progress(progress, 0.4, "Calling fal.ai image model")
response = run_fal(
model=model,
arguments=arguments,
api_key=self._api_key,
)
call_progress(progress, 0.7, "Downloading images from fal.ai")
raw_images = response.get("images")
if not isinstance(raw_images, list) or not raw_images:
raise GenerationError(
"fal.ai image model did not return any images in the response."
)
decoded_images: list[Image.Image] = []
for item in raw_images:
if not isinstance(item, dict):
continue
url = item.get("url")
if not isinstance(url, str) or not url:
continue
try:
resp = requests.get(url, timeout=30)
except requests.RequestException as exc:
raise GenerationError(
f"Failed to download image from fal.ai URL {url!r}."
) from exc
if resp.status_code != 200:
raise GenerationError(
f"fal.ai image URL {url!r} returned status code {resp.status_code}."
)
try:
img = Image.open(io.BytesIO(resp.content)).convert("RGBA")
except OSError as exc:
raise GenerationError(
"Received invalid image data from fal.ai."
) from exc
img.load()
decoded_images.append(img)
if not decoded_images:
raise GenerationError(
"fal.ai image model did not yield any decodable images."
)
call_progress(progress, 0.95, "Preparing fal.ai image result")
return ImageGenerationResult(
provider="fal.ai",
model=model,
images=decoded_images,
raw_response=response,
)
# concrete model combinations
@register_service
class FalAINanoBananaGenerator(FalAIImageGenerator):
"""fal-ai/nano-banana text and edit combination."""
service_id = "fal_ai_image_nano_banana"
def __init__(self, api_key: str | None = None) -> None:
super().__init__(
service_model_name="fal-ai/nano-banana",
api_key=api_key,
extra_arguments={},
)
self._text_model: str = "fal-ai/nano-banana"
self._edit_model: str = "fal-ai/nano-banana/edit"
# nano banana edit expects images under "image_urls"
self._image_argument: str = "image_urls"
def _select_model(
self,
*,
prompt: str,
images: Sequence[Image.Image] | None,
**kwargs: Any,
) -> str:
if not images:
return self._text_model
return self._edit_model
def _build_arguments(
self,
*,
prompt: str,
images: Sequence[Image.Image] | None,
**kwargs: Any,
) -> dict[str, Any]:
arguments = self._base_arguments(**kwargs)
arguments["prompt"] = prompt
if not images:
return arguments
self._attach_image_list_argument(arguments, images, self._image_argument)
return arguments
@classmethod
def default_model_name(cls) -> str:
return "fal-ai/nano-banana"
@register_service
class FalAIReveGenerator(FalAIImageGenerator):
"""fal-ai/reve combination:
- text to image
- fast edit for one image
- fast remix for multiple images
"""
service_id = "fal_ai_image_reve"
def __init__(self, api_key: str | None = None) -> None:
super().__init__(
service_model_name="fal-ai/reve",
api_key=api_key,
extra_arguments={},
)
self._text_model: str = "fal-ai/reve/text-to-image"
self._edit_model: str = "fal-ai/reve/fast/edit"
self._remix_model: str = "fal-ai/reve/fast/remix"
def _select_model(
self,
*,
prompt: str,
images: Sequence[Image.Image] | None,
**kwargs: Any,
) -> str:
count = len(images) if images is not None else 0
if count == 0:
return self._text_model
if count == 1:
return self._edit_model
return self._remix_model
def _build_arguments(
self,
*,
prompt: str,
images: Sequence[Image.Image] | None,
**kwargs: Any,
) -> dict[str, Any]:
arguments = self._base_arguments(**kwargs)
arguments["prompt"] = prompt
if not images:
# text to image ignores image inputs
return arguments
count = len(images)
if count == 1:
# fast edit expects a single "image_url"
arguments["image_url"] = encode_image(images[0])
return arguments
# fast remix expects "image_urls" as a list
self._attach_image_list_argument(arguments, images, "image_urls")
return arguments
@classmethod
def default_model_name(cls) -> str:
return "fal-ai/reve"
|