File size: 4,500 Bytes
691f45a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import io
from typing import Any, Sequence
from PIL import Image
from services.exceptions import GenerationError
from services.progress import call_progress, ProgressCallback
from services.registry import register_service
from services.text.TextGenerationResult import TextGenerationResult
from services.text.TextGenerator import TextGenerator
from services.utils.google_service import DEFAULT_GEMINI_TEXT_MODEL, iter_response_parts, extract_usage_tokens
@register_service
class GoogleTextGenerator(TextGenerator):
"""Text generator backed by the Google Gemini API."""
service_id = "google_gemini_text"
def __init__(self, model_name: str | None = None, api_key: str | None = None) -> None:
super().__init__(model_name=model_name)
try:
from google import genai # type: ignore[import-not-found]
except ModuleNotFoundError as exc:
raise GenerationError(
"The 'google-genai' package is required to use the Google text backend."
) from exc
client_kwargs: dict[str, Any] = {}
if api_key is not None:
client_kwargs["api_key"] = api_key
self._client = genai.Client(**client_kwargs)
@classmethod
def default_model_name(cls) -> str:
return DEFAULT_GEMINI_TEXT_MODEL
def close(self) -> None:
# The google client does not expose an explicit close method currently
return
@staticmethod
def _extract_text(response: Any) -> str:
text = getattr(response, "text", None)
if isinstance(text, str) and text.strip():
return text
parts_text: list[str] = []
for part in iter_response_parts(response):
value = getattr(part, "text", None)
if isinstance(value, str):
parts_text.append(value)
if parts_text:
return "".join(parts_text)
candidates = getattr(response, "candidates", None)
if candidates:
for candidate in candidates:
content = getattr(candidate, "content", None)
if content is None:
continue
content_parts = getattr(content, "parts", None) or []
for part in content_parts:
value = getattr(part, "text", None)
if isinstance(value, str):
parts_text.append(value)
if parts_text:
return "".join(parts_text)
raise GenerationError("The Google text model did not return any text output.")
def generate(
self,
prompt: str,
images: Sequence[Image.Image] | None = None,
*,
progress: ProgressCallback | None = None,
) -> TextGenerationResult:
from google.genai import types # type: ignore[import-not-found]
call_progress(progress, 0.1, "Encoding inputs")
request_parts: list[object] = [prompt]
images = images or []
for image in images:
buffer = io.BytesIO()
image.save(buffer, format="PNG")
image_bytes = buffer.getvalue()
part: Any | None = None
part_class = getattr(types, "Part", None)
from_bytes = getattr(part_class, "from_bytes", None)
if callable(from_bytes):
part = from_bytes(data=image_bytes, mime_type="image/png")
if part is None:
input_image_class = getattr(types, "InputImage", None)
if input_image_class is not None:
part = input_image_class(mime_type="image/png", data=image_bytes)
if part is None:
raise GenerationError(
"The installed google-genai client does not support image inputs."
)
request_parts.append(part)
call_progress(progress, 0.4, "Calling Google Gemini text model")
response = self._client.models.generate_content(
model=self.model_name,
contents=request_parts,
)
call_progress(progress, 0.7, "Decoding text")
output_text = self._extract_text(response)
input_tokens, output_tokens = extract_usage_tokens(response)
call_progress(progress, 0.95, "Preparing text result")
return TextGenerationResult(
provider="google",
model=self.model_name,
text=output_text,
raw_response=response,
)
|