| import io |
| from typing import Any, Sequence, List |
|
|
| from PIL import Image |
|
|
| from services.exceptions import GenerationError |
| from services.image.ImageGenerationResult import ImageGenerationResult |
| from services.image.ImageGenerator import ImageGenerator |
| from services.progress import ProgressCallback, call_progress |
| from services.registry import register_service |
| from services.utils.google_service import DEFAULT_GEMINI_IMAGE_MODEL, \ |
| iter_response_parts |
|
|
|
|
| @register_service |
| class GoogleImageGenerator(ImageGenerator): |
| """Image generator backed by the Google Gemini API.""" |
| service_id = "google_gemini_image" |
|
|
| def __init__(self, model_name: str | None = None, api_key: str | None = None) -> None: |
| super().__init__(model_name=model_name) |
| try: |
| from google import genai |
| except ModuleNotFoundError as exc: |
| raise GenerationError( |
| "The 'google-genai' package is required to use the Google image backend." |
| ) from exc |
|
|
| client_kwargs: dict[str, Any] = {} |
| if api_key is not None: |
| client_kwargs["api_key"] = api_key |
|
|
| |
| self._client = genai.Client(**client_kwargs) |
|
|
| @classmethod |
| def default_model_name(cls) -> str: |
| return DEFAULT_GEMINI_IMAGE_MODEL |
|
|
| def close(self) -> None: |
| |
| return |
|
|
| def generate( |
| self, |
| prompt: str, |
| images: Sequence[Image.Image] | None = None, |
| *, |
| progress: ProgressCallback | None = None, |
| aspect_ratio: str | None = None, |
| **kwargs: Any, |
| ) -> ImageGenerationResult: |
| from google.genai import types |
|
|
| call_progress(progress, 0.1, "Encoding inputs") |
|
|
| request_parts: List[object] = [prompt] |
| images = images or [] |
|
|
| for image in images: |
| buffer = io.BytesIO() |
| image.save(buffer, format="PNG") |
| image_bytes = buffer.getvalue() |
|
|
| part: Any | None = None |
|
|
| part_class = getattr(types, "Part", None) |
| from_bytes = getattr(part_class, "from_bytes", None) |
| if callable(from_bytes): |
| part = from_bytes(data=image_bytes, mime_type="image/png") |
|
|
| if part is None: |
| input_image_class = getattr(types, "InputImage", None) |
| if input_image_class is not None: |
| part = input_image_class(mime_type="image/png", data=image_bytes) |
|
|
| if part is None: |
| raise GenerationError( |
| "The installed google-genai client does not support image inputs." |
| ) |
|
|
| request_parts.append(part) |
|
|
| call_progress(progress, 0.4, "Calling Google Gemini image model") |
|
|
| response = self._client.models.generate_content( |
| model=self.model_name, |
| contents=request_parts, |
| ) |
|
|
| call_progress(progress, 0.7, "Decoding image") |
|
|
| generated_images: List[Image.Image] = [] |
| for part in iter_response_parts(response): |
| inline_data = getattr(part, "inline_data", None) |
| if inline_data is None: |
| continue |
|
|
| data: bytes | None = getattr(inline_data, "data", None) |
| if data is None: |
| continue |
|
|
| buffer = io.BytesIO(data) |
| try: |
| img = Image.open(buffer).convert("RGBA") |
| except OSError as exc: |
| raise GenerationError("Received invalid image data from Gemini.") from exc |
| img.load() |
| generated_images.append(img) |
|
|
| if not generated_images: |
| raise GenerationError("The Google image model did not return any image data.") |
|
|
| call_progress(progress, 0.95, "Preparing image result") |
|
|
| return ImageGenerationResult( |
| provider="google", |
| model=self.model_name, |
| images=generated_images, |
| raw_response=response, |
| ) |
|
|