Chatopus / app /gemini_client.py
VietCat's picture
update reranker
fd5dbb4
from google.generativeai.embedding import embed_content
from google.generativeai.client import configure
from google.generativeai.generative_models import GenerativeModel
from loguru import logger
from typing import Dict, List, Optional
from google.generativeai.types import GenerationConfig, HarmCategory, HarmBlockThreshold
from .request_limit_manager import RequestLimitManager
from .utils import _safe_truncate
from .config import get_settings
class GeminiResponseError(Exception):
"""Custom exception for non-retriable Gemini response issues like safety or token limits."""
def __init__(self, message, finish_reason=None, usage_metadata=None):
super().__init__(message)
self.finish_reason = finish_reason
self.usage_metadata = usage_metadata
def __str__(self):
usage_str = (
f"Prompt: {self.usage_metadata.prompt_token_count}, Candidates: {self.usage_metadata.candidates_token_count}, Total: {self.usage_metadata.total_token_count}"
if self.usage_metadata
else "N/A"
)
return f"{super().__str__()} (Finish Reason: {self.finish_reason}, Usage: {usage_str})"
class GeminiClient:
def __init__(self):
self.limit_manager = RequestLimitManager("gemini")
settings = get_settings()
num_keys = (
len(settings.gemini_api_keys.split(",")) if settings.gemini_api_keys else 0
)
num_models = (
len(settings.gemini_models.split(",")) if settings.gemini_models else 0
)
logger.info(
f"[GEMINI_INIT] Limiter is considering {num_keys} API keys and {num_models} models."
)
self._cached_model = None
self._cached_key = None
self._cached_model_instance = None
# Thêm cấu hình safety_settings để bỏ chặn các phản hồi bị coi là không an toàn
self.safety_settings: Dict[HarmCategory, HarmBlockThreshold] = {
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
def _get_model_instance(self, key: str, model: str):
"""
Cache model instance để tránh recreate mỗi lần.
"""
if (
self._cached_key == key
and self._cached_model == model
and self._cached_model_instance is not None
):
return self._cached_model_instance
# Configure và tạo model instance mới
configure(api_key=key)
self._cached_model_instance = GenerativeModel(model)
self._cached_key = key
self._cached_model = model
logger.info(
f"[GEMINI] Created new model instance for key={key[:5]}...{key[-5:]} model={model}"
)
return self._cached_model_instance # noqa
def _clear_cache_if_needed(self, new_key: str, new_model: str):
"""
Chỉ clear cache khi key/model thực sự thay đổi.
"""
if self._cached_key != new_key or self._cached_model != new_model:
logger.info(
f"[GEMINI] Clearing cache due to key/model change: {self._cached_key}->{new_key}, {self._cached_model}->{new_model}"
)
self._cached_model_instance = None
self._cached_key = None
self._cached_model = None
def generate_text(self, prompt: str, **kwargs) -> str:
last_error = None
max_retries = 3
for attempt in range(max_retries):
try:
# Lấy current key/model từ manager
key, model = self.limit_manager.get_current_key_model()
# Sử dụng cached model instance
_model = self._get_model_instance(key, model)
response = _model.generate_content(
prompt, safety_settings=self.safety_settings, **kwargs
)
# Log toàn bộ nội dung response ở mức INFO để tiện gỡ lỗi
logger.debug(f"[GEMINI][RAW_RESPONSE] {response}")
# --- START: Cải tiến logic xử lý response ---
# 1. Kiểm tra response có hợp lệ không
if not response.candidates:
# Lỗi này nên được coi là lỗi tạm thời, thử lại với key/model khác
raise ValueError(
"Gemini response is missing 'candidates' field. Retrying..."
)
candidate = response.candidates[0]
finish_reason_name = getattr(
getattr(candidate, "finish_reason", None), "name", "UNKNOWN"
)
# Kiểm tra xem có nội dung thực sự không
# Sửa: Dùng getattr để tránh AttributeError nếu 'parts' không tồn tại
has_content = bool(
candidate.content and getattr(candidate.content, "parts", None)
)
# 2. Phân loại lỗi và xử lý
# Case 1: Lỗi nội dung không thể thử lại (SAFETY, MAX_TOKENS, etc.)
if finish_reason_name != "STOP":
usage_metadata = (
response.usage_metadata
if hasattr(response, "usage_metadata")
else None
)
error_message = f"Gemini response finished with non-OK reason: {finish_reason_name}."
raise GeminiResponseError(
error_message,
finish_reason=finish_reason_name,
usage_metadata=usage_metadata,
)
# Case 2: Lỗi có thể thử lại (STOP nhưng không có nội dung)
if (
not has_content
): # Tại đây, ta biết chắc chắn finish_reason_name là "STOP"
usage_metadata = (
response.usage_metadata
if hasattr(response, "usage_metadata")
else None
)
last_error = GeminiResponseError(
"Gemini response finished with STOP but has no content parts.",
finish_reason="STOP_NO_CONTENT",
usage_metadata=usage_metadata,
)
logger.warning(
f"[GEMINI] Model returned STOP with no content. Retrying with another key/model... (Attempt {attempt + 1}/{max_retries})"
)
self.limit_manager.log_request(
key, model, success=False, retry_delay=5
)
continue # Thử lại vòng lặp với key/model mới
# Case 3: Thành công (STOP và có nội dung)
self.limit_manager.log_request(key, model, success=True)
if hasattr(response, "usage_metadata"):
logger.info(
f"[GEMINI][USAGE] Prompt Token Count: {response.usage_metadata.prompt_token_count} - Candidate Token Count: {response.usage_metadata.candidates_token_count} - Total Token Count: {response.usage_metadata.total_token_count}" # noqa
)
try:
logger.debug(
f"[GEMINI][TEXT_RESPONSE] {_safe_truncate(response.text)}"
)
return response.text
except ValueError as ve:
# Safety net: Nếu truy cập .text thất bại dù các kiểm tra trước đó đã qua,
# coi như đây là lỗi STOP_NO_CONTENT và ném ra để tầng trên xử lý.
usage_metadata = (
response.usage_metadata
if hasattr(response, "usage_metadata")
else None
)
raise GeminiResponseError(
f"Gemini response has no valid content part. Original error: {ve}",
finish_reason="STOP_NO_CONTENT",
usage_metadata=usage_metadata,
) from ve
# --- END: Cải tiến logic xử lý response ---
except GeminiResponseError as e:
# Lỗi nội dung, không thể retry bằng cách đổi key. Propagate lên.
logger.error(f"[GEMINI] Non-retriable content error: {e}")
raise e
except Exception as e:
import re
msg = str(e)
# Kiểm tra lỗi rate limit hoặc lỗi server (5xx)
is_rate_limit = "429" in msg or "rate limit" in msg.lower()
is_server_error = any(
code in msg for code in ["500", "502", "503", "504"]
)
if is_rate_limit or is_server_error:
retry_delay = 60 # Mặc định cho lỗi server
if is_rate_limit:
m = re.search(r"retry_delay.*?seconds: (\d+)", msg)
if m:
retry_delay = int(m.group(1))
# Log lỗi và chặn cặp key/model hiện tại trong một khoảng thời gian
self.limit_manager.log_request(
key, model, success=False, retry_delay=retry_delay
)
error_type = "Rate limit" if is_rate_limit else "Server"
logger.warning(
f"[GEMINI] {error_type} error hit, will retry with new key/model "
f"(attempt {attempt + 1}/{max_retries}). Error: {e}"
)
last_error = e
continue # Tiếp tục vòng lặp để thử key/model mới
else:
# Các lỗi khác không phải rate limit hoặc server error (vd: network timeout, invalid argument)
# sẽ được propagate lên để lớp llm.py/reranker.py xử lý retry với backoff.
logger.error(
f"[GEMINI] Unhandled error generating text, propagating up: {e}"
)
raise e
raise last_error or RuntimeError("No available Gemini API key/model")
def count_tokens(self, prompt: str) -> int:
try:
key, model = self.limit_manager.get_current_key_model()
_model = self._get_model_instance(key, model)
return _model.count_tokens(prompt).total_tokens
except Exception as e:
logger.error(f"[GEMINI] Error counting tokens: {e}")
return 0
def create_embedding(
self, text: str, model: Optional[str] = None, task_type: str = "retrieval_query"
) -> list:
last_error = None
max_retries = 3
for attempt in range(max_retries):
try:
key, default_model = self.limit_manager.get_current_key_model()
# Ưu tiên model được truyền vào parameter, chỉ fallback về default_model nếu không có
use_model = model if model and model.strip() else default_model
if not use_model:
raise ValueError("No model specified for embedding")
logger.debug(
f"[GEMINI][EMBEDDING] Using model={use_model} (requested={model}, default={default_model}), task_type={task_type}"
)
configure(api_key=key)
response = embed_content(
model=use_model, content=text, task_type=task_type
)
self.limit_manager.log_request(key, use_model, success=True)
logger.debug(
f"[GEMINI][EMBEDDING][RAW_RESPONSE] {response['embedding'][:10]} ..... {response['embedding'][-10:]}"
)
return response["embedding"]
except Exception as e:
import re
msg = str(e)
if "429" in msg or "rate limit" in msg.lower():
retry_delay = 60
m_retry = re.search(r"retry_delay.*?seconds: (\d+)", msg)
if m_retry:
retry_delay = int(m_retry.group(1))
# Log failure và trigger scan cho key/model mới
self.limit_manager.log_request(
key, use_model, success=False, retry_delay=retry_delay
)
logger.warning(
f"[GEMINI] Rate limit hit in embedding, will retry with new key/model (attempt {attempt + 1}/{max_retries})"
)
last_error = e
continue
else:
logger.error(f"[GEMINI] Error creating embedding: {e}")
last_error = e
break
raise last_error or RuntimeError("No available Gemini API key/model")