|
|
from google.generativeai.embedding import embed_content |
|
|
from google.generativeai.client import configure |
|
|
from google.generativeai.generative_models import GenerativeModel |
|
|
from loguru import logger |
|
|
from typing import Dict, List, Optional |
|
|
from google.generativeai.types import GenerationConfig, HarmCategory, HarmBlockThreshold |
|
|
|
|
|
from .request_limit_manager import RequestLimitManager |
|
|
from .utils import _safe_truncate |
|
|
from .config import get_settings |
|
|
|
|
|
|
|
|
class GeminiResponseError(Exception): |
|
|
"""Custom exception for non-retriable Gemini response issues like safety or token limits.""" |
|
|
|
|
|
def __init__(self, message, finish_reason=None, usage_metadata=None): |
|
|
super().__init__(message) |
|
|
self.finish_reason = finish_reason |
|
|
self.usage_metadata = usage_metadata |
|
|
|
|
|
def __str__(self): |
|
|
usage_str = ( |
|
|
f"Prompt: {self.usage_metadata.prompt_token_count}, Candidates: {self.usage_metadata.candidates_token_count}, Total: {self.usage_metadata.total_token_count}" |
|
|
if self.usage_metadata |
|
|
else "N/A" |
|
|
) |
|
|
return f"{super().__str__()} (Finish Reason: {self.finish_reason}, Usage: {usage_str})" |
|
|
|
|
|
|
|
|
class GeminiClient: |
|
|
def __init__(self): |
|
|
self.limit_manager = RequestLimitManager("gemini") |
|
|
settings = get_settings() |
|
|
num_keys = ( |
|
|
len(settings.gemini_api_keys.split(",")) if settings.gemini_api_keys else 0 |
|
|
) |
|
|
num_models = ( |
|
|
len(settings.gemini_models.split(",")) if settings.gemini_models else 0 |
|
|
) |
|
|
logger.info( |
|
|
f"[GEMINI_INIT] Limiter is considering {num_keys} API keys and {num_models} models." |
|
|
) |
|
|
self._cached_model = None |
|
|
self._cached_key = None |
|
|
self._cached_model_instance = None |
|
|
|
|
|
self.safety_settings: Dict[HarmCategory, HarmBlockThreshold] = { |
|
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, |
|
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, |
|
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, |
|
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, |
|
|
} |
|
|
|
|
|
def _get_model_instance(self, key: str, model: str): |
|
|
""" |
|
|
Cache model instance để tránh recreate mỗi lần. |
|
|
""" |
|
|
if ( |
|
|
self._cached_key == key |
|
|
and self._cached_model == model |
|
|
and self._cached_model_instance is not None |
|
|
): |
|
|
return self._cached_model_instance |
|
|
|
|
|
|
|
|
configure(api_key=key) |
|
|
self._cached_model_instance = GenerativeModel(model) |
|
|
self._cached_key = key |
|
|
self._cached_model = model |
|
|
|
|
|
logger.info( |
|
|
f"[GEMINI] Created new model instance for key={key[:5]}...{key[-5:]} model={model}" |
|
|
) |
|
|
return self._cached_model_instance |
|
|
|
|
|
def _clear_cache_if_needed(self, new_key: str, new_model: str): |
|
|
""" |
|
|
Chỉ clear cache khi key/model thực sự thay đổi. |
|
|
""" |
|
|
if self._cached_key != new_key or self._cached_model != new_model: |
|
|
logger.info( |
|
|
f"[GEMINI] Clearing cache due to key/model change: {self._cached_key}->{new_key}, {self._cached_model}->{new_model}" |
|
|
) |
|
|
self._cached_model_instance = None |
|
|
self._cached_key = None |
|
|
self._cached_model = None |
|
|
|
|
|
def generate_text(self, prompt: str, **kwargs) -> str: |
|
|
last_error = None |
|
|
max_retries = 3 |
|
|
|
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
|
|
|
key, model = self.limit_manager.get_current_key_model() |
|
|
|
|
|
|
|
|
_model = self._get_model_instance(key, model) |
|
|
|
|
|
response = _model.generate_content( |
|
|
prompt, safety_settings=self.safety_settings, **kwargs |
|
|
) |
|
|
|
|
|
|
|
|
logger.debug(f"[GEMINI][RAW_RESPONSE] {response}") |
|
|
|
|
|
|
|
|
|
|
|
if not response.candidates: |
|
|
|
|
|
raise ValueError( |
|
|
"Gemini response is missing 'candidates' field. Retrying..." |
|
|
) |
|
|
|
|
|
candidate = response.candidates[0] |
|
|
finish_reason_name = getattr( |
|
|
getattr(candidate, "finish_reason", None), "name", "UNKNOWN" |
|
|
) |
|
|
|
|
|
|
|
|
has_content = bool( |
|
|
candidate.content and getattr(candidate.content, "parts", None) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if finish_reason_name != "STOP": |
|
|
usage_metadata = ( |
|
|
response.usage_metadata |
|
|
if hasattr(response, "usage_metadata") |
|
|
else None |
|
|
) |
|
|
error_message = f"Gemini response finished with non-OK reason: {finish_reason_name}." |
|
|
raise GeminiResponseError( |
|
|
error_message, |
|
|
finish_reason=finish_reason_name, |
|
|
usage_metadata=usage_metadata, |
|
|
) |
|
|
|
|
|
|
|
|
if ( |
|
|
not has_content |
|
|
): |
|
|
usage_metadata = ( |
|
|
response.usage_metadata |
|
|
if hasattr(response, "usage_metadata") |
|
|
else None |
|
|
) |
|
|
last_error = GeminiResponseError( |
|
|
"Gemini response finished with STOP but has no content parts.", |
|
|
finish_reason="STOP_NO_CONTENT", |
|
|
usage_metadata=usage_metadata, |
|
|
) |
|
|
logger.warning( |
|
|
f"[GEMINI] Model returned STOP with no content. Retrying with another key/model... (Attempt {attempt + 1}/{max_retries})" |
|
|
) |
|
|
self.limit_manager.log_request( |
|
|
key, model, success=False, retry_delay=5 |
|
|
) |
|
|
continue |
|
|
|
|
|
|
|
|
self.limit_manager.log_request(key, model, success=True) |
|
|
if hasattr(response, "usage_metadata"): |
|
|
logger.info( |
|
|
f"[GEMINI][USAGE] Prompt Token Count: {response.usage_metadata.prompt_token_count} - Candidate Token Count: {response.usage_metadata.candidates_token_count} - Total Token Count: {response.usage_metadata.total_token_count}" |
|
|
) |
|
|
|
|
|
try: |
|
|
logger.debug( |
|
|
f"[GEMINI][TEXT_RESPONSE] {_safe_truncate(response.text)}" |
|
|
) |
|
|
return response.text |
|
|
except ValueError as ve: |
|
|
|
|
|
|
|
|
usage_metadata = ( |
|
|
response.usage_metadata |
|
|
if hasattr(response, "usage_metadata") |
|
|
else None |
|
|
) |
|
|
raise GeminiResponseError( |
|
|
f"Gemini response has no valid content part. Original error: {ve}", |
|
|
finish_reason="STOP_NO_CONTENT", |
|
|
usage_metadata=usage_metadata, |
|
|
) from ve |
|
|
|
|
|
except GeminiResponseError as e: |
|
|
|
|
|
logger.error(f"[GEMINI] Non-retriable content error: {e}") |
|
|
raise e |
|
|
except Exception as e: |
|
|
import re |
|
|
|
|
|
msg = str(e) |
|
|
|
|
|
is_rate_limit = "429" in msg or "rate limit" in msg.lower() |
|
|
is_server_error = any( |
|
|
code in msg for code in ["500", "502", "503", "504"] |
|
|
) |
|
|
|
|
|
if is_rate_limit or is_server_error: |
|
|
retry_delay = 60 |
|
|
if is_rate_limit: |
|
|
m = re.search(r"retry_delay.*?seconds: (\d+)", msg) |
|
|
if m: |
|
|
retry_delay = int(m.group(1)) |
|
|
|
|
|
|
|
|
self.limit_manager.log_request( |
|
|
key, model, success=False, retry_delay=retry_delay |
|
|
) |
|
|
|
|
|
error_type = "Rate limit" if is_rate_limit else "Server" |
|
|
logger.warning( |
|
|
f"[GEMINI] {error_type} error hit, will retry with new key/model " |
|
|
f"(attempt {attempt + 1}/{max_retries}). Error: {e}" |
|
|
) |
|
|
last_error = e |
|
|
continue |
|
|
else: |
|
|
|
|
|
|
|
|
logger.error( |
|
|
f"[GEMINI] Unhandled error generating text, propagating up: {e}" |
|
|
) |
|
|
raise e |
|
|
|
|
|
raise last_error or RuntimeError("No available Gemini API key/model") |
|
|
|
|
|
def count_tokens(self, prompt: str) -> int: |
|
|
try: |
|
|
key, model = self.limit_manager.get_current_key_model() |
|
|
_model = self._get_model_instance(key, model) |
|
|
return _model.count_tokens(prompt).total_tokens |
|
|
except Exception as e: |
|
|
logger.error(f"[GEMINI] Error counting tokens: {e}") |
|
|
return 0 |
|
|
|
|
|
def create_embedding( |
|
|
self, text: str, model: Optional[str] = None, task_type: str = "retrieval_query" |
|
|
) -> list: |
|
|
last_error = None |
|
|
max_retries = 3 |
|
|
|
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
key, default_model = self.limit_manager.get_current_key_model() |
|
|
|
|
|
|
|
|
use_model = model if model and model.strip() else default_model |
|
|
|
|
|
if not use_model: |
|
|
raise ValueError("No model specified for embedding") |
|
|
|
|
|
logger.debug( |
|
|
f"[GEMINI][EMBEDDING] Using model={use_model} (requested={model}, default={default_model}), task_type={task_type}" |
|
|
) |
|
|
|
|
|
configure(api_key=key) |
|
|
response = embed_content( |
|
|
model=use_model, content=text, task_type=task_type |
|
|
) |
|
|
|
|
|
self.limit_manager.log_request(key, use_model, success=True) |
|
|
logger.debug( |
|
|
f"[GEMINI][EMBEDDING][RAW_RESPONSE] {response['embedding'][:10]} ..... {response['embedding'][-10:]}" |
|
|
) |
|
|
return response["embedding"] |
|
|
|
|
|
except Exception as e: |
|
|
import re |
|
|
|
|
|
msg = str(e) |
|
|
if "429" in msg or "rate limit" in msg.lower(): |
|
|
retry_delay = 60 |
|
|
m_retry = re.search(r"retry_delay.*?seconds: (\d+)", msg) |
|
|
if m_retry: |
|
|
retry_delay = int(m_retry.group(1)) |
|
|
|
|
|
|
|
|
self.limit_manager.log_request( |
|
|
key, use_model, success=False, retry_delay=retry_delay |
|
|
) |
|
|
|
|
|
logger.warning( |
|
|
f"[GEMINI] Rate limit hit in embedding, will retry with new key/model (attempt {attempt + 1}/{max_retries})" |
|
|
) |
|
|
last_error = e |
|
|
continue |
|
|
else: |
|
|
logger.error(f"[GEMINI] Error creating embedding: {e}") |
|
|
last_error = e |
|
|
break |
|
|
|
|
|
raise last_error or RuntimeError("No available Gemini API key/model") |
|
|
|