| import logging |
| from google import genai |
| from google.genai import types |
| import google.api_core.exceptions as exceptions |
| from .config import GEMINI_API_KEY, GEMINI_MODELS, DEFAULT_MODEL |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class GeminiClient: |
| """ |
| A wrapper around the Google GenAI SDK that handles model rotation on quota exhaustion. |
| """ |
| def __init__(self): |
| |
| self.models = [DEFAULT_MODEL] + [m for m in GEMINI_MODELS if m != DEFAULT_MODEL] |
| self.current_model_index = 0 |
| self.client = None |
| if GEMINI_API_KEY: |
| self.client = genai.Client(api_key=GEMINI_API_KEY) |
|
|
| def rotate_model(self): |
| """Moves to the next model in the list.""" |
| self.current_model_index += 1 |
| if self.current_model_index < len(self.models): |
| logger.warning(f"Rotating to next model: {self.models[self.current_model_index]}") |
| return True |
| logger.error("All models exhausted.") |
| return False |
|
|
| def _prepare_config(self, system_instruction=None, temperature=0.7, max_output_tokens=2048, tools=None): |
| """Helper to create GenerateContentConfig with proper types.""" |
| |
| |
| return types.GenerateContentConfig( |
| system_instruction=system_instruction, |
| temperature=temperature, |
| max_output_tokens=max_output_tokens, |
| tools=tools, |
| ) |
|
|
| def _get_full_text(self, response): |
| """Extracts and joins all text parts from the response.""" |
| if not response or not response.candidates: |
| return "" |
| |
| candidate = response.candidates[0] |
| if not candidate.content or not candidate.content.parts: |
| |
| if candidate.finish_reason == 'SAFETY' or candidate.finish_reason == 'OTHER': |
| logger.warning(f"Response blocked or truncated. Reason: {candidate.finish_reason}") |
| return f"[Nội dung bị chặn hoặc cắt ngang do lý do: {candidate.finish_reason}]" |
| return "" |
| |
| full_text = [] |
| for part in candidate.content.parts: |
| if part.text: |
| full_text.append(part.text) |
| |
| return "".join(full_text) |
|
|
| def generate_content(self, prompt, system_instruction=None, tools=None, temperature=0.7, max_output_tokens=2048): |
| """ |
| Calls the Gemini API with model rotation logic. |
| """ |
| if not self.client: |
| raise Exception("GEMINI_API_KEY is not configured.") |
|
|
| while self.current_model_index < len(self.models): |
| model_name = self.models[self.current_model_index] |
| try: |
| config = self._prepare_config(system_instruction, temperature, max_output_tokens, tools) |
| |
| response = self.client.models.generate_content( |
| model=model_name, |
| contents=prompt, |
| config=config |
| ) |
| |
| text = self._get_full_text(response) |
| if text: |
| |
| class ResponseWrapper: |
| def __init__(self, raw, text_content): |
| self.raw = raw |
| self.text = text_content |
| self.candidates = raw.candidates |
| self.usage_metadata = raw.usage_metadata |
| return ResponseWrapper(response, text) |
| |
| if not self.rotate_model(): break |
| continue |
| |
| except Exception as e: |
| err_msg = str(e).upper() |
| if any(x in err_msg for x in ["429", "RESOURCE_EXHAUSTED", "QUOTA", "503", "UNAVAILABLE"]): |
| if not self.rotate_model(): |
| raise Exception("Hết quota hoặc model không khả dụng cho tất cả các model Gemini. Vui lòng thử lại sau.") |
| continue |
| logger.error(f"Error calling model {model_name}: {e}") |
| if any(x in err_msg for x in ["404", "NOT FOUND", "UNSUPPORTED"]): |
| if not self.rotate_model(): break |
| continue |
| raise e |
| |
| raise Exception("Không còn model khả dụng.") |
|
|
| def chat(self, messages, system_instruction=None, tools=None, temperature=0.7, max_output_tokens=2048): |
| """ |
| Handles chat sessions with model rotation. |
| """ |
| if not self.client: |
| raise Exception("GEMINI_API_KEY is not configured.") |
|
|
| while self.current_model_index < len(self.models): |
| model_name = self.models[self.current_model_index] |
| try: |
| config = self._prepare_config(system_instruction, temperature, max_output_tokens, tools) |
| |
| history = [] |
| for msg in messages[:-1]: |
| parts = [] |
| for p in msg.get("parts", []): |
| if "text" in p: |
| parts.append(types.Part(text=p["text"])) |
| elif "function_call" in p: |
| fc = p["function_call"] |
| parts.append(types.Part(function_call=types.FunctionCall(name=fc["name"], args=fc["args"]))) |
| elif "function_response" in p: |
| fr = p["function_response"] |
| parts.append(types.Part(function_response=types.FunctionResponse(name=fr["name"], response=fr["response"]))) |
| |
| history.append(types.Content(role=msg["role"], parts=parts)) |
| |
| last_msg = messages[-1] |
| last_parts = [] |
| for p in last_msg.get("parts", []): |
| if "text" in p: |
| last_parts.append(types.Part(text=p["text"])) |
| |
| chat = self.client.chats.create(model=model_name, config=config, history=history) |
| response = chat.send_message(message=last_parts) |
| |
| text = self._get_full_text(response) |
| class ResponseWrapper: |
| def __init__(self, raw, text_content): |
| self.raw = raw |
| self.text = text_content |
| self.candidates = raw.candidates |
| return ResponseWrapper(response, text) |
| |
| except Exception as e: |
| err_msg = str(e).upper() |
| if any(x in err_msg for x in ["429", "RESOURCE_EXHAUSTED", "QUOTA", "503", "UNAVAILABLE"]): |
| if not self.rotate_model(): |
| raise Exception("Hết quota hoặc model không khả dụng cho tất cả các model Gemini. Vui lòng thử lại sau.") |
| continue |
| logger.error(f"Error calling model {model_name} in chat: {e}") |
| if any(x in err_msg for x in ["404", "NOT FOUND", "UNSUPPORTED"]): |
| if not self.rotate_model(): break |
| continue |
| raise e |
|
|
| raise Exception("Không còn model khả dụng.") |
|
|
| |
| gemini_client = GeminiClient() |
|
|