Spaces:
Running
Running
File size: 4,834 Bytes
0f8b3a0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import io
import os
from typing import Any
from velai.services.exceptions import GenerationError
from velai.services.progress import ProgressCallback, call_progress
from velai.services.registry import register_service
from velai.services.text.TextGenerator import TextGenerationInput, TextGenerationResult, TextGenerator
from velai.services.utils.google_service import extract_usage_tokens, iter_response_parts
DEFAULT_GEMINI_TEXT_MODEL = os.getenv("GEMINI_TEXT_MODEL", "gemini-2.5-flash-lite-preview-09-2025")
@register_service(TextGenerator)
class GoogleTextGenerator(TextGenerator):
"""Text generator backed by the Google Gemini API."""
def __init__(self, model_name: str | None = None, api_key: str | None = None) -> None:
super().__init__()
self._model_name = model_name or DEFAULT_GEMINI_TEXT_MODEL
try:
from google import genai # type: ignore[import-not-found]
except ModuleNotFoundError as exc:
raise GenerationError("The 'google-genai' package is required to use the Google text backend.") from exc
client_kwargs: dict[str, Any] = {}
if api_key is not None:
client_kwargs["api_key"] = api_key
self._client = genai.Client(**client_kwargs)
@classmethod
def get_service_id(cls) -> str:
return "google_text"
@classmethod
def get_name(cls) -> str:
return "Google Gemini"
@classmethod
def is_available(cls) -> bool:
return os.getenv("GOOGLE_API_KEY") is not None
def close(self) -> None:
# The google client does not expose an explicit close method currently
return
@staticmethod
def _extract_text(response: Any) -> str:
text = getattr(response, "text", None)
if isinstance(text, str) and text.strip():
return text
parts_text: list[str] = []
for part in iter_response_parts(response):
value = getattr(part, "text", None)
if isinstance(value, str):
parts_text.append(value)
if parts_text:
return "".join(parts_text)
candidates = getattr(response, "candidates", None)
if candidates:
for candidate in candidates:
content = getattr(candidate, "content", None)
if content is None:
continue
content_parts = getattr(content, "parts", None) or []
for part in content_parts:
value = getattr(part, "text", None)
if isinstance(value, str):
parts_text.append(value)
if parts_text:
return "".join(parts_text)
raise GenerationError("The Google text model did not return any text output.")
def generate(
self, input_data: TextGenerationInput, *, progress: ProgressCallback | None = None
) -> TextGenerationResult:
from google.genai import types # type: ignore[import-not-found]
call_progress(progress, 0.1, "Encoding inputs")
request_parts: list[object] = [input_data.prompt]
images = input_data.images or []
for image in images:
buffer = io.BytesIO()
image.save(buffer, format="PNG")
image_bytes = buffer.getvalue()
part: Any | None = None
part_class = getattr(types, "Part", None)
from_bytes = getattr(part_class, "from_bytes", None)
if callable(from_bytes):
part = from_bytes(data=image_bytes, mime_type="image/png")
if part is None:
input_image_class = getattr(types, "InputImage", None)
if input_image_class is not None:
part = input_image_class(mime_type="image/png", data=image_bytes)
if part is None:
raise GenerationError("The installed google-genai client does not support image inputs.")
request_parts.append(part)
call_progress(progress, 0.4, "Calling Google Gemini text model")
response = self._client.models.generate_content(
model=self._model_name,
contents=request_parts,
)
call_progress(progress, 0.7, "Decoding text")
output_text = self._extract_text(response)
input_tokens, output_tokens = extract_usage_tokens(response)
generation_cost: float | None = None
if input_tokens is not None and output_tokens is not None:
input_cost = input_tokens * (0.10 / 1000000)
output_cost = output_tokens * (0.40 / 1000000)
generation_cost = input_cost + output_cost
call_progress(progress, 0.95, "Preparing text result")
return TextGenerationResult(service=self.get_name(), text=output_text, cost=generation_cost)
|