import io from typing import Any, Sequence, List from PIL import Image from services.exceptions import GenerationError from services.image.ImageGenerationResult import ImageGenerationResult from services.image.ImageGenerator import ImageGenerator from services.progress import ProgressCallback, call_progress from services.registry import register_service from services.utils.google_service import DEFAULT_GEMINI_IMAGE_MODEL, \ iter_response_parts @register_service class GoogleImageGenerator(ImageGenerator): """Image generator backed by the Google Gemini API.""" service_id = "google_gemini_image" def __init__(self, model_name: str | None = None, api_key: str | None = None) -> None: super().__init__(model_name=model_name) try: from google import genai # type: ignore[import-not-found] except ModuleNotFoundError as exc: raise GenerationError( "The 'google-genai' package is required to use the Google image backend." ) from exc client_kwargs: dict[str, Any] = {} if api_key is not None: client_kwargs["api_key"] = api_key # If api_key is omitted, the client will read GOOGLE_API_KEY from the environment self._client = genai.Client(**client_kwargs) @classmethod def default_model_name(cls) -> str: return DEFAULT_GEMINI_IMAGE_MODEL def close(self) -> None: # The google client does not expose an explicit close method currently return def generate( self, prompt: str, images: Sequence[Image.Image] | None = None, *, progress: ProgressCallback | None = None, aspect_ratio: str | None = None, **kwargs: Any, ) -> ImageGenerationResult: from google.genai import types # type: ignore[import-not-found] call_progress(progress, 0.1, "Encoding inputs") request_parts: List[object] = [prompt] images = images or [] for image in images: buffer = io.BytesIO() image.save(buffer, format="PNG") image_bytes = buffer.getvalue() part: Any | None = None part_class = getattr(types, "Part", None) from_bytes = getattr(part_class, "from_bytes", None) if callable(from_bytes): part = from_bytes(data=image_bytes, mime_type="image/png") if part is None: input_image_class = getattr(types, "InputImage", None) if input_image_class is not None: part = input_image_class(mime_type="image/png", data=image_bytes) if part is None: raise GenerationError( "The installed google-genai client does not support image inputs." ) request_parts.append(part) call_progress(progress, 0.4, "Calling Google Gemini image model") response = self._client.models.generate_content( model=self.model_name, contents=request_parts, ) call_progress(progress, 0.7, "Decoding image") generated_images: List[Image.Image] = [] for part in iter_response_parts(response): inline_data = getattr(part, "inline_data", None) if inline_data is None: continue data: bytes | None = getattr(inline_data, "data", None) if data is None: continue buffer = io.BytesIO(data) try: img = Image.open(buffer).convert("RGBA") except OSError as exc: raise GenerationError("Received invalid image data from Gemini.") from exc img.load() generated_images.append(img) if not generated_images: raise GenerationError("The Google image model did not return any image data.") call_progress(progress, 0.95, "Preparing image result") return ImageGenerationResult( provider="google", model=self.model_name, images=generated_images, raw_response=response, )