Spaces:
Running
Running
| import io | |
| import os | |
| from typing import Any | |
| from velai.services.exceptions import GenerationError | |
| from velai.services.progress import ProgressCallback, call_progress | |
| from velai.services.registry import register_service | |
| from velai.services.text.TextGenerator import TextGenerationInput, TextGenerationResult, TextGenerator | |
| from velai.services.utils.google_service import extract_usage_tokens, iter_response_parts | |
| DEFAULT_GEMINI_TEXT_MODEL = os.getenv("GEMINI_TEXT_MODEL", "gemini-2.5-flash-lite-preview-09-2025") | |
| class GoogleTextGenerator(TextGenerator): | |
| """Text generator backed by the Google Gemini API.""" | |
| def __init__(self, model_name: str | None = None, api_key: str | None = None) -> None: | |
| super().__init__() | |
| self._model_name = model_name or DEFAULT_GEMINI_TEXT_MODEL | |
| try: | |
| from google import genai # type: ignore[import-not-found] | |
| except ModuleNotFoundError as exc: | |
| raise GenerationError("The 'google-genai' package is required to use the Google text backend.") from exc | |
| client_kwargs: dict[str, Any] = {} | |
| if api_key is not None: | |
| client_kwargs["api_key"] = api_key | |
| self._client = genai.Client(**client_kwargs) | |
| def get_service_id(cls) -> str: | |
| return "google_text" | |
| def get_name(cls) -> str: | |
| return "Google Gemini" | |
| def is_available(cls) -> bool: | |
| return os.getenv("GOOGLE_API_KEY") is not None | |
| def close(self) -> None: | |
| # The google client does not expose an explicit close method currently | |
| return | |
| def _extract_text(response: Any) -> str: | |
| text = getattr(response, "text", None) | |
| if isinstance(text, str) and text.strip(): | |
| return text | |
| parts_text: list[str] = [] | |
| for part in iter_response_parts(response): | |
| value = getattr(part, "text", None) | |
| if isinstance(value, str): | |
| parts_text.append(value) | |
| if parts_text: | |
| return "".join(parts_text) | |
| candidates = getattr(response, "candidates", None) | |
| if candidates: | |
| for candidate in candidates: | |
| content = getattr(candidate, "content", None) | |
| if content is None: | |
| continue | |
| content_parts = getattr(content, "parts", None) or [] | |
| for part in content_parts: | |
| value = getattr(part, "text", None) | |
| if isinstance(value, str): | |
| parts_text.append(value) | |
| if parts_text: | |
| return "".join(parts_text) | |
| raise GenerationError("The Google text model did not return any text output.") | |
| def generate( | |
| self, input_data: TextGenerationInput, *, progress: ProgressCallback | None = None | |
| ) -> TextGenerationResult: | |
| from google.genai import types # type: ignore[import-not-found] | |
| call_progress(progress, 0.1, "Encoding inputs") | |
| request_parts: list[object] = [input_data.prompt] | |
| images = input_data.images or [] | |
| for image in images: | |
| buffer = io.BytesIO() | |
| image.save(buffer, format="PNG") | |
| image_bytes = buffer.getvalue() | |
| part: Any | None = None | |
| part_class = getattr(types, "Part", None) | |
| from_bytes = getattr(part_class, "from_bytes", None) | |
| if callable(from_bytes): | |
| part = from_bytes(data=image_bytes, mime_type="image/png") | |
| if part is None: | |
| input_image_class = getattr(types, "InputImage", None) | |
| if input_image_class is not None: | |
| part = input_image_class(mime_type="image/png", data=image_bytes) | |
| if part is None: | |
| raise GenerationError("The installed google-genai client does not support image inputs.") | |
| request_parts.append(part) | |
| call_progress(progress, 0.4, "Calling Google Gemini text model") | |
| response = self._client.models.generate_content( | |
| model=self._model_name, | |
| contents=request_parts, | |
| ) | |
| call_progress(progress, 0.7, "Decoding text") | |
| output_text = self._extract_text(response) | |
| input_tokens, output_tokens = extract_usage_tokens(response) | |
| generation_cost: float | None = None | |
| if input_tokens is not None and output_tokens is not None: | |
| input_cost = input_tokens * (0.10 / 1000000) | |
| output_cost = output_tokens * (0.40 / 1000000) | |
| generation_cost = input_cost + output_cost | |
| call_progress(progress, 0.95, "Preparing text result") | |
| return TextGenerationResult(service=self.get_name(), text=output_text, cost=generation_cost) | |