File size: 4,834 Bytes
0f8b3a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import io
import os
from typing import Any

from velai.services.exceptions import GenerationError
from velai.services.progress import ProgressCallback, call_progress
from velai.services.registry import register_service
from velai.services.text.TextGenerator import TextGenerationInput, TextGenerationResult, TextGenerator
from velai.services.utils.google_service import extract_usage_tokens, iter_response_parts

DEFAULT_GEMINI_TEXT_MODEL = os.getenv("GEMINI_TEXT_MODEL", "gemini-2.5-flash-lite-preview-09-2025")


@register_service(TextGenerator)
class GoogleTextGenerator(TextGenerator):
    """Text generator backed by the Google Gemini API."""

    def __init__(self, model_name: str | None = None, api_key: str | None = None) -> None:
        super().__init__()
        self._model_name = model_name or DEFAULT_GEMINI_TEXT_MODEL
        try:
            from google import genai  # type: ignore[import-not-found]
        except ModuleNotFoundError as exc:
            raise GenerationError("The 'google-genai' package is required to use the Google text backend.") from exc

        client_kwargs: dict[str, Any] = {}
        if api_key is not None:
            client_kwargs["api_key"] = api_key

        self._client = genai.Client(**client_kwargs)

    @classmethod
    def get_service_id(cls) -> str:
        return "google_text"

    @classmethod
    def get_name(cls) -> str:
        return "Google Gemini"

    @classmethod
    def is_available(cls) -> bool:
        return os.getenv("GOOGLE_API_KEY") is not None

    def close(self) -> None:
        # The google client does not expose an explicit close method currently
        return

    @staticmethod
    def _extract_text(response: Any) -> str:
        text = getattr(response, "text", None)
        if isinstance(text, str) and text.strip():
            return text

        parts_text: list[str] = []
        for part in iter_response_parts(response):
            value = getattr(part, "text", None)
            if isinstance(value, str):
                parts_text.append(value)

        if parts_text:
            return "".join(parts_text)

        candidates = getattr(response, "candidates", None)
        if candidates:
            for candidate in candidates:
                content = getattr(candidate, "content", None)
                if content is None:
                    continue
                content_parts = getattr(content, "parts", None) or []
                for part in content_parts:
                    value = getattr(part, "text", None)
                    if isinstance(value, str):
                        parts_text.append(value)
            if parts_text:
                return "".join(parts_text)

        raise GenerationError("The Google text model did not return any text output.")

    def generate(
        self, input_data: TextGenerationInput, *, progress: ProgressCallback | None = None
    ) -> TextGenerationResult:
        from google.genai import types  # type: ignore[import-not-found]

        call_progress(progress, 0.1, "Encoding inputs")

        request_parts: list[object] = [input_data.prompt]
        images = input_data.images or []

        for image in images:
            buffer = io.BytesIO()
            image.save(buffer, format="PNG")
            image_bytes = buffer.getvalue()

            part: Any | None = None

            part_class = getattr(types, "Part", None)
            from_bytes = getattr(part_class, "from_bytes", None)
            if callable(from_bytes):
                part = from_bytes(data=image_bytes, mime_type="image/png")

            if part is None:
                input_image_class = getattr(types, "InputImage", None)
                if input_image_class is not None:
                    part = input_image_class(mime_type="image/png", data=image_bytes)

            if part is None:
                raise GenerationError("The installed google-genai client does not support image inputs.")

            request_parts.append(part)

        call_progress(progress, 0.4, "Calling Google Gemini text model")

        response = self._client.models.generate_content(
            model=self._model_name,
            contents=request_parts,
        )

        call_progress(progress, 0.7, "Decoding text")

        output_text = self._extract_text(response)

        input_tokens, output_tokens = extract_usage_tokens(response)

        generation_cost: float | None = None
        if input_tokens is not None and output_tokens is not None:
            input_cost = input_tokens * (0.10 / 1000000)
            output_cost = output_tokens * (0.40 / 1000000)
            generation_cost = input_cost + output_cost

        call_progress(progress, 0.95, "Preparing text result")

        return TextGenerationResult(service=self.get_name(), text=output_text, cost=generation_cost)