File size: 4,582 Bytes
e7069ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8bafcb
e7069ae
 
 
 
f8bafcb
e7069ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import time
from google import genai
from google.genai import types
from typing import List, Optional
from dotenv import load_dotenv, find_dotenv

from ..utils import log
from ..utils.common_utils import exponential_backoff


class GeminiClient:
    def __init__(self, model: str, api_key: Optional[str] = None):
        self.model = model
        self._init_environment(api_key)
        self.histories = list()
        self.token_usages = dict()
        self.__first_turn = True

    def _init_environment(self, api_key: Optional[str] = None) -> None:
        if not api_key:
            dotenv_path = find_dotenv(usecwd=True)
            load_dotenv(dotenv_path, override=True)
            api_key = os.environ.get("GOOGLE_API_KEY", None)
        self.client = genai.Client(api_key=api_key)

    def reset_history(self, verbose: bool = True) -> None:
        self.__first_turn = True
        self.histories = list()
        self.token_usages = dict()
        if verbose:
            log('Conversation history has been reset.', color=True)

    def __make_payload(self, user_prompt: str) -> List[types.Content]:
        return [types.Content(role='user', parts=[types.Part.from_text(text=user_prompt)])]

    def __call__(self,
                 user_prompt: str,
                 system_prompt: Optional[str] = None,
                 using_multi_turn: bool = True,
                 greeting: Optional[str] = None,
                 verbose: bool = True,
                 **kwargs) -> str:
        try:
            if not using_multi_turn:
                self.reset_history(verbose)

            if greeting and self.__first_turn:
                self.histories.append(types.Content(role='model', parts=[types.Part.from_text(text=greeting)]))
                self.__first_turn = False

            self.histories += self.__make_payload(user_prompt)

            count = 0
            max_retry = kwargs.pop('max_retry', 5)
            kwargs.pop('seed', None)       # not a valid Gemini param
            kwargs.pop('verbose', None)    # internal flag, not for API

            # Minimise thinking tokens
            if 'thinking_config' not in kwargs:
                if 'gemini-2' in self.model.lower():
                    kwargs['thinking_config'] = types.ThinkingConfig(thinking_budget=0)
                elif 'gemini-3' in self.model.lower():
                    kwargs['thinking_config'] = types.ThinkingConfig(thinking_level="minimal")

            while 1:
                response = self.client.models.generate_content(
                    model=self.model,
                    contents=self.histories,
                    config=types.GenerateContentConfig(
                        system_instruction=system_prompt,
                        **kwargs
                    )
                )

                if response.usage_metadata:
                    prompt_token_cnt = response.usage_metadata.prompt_token_count if isinstance(response.usage_metadata.prompt_token_count, int) else 0
                    candidates_token_cnt = response.usage_metadata.candidates_token_count if isinstance(response.usage_metadata.candidates_token_count, int) else 0
                    total_token_cnt = response.usage_metadata.total_token_count if isinstance(response.usage_metadata.total_token_count, int) else 0
                    thoughts_token_cnt = response.usage_metadata.thoughts_token_count if isinstance(response.usage_metadata.thoughts_token_count, int) else 0
                    self.token_usages.setdefault("prompt_tokens", []).append(prompt_token_cnt)
                    self.token_usages.setdefault("completion_tokens", []).append(candidates_token_cnt)
                    self.token_usages.setdefault("total_tokens", []).append(total_token_cnt)
                    self.token_usages.setdefault("reasoning_tokens", []).append(thoughts_token_cnt)

                if count >= max_retry:
                    replace_text = 'Could you tell me again?'
                    self.histories.append(types.Content(role='model', parts=[types.Part.from_text(text=replace_text)]))
                    return replace_text

                if response.text is None:
                    wait_time = exponential_backoff(count)
                    time.sleep(wait_time)
                    count += 1
                    continue
                else:
                    break

            self.histories.append(types.Content(role='model', parts=[types.Part.from_text(text=response.text)]))
            return response.text

        except Exception as e:
            raise e