AI-Coach / src /gemini_client.py
anhlehong
feat/enhance
1c58706
import logging
from google import genai
from google.genai import types
import google.api_core.exceptions as exceptions
from .config import GEMINI_API_KEY, GEMINI_MODELS, DEFAULT_MODEL
logger = logging.getLogger(__name__)
class GeminiClient:
"""
A wrapper around the Google GenAI SDK that handles model rotation on quota exhaustion.
"""
def __init__(self):
# Current model priority: DEFAULT_MODEL (now gemini-1.5-flash) then others
self.models = [DEFAULT_MODEL] + [m for m in GEMINI_MODELS if m != DEFAULT_MODEL]
self.current_model_index = 0
self.client = None
if GEMINI_API_KEY:
self.client = genai.Client(api_key=GEMINI_API_KEY)
def rotate_model(self):
"""Moves to the next model in the list."""
self.current_model_index += 1
if self.current_model_index < len(self.models):
logger.warning(f"Rotating to next model: {self.models[self.current_model_index]}")
return True
logger.error("All models exhausted.")
return False
def _prepare_config(self, system_instruction=None, temperature=0.7, max_output_tokens=2048, tools=None):
"""Helper to create GenerateContentConfig with proper types."""
# For system_instruction, google-genai can take a string directly or a Content object.
# Passing a string is safer against role conflicts.
return types.GenerateContentConfig(
system_instruction=system_instruction,
temperature=temperature,
max_output_tokens=max_output_tokens,
tools=tools,
)
def _get_full_text(self, response):
"""Extracts and joins all text parts from the response."""
if not response or not response.candidates:
return ""
candidate = response.candidates[0]
if not candidate.content or not candidate.content.parts:
# Check if it was blocked
if candidate.finish_reason == 'SAFETY' or candidate.finish_reason == 'OTHER':
logger.warning(f"Response blocked or truncated. Reason: {candidate.finish_reason}")
return f"[Nội dung bị chặn hoặc cắt ngang do lý do: {candidate.finish_reason}]"
return ""
full_text = []
for part in candidate.content.parts:
if part.text:
full_text.append(part.text)
return "".join(full_text)
def generate_content(self, prompt, system_instruction=None, tools=None, temperature=0.7, max_output_tokens=2048):
"""
Calls the Gemini API with model rotation logic.
"""
if not self.client:
raise Exception("GEMINI_API_KEY is not configured.")
while self.current_model_index < len(self.models):
model_name = self.models[self.current_model_index]
try:
config = self._prepare_config(system_instruction, temperature, max_output_tokens, tools)
response = self.client.models.generate_content(
model=model_name,
contents=prompt,
config=config
)
text = self._get_full_text(response)
if text:
# Monkey-patch .text property onto response-like behavior for compatibility
class ResponseWrapper:
def __init__(self, raw, text_content):
self.raw = raw
self.text = text_content
self.candidates = raw.candidates
self.usage_metadata = raw.usage_metadata
return ResponseWrapper(response, text)
if not self.rotate_model(): break
continue
except Exception as e:
err_msg = str(e).upper()
if any(x in err_msg for x in ["429", "RESOURCE_EXHAUSTED", "QUOTA", "503", "UNAVAILABLE"]):
if not self.rotate_model():
raise Exception("Hết quota hoặc model không khả dụng cho tất cả các model Gemini. Vui lòng thử lại sau.")
continue
logger.error(f"Error calling model {model_name}: {e}")
if any(x in err_msg for x in ["404", "NOT FOUND", "UNSUPPORTED"]):
if not self.rotate_model(): break
continue
raise e
raise Exception("Không còn model khả dụng.")
def chat(self, messages, system_instruction=None, tools=None, temperature=0.7, max_output_tokens=2048):
"""
Handles chat sessions with model rotation.
"""
if not self.client:
raise Exception("GEMINI_API_KEY is not configured.")
while self.current_model_index < len(self.models):
model_name = self.models[self.current_model_index]
try:
config = self._prepare_config(system_instruction, temperature, max_output_tokens, tools)
history = []
for msg in messages[:-1]:
parts = []
for p in msg.get("parts", []):
if "text" in p:
parts.append(types.Part(text=p["text"]))
elif "function_call" in p:
fc = p["function_call"]
parts.append(types.Part(function_call=types.FunctionCall(name=fc["name"], args=fc["args"])))
elif "function_response" in p:
fr = p["function_response"]
parts.append(types.Part(function_response=types.FunctionResponse(name=fr["name"], response=fr["response"])))
history.append(types.Content(role=msg["role"], parts=parts))
last_msg = messages[-1]
last_parts = []
for p in last_msg.get("parts", []):
if "text" in p:
last_parts.append(types.Part(text=p["text"]))
chat = self.client.chats.create(model=model_name, config=config, history=history)
response = chat.send_message(message=last_parts)
text = self._get_full_text(response)
class ResponseWrapper:
def __init__(self, raw, text_content):
self.raw = raw
self.text = text_content
self.candidates = raw.candidates
return ResponseWrapper(response, text)
except Exception as e:
err_msg = str(e).upper()
if any(x in err_msg for x in ["429", "RESOURCE_EXHAUSTED", "QUOTA", "503", "UNAVAILABLE"]):
if not self.rotate_model():
raise Exception("Hết quota hoặc model không khả dụng cho tất cả các model Gemini. Vui lòng thử lại sau.")
continue
logger.error(f"Error calling model {model_name} in chat: {e}")
if any(x in err_msg for x in ["404", "NOT FOUND", "UNSUPPORTED"]):
if not self.rotate_model(): break
continue
raise e
raise Exception("Không còn model khả dụng.")
# Global instance
gemini_client = GeminiClient()