velai-workshop / velai /services /text /GoogleTextGenerator.py
kratadata's picture
Upload folder via script
0f8b3a0 verified
import io
import os
from typing import Any
from velai.services.exceptions import GenerationError
from velai.services.progress import ProgressCallback, call_progress
from velai.services.registry import register_service
from velai.services.text.TextGenerator import TextGenerationInput, TextGenerationResult, TextGenerator
from velai.services.utils.google_service import extract_usage_tokens, iter_response_parts
DEFAULT_GEMINI_TEXT_MODEL = os.getenv("GEMINI_TEXT_MODEL", "gemini-2.5-flash-lite-preview-09-2025")
@register_service(TextGenerator)
class GoogleTextGenerator(TextGenerator):
"""Text generator backed by the Google Gemini API."""
def __init__(self, model_name: str | None = None, api_key: str | None = None) -> None:
super().__init__()
self._model_name = model_name or DEFAULT_GEMINI_TEXT_MODEL
try:
from google import genai # type: ignore[import-not-found]
except ModuleNotFoundError as exc:
raise GenerationError("The 'google-genai' package is required to use the Google text backend.") from exc
client_kwargs: dict[str, Any] = {}
if api_key is not None:
client_kwargs["api_key"] = api_key
self._client = genai.Client(**client_kwargs)
@classmethod
def get_service_id(cls) -> str:
return "google_text"
@classmethod
def get_name(cls) -> str:
return "Google Gemini"
@classmethod
def is_available(cls) -> bool:
return os.getenv("GOOGLE_API_KEY") is not None
def close(self) -> None:
# The google client does not expose an explicit close method currently
return
@staticmethod
def _extract_text(response: Any) -> str:
text = getattr(response, "text", None)
if isinstance(text, str) and text.strip():
return text
parts_text: list[str] = []
for part in iter_response_parts(response):
value = getattr(part, "text", None)
if isinstance(value, str):
parts_text.append(value)
if parts_text:
return "".join(parts_text)
candidates = getattr(response, "candidates", None)
if candidates:
for candidate in candidates:
content = getattr(candidate, "content", None)
if content is None:
continue
content_parts = getattr(content, "parts", None) or []
for part in content_parts:
value = getattr(part, "text", None)
if isinstance(value, str):
parts_text.append(value)
if parts_text:
return "".join(parts_text)
raise GenerationError("The Google text model did not return any text output.")
def generate(
self, input_data: TextGenerationInput, *, progress: ProgressCallback | None = None
) -> TextGenerationResult:
from google.genai import types # type: ignore[import-not-found]
call_progress(progress, 0.1, "Encoding inputs")
request_parts: list[object] = [input_data.prompt]
images = input_data.images or []
for image in images:
buffer = io.BytesIO()
image.save(buffer, format="PNG")
image_bytes = buffer.getvalue()
part: Any | None = None
part_class = getattr(types, "Part", None)
from_bytes = getattr(part_class, "from_bytes", None)
if callable(from_bytes):
part = from_bytes(data=image_bytes, mime_type="image/png")
if part is None:
input_image_class = getattr(types, "InputImage", None)
if input_image_class is not None:
part = input_image_class(mime_type="image/png", data=image_bytes)
if part is None:
raise GenerationError("The installed google-genai client does not support image inputs.")
request_parts.append(part)
call_progress(progress, 0.4, "Calling Google Gemini text model")
response = self._client.models.generate_content(
model=self._model_name,
contents=request_parts,
)
call_progress(progress, 0.7, "Decoding text")
output_text = self._extract_text(response)
input_tokens, output_tokens = extract_usage_tokens(response)
generation_cost: float | None = None
if input_tokens is not None and output_tokens is not None:
input_cost = input_tokens * (0.10 / 1000000)
output_cost = output_tokens * (0.40 / 1000000)
generation_cost = input_cost + output_cost
call_progress(progress, 0.95, "Preparing text result")
return TextGenerationResult(service=self.get_name(), text=output_text, cost=generation_cost)