| | import io |
| | from typing import Any, Sequence |
| |
|
| | from PIL import Image |
| |
|
| | from services.exceptions import GenerationError |
| | from services.progress import call_progress, ProgressCallback |
| | from services.registry import register_service |
| | from services.text.TextGenerationResult import TextGenerationResult |
| | from services.text.TextGenerator import TextGenerator |
| | from services.utils.google_service import DEFAULT_GEMINI_TEXT_MODEL, iter_response_parts, extract_usage_tokens |
| |
|
| |
|
| | @register_service |
| | class GoogleTextGenerator(TextGenerator): |
| | """Text generator backed by the Google Gemini API.""" |
| | service_id = "google_gemini_text" |
| |
|
| | def __init__(self, model_name: str | None = None, api_key: str | None = None) -> None: |
| | super().__init__(model_name=model_name) |
| | try: |
| | from google import genai |
| | except ModuleNotFoundError as exc: |
| | raise GenerationError( |
| | "The 'google-genai' package is required to use the Google text backend." |
| | ) from exc |
| |
|
| | client_kwargs: dict[str, Any] = {} |
| | if api_key is not None: |
| | client_kwargs["api_key"] = api_key |
| |
|
| | self._client = genai.Client(**client_kwargs) |
| |
|
| | @classmethod |
| | def default_model_name(cls) -> str: |
| | return DEFAULT_GEMINI_TEXT_MODEL |
| |
|
| | def close(self) -> None: |
| | |
| | return |
| |
|
| | @staticmethod |
| | def _extract_text(response: Any) -> str: |
| | text = getattr(response, "text", None) |
| | if isinstance(text, str) and text.strip(): |
| | return text |
| |
|
| | parts_text: list[str] = [] |
| | for part in iter_response_parts(response): |
| | value = getattr(part, "text", None) |
| | if isinstance(value, str): |
| | parts_text.append(value) |
| |
|
| | if parts_text: |
| | return "".join(parts_text) |
| |
|
| | candidates = getattr(response, "candidates", None) |
| | if candidates: |
| | for candidate in candidates: |
| | content = getattr(candidate, "content", None) |
| | if content is None: |
| | continue |
| | content_parts = getattr(content, "parts", None) or [] |
| | for part in content_parts: |
| | value = getattr(part, "text", None) |
| | if isinstance(value, str): |
| | parts_text.append(value) |
| | if parts_text: |
| | return "".join(parts_text) |
| |
|
| | raise GenerationError("The Google text model did not return any text output.") |
| |
|
| | def generate( |
| | self, |
| | prompt: str, |
| | images: Sequence[Image.Image] | None = None, |
| | *, |
| | progress: ProgressCallback | None = None, |
| | ) -> TextGenerationResult: |
| | from google.genai import types |
| |
|
| | call_progress(progress, 0.1, "Encoding inputs") |
| |
|
| | request_parts: list[object] = [prompt] |
| | images = images or [] |
| |
|
| | for image in images: |
| | buffer = io.BytesIO() |
| | image.save(buffer, format="PNG") |
| | image_bytes = buffer.getvalue() |
| |
|
| | part: Any | None = None |
| |
|
| | part_class = getattr(types, "Part", None) |
| | from_bytes = getattr(part_class, "from_bytes", None) |
| | if callable(from_bytes): |
| | part = from_bytes(data=image_bytes, mime_type="image/png") |
| |
|
| | if part is None: |
| | input_image_class = getattr(types, "InputImage", None) |
| | if input_image_class is not None: |
| | part = input_image_class(mime_type="image/png", data=image_bytes) |
| |
|
| | if part is None: |
| | raise GenerationError( |
| | "The installed google-genai client does not support image inputs." |
| | ) |
| |
|
| | request_parts.append(part) |
| |
|
| | call_progress(progress, 0.4, "Calling Google Gemini text model") |
| |
|
| | response = self._client.models.generate_content( |
| | model=self.model_name, |
| | contents=request_parts, |
| | ) |
| |
|
| | call_progress(progress, 0.7, "Decoding text") |
| |
|
| | output_text = self._extract_text(response) |
| |
|
| | input_tokens, output_tokens = extract_usage_tokens(response) |
| |
|
| | call_progress(progress, 0.95, "Preparing text result") |
| |
|
| | return TextGenerationResult( |
| | provider="google", |
| | model=self.model_name, |
| | text=output_text, |
| | raw_response=response, |
| | ) |
| |
|