import io from typing import Any, Sequence from PIL import Image from services.exceptions import GenerationError from services.progress import call_progress, ProgressCallback from services.registry import register_service from services.text.TextGenerationResult import TextGenerationResult from services.text.TextGenerator import TextGenerator from services.utils.google_service import DEFAULT_GEMINI_TEXT_MODEL, iter_response_parts, extract_usage_tokens @register_service class GoogleTextGenerator(TextGenerator): """Text generator backed by the Google Gemini API.""" service_id = "google_gemini_text" def __init__(self, model_name: str | None = None, api_key: str | None = None) -> None: super().__init__(model_name=model_name) try: from google import genai # type: ignore[import-not-found] except ModuleNotFoundError as exc: raise GenerationError( "The 'google-genai' package is required to use the Google text backend." ) from exc client_kwargs: dict[str, Any] = {} if api_key is not None: client_kwargs["api_key"] = api_key self._client = genai.Client(**client_kwargs) @classmethod def default_model_name(cls) -> str: return DEFAULT_GEMINI_TEXT_MODEL def close(self) -> None: # The google client does not expose an explicit close method currently return @staticmethod def _extract_text(response: Any) -> str: text = getattr(response, "text", None) if isinstance(text, str) and text.strip(): return text parts_text: list[str] = [] for part in iter_response_parts(response): value = getattr(part, "text", None) if isinstance(value, str): parts_text.append(value) if parts_text: return "".join(parts_text) candidates = getattr(response, "candidates", None) if candidates: for candidate in candidates: content = getattr(candidate, "content", None) if content is None: continue content_parts = getattr(content, "parts", None) or [] for part in content_parts: value = getattr(part, "text", None) if isinstance(value, str): parts_text.append(value) if parts_text: return "".join(parts_text) raise GenerationError("The Google text model did not return any text output.") def generate( self, prompt: str, images: Sequence[Image.Image] | None = None, *, progress: ProgressCallback | None = None, ) -> TextGenerationResult: from google.genai import types # type: ignore[import-not-found] call_progress(progress, 0.1, "Encoding inputs") request_parts: list[object] = [prompt] images = images or [] for image in images: buffer = io.BytesIO() image.save(buffer, format="PNG") image_bytes = buffer.getvalue() part: Any | None = None part_class = getattr(types, "Part", None) from_bytes = getattr(part_class, "from_bytes", None) if callable(from_bytes): part = from_bytes(data=image_bytes, mime_type="image/png") if part is None: input_image_class = getattr(types, "InputImage", None) if input_image_class is not None: part = input_image_class(mime_type="image/png", data=image_bytes) if part is None: raise GenerationError( "The installed google-genai client does not support image inputs." ) request_parts.append(part) call_progress(progress, 0.4, "Calling Google Gemini text model") response = self._client.models.generate_content( model=self.model_name, contents=request_parts, ) call_progress(progress, 0.7, "Decoding text") output_text = self._extract_text(response) input_tokens, output_tokens = extract_usage_tokens(response) call_progress(progress, 0.95, "Preparing text result") return TextGenerationResult( provider="google", model=self.model_name, text=output_text, raw_response=response, )