update query and rerank in parallel
Browse files- app/gemini_client.py +151 -80
- app/reranker.py +152 -90
app/gemini_client.py
CHANGED
|
@@ -6,29 +6,40 @@ from typing import Dict, List, Optional
|
|
| 6 |
from google.generativeai.types import GenerationConfig, HarmCategory, HarmBlockThreshold
|
| 7 |
|
| 8 |
from .request_limit_manager import RequestLimitManager
|
| 9 |
-
from .utils import
|
| 10 |
-
_safe_truncate
|
| 11 |
-
)
|
| 12 |
from .config import get_settings
|
| 13 |
|
|
|
|
| 14 |
class GeminiResponseError(Exception):
|
| 15 |
"""Custom exception for non-retriable Gemini response issues like safety or token limits."""
|
|
|
|
| 16 |
def __init__(self, message, finish_reason=None, usage_metadata=None):
|
| 17 |
super().__init__(message)
|
| 18 |
self.finish_reason = finish_reason
|
| 19 |
self.usage_metadata = usage_metadata
|
| 20 |
|
| 21 |
def __str__(self):
|
| 22 |
-
usage_str =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
return f"{super().__str__()} (Finish Reason: {self.finish_reason}, Usage: {usage_str})"
|
| 24 |
|
|
|
|
| 25 |
class GeminiClient:
|
| 26 |
def __init__(self):
|
| 27 |
self.limit_manager = RequestLimitManager("gemini")
|
| 28 |
settings = get_settings()
|
| 29 |
-
num_keys =
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
self._cached_model = None
|
| 33 |
self._cached_key = None
|
| 34 |
self._cached_model_instance = None
|
|
@@ -44,26 +55,32 @@ class GeminiClient:
|
|
| 44 |
"""
|
| 45 |
Cache model instance để tránh recreate mỗi lần.
|
| 46 |
"""
|
| 47 |
-
if (
|
| 48 |
-
self.
|
| 49 |
-
self.
|
|
|
|
|
|
|
| 50 |
return self._cached_model_instance
|
| 51 |
-
|
| 52 |
# Configure và tạo model instance mới
|
| 53 |
configure(api_key=key)
|
| 54 |
self._cached_model_instance = GenerativeModel(model)
|
| 55 |
self._cached_key = key
|
| 56 |
self._cached_model = model
|
| 57 |
-
|
| 58 |
-
logger.info(
|
|
|
|
|
|
|
| 59 |
return self._cached_model_instance
|
| 60 |
|
| 61 |
def _clear_cache_if_needed(self, new_key: str, new_model: str):
|
| 62 |
"""
|
| 63 |
Chỉ clear cache khi key/model thực sự thay đổi.
|
| 64 |
"""
|
| 65 |
-
if
|
| 66 |
-
logger.info(
|
|
|
|
|
|
|
| 67 |
self._cached_model_instance = None
|
| 68 |
self._cached_key = None
|
| 69 |
self._cached_model = None
|
|
@@ -71,21 +88,19 @@ class GeminiClient:
|
|
| 71 |
def generate_text(self, prompt: str, **kwargs) -> str:
|
| 72 |
last_error = None
|
| 73 |
max_retries = 3
|
| 74 |
-
|
| 75 |
for attempt in range(max_retries):
|
| 76 |
try:
|
| 77 |
# Lấy current key/model từ manager
|
| 78 |
key, model = self.limit_manager.get_current_key_model()
|
| 79 |
-
|
| 80 |
# Sử dụng cached model instance
|
| 81 |
_model = self._get_model_instance(key, model)
|
| 82 |
-
|
| 83 |
response = _model.generate_content(
|
| 84 |
-
prompt,
|
| 85 |
-
safety_settings=self.safety_settings,
|
| 86 |
-
**kwargs
|
| 87 |
)
|
| 88 |
-
|
| 89 |
# Log toàn bộ nội dung response ở mức INFO để tiện gỡ lỗi
|
| 90 |
logger.info(f"[GEMINI][RAW_RESPONSE] {response}")
|
| 91 |
|
|
@@ -93,47 +108,81 @@ class GeminiClient:
|
|
| 93 |
# 1. Kiểm tra response có hợp lệ không
|
| 94 |
if not response.candidates:
|
| 95 |
# Lỗi này nên được coi là lỗi tạm thời, thử lại với key/model khác
|
| 96 |
-
raise ValueError(
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
candidate = response.candidates[0]
|
| 99 |
-
finish_reason_name = getattr(
|
|
|
|
|
|
|
| 100 |
# Kiểm tra xem có nội dung thực sự không
|
| 101 |
# Sửa: Dùng getattr để tránh AttributeError nếu 'parts' không tồn tại
|
| 102 |
-
has_content = bool(
|
|
|
|
|
|
|
| 103 |
|
| 104 |
# 2. Phân loại lỗi và xử lý
|
| 105 |
# Case 1: Lỗi nội dung không thể thử lại (SAFETY, MAX_TOKENS, etc.)
|
| 106 |
if finish_reason_name != "STOP":
|
| 107 |
-
usage_metadata =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
error_message = f"Gemini response finished with non-OK reason: {finish_reason_name}."
|
| 109 |
raise GeminiResponseError(
|
| 110 |
-
error_message,
|
|
|
|
|
|
|
| 111 |
)
|
| 112 |
|
| 113 |
# Case 2: Lỗi có thể thử lại (STOP nhưng không có nội dung)
|
| 114 |
-
if
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
# Case 3: Thành công (STOP và có nội dung)
|
| 122 |
self.limit_manager.log_request(key, model, success=True)
|
| 123 |
-
if hasattr(response,
|
| 124 |
-
logger.info(
|
| 125 |
-
|
|
|
|
|
|
|
| 126 |
try:
|
| 127 |
-
logger.info(
|
|
|
|
|
|
|
| 128 |
return response.text
|
| 129 |
except ValueError as ve:
|
| 130 |
# Safety net: Nếu truy cập .text thất bại dù các kiểm tra trước đó đã qua,
|
| 131 |
# coi như đây là lỗi STOP_NO_CONTENT và ném ra để tầng trên xử lý.
|
| 132 |
-
usage_metadata =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
raise GeminiResponseError(
|
| 134 |
f"Gemini response has no valid content part. Original error: {ve}",
|
| 135 |
-
finish_reason=
|
| 136 |
-
usage_metadata=usage_metadata
|
| 137 |
) from ve
|
| 138 |
# --- END: Cải tiến logic xử lý response ---
|
| 139 |
except GeminiResponseError as e:
|
|
@@ -142,28 +191,41 @@ class GeminiClient:
|
|
| 142 |
raise e
|
| 143 |
except Exception as e:
|
| 144 |
import re
|
|
|
|
| 145 |
msg = str(e)
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
last_error = e
|
| 160 |
-
continue
|
| 161 |
else:
|
| 162 |
-
#
|
| 163 |
-
# sẽ được propagate lên để lớp llm.py xử lý retry với backoff.
|
| 164 |
-
logger.error(
|
|
|
|
|
|
|
| 165 |
raise e
|
| 166 |
-
|
| 167 |
raise last_error or RuntimeError("No available Gemini API key/model")
|
| 168 |
|
| 169 |
def count_tokens(self, prompt: str) -> int:
|
|
@@ -175,51 +237,60 @@ class GeminiClient:
|
|
| 175 |
logger.error(f"[GEMINI] Error counting tokens: {e}")
|
| 176 |
return 0
|
| 177 |
|
| 178 |
-
def create_embedding(
|
|
|
|
|
|
|
| 179 |
last_error = None
|
| 180 |
max_retries = 3
|
| 181 |
-
|
| 182 |
for attempt in range(max_retries):
|
| 183 |
try:
|
| 184 |
key, default_model = self.limit_manager.get_current_key_model()
|
| 185 |
-
|
| 186 |
# Ưu tiên model được truyền vào parameter, chỉ fallback về default_model nếu không có
|
| 187 |
use_model = model if model and model.strip() else default_model
|
| 188 |
-
|
| 189 |
if not use_model:
|
| 190 |
raise ValueError("No model specified for embedding")
|
| 191 |
-
|
| 192 |
-
logger.info(
|
| 193 |
-
|
|
|
|
|
|
|
| 194 |
configure(api_key=key)
|
| 195 |
response = embed_content(
|
| 196 |
-
model=use_model,
|
| 197 |
-
content=text,
|
| 198 |
-
task_type=task_type
|
| 199 |
)
|
| 200 |
-
|
| 201 |
self.limit_manager.log_request(key, use_model, success=True)
|
| 202 |
-
logger.info(
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
| 205 |
except Exception as e:
|
| 206 |
import re
|
|
|
|
| 207 |
msg = str(e)
|
| 208 |
if "429" in msg or "rate limit" in msg.lower():
|
| 209 |
retry_delay = 60
|
| 210 |
-
m_retry = re.search(r
|
| 211 |
if m_retry:
|
| 212 |
retry_delay = int(m_retry.group(1))
|
| 213 |
-
|
| 214 |
# Log failure và trigger scan cho key/model mới
|
| 215 |
-
self.limit_manager.log_request(
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
last_error = e
|
| 219 |
continue
|
| 220 |
else:
|
| 221 |
logger.error(f"[GEMINI] Error creating embedding: {e}")
|
| 222 |
last_error = e
|
| 223 |
break
|
| 224 |
-
|
| 225 |
-
raise last_error or RuntimeError("No available Gemini API key/model")
|
|
|
|
| 6 |
from google.generativeai.types import GenerationConfig, HarmCategory, HarmBlockThreshold
|
| 7 |
|
| 8 |
from .request_limit_manager import RequestLimitManager
|
| 9 |
+
from .utils import _safe_truncate
|
|
|
|
|
|
|
| 10 |
from .config import get_settings
|
| 11 |
|
| 12 |
+
|
| 13 |
class GeminiResponseError(Exception):
|
| 14 |
"""Custom exception for non-retriable Gemini response issues like safety or token limits."""
|
| 15 |
+
|
| 16 |
def __init__(self, message, finish_reason=None, usage_metadata=None):
|
| 17 |
super().__init__(message)
|
| 18 |
self.finish_reason = finish_reason
|
| 19 |
self.usage_metadata = usage_metadata
|
| 20 |
|
| 21 |
def __str__(self):
|
| 22 |
+
usage_str = (
|
| 23 |
+
f"Prompt: {self.usage_metadata.prompt_token_count}, Candidates: {self.usage_metadata.candidates_token_count}, Total: {self.usage_metadata.total_token_count}"
|
| 24 |
+
if self.usage_metadata
|
| 25 |
+
else "N/A"
|
| 26 |
+
)
|
| 27 |
return f"{super().__str__()} (Finish Reason: {self.finish_reason}, Usage: {usage_str})"
|
| 28 |
|
| 29 |
+
|
| 30 |
class GeminiClient:
|
| 31 |
def __init__(self):
|
| 32 |
self.limit_manager = RequestLimitManager("gemini")
|
| 33 |
settings = get_settings()
|
| 34 |
+
num_keys = (
|
| 35 |
+
len(settings.gemini_api_keys.split(",")) if settings.gemini_api_keys else 0
|
| 36 |
+
)
|
| 37 |
+
num_models = (
|
| 38 |
+
len(settings.gemini_models.split(",")) if settings.gemini_models else 0
|
| 39 |
+
)
|
| 40 |
+
logger.info(
|
| 41 |
+
f"[GEMINI_INIT] Limiter is considering {num_keys} API keys and {num_models} models."
|
| 42 |
+
)
|
| 43 |
self._cached_model = None
|
| 44 |
self._cached_key = None
|
| 45 |
self._cached_model_instance = None
|
|
|
|
| 55 |
"""
|
| 56 |
Cache model instance để tránh recreate mỗi lần.
|
| 57 |
"""
|
| 58 |
+
if (
|
| 59 |
+
self._cached_key == key
|
| 60 |
+
and self._cached_model == model
|
| 61 |
+
and self._cached_model_instance is not None
|
| 62 |
+
):
|
| 63 |
return self._cached_model_instance
|
| 64 |
+
|
| 65 |
# Configure và tạo model instance mới
|
| 66 |
configure(api_key=key)
|
| 67 |
self._cached_model_instance = GenerativeModel(model)
|
| 68 |
self._cached_key = key
|
| 69 |
self._cached_model = model
|
| 70 |
+
|
| 71 |
+
logger.info(
|
| 72 |
+
f"[GEMINI] Created new model instance for key={key[:5]}...{key[-5:]} model={model}"
|
| 73 |
+
)
|
| 74 |
return self._cached_model_instance
|
| 75 |
|
| 76 |
def _clear_cache_if_needed(self, new_key: str, new_model: str):
|
| 77 |
"""
|
| 78 |
Chỉ clear cache khi key/model thực sự thay đổi.
|
| 79 |
"""
|
| 80 |
+
if self._cached_key != new_key or self._cached_model != new_model:
|
| 81 |
+
logger.info(
|
| 82 |
+
f"[GEMINI] Clearing cache due to key/model change: {self._cached_key}->{new_key}, {self._cached_model}->{new_model}"
|
| 83 |
+
)
|
| 84 |
self._cached_model_instance = None
|
| 85 |
self._cached_key = None
|
| 86 |
self._cached_model = None
|
|
|
|
| 88 |
def generate_text(self, prompt: str, **kwargs) -> str:
|
| 89 |
last_error = None
|
| 90 |
max_retries = 3
|
| 91 |
+
|
| 92 |
for attempt in range(max_retries):
|
| 93 |
try:
|
| 94 |
# Lấy current key/model từ manager
|
| 95 |
key, model = self.limit_manager.get_current_key_model()
|
| 96 |
+
|
| 97 |
# Sử dụng cached model instance
|
| 98 |
_model = self._get_model_instance(key, model)
|
| 99 |
+
|
| 100 |
response = _model.generate_content(
|
| 101 |
+
prompt, safety_settings=self.safety_settings, **kwargs
|
|
|
|
|
|
|
| 102 |
)
|
| 103 |
+
|
| 104 |
# Log toàn bộ nội dung response ở mức INFO để tiện gỡ lỗi
|
| 105 |
logger.info(f"[GEMINI][RAW_RESPONSE] {response}")
|
| 106 |
|
|
|
|
| 108 |
# 1. Kiểm tra response có hợp lệ không
|
| 109 |
if not response.candidates:
|
| 110 |
# Lỗi này nên được coi là lỗi tạm thời, thử lại với key/model khác
|
| 111 |
+
raise ValueError(
|
| 112 |
+
"Gemini response is missing 'candidates' field. Retrying..."
|
| 113 |
+
)
|
| 114 |
|
| 115 |
+
candidate = response.candidates[0]
|
| 116 |
+
finish_reason_name = getattr(
|
| 117 |
+
getattr(candidate, "finish_reason", None), "name", "UNKNOWN"
|
| 118 |
+
)
|
| 119 |
# Kiểm tra xem có nội dung thực sự không
|
| 120 |
# Sửa: Dùng getattr để tránh AttributeError nếu 'parts' không tồn tại
|
| 121 |
+
has_content = bool(
|
| 122 |
+
candidate.content and getattr(candidate.content, "parts", None)
|
| 123 |
+
)
|
| 124 |
|
| 125 |
# 2. Phân loại lỗi và xử lý
|
| 126 |
# Case 1: Lỗi nội dung không thể thử lại (SAFETY, MAX_TOKENS, etc.)
|
| 127 |
if finish_reason_name != "STOP":
|
| 128 |
+
usage_metadata = (
|
| 129 |
+
response.usage_metadata
|
| 130 |
+
if hasattr(response, "usage_metadata")
|
| 131 |
+
else None
|
| 132 |
+
)
|
| 133 |
error_message = f"Gemini response finished with non-OK reason: {finish_reason_name}."
|
| 134 |
raise GeminiResponseError(
|
| 135 |
+
error_message,
|
| 136 |
+
finish_reason=finish_reason_name,
|
| 137 |
+
usage_metadata=usage_metadata,
|
| 138 |
)
|
| 139 |
|
| 140 |
# Case 2: Lỗi có thể thử lại (STOP nhưng không có nội dung)
|
| 141 |
+
if (
|
| 142 |
+
not has_content
|
| 143 |
+
): # Tại đây, ta biết chắc chắn finish_reason_name là "STOP"
|
| 144 |
+
usage_metadata = (
|
| 145 |
+
response.usage_metadata
|
| 146 |
+
if hasattr(response, "usage_metadata")
|
| 147 |
+
else None
|
| 148 |
+
)
|
| 149 |
+
last_error = GeminiResponseError(
|
| 150 |
+
"Gemini response finished with STOP but has no content parts.",
|
| 151 |
+
finish_reason="STOP_NO_CONTENT",
|
| 152 |
+
usage_metadata=usage_metadata,
|
| 153 |
+
)
|
| 154 |
+
logger.warning(
|
| 155 |
+
f"[GEMINI] Model returned STOP with no content. Retrying with another key/model... (Attempt {attempt + 1}/{max_retries})"
|
| 156 |
+
)
|
| 157 |
+
self.limit_manager.log_request(
|
| 158 |
+
key, model, success=False, retry_delay=5
|
| 159 |
+
)
|
| 160 |
+
continue # Thử lại vòng lặp với key/model mới
|
| 161 |
|
| 162 |
# Case 3: Thành công (STOP và có nội dung)
|
| 163 |
self.limit_manager.log_request(key, model, success=True)
|
| 164 |
+
if hasattr(response, "usage_metadata"):
|
| 165 |
+
logger.info(
|
| 166 |
+
f"[GEMINI][USAGE] Prompt Token Count: {response.usage_metadata.prompt_token_count} - Candidate Token Count: {response.usage_metadata.candidates_token_count} - Total Token Count: {response.usage_metadata.total_token_count}"
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
try:
|
| 170 |
+
logger.info(
|
| 171 |
+
f"[GEMINI][TEXT_RESPONSE] {_safe_truncate(response.text)}"
|
| 172 |
+
)
|
| 173 |
return response.text
|
| 174 |
except ValueError as ve:
|
| 175 |
# Safety net: Nếu truy cập .text thất bại dù các kiểm tra trước đó đã qua,
|
| 176 |
# coi như đây là lỗi STOP_NO_CONTENT và ném ra để tầng trên xử lý.
|
| 177 |
+
usage_metadata = (
|
| 178 |
+
response.usage_metadata
|
| 179 |
+
if hasattr(response, "usage_metadata")
|
| 180 |
+
else None
|
| 181 |
+
)
|
| 182 |
raise GeminiResponseError(
|
| 183 |
f"Gemini response has no valid content part. Original error: {ve}",
|
| 184 |
+
finish_reason="STOP_NO_CONTENT",
|
| 185 |
+
usage_metadata=usage_metadata,
|
| 186 |
) from ve
|
| 187 |
# --- END: Cải tiến logic xử lý response ---
|
| 188 |
except GeminiResponseError as e:
|
|
|
|
| 191 |
raise e
|
| 192 |
except Exception as e:
|
| 193 |
import re
|
| 194 |
+
|
| 195 |
msg = str(e)
|
| 196 |
+
# Kiểm tra lỗi rate limit hoặc lỗi server (5xx)
|
| 197 |
+
is_rate_limit = "429" in msg or "rate limit" in msg.lower()
|
| 198 |
+
is_server_error = any(
|
| 199 |
+
code in msg for code in ["500", "502", "503", "504"]
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
if is_rate_limit or is_server_error:
|
| 203 |
+
retry_delay = 60 # Mặc định cho lỗi server
|
| 204 |
+
if is_rate_limit:
|
| 205 |
+
m = re.search(r"retry_delay.*?seconds: (\d+)", msg)
|
| 206 |
+
if m:
|
| 207 |
+
retry_delay = int(m.group(1))
|
| 208 |
+
|
| 209 |
+
# Log lỗi và chặn cặp key/model hiện tại trong một khoảng thời gian
|
| 210 |
+
self.limit_manager.log_request(
|
| 211 |
+
key, model, success=False, retry_delay=retry_delay
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
error_type = "Rate limit" if is_rate_limit else "Server"
|
| 215 |
+
logger.warning(
|
| 216 |
+
f"[GEMINI] {error_type} error hit, will retry with new key/model "
|
| 217 |
+
f"(attempt {attempt + 1}/{max_retries}). Error: {e}"
|
| 218 |
+
)
|
| 219 |
last_error = e
|
| 220 |
+
continue # Tiếp tục vòng lặp để thử key/model mới
|
| 221 |
else:
|
| 222 |
+
# Các lỗi khác không phải rate limit hoặc server error (vd: network timeout, invalid argument)
|
| 223 |
+
# sẽ được propagate lên để lớp llm.py/reranker.py xử lý retry với backoff.
|
| 224 |
+
logger.error(
|
| 225 |
+
f"[GEMINI] Unhandled error generating text, propagating up: {e}"
|
| 226 |
+
)
|
| 227 |
raise e
|
| 228 |
+
|
| 229 |
raise last_error or RuntimeError("No available Gemini API key/model")
|
| 230 |
|
| 231 |
def count_tokens(self, prompt: str) -> int:
|
|
|
|
| 237 |
logger.error(f"[GEMINI] Error counting tokens: {e}")
|
| 238 |
return 0
|
| 239 |
|
| 240 |
+
def create_embedding(
|
| 241 |
+
self, text: str, model: Optional[str] = None, task_type: str = "retrieval_query"
|
| 242 |
+
) -> list:
|
| 243 |
last_error = None
|
| 244 |
max_retries = 3
|
| 245 |
+
|
| 246 |
for attempt in range(max_retries):
|
| 247 |
try:
|
| 248 |
key, default_model = self.limit_manager.get_current_key_model()
|
| 249 |
+
|
| 250 |
# Ưu tiên model được truyền vào parameter, chỉ fallback về default_model nếu không có
|
| 251 |
use_model = model if model and model.strip() else default_model
|
| 252 |
+
|
| 253 |
if not use_model:
|
| 254 |
raise ValueError("No model specified for embedding")
|
| 255 |
+
|
| 256 |
+
logger.info(
|
| 257 |
+
f"[GEMINI][EMBEDDING] Using model={use_model} (requested={model}, default={default_model}), task_type={task_type}"
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
configure(api_key=key)
|
| 261 |
response = embed_content(
|
| 262 |
+
model=use_model, content=text, task_type=task_type
|
|
|
|
|
|
|
| 263 |
)
|
| 264 |
+
|
| 265 |
self.limit_manager.log_request(key, use_model, success=True)
|
| 266 |
+
logger.info(
|
| 267 |
+
f"[GEMINI][EMBEDDING][RAW_RESPONSE] {response['embedding'][:10]} ..... {response['embedding'][-10:]}"
|
| 268 |
+
)
|
| 269 |
+
return response["embedding"]
|
| 270 |
+
|
| 271 |
except Exception as e:
|
| 272 |
import re
|
| 273 |
+
|
| 274 |
msg = str(e)
|
| 275 |
if "429" in msg or "rate limit" in msg.lower():
|
| 276 |
retry_delay = 60
|
| 277 |
+
m_retry = re.search(r"retry_delay.*?seconds: (\d+)", msg)
|
| 278 |
if m_retry:
|
| 279 |
retry_delay = int(m_retry.group(1))
|
| 280 |
+
|
| 281 |
# Log failure và trigger scan cho key/model mới
|
| 282 |
+
self.limit_manager.log_request(
|
| 283 |
+
key, use_model, success=False, retry_delay=retry_delay
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
logger.warning(
|
| 287 |
+
f"[GEMINI] Rate limit hit in embedding, will retry with new key/model (attempt {attempt + 1}/{max_retries})"
|
| 288 |
+
)
|
| 289 |
last_error = e
|
| 290 |
continue
|
| 291 |
else:
|
| 292 |
logger.error(f"[GEMINI] Error creating embedding: {e}")
|
| 293 |
last_error = e
|
| 294 |
break
|
| 295 |
+
|
| 296 |
+
raise last_error or RuntimeError("No available Gemini API key/model")
|
app/reranker.py
CHANGED
|
@@ -2,27 +2,47 @@ from typing import List, Dict
|
|
| 2 |
|
| 3 |
from app.utils import timing_decorator_async
|
| 4 |
from .config import get_settings
|
| 5 |
-
from .gemini_client import GeminiClient
|
| 6 |
from loguru import logger
|
|
|
|
| 7 |
import asyncio
|
| 8 |
import hashlib
|
| 9 |
import time
|
|
|
|
|
|
|
| 10 |
# from .constants import BATCH_STATUS_MESSAGES
|
| 11 |
# from .utils import get_random_message
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
class Reranker:
|
| 14 |
def __init__(self):
|
| 15 |
settings = get_settings()
|
| 16 |
-
self.provider = getattr(settings,
|
| 17 |
-
self.model = getattr(settings,
|
| 18 |
-
if self.provider ==
|
| 19 |
self.client = GeminiClient()
|
| 20 |
# elif self.provider == 'openai':
|
| 21 |
# self.client = OpenAIClient(settings.openai_api_key, model=self.model)
|
| 22 |
# elif self.provider == 'cohere':
|
| 23 |
# self.client = CohereClient(settings.cohere_api_key, model=self.model)
|
| 24 |
else:
|
| 25 |
-
raise NotImplementedError(
|
|
|
|
|
|
|
| 26 |
# Cải thiện cache với TTL và quản lý memory
|
| 27 |
self._rerank_cache = {}
|
| 28 |
self._cache_ttl = 3600 # 1 giờ
|
|
@@ -35,65 +55,77 @@ class Reranker:
|
|
| 35 |
"""Tạo cache key từ query và docs."""
|
| 36 |
# Tối ưu hóa cache key generation
|
| 37 |
query_normalized = query.lower().strip()
|
| 38 |
-
doc_ids = [str(doc.get(
|
| 39 |
cache_content = query_normalized + "|".join(sorted(doc_ids))
|
| 40 |
return hashlib.md5(cache_content.encode()).hexdigest()
|
| 41 |
|
| 42 |
def _clean_cache(self):
|
| 43 |
"""Dọn dẹp cache cũ và quản lý memory."""
|
| 44 |
current_time = time.time()
|
| 45 |
-
|
| 46 |
# Xóa cache entries đã hết hạn
|
| 47 |
expired_keys = [
|
| 48 |
-
key
|
|
|
|
| 49 |
if current_time - timestamp > self._cache_ttl
|
| 50 |
]
|
| 51 |
-
|
| 52 |
for key in expired_keys:
|
| 53 |
del self._rerank_cache[key]
|
| 54 |
del self._cache_timestamps[key]
|
| 55 |
-
|
| 56 |
# Nếu cache vẫn quá lớn, xóa entries cũ nhất
|
| 57 |
if len(self._rerank_cache) > self._max_cache_size:
|
| 58 |
sorted_keys = sorted(
|
| 59 |
-
self._cache_timestamps.keys(),
|
| 60 |
-
key=lambda k: self._cache_timestamps[k]
|
| 61 |
)
|
| 62 |
-
|
| 63 |
# Xóa 20% cache entries cũ nhất
|
| 64 |
-
keys_to_remove = sorted_keys[:len(sorted_keys) // 5]
|
| 65 |
for key in keys_to_remove:
|
| 66 |
del self._rerank_cache[key]
|
| 67 |
del self._cache_timestamps[key]
|
| 68 |
-
|
| 69 |
-
logger.info(
|
|
|
|
|
|
|
| 70 |
|
| 71 |
def _get_cached_result(self, cache_key: str, min_score: float) -> List[Dict]:
|
| 72 |
"""Lấy kết quả từ cache nếu có và còn hợp lệ."""
|
| 73 |
if cache_key in self._rerank_cache:
|
| 74 |
current_time = time.time()
|
| 75 |
-
if
|
|
|
|
|
|
|
|
|
|
| 76 |
# Lọc theo điểm thay vì lấy top_k
|
| 77 |
cached_docs = self._rerank_cache[cache_key]
|
| 78 |
-
cached_result = [
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
return cached_result
|
| 81 |
else:
|
| 82 |
# Cache đã hết hạn, xóa
|
| 83 |
del self._rerank_cache[cache_key]
|
| 84 |
del self._cache_timestamps[cache_key]
|
| 85 |
-
|
| 86 |
return []
|
| 87 |
|
| 88 |
def _set_cached_result(self, cache_key: str, scored_docs: List[Dict]):
|
| 89 |
"""Lưu kết quả vào cache."""
|
| 90 |
self._rerank_cache[cache_key] = scored_docs
|
| 91 |
self._cache_timestamps[cache_key] = time.time()
|
| 92 |
-
|
| 93 |
# Dọn dẹp cache nếu cần
|
| 94 |
if len(self._rerank_cache) > self._max_cache_size:
|
| 95 |
self._clean_cache()
|
| 96 |
|
|
|
|
| 97 |
async def _batch_score_docs(self, query: str, docs: List[Dict]) -> List[Dict]:
|
| 98 |
"""
|
| 99 |
Score nhiều documents cùng lúc bằng một prompt duy nhất.
|
|
@@ -101,16 +133,16 @@ class Reranker:
|
|
| 101 |
"""
|
| 102 |
if not docs:
|
| 103 |
return []
|
| 104 |
-
|
| 105 |
# Không giới hạn content length, giữ nguyên nội dung luật
|
| 106 |
docs_content = []
|
| 107 |
for i, doc in enumerate(docs):
|
| 108 |
# tieude = (doc.get('tieude') or '').strip()
|
| 109 |
# noidung = (doc.get('noidung') or '').strip()
|
| 110 |
# content = f"{tieude} {noidung}".strip()
|
| 111 |
-
content = (doc.get(
|
| 112 |
docs_content.append(f"{i+1}. {content}")
|
| 113 |
-
|
| 114 |
batch_prompt = (
|
| 115 |
f"Đánh giá mức độ liên quan giữa câu hỏi và các đoạn luật sau:\n\n"
|
| 116 |
f"Câu hỏi: {query}\n\n"
|
|
@@ -118,63 +150,70 @@ class Reranker:
|
|
| 118 |
f"Trả về điểm số từ 0-10 cho từng đoạn, phân cách bằng dấu phẩy.\n"
|
| 119 |
f"Ví dụ: 8,5,7,3,9"
|
| 120 |
)
|
| 121 |
-
|
| 122 |
try:
|
| 123 |
-
if self.provider ==
|
| 124 |
loop = asyncio.get_event_loop()
|
| 125 |
-
logger.info(
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
logger.info(f"[RERANK] Got batch scores from Gemini: {response}")
|
| 128 |
-
|
| 129 |
-
# Cải thiện parsing scores
|
| 130 |
scores_text = str(response).strip()
|
|
|
|
|
|
|
|
|
|
| 131 |
scores = []
|
| 132 |
-
|
| 133 |
-
# Xử lý nhiều format response có thể có
|
| 134 |
-
if ',' in scores_text:
|
| 135 |
-
score_parts = scores_text.split(',')
|
| 136 |
-
elif ' ' in scores_text:
|
| 137 |
-
score_parts = scores_text.split()
|
| 138 |
-
else:
|
| 139 |
-
score_parts = scores_text.replace('.', ',').split(',')
|
| 140 |
-
|
| 141 |
-
for score_str in score_parts:
|
| 142 |
try:
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
score = max(0, min(10, score))
|
| 147 |
scores.append(score)
|
| 148 |
-
else:
|
| 149 |
-
scores.append(0)
|
| 150 |
except (ValueError, TypeError):
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
for i, doc in enumerate(docs):
|
| 157 |
-
doc[
|
| 158 |
-
|
| 159 |
-
logger.info(
|
|
|
|
|
|
|
| 160 |
return docs
|
| 161 |
-
|
| 162 |
else:
|
| 163 |
-
raise NotImplementedError(
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
for doc in docs:
|
| 168 |
-
doc[
|
| 169 |
return docs
|
| 170 |
|
|
|
|
| 171 |
async def _score_doc(self, query: str, doc: Dict) -> Dict:
|
| 172 |
"""
|
| 173 |
Score một document với query.
|
| 174 |
Không cắt bớt nội dung luật.
|
| 175 |
"""
|
| 176 |
-
tieude = (doc.get(
|
| 177 |
-
noidung = (doc.get(
|
| 178 |
content = f"{tieude} {noidung}".strip()
|
| 179 |
prompt = (
|
| 180 |
f"Đánh giá mức độ liên quan:\n"
|
|
@@ -183,14 +222,20 @@ class Reranker:
|
|
| 183 |
f"Điểm (0-10):"
|
| 184 |
)
|
| 185 |
try:
|
| 186 |
-
if self.provider ==
|
| 187 |
loop = asyncio.get_event_loop()
|
| 188 |
logger.info(f"[RERANK] Sending individual prompt to Gemini")
|
| 189 |
-
score_response = await loop.run_in_executor(
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
score_text = str(score_response).strip()
|
| 192 |
try:
|
| 193 |
-
clean_score =
|
|
|
|
|
|
|
| 194 |
if clean_score:
|
| 195 |
score = float(clean_score)
|
| 196 |
score = max(0, min(10, score))
|
|
@@ -198,44 +243,59 @@ class Reranker:
|
|
| 198 |
score = 0
|
| 199 |
except (ValueError, TypeError):
|
| 200 |
score = 0
|
| 201 |
-
doc[
|
| 202 |
return doc
|
| 203 |
else:
|
| 204 |
-
raise NotImplementedError(
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
return doc
|
| 209 |
|
| 210 |
@timing_decorator_async
|
| 211 |
-
async def rerank(
|
|
|
|
|
|
|
| 212 |
"""
|
| 213 |
Rerank docs theo độ liên quan với query, trả về các docs có điểm >= min_score.
|
| 214 |
Sử dụng batch processing và caching để tối ưu hiệu suất.
|
| 215 |
"""
|
| 216 |
-
logger.info(
|
| 217 |
-
|
|
|
|
|
|
|
| 218 |
if not docs:
|
| 219 |
return []
|
| 220 |
-
|
| 221 |
# Kiểm tra cache trước
|
| 222 |
cache_key = self._get_cache_key(query, docs)
|
| 223 |
cached_result = self._get_cached_result(cache_key, min_score)
|
| 224 |
-
|
| 225 |
if cached_result:
|
| 226 |
return cached_result
|
| 227 |
-
|
| 228 |
# Giới hạn số lượng docs để rerank - chỉ rerank top N docs có similarity cao nhất
|
| 229 |
max_docs_to_rerank = self.max_docs_to_rerank
|
| 230 |
docs_to_rerank = docs[:max_docs_to_rerank]
|
| 231 |
-
logger.info(
|
| 232 |
-
|
|
|
|
|
|
|
| 233 |
# Sử dụng batch processing thay vì individual scoring
|
| 234 |
try:
|
| 235 |
scored = await self._batch_score_docs(query, docs_to_rerank)
|
| 236 |
-
logger.info(
|
|
|
|
|
|
|
| 237 |
except Exception as e:
|
| 238 |
-
logger.error(
|
|
|
|
|
|
|
| 239 |
# Fallback về individual scoring nếu batch processing thất bại
|
| 240 |
scored = []
|
| 241 |
for doc in docs_to_rerank:
|
|
@@ -244,17 +304,19 @@ class Reranker:
|
|
| 244 |
scored.append(scored_doc)
|
| 245 |
except Exception as e:
|
| 246 |
logger.error(f"[RERANK] Error scoring individual doc: {e}")
|
| 247 |
-
doc[
|
| 248 |
scored.append(doc)
|
| 249 |
-
|
| 250 |
# Sort theo score
|
| 251 |
-
scored = sorted(scored, key=lambda x: x.get(
|
| 252 |
-
|
| 253 |
# Lọc theo min_score
|
| 254 |
-
result = [doc for doc in scored if doc.get(
|
| 255 |
-
|
| 256 |
# Cache kết quả đã được chấm điểm (toàn bộ, trước khi lọc)
|
| 257 |
self._set_cached_result(cache_key, scored)
|
| 258 |
-
|
| 259 |
-
logger.info(
|
| 260 |
-
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from app.utils import timing_decorator_async
|
| 4 |
from .config import get_settings
|
| 5 |
+
from .gemini_client import GeminiClient, GeminiResponseError
|
| 6 |
from loguru import logger
|
| 7 |
+
import re
|
| 8 |
import asyncio
|
| 9 |
import hashlib
|
| 10 |
import time
|
| 11 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
| 12 |
+
|
| 13 |
# from .constants import BATCH_STATUS_MESSAGES
|
| 14 |
# from .utils import get_random_message
|
| 15 |
|
| 16 |
+
# --- Retry decorator cho các lỗi tạm thời của Reranker (network, server-side) ---
|
| 17 |
+
retry_on_rerank_transient_error = retry(
|
| 18 |
+
stop=stop_after_attempt(4), # 1 lần gọi gốc + 3 lần thử lại
|
| 19 |
+
wait=wait_exponential(multiplier=5, min=10, max=60), # Chờ 10s, 20s, 40s
|
| 20 |
+
retry=lambda retry_state: (
|
| 21 |
+
retry_state.outcome.failed
|
| 22 |
+
and not isinstance(retry_state.outcome.exception(), GeminiResponseError)
|
| 23 |
+
),
|
| 24 |
+
before_sleep=lambda retry_state: logger.warning(
|
| 25 |
+
f"[RERANK][RETRY] Rerank call failed with transient error, retrying... "
|
| 26 |
+
f"Attempt: {retry_state.attempt_number}, Error: {retry_state.outcome.exception()}"
|
| 27 |
+
),
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
class Reranker:
|
| 32 |
def __init__(self):
|
| 33 |
settings = get_settings()
|
| 34 |
+
self.provider = getattr(settings, "rerank_provider", settings.llm_provider)
|
| 35 |
+
self.model = getattr(settings, "rerank_model", settings.llm_model)
|
| 36 |
+
if self.provider == "gemini":
|
| 37 |
self.client = GeminiClient()
|
| 38 |
# elif self.provider == 'openai':
|
| 39 |
# self.client = OpenAIClient(settings.openai_api_key, model=self.model)
|
| 40 |
# elif self.provider == 'cohere':
|
| 41 |
# self.client = CohereClient(settings.cohere_api_key, model=self.model)
|
| 42 |
else:
|
| 43 |
+
raise NotImplementedError(
|
| 44 |
+
f"Rerank provider {self.provider} not supported yet."
|
| 45 |
+
)
|
| 46 |
# Cải thiện cache với TTL và quản lý memory
|
| 47 |
self._rerank_cache = {}
|
| 48 |
self._cache_ttl = 3600 # 1 giờ
|
|
|
|
| 55 |
"""Tạo cache key từ query và docs."""
|
| 56 |
# Tối ưu hóa cache key generation
|
| 57 |
query_normalized = query.lower().strip()
|
| 58 |
+
doc_ids = [str(doc.get("id", "")) for doc in docs[:15]] # Chỉ cache top 15 docs
|
| 59 |
cache_content = query_normalized + "|".join(sorted(doc_ids))
|
| 60 |
return hashlib.md5(cache_content.encode()).hexdigest()
|
| 61 |
|
| 62 |
def _clean_cache(self):
|
| 63 |
"""Dọn dẹp cache cũ và quản lý memory."""
|
| 64 |
current_time = time.time()
|
| 65 |
+
|
| 66 |
# Xóa cache entries đã hết hạn
|
| 67 |
expired_keys = [
|
| 68 |
+
key
|
| 69 |
+
for key, timestamp in self._cache_timestamps.items()
|
| 70 |
if current_time - timestamp > self._cache_ttl
|
| 71 |
]
|
| 72 |
+
|
| 73 |
for key in expired_keys:
|
| 74 |
del self._rerank_cache[key]
|
| 75 |
del self._cache_timestamps[key]
|
| 76 |
+
|
| 77 |
# Nếu cache vẫn quá lớn, xóa entries cũ nhất
|
| 78 |
if len(self._rerank_cache) > self._max_cache_size:
|
| 79 |
sorted_keys = sorted(
|
| 80 |
+
self._cache_timestamps.keys(), key=lambda k: self._cache_timestamps[k]
|
|
|
|
| 81 |
)
|
| 82 |
+
|
| 83 |
# Xóa 20% cache entries cũ nhất
|
| 84 |
+
keys_to_remove = sorted_keys[: len(sorted_keys) // 5]
|
| 85 |
for key in keys_to_remove:
|
| 86 |
del self._rerank_cache[key]
|
| 87 |
del self._cache_timestamps[key]
|
| 88 |
+
|
| 89 |
+
logger.info(
|
| 90 |
+
f"[RERANK] Cleaned cache: removed {len(keys_to_remove)} old entries"
|
| 91 |
+
)
|
| 92 |
|
| 93 |
def _get_cached_result(self, cache_key: str, min_score: float) -> List[Dict]:
|
| 94 |
"""Lấy kết quả từ cache nếu có và còn hợp lệ."""
|
| 95 |
if cache_key in self._rerank_cache:
|
| 96 |
current_time = time.time()
|
| 97 |
+
if (
|
| 98 |
+
current_time - self._cache_timestamps.get(cache_key, 0)
|
| 99 |
+
<= self._cache_ttl
|
| 100 |
+
):
|
| 101 |
# Lọc theo điểm thay vì lấy top_k
|
| 102 |
cached_docs = self._rerank_cache[cache_key]
|
| 103 |
+
cached_result = [
|
| 104 |
+
doc
|
| 105 |
+
for doc in cached_docs
|
| 106 |
+
if doc.get("rerank_score", 0) >= min_score
|
| 107 |
+
]
|
| 108 |
+
logger.info(
|
| 109 |
+
f"[RERANK] Cache hit for query, returning {len(cached_result)} cached results with score >= {min_score}"
|
| 110 |
+
)
|
| 111 |
return cached_result
|
| 112 |
else:
|
| 113 |
# Cache đã hết hạn, xóa
|
| 114 |
del self._rerank_cache[cache_key]
|
| 115 |
del self._cache_timestamps[cache_key]
|
| 116 |
+
|
| 117 |
return []
|
| 118 |
|
| 119 |
def _set_cached_result(self, cache_key: str, scored_docs: List[Dict]):
|
| 120 |
"""Lưu kết quả vào cache."""
|
| 121 |
self._rerank_cache[cache_key] = scored_docs
|
| 122 |
self._cache_timestamps[cache_key] = time.time()
|
| 123 |
+
|
| 124 |
# Dọn dẹp cache nếu cần
|
| 125 |
if len(self._rerank_cache) > self._max_cache_size:
|
| 126 |
self._clean_cache()
|
| 127 |
|
| 128 |
+
@retry_on_rerank_transient_error
|
| 129 |
async def _batch_score_docs(self, query: str, docs: List[Dict]) -> List[Dict]:
|
| 130 |
"""
|
| 131 |
Score nhiều documents cùng lúc bằng một prompt duy nhất.
|
|
|
|
| 133 |
"""
|
| 134 |
if not docs:
|
| 135 |
return []
|
| 136 |
+
|
| 137 |
# Không giới hạn content length, giữ nguyên nội dung luật
|
| 138 |
docs_content = []
|
| 139 |
for i, doc in enumerate(docs):
|
| 140 |
# tieude = (doc.get('tieude') or '').strip()
|
| 141 |
# noidung = (doc.get('noidung') or '').strip()
|
| 142 |
# content = f"{tieude} {noidung}".strip()
|
| 143 |
+
content = (doc.get("fullcontent") or "").strip()
|
| 144 |
docs_content.append(f"{i+1}. {content}")
|
| 145 |
+
|
| 146 |
batch_prompt = (
|
| 147 |
f"Đánh giá mức độ liên quan giữa câu hỏi và các đoạn luật sau:\n\n"
|
| 148 |
f"Câu hỏi: {query}\n\n"
|
|
|
|
| 150 |
f"Trả về điểm số từ 0-10 cho từng đoạn, phân cách bằng dấu phẩy.\n"
|
| 151 |
f"Ví dụ: 8,5,7,3,9"
|
| 152 |
)
|
| 153 |
+
|
| 154 |
try:
|
| 155 |
+
if self.provider == "gemini":
|
| 156 |
loop = asyncio.get_event_loop()
|
| 157 |
+
logger.info(
|
| 158 |
+
f"[RERANK] Sending batch prompt to Gemini for {len(docs)} docs"
|
| 159 |
+
)
|
| 160 |
+
response = await loop.run_in_executor(
|
| 161 |
+
None, self.client.generate_text, batch_prompt
|
| 162 |
+
)
|
| 163 |
logger.info(f"[RERANK] Got batch scores from Gemini: {response}")
|
| 164 |
+
|
| 165 |
+
# Cải thiện parsing scores bằng regex để chỉ lấy các số hợp lệ
|
| 166 |
scores_text = str(response).strip()
|
| 167 |
+
# Tìm tất cả các chuỗi số (integer hoặc float) trong văn bản trả về
|
| 168 |
+
score_strings = re.findall(r"\b\d+(?:\.\d+)?\b", scores_text)
|
| 169 |
+
|
| 170 |
scores = []
|
| 171 |
+
for s in score_strings:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
try:
|
| 173 |
+
score = float(s)
|
| 174 |
+
# Chỉ chấp nhận các điểm số trong khoảng 0-10 để tăng độ chính xác
|
| 175 |
+
if 0 <= score <= 10:
|
|
|
|
| 176 |
scores.append(score)
|
|
|
|
|
|
|
| 177 |
except (ValueError, TypeError):
|
| 178 |
+
# Bỏ qua các giá trị không phải là số hợp lệ
|
| 179 |
+
continue
|
| 180 |
+
|
| 181 |
+
# Đảm bảo số lượng điểm khớp với số lượng văn bản
|
| 182 |
+
# Nếu thiếu, thêm điểm 0. Nếu thừa, cắt bớt.
|
| 183 |
+
if len(scores) < len(docs):
|
| 184 |
+
scores.extend([0.0] * (len(docs) - len(scores)))
|
| 185 |
+
else:
|
| 186 |
+
scores = scores[: len(docs)]
|
| 187 |
+
|
| 188 |
for i, doc in enumerate(docs):
|
| 189 |
+
doc["rerank_score"] = scores[i]
|
| 190 |
+
|
| 191 |
+
logger.info(
|
| 192 |
+
f"[RERANK] Successfully scored {len(docs)} docs with scores: {scores}"
|
| 193 |
+
)
|
| 194 |
return docs
|
| 195 |
+
|
| 196 |
else:
|
| 197 |
+
raise NotImplementedError(
|
| 198 |
+
f"Rerank provider {self.provider} not supported yet in batch method."
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
except GeminiResponseError as e:
|
| 202 |
+
# Lỗi nội dung không thể retry (safety, max_tokens), gán điểm 0 và trả về.
|
| 203 |
+
# Các lỗi khác (network, 500) sẽ được decorator retry.
|
| 204 |
+
logger.error(f"[RERANK] Lỗi nội dung không thể retry khi batch score: {e}")
|
| 205 |
for doc in docs:
|
| 206 |
+
doc["rerank_score"] = 0
|
| 207 |
return docs
|
| 208 |
|
| 209 |
+
@retry_on_rerank_transient_error
|
| 210 |
async def _score_doc(self, query: str, doc: Dict) -> Dict:
|
| 211 |
"""
|
| 212 |
Score một document với query.
|
| 213 |
Không cắt bớt nội dung luật.
|
| 214 |
"""
|
| 215 |
+
tieude = (doc.get("tieude") or "").strip()
|
| 216 |
+
noidung = (doc.get("noidung") or "").strip()
|
| 217 |
content = f"{tieude} {noidung}".strip()
|
| 218 |
prompt = (
|
| 219 |
f"Đánh giá mức độ liên quan:\n"
|
|
|
|
| 222 |
f"Điểm (0-10):"
|
| 223 |
)
|
| 224 |
try:
|
| 225 |
+
if self.provider == "gemini":
|
| 226 |
loop = asyncio.get_event_loop()
|
| 227 |
logger.info(f"[RERANK] Sending individual prompt to Gemini")
|
| 228 |
+
score_response = await loop.run_in_executor(
|
| 229 |
+
None, self.client.generate_text, prompt
|
| 230 |
+
)
|
| 231 |
+
logger.info(
|
| 232 |
+
f"[RERANK] Got individual score from Gemini: {score_response}"
|
| 233 |
+
)
|
| 234 |
score_text = str(score_response).strip()
|
| 235 |
try:
|
| 236 |
+
clean_score = "".join(
|
| 237 |
+
c for c in score_text if c.isdigit() or c == "."
|
| 238 |
+
)
|
| 239 |
if clean_score:
|
| 240 |
score = float(clean_score)
|
| 241 |
score = max(0, min(10, score))
|
|
|
|
| 243 |
score = 0
|
| 244 |
except (ValueError, TypeError):
|
| 245 |
score = 0
|
| 246 |
+
doc["rerank_score"] = score
|
| 247 |
return doc
|
| 248 |
else:
|
| 249 |
+
raise NotImplementedError(
|
| 250 |
+
f"Rerank provider {self.provider} not supported yet in rerank method."
|
| 251 |
+
)
|
| 252 |
+
except GeminiResponseError as e:
|
| 253 |
+
# Lỗi nội dung không thể retry (safety, max_tokens), gán điểm 0 và trả về.
|
| 254 |
+
logger.error(
|
| 255 |
+
f"[RERANK] Lỗi nội dung không thể retry khi tính score: {e} | doc: {doc}"
|
| 256 |
+
)
|
| 257 |
+
doc["rerank_score"] = 0
|
| 258 |
return doc
|
| 259 |
|
| 260 |
@timing_decorator_async
|
| 261 |
+
async def rerank(
|
| 262 |
+
self, query: str, docs: List[Dict], min_score: float = 7.0
|
| 263 |
+
) -> List[Dict]:
|
| 264 |
"""
|
| 265 |
Rerank docs theo độ liên quan với query, trả về các docs có điểm >= min_score.
|
| 266 |
Sử dụng batch processing và caching để tối ưu hiệu suất.
|
| 267 |
"""
|
| 268 |
+
logger.info(
|
| 269 |
+
f"[RERANK] Start rerank for query: {query} | docs: {len(docs)} | min_score: {min_score}"
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
if not docs:
|
| 273 |
return []
|
| 274 |
+
|
| 275 |
# Kiểm tra cache trước
|
| 276 |
cache_key = self._get_cache_key(query, docs)
|
| 277 |
cached_result = self._get_cached_result(cache_key, min_score)
|
| 278 |
+
|
| 279 |
if cached_result:
|
| 280 |
return cached_result
|
| 281 |
+
|
| 282 |
# Giới hạn số lượng docs để rerank - chỉ rerank top N docs có similarity cao nhất
|
| 283 |
max_docs_to_rerank = self.max_docs_to_rerank
|
| 284 |
docs_to_rerank = docs[:max_docs_to_rerank]
|
| 285 |
+
logger.info(
|
| 286 |
+
f"[RERANK] Will rerank {len(docs_to_rerank)} docs (limited to top {max_docs_to_rerank})"
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
# Sử dụng batch processing thay vì individual scoring
|
| 290 |
try:
|
| 291 |
scored = await self._batch_score_docs(query, docs_to_rerank)
|
| 292 |
+
logger.info(
|
| 293 |
+
f"[RERANK] Batch processing completed, scored {len(scored)} docs"
|
| 294 |
+
)
|
| 295 |
except Exception as e:
|
| 296 |
+
logger.error(
|
| 297 |
+
f"[RERANK] Batch processing failed, falling back to individual scoring: {e}"
|
| 298 |
+
)
|
| 299 |
# Fallback về individual scoring nếu batch processing thất bại
|
| 300 |
scored = []
|
| 301 |
for doc in docs_to_rerank:
|
|
|
|
| 304 |
scored.append(scored_doc)
|
| 305 |
except Exception as e:
|
| 306 |
logger.error(f"[RERANK] Error scoring individual doc: {e}")
|
| 307 |
+
doc["rerank_score"] = 0
|
| 308 |
scored.append(doc)
|
| 309 |
+
|
| 310 |
# Sort theo score
|
| 311 |
+
scored = sorted(scored, key=lambda x: x.get("rerank_score", 0), reverse=True)
|
| 312 |
+
|
| 313 |
# Lọc theo min_score
|
| 314 |
+
result = [doc for doc in scored if doc.get("rerank_score", 0) >= min_score]
|
| 315 |
+
|
| 316 |
# Cache kết quả đã được chấm điểm (toàn bộ, trước khi lọc)
|
| 317 |
self._set_cached_result(cache_key, scored)
|
| 318 |
+
|
| 319 |
+
logger.info(
|
| 320 |
+
f"[RERANK] Found {len(result)} docs with score >= {min_score}. Top results: {result[:2]}...{result[-2:] if len(result) > 2 else ''}"
|
| 321 |
+
)
|
| 322 |
+
return result
|