velai / services /text /GoogleTextGenerator.py
cansik's picture
Upload folder via script
691f45a verified
import io
from typing import Any, Sequence
from PIL import Image
from services.exceptions import GenerationError
from services.progress import call_progress, ProgressCallback
from services.registry import register_service
from services.text.TextGenerationResult import TextGenerationResult
from services.text.TextGenerator import TextGenerator
from services.utils.google_service import DEFAULT_GEMINI_TEXT_MODEL, iter_response_parts, extract_usage_tokens
@register_service
class GoogleTextGenerator(TextGenerator):
"""Text generator backed by the Google Gemini API."""
service_id = "google_gemini_text"
def __init__(self, model_name: str | None = None, api_key: str | None = None) -> None:
super().__init__(model_name=model_name)
try:
from google import genai # type: ignore[import-not-found]
except ModuleNotFoundError as exc:
raise GenerationError(
"The 'google-genai' package is required to use the Google text backend."
) from exc
client_kwargs: dict[str, Any] = {}
if api_key is not None:
client_kwargs["api_key"] = api_key
self._client = genai.Client(**client_kwargs)
@classmethod
def default_model_name(cls) -> str:
return DEFAULT_GEMINI_TEXT_MODEL
def close(self) -> None:
# The google client does not expose an explicit close method currently
return
@staticmethod
def _extract_text(response: Any) -> str:
text = getattr(response, "text", None)
if isinstance(text, str) and text.strip():
return text
parts_text: list[str] = []
for part in iter_response_parts(response):
value = getattr(part, "text", None)
if isinstance(value, str):
parts_text.append(value)
if parts_text:
return "".join(parts_text)
candidates = getattr(response, "candidates", None)
if candidates:
for candidate in candidates:
content = getattr(candidate, "content", None)
if content is None:
continue
content_parts = getattr(content, "parts", None) or []
for part in content_parts:
value = getattr(part, "text", None)
if isinstance(value, str):
parts_text.append(value)
if parts_text:
return "".join(parts_text)
raise GenerationError("The Google text model did not return any text output.")
def generate(
self,
prompt: str,
images: Sequence[Image.Image] | None = None,
*,
progress: ProgressCallback | None = None,
) -> TextGenerationResult:
from google.genai import types # type: ignore[import-not-found]
call_progress(progress, 0.1, "Encoding inputs")
request_parts: list[object] = [prompt]
images = images or []
for image in images:
buffer = io.BytesIO()
image.save(buffer, format="PNG")
image_bytes = buffer.getvalue()
part: Any | None = None
part_class = getattr(types, "Part", None)
from_bytes = getattr(part_class, "from_bytes", None)
if callable(from_bytes):
part = from_bytes(data=image_bytes, mime_type="image/png")
if part is None:
input_image_class = getattr(types, "InputImage", None)
if input_image_class is not None:
part = input_image_class(mime_type="image/png", data=image_bytes)
if part is None:
raise GenerationError(
"The installed google-genai client does not support image inputs."
)
request_parts.append(part)
call_progress(progress, 0.4, "Calling Google Gemini text model")
response = self._client.models.generate_content(
model=self.model_name,
contents=request_parts,
)
call_progress(progress, 0.7, "Decoding text")
output_text = self._extract_text(response)
input_tokens, output_tokens = extract_usage_tokens(response)
call_progress(progress, 0.95, "Preparing text result")
return TextGenerationResult(
provider="google",
model=self.model_name,
text=output_text,
raw_response=response,
)