velai / services /image /GoogleImageGenerator.py
cansik's picture
Upload folder via script
abd08cb verified
import io
from typing import Any, Sequence, List
from PIL import Image
from services.exceptions import GenerationError
from services.image.ImageGenerationResult import ImageGenerationResult
from services.image.ImageGenerator import ImageGenerator
from services.progress import ProgressCallback, call_progress
from services.registry import register_service
from services.utils.google_service import DEFAULT_GEMINI_IMAGE_MODEL, \
iter_response_parts
@register_service
class GoogleImageGenerator(ImageGenerator):
"""Image generator backed by the Google Gemini API."""
service_id = "google_gemini_image"
def __init__(self, model_name: str | None = None, api_key: str | None = None) -> None:
super().__init__(model_name=model_name)
try:
from google import genai # type: ignore[import-not-found]
except ModuleNotFoundError as exc:
raise GenerationError(
"The 'google-genai' package is required to use the Google image backend."
) from exc
client_kwargs: dict[str, Any] = {}
if api_key is not None:
client_kwargs["api_key"] = api_key
# If api_key is omitted, the client will read GOOGLE_API_KEY from the environment
self._client = genai.Client(**client_kwargs)
@classmethod
def default_model_name(cls) -> str:
return DEFAULT_GEMINI_IMAGE_MODEL
def close(self) -> None:
# The google client does not expose an explicit close method currently
return
def generate(
self,
prompt: str,
images: Sequence[Image.Image] | None = None,
*,
progress: ProgressCallback | None = None,
aspect_ratio: str | None = None,
**kwargs: Any,
) -> ImageGenerationResult:
from google.genai import types # type: ignore[import-not-found]
call_progress(progress, 0.1, "Encoding inputs")
request_parts: List[object] = [prompt]
images = images or []
for image in images:
buffer = io.BytesIO()
image.save(buffer, format="PNG")
image_bytes = buffer.getvalue()
part: Any | None = None
part_class = getattr(types, "Part", None)
from_bytes = getattr(part_class, "from_bytes", None)
if callable(from_bytes):
part = from_bytes(data=image_bytes, mime_type="image/png")
if part is None:
input_image_class = getattr(types, "InputImage", None)
if input_image_class is not None:
part = input_image_class(mime_type="image/png", data=image_bytes)
if part is None:
raise GenerationError(
"The installed google-genai client does not support image inputs."
)
request_parts.append(part)
call_progress(progress, 0.4, "Calling Google Gemini image model")
response = self._client.models.generate_content(
model=self.model_name,
contents=request_parts,
)
call_progress(progress, 0.7, "Decoding image")
generated_images: List[Image.Image] = []
for part in iter_response_parts(response):
inline_data = getattr(part, "inline_data", None)
if inline_data is None:
continue
data: bytes | None = getattr(inline_data, "data", None)
if data is None:
continue
buffer = io.BytesIO(data)
try:
img = Image.open(buffer).convert("RGBA")
except OSError as exc:
raise GenerationError("Received invalid image data from Gemini.") from exc
img.load()
generated_images.append(img)
if not generated_images:
raise GenerationError("The Google image model did not return any image data.")
call_progress(progress, 0.95, "Preparing image result")
return ImageGenerationResult(
provider="google",
model=self.model_name,
images=generated_images,
raw_response=response,
)