PatientSim / patientsim /client /google_client.py
dek924's picture
feat: remove unused label
5ddcef8
import os
import time
from google import genai
from google.genai import types
from typing import List, Optional
from dotenv import load_dotenv, find_dotenv
from ..utils import log
from ..utils.common_utils import exponential_backoff
class GeminiClient:
def __init__(self, model: str, api_key: Optional[str] = None):
self.model = model
self._init_environment(api_key)
self.histories = list()
self.token_usages = dict()
self.__first_turn = True
def _init_environment(self, api_key: Optional[str] = None) -> None:
if not api_key:
dotenv_path = find_dotenv(usecwd=True)
load_dotenv(dotenv_path, override=True)
api_key = os.environ.get("GOOGLE_API_KEY", None)
self.client = genai.Client(api_key=api_key)
def reset_history(self, verbose: bool = True) -> None:
self.__first_turn = True
self.histories = list()
self.token_usages = dict()
if verbose:
log('Conversation history has been reset.', color=True)
def __make_payload(self, user_prompt: str) -> List[types.Content]:
return [types.Content(role='user', parts=[types.Part.from_text(text=user_prompt)])]
def __call__(self,
user_prompt: str,
system_prompt: Optional[str] = None,
using_multi_turn: bool = True,
greeting: Optional[str] = None,
verbose: bool = True,
**kwargs) -> str:
try:
if not using_multi_turn:
self.reset_history(verbose)
if greeting and self.__first_turn:
self.histories.append(types.Content(role='model', parts=[types.Part.from_text(text=greeting)]))
self.__first_turn = False
self.histories += self.__make_payload(user_prompt)
count = 0
max_retry = kwargs.pop('max_retry', 5)
kwargs.pop('seed', None) # not a valid Gemini param
kwargs.pop('verbose', None) # internal flag, not for API
# Minimise thinking tokens
if 'thinking_config' not in kwargs:
if 'gemini-2' in self.model.lower():
kwargs['thinking_config'] = types.ThinkingConfig(thinking_budget=0)
elif 'gemini-3' in self.model.lower():
kwargs['thinking_config'] = types.ThinkingConfig(thinking_level="minimal")
while 1:
response = self.client.models.generate_content(
model=self.model,
contents=self.histories,
config=types.GenerateContentConfig(
system_instruction=system_prompt,
**kwargs
)
)
if response.usage_metadata:
prompt_token_cnt = response.usage_metadata.prompt_token_count if isinstance(response.usage_metadata.prompt_token_count, int) else 0
candidates_token_cnt = response.usage_metadata.candidates_token_count if isinstance(response.usage_metadata.candidates_token_count, int) else 0
total_token_cnt = response.usage_metadata.total_token_count if isinstance(response.usage_metadata.total_token_count, int) else 0
thoughts_token_cnt = response.usage_metadata.thoughts_token_count if isinstance(response.usage_metadata.thoughts_token_count, int) else 0
self.token_usages.setdefault("prompt_tokens", []).append(prompt_token_cnt)
self.token_usages.setdefault("completion_tokens", []).append(candidates_token_cnt)
self.token_usages.setdefault("total_tokens", []).append(total_token_cnt)
self.token_usages.setdefault("reasoning_tokens", []).append(thoughts_token_cnt)
if count >= max_retry:
replace_text = 'Could you tell me again?'
self.histories.append(types.Content(role='model', parts=[types.Part.from_text(text=replace_text)]))
return replace_text
if response.text is None:
wait_time = exponential_backoff(count)
time.sleep(wait_time)
count += 1
continue
else:
break
self.histories.append(types.Content(role='model', parts=[types.Part.from_text(text=response.text)]))
return response.text
except Exception as e:
raise e