File size: 4,500 Bytes
691f45a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import io
from typing import Any, Sequence

from PIL import Image

from services.exceptions import GenerationError
from services.progress import call_progress, ProgressCallback
from services.registry import register_service
from services.text.TextGenerationResult import TextGenerationResult
from services.text.TextGenerator import TextGenerator
from services.utils.google_service import DEFAULT_GEMINI_TEXT_MODEL, iter_response_parts, extract_usage_tokens


@register_service
class GoogleTextGenerator(TextGenerator):
    """Text generator backed by the Google Gemini API."""
    service_id = "google_gemini_text"

    def __init__(self, model_name: str | None = None, api_key: str | None = None) -> None:
        super().__init__(model_name=model_name)
        try:
            from google import genai  # type: ignore[import-not-found]
        except ModuleNotFoundError as exc:
            raise GenerationError(
                "The 'google-genai' package is required to use the Google text backend."
            ) from exc

        client_kwargs: dict[str, Any] = {}
        if api_key is not None:
            client_kwargs["api_key"] = api_key

        self._client = genai.Client(**client_kwargs)

    @classmethod
    def default_model_name(cls) -> str:
        return DEFAULT_GEMINI_TEXT_MODEL

    def close(self) -> None:
        # The google client does not expose an explicit close method currently
        return

    @staticmethod
    def _extract_text(response: Any) -> str:
        text = getattr(response, "text", None)
        if isinstance(text, str) and text.strip():
            return text

        parts_text: list[str] = []
        for part in iter_response_parts(response):
            value = getattr(part, "text", None)
            if isinstance(value, str):
                parts_text.append(value)

        if parts_text:
            return "".join(parts_text)

        candidates = getattr(response, "candidates", None)
        if candidates:
            for candidate in candidates:
                content = getattr(candidate, "content", None)
                if content is None:
                    continue
                content_parts = getattr(content, "parts", None) or []
                for part in content_parts:
                    value = getattr(part, "text", None)
                    if isinstance(value, str):
                        parts_text.append(value)
            if parts_text:
                return "".join(parts_text)

        raise GenerationError("The Google text model did not return any text output.")

    def generate(
            self,
            prompt: str,
            images: Sequence[Image.Image] | None = None,
            *,
            progress: ProgressCallback | None = None,
    ) -> TextGenerationResult:
        from google.genai import types  # type: ignore[import-not-found]

        call_progress(progress, 0.1, "Encoding inputs")

        request_parts: list[object] = [prompt]
        images = images or []

        for image in images:
            buffer = io.BytesIO()
            image.save(buffer, format="PNG")
            image_bytes = buffer.getvalue()

            part: Any | None = None

            part_class = getattr(types, "Part", None)
            from_bytes = getattr(part_class, "from_bytes", None)
            if callable(from_bytes):
                part = from_bytes(data=image_bytes, mime_type="image/png")

            if part is None:
                input_image_class = getattr(types, "InputImage", None)
                if input_image_class is not None:
                    part = input_image_class(mime_type="image/png", data=image_bytes)

            if part is None:
                raise GenerationError(
                    "The installed google-genai client does not support image inputs."
                )

            request_parts.append(part)

        call_progress(progress, 0.4, "Calling Google Gemini text model")

        response = self._client.models.generate_content(
            model=self.model_name,
            contents=request_parts,
        )

        call_progress(progress, 0.7, "Decoding text")

        output_text = self._extract_text(response)

        input_tokens, output_tokens = extract_usage_tokens(response)

        call_progress(progress, 0.95, "Preparing text result")

        return TextGenerationResult(
            provider="google",
            model=self.model_name,
            text=output_text,
            raw_response=response,
        )