File size: 7,281 Bytes
3c8c274 eab288c 8812f42 e830260 3c8c274 eab288c 44013a5 3c8c274 906da16 3c8c274 eab288c 44013a5 eab288c 44013a5 eab288c 44013a5 eab288c 44013a5 eab288c e529ed6 eab288c e529ed6 eab288c 44013a5 eab288c 44013a5 eab288c 44013a5 906da16 eab288c 44013a5 906da16 44013a5 eab288c 3c8c274 44013a5 3c8c274 6723e05 eab288c 44013a5 eab288c 44013a5 6723e05 44013a5 eab288c 8812f42 eab288c 6723e05 eab288c 44013a5 8812f42 eab288c 44013a5 eab288c 8812f42 44013a5 8812f42 44013a5 eab288c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from google.generativeai.embedding import embed_content
from google.generativeai.client import configure
from google.generativeai.generative_models import GenerativeModel
from loguru import logger
from .request_limit_manager import RequestLimitManager
from typing import List, Optional
from .utils import (
_safe_truncate
)
class GeminiClient:
def __init__(self):
self.limit_manager = RequestLimitManager("gemini")
self._cached_model = None
self._cached_key = None
self._cached_model_instance = None
def _get_model_instance(self, key: str, model: str):
"""
Cache model instance để tránh recreate mỗi lần.
"""
if (self._cached_key == key and
self._cached_model == model and
self._cached_model_instance is not None):
return self._cached_model_instance
# Configure và tạo model instance mới
configure(api_key=key)
self._cached_model_instance = GenerativeModel(model)
self._cached_key = key
self._cached_model = model
logger.info(f"[GEMINI] Created new model instance for key={key[:5]}...{key[-5:]} model={model}")
return self._cached_model_instance
def _clear_cache_if_needed(self, new_key: str, new_model: str):
"""
Chỉ clear cache khi key/model thực sự thay đổi.
"""
if (self._cached_key != new_key or self._cached_model != new_model):
logger.info(f"[GEMINI] Clearing cache due to key/model change: {self._cached_key}->{new_key}, {self._cached_model}->{new_model}")
self._cached_model_instance = None
self._cached_key = None
self._cached_model = None
def generate_text(self, prompt: str, **kwargs) -> str:
last_error = None
max_retries = 3
for attempt in range(max_retries):
try:
# Lấy current key/model từ manager
key, model = self.limit_manager.get_current_key_model()
# Sử dụng cached model instance
_model = self._get_model_instance(key, model)
response = _model.generate_content(prompt, **kwargs)
self.limit_manager.log_request(key, model, success=True)
if hasattr(response, 'usage_metadata'):
logger.info(f"[GEMINI][USAGE] Prompt Token Count: {response.usage_metadata.prompt_token_count} - Candidate Token Count: {response.usage_metadata.candidates_token_count} - Total Token Count: {response.usage_metadata.total_token_count}")
if hasattr(response, 'text'):
logger.info(f"[GEMINI][TEXT_RESPONSE] {_safe_truncate(response.text)}")
return response.text
elif hasattr(response, 'candidates') and response.candidates:
logger.info(f"[GEMINI][CANDIDATES_RESPONSE] {_safe_truncate(response.candidates[0].content.parts[0].text)}")
return response.candidates[0].content.parts[0].text
logger.info(f"[GEMINI][RAW_RESPONSE] {response}")
return str(response)
except Exception as e:
import re
msg = str(e)
if "429" in msg or "rate limit" in msg.lower():
retry_delay = 60
m = re.search(r'retry_delay.*?seconds: (\d+)', msg)
if m:
retry_delay = int(m.group(1))
# Log failure với key/model thực tế đang được sử dụng
self.limit_manager.log_request(key, model, success=False, retry_delay=retry_delay)
# Chỉ clear cache nếu key/model thay đổi
# Không clear cache ngay lập tức để tránh recreate không cần thiết
logger.warning(f"[GEMINI] Rate limit hit, will retry with new key/model (attempt {attempt + 1}/{max_retries})")
last_error = e
continue
else:
# Lỗi khác không phải rate limit
logger.error(f"[GEMINI] Error generating text: {e}")
last_error = e
break
raise last_error or RuntimeError("No available Gemini API key/model")
def count_tokens(self, prompt: str) -> int:
try:
key, model = self.limit_manager.get_current_key_model()
_model = self._get_model_instance(key, model)
return _model.count_tokens(prompt).total_tokens
except Exception as e:
logger.error(f"[GEMINI] Error counting tokens: {e}")
return 0
def create_embedding(self, text: str, model: Optional[str] = None, task_type: str = "retrieval_query") -> list:
last_error = None
max_retries = 3
for attempt in range(max_retries):
try:
key, default_model = self.limit_manager.get_current_key_model()
# Ưu tiên model được truyền vào parameter, chỉ fallback về default_model nếu không có
use_model = model if model and model.strip() else default_model
if not use_model:
raise ValueError("No model specified for embedding")
logger.info(f"[GEMINI][EMBEDDING] Using model={use_model} (requested={model}, default={default_model}), task_type={task_type}")
configure(api_key=key)
response = embed_content(
model=use_model,
content=text,
task_type=task_type
)
self.limit_manager.log_request(key, use_model, success=True)
logger.info(f"[GEMINI][EMBEDDING][RAW_RESPONSE] {response['embedding'][:10]} ..... {response['embedding'][-10:]}")
return response['embedding']
except Exception as e:
import re
msg = str(e)
if "429" in msg or "rate limit" in msg.lower():
retry_delay = 60
m_retry = re.search(r'retry_delay.*?seconds: (\d+)', msg)
if m_retry:
retry_delay = int(m_retry.group(1))
# Log failure và trigger scan cho key/model mới
self.limit_manager.log_request(key, use_model, success=False, retry_delay=retry_delay)
logger.warning(f"[GEMINI] Rate limit hit in embedding, will retry with new key/model (attempt {attempt + 1}/{max_retries})")
last_error = e
continue
else:
logger.error(f"[GEMINI] Error creating embedding: {e}")
last_error = e
break
raise last_error or RuntimeError("No available Gemini API key/model") |