|
|
from typing import List, Dict |
|
|
|
|
|
from app.utils import timing_decorator_async |
|
|
from .config import get_settings |
|
|
from .gemini_client import GeminiClient, GeminiResponseError |
|
|
from loguru import logger |
|
|
import re |
|
|
import asyncio |
|
|
import hashlib |
|
|
import time |
|
|
from tenacity import retry, stop_after_attempt, wait_exponential |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retry_on_rerank_transient_error = retry( |
|
|
stop=stop_after_attempt(4), |
|
|
wait=wait_exponential(multiplier=5, min=10, max=60), |
|
|
retry=lambda retry_state: ( |
|
|
retry_state.outcome.failed |
|
|
and not isinstance(retry_state.outcome.exception(), GeminiResponseError) |
|
|
), |
|
|
before_sleep=lambda retry_state: logger.warning( |
|
|
f"[RERANK][RETRY] Rerank call failed with transient error, retrying... " |
|
|
f"Attempt: {retry_state.attempt_number}, Error: {retry_state.outcome.exception()}" |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
class Reranker: |
|
|
def __init__(self): |
|
|
settings = get_settings() |
|
|
self.provider = getattr(settings, "rerank_provider", settings.llm_provider) |
|
|
self.model = getattr(settings, "rerank_model", settings.llm_model) |
|
|
if self.provider == "gemini": |
|
|
self.client = GeminiClient() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
raise NotImplementedError( |
|
|
f"Rerank provider {self.provider} not supported yet." |
|
|
) |
|
|
|
|
|
self._rerank_cache = {} |
|
|
self._cache_ttl = 3600 |
|
|
self._max_cache_size = 200 |
|
|
self._cache_timestamps = {} |
|
|
|
|
|
self.max_docs_to_rerank = settings.max_docs_to_rerank |
|
|
|
|
|
def _get_cache_key(self, query: str, docs: List[Dict]) -> str: |
|
|
"""Tạo cache key từ query và docs.""" |
|
|
|
|
|
query_normalized = query.lower().strip() |
|
|
doc_ids = [str(doc.get("id", "")) for doc in docs[:15]] |
|
|
cache_content = query_normalized + "|".join(sorted(doc_ids)) |
|
|
return hashlib.md5(cache_content.encode()).hexdigest() |
|
|
|
|
|
def _clean_cache(self): |
|
|
"""Dọn dẹp cache cũ và quản lý memory.""" |
|
|
current_time = time.time() |
|
|
|
|
|
|
|
|
expired_keys = [ |
|
|
key |
|
|
for key, timestamp in self._cache_timestamps.items() |
|
|
if current_time - timestamp > self._cache_ttl |
|
|
] |
|
|
|
|
|
for key in expired_keys: |
|
|
del self._rerank_cache[key] |
|
|
del self._cache_timestamps[key] |
|
|
|
|
|
|
|
|
if len(self._rerank_cache) > self._max_cache_size: |
|
|
sorted_keys = sorted( |
|
|
self._cache_timestamps.keys(), key=lambda k: self._cache_timestamps[k] |
|
|
) |
|
|
|
|
|
|
|
|
keys_to_remove = sorted_keys[: len(sorted_keys) // 5] |
|
|
for key in keys_to_remove: |
|
|
del self._rerank_cache[key] |
|
|
del self._cache_timestamps[key] |
|
|
|
|
|
logger.info( |
|
|
f"[RERANK] Cleaned cache: removed {len(keys_to_remove)} old entries" |
|
|
) |
|
|
|
|
|
def _get_cached_result(self, cache_key: str, min_score: float) -> List[Dict]: |
|
|
"""Lấy kết quả từ cache nếu có và còn hợp lệ.""" |
|
|
if cache_key in self._rerank_cache: |
|
|
current_time = time.time() |
|
|
if ( |
|
|
current_time - self._cache_timestamps.get(cache_key, 0) |
|
|
<= self._cache_ttl |
|
|
): |
|
|
|
|
|
cached_docs = self._rerank_cache[cache_key] |
|
|
cached_result = [ |
|
|
doc |
|
|
for doc in cached_docs |
|
|
if doc.get("rerank_score", 0) >= min_score |
|
|
] |
|
|
logger.info( |
|
|
f"[RERANK] Cache hit for query, returning {len(cached_result)} cached results with score >= {min_score}" |
|
|
) |
|
|
return cached_result |
|
|
else: |
|
|
|
|
|
del self._rerank_cache[cache_key] |
|
|
del self._cache_timestamps[cache_key] |
|
|
|
|
|
return [] |
|
|
|
|
|
def _set_cached_result(self, cache_key: str, scored_docs: List[Dict]): |
|
|
"""Lưu kết quả vào cache.""" |
|
|
self._rerank_cache[cache_key] = scored_docs |
|
|
self._cache_timestamps[cache_key] = time.time() |
|
|
|
|
|
|
|
|
if len(self._rerank_cache) > self._max_cache_size: |
|
|
self._clean_cache() |
|
|
|
|
|
@retry_on_rerank_transient_error |
|
|
async def _batch_score_docs(self, query: str, docs: List[Dict]) -> List[Dict]: |
|
|
""" |
|
|
Score nhiều documents cùng lúc bằng một prompt duy nhất. |
|
|
Không cắt bớt nội dung luật. |
|
|
""" |
|
|
if not docs: |
|
|
return [] |
|
|
|
|
|
|
|
|
docs_content = [] |
|
|
for i, doc in enumerate(docs): |
|
|
|
|
|
|
|
|
|
|
|
content = (doc.get("fullcontent") or "").strip() |
|
|
docs_content.append(f"{i+1}. {content}") |
|
|
|
|
|
|
|
|
batch_prompt = ( |
|
|
"Bạn là một hệ thống đánh giá. Đánh giá mức độ liên quan giữa câu hỏi và các đoạn luật được đánh số sau đây.\n" |
|
|
"QUY TẮC:\n" |
|
|
"1. Chỉ trả về MỘT DÒNG DUY NHẤT.\n" |
|
|
"2. Dòng đó CHỈ chứa danh sách các điểm số (từ 0 đến 10), mỗi điểm tương ứng với một đoạn luật.\n" |
|
|
"3. Các điểm số phải được phân cách bởi dấu phẩy.\n" |
|
|
"4. KHÔNG giải thích, KHÔNG định dạng markdown, KHÔNG thêm bất kỳ văn bản nào khác.\n\n" |
|
|
f"Câu hỏi: {query}\n\n" |
|
|
"Các đoạn luật:\n" + "\n".join(docs_content) + "\n\n" |
|
|
"ĐIỂM SỐ:" |
|
|
) |
|
|
|
|
|
try: |
|
|
if self.provider == "gemini": |
|
|
loop = asyncio.get_event_loop() |
|
|
logger.info( |
|
|
f"[RERANK] Sending batch prompt to Gemini for {len(docs)} docs" |
|
|
) |
|
|
response = await loop.run_in_executor( |
|
|
None, self.client.generate_text, batch_prompt |
|
|
) |
|
|
logger.info(f"[RERANK] Got batch scores from Gemini: {response}") |
|
|
|
|
|
|
|
|
scores_text = str(response).strip() |
|
|
scores_line = "" |
|
|
score_strings = [] |
|
|
|
|
|
|
|
|
lines = scores_text.split("\n") |
|
|
|
|
|
|
|
|
|
|
|
for line in reversed(lines): |
|
|
line = line.strip() |
|
|
if line and re.fullmatch(r"[0-9.,\s]+", line): |
|
|
scores_line = line |
|
|
logger.debug( |
|
|
f"[RERANK] Found pure score line (best case): '{scores_line}'" |
|
|
) |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
if not scores_line: |
|
|
keyword_regex = ( |
|
|
r"(?i)(?:Kết quả|Scores|Trả về|Điểm số)[\s\*:]*([0-9.,\s]+)$" |
|
|
) |
|
|
for line in reversed(lines): |
|
|
line = line.strip() |
|
|
match = re.search(keyword_regex, line) |
|
|
if match: |
|
|
scores_line = match.group(1).strip() |
|
|
logger.debug( |
|
|
f"[RERANK] Found scores line using keyword regex: '{scores_line}'" |
|
|
) |
|
|
break |
|
|
|
|
|
if scores_line: |
|
|
|
|
|
score_strings = re.findall(r"\b\d+(?:\.\d+)?\b", scores_line) |
|
|
else: |
|
|
|
|
|
logger.warning( |
|
|
"[RERANK] Could not find a dedicated score line. Falling back to parsing all numbers from response." |
|
|
) |
|
|
score_strings = re.findall(r"\b\d+(?:\.\d+)?\b", scores_text) |
|
|
|
|
|
|
|
|
scores = [] |
|
|
for s in score_strings: |
|
|
try: |
|
|
score = float(s) |
|
|
|
|
|
if 0 <= score <= 10: |
|
|
scores.append(score) |
|
|
except (ValueError, TypeError): |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
if len(scores) < len(docs): |
|
|
scores.extend([0.0] * (len(docs) - len(scores))) |
|
|
else: |
|
|
scores = scores[: len(docs)] |
|
|
|
|
|
for i, doc in enumerate(docs): |
|
|
doc["rerank_score"] = scores[i] |
|
|
|
|
|
logger.info( |
|
|
f"[RERANK] Successfully scored {len(docs)} docs with scores: {scores}" |
|
|
) |
|
|
return docs |
|
|
|
|
|
else: |
|
|
raise NotImplementedError( |
|
|
f"Rerank provider {self.provider} not supported yet in batch method." |
|
|
) |
|
|
|
|
|
except GeminiResponseError as e: |
|
|
|
|
|
|
|
|
logger.error(f"[RERANK] Lỗi nội dung không thể retry khi batch score: {e}") |
|
|
for doc in docs: |
|
|
doc["rerank_score"] = 0 |
|
|
return docs |
|
|
|
|
|
@retry_on_rerank_transient_error |
|
|
async def _score_doc(self, query: str, doc: Dict) -> Dict: |
|
|
""" |
|
|
Score một document với query. |
|
|
Không cắt bớt nội dung luật. |
|
|
""" |
|
|
tieude = (doc.get("tieude") or "").strip() |
|
|
noidung = (doc.get("noidung") or "").strip() |
|
|
content = f"{tieude} {noidung}".strip() |
|
|
prompt = ( |
|
|
f"Đánh giá mức độ liên quan:\n" |
|
|
f"Luật: {content}\n" |
|
|
f"Hỏi: {query}\n" |
|
|
f"Điểm (0-10):" |
|
|
) |
|
|
try: |
|
|
if self.provider == "gemini": |
|
|
loop = asyncio.get_event_loop() |
|
|
logger.info(f"[RERANK] Sending individual prompt to Gemini") |
|
|
score_response = await loop.run_in_executor( |
|
|
None, self.client.generate_text, prompt |
|
|
) |
|
|
logger.info( |
|
|
f"[RERANK] Got individual score from Gemini: {score_response}" |
|
|
) |
|
|
score_text = str(score_response).strip() |
|
|
try: |
|
|
clean_score = "".join( |
|
|
c for c in score_text if c.isdigit() or c == "." |
|
|
) |
|
|
if clean_score: |
|
|
score = float(clean_score) |
|
|
score = max(0, min(10, score)) |
|
|
else: |
|
|
score = 0 |
|
|
except (ValueError, TypeError): |
|
|
score = 0 |
|
|
doc["rerank_score"] = score |
|
|
return doc |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
f"Rerank provider {self.provider} not supported yet in rerank method." |
|
|
) |
|
|
except GeminiResponseError as e: |
|
|
|
|
|
logger.error( |
|
|
f"[RERANK] Lỗi nội dung không thể retry khi tính score: {e} | doc: {doc}" |
|
|
) |
|
|
doc["rerank_score"] = 0 |
|
|
return doc |
|
|
|
|
|
@timing_decorator_async |
|
|
async def rerank( |
|
|
self, query: str, docs: List[Dict], min_score: float = 7.0 |
|
|
) -> List[Dict]: |
|
|
""" |
|
|
Rerank docs theo độ liên quan với query, trả về các docs có điểm >= min_score. |
|
|
Sử dụng batch processing và caching để tối ưu hiệu suất. |
|
|
""" |
|
|
logger.info( |
|
|
f"[RERANK] Start rerank for query: {query} | docs: {len(docs)} | min_score: {min_score}" |
|
|
) |
|
|
|
|
|
if not docs: |
|
|
return [] |
|
|
|
|
|
|
|
|
cache_key = self._get_cache_key(query, docs) |
|
|
cached_result = self._get_cached_result(cache_key, min_score) |
|
|
|
|
|
if cached_result: |
|
|
return cached_result |
|
|
|
|
|
|
|
|
max_docs_to_rerank = self.max_docs_to_rerank |
|
|
docs_to_rerank = docs[:max_docs_to_rerank] |
|
|
logger.info( |
|
|
f"[RERANK] Will rerank {len(docs_to_rerank)} docs (limited to top {max_docs_to_rerank})" |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
scored = await self._batch_score_docs(query, docs_to_rerank) |
|
|
logger.info( |
|
|
f"[RERANK] Batch processing completed, scored {len(scored)} docs" |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error( |
|
|
f"[RERANK] Batch processing failed, falling back to individual scoring: {e}" |
|
|
) |
|
|
|
|
|
scored = [] |
|
|
for doc in docs_to_rerank: |
|
|
try: |
|
|
scored_doc = await self._score_doc(query, doc) |
|
|
scored.append(scored_doc) |
|
|
except Exception as e: |
|
|
logger.error(f"[RERANK] Error scoring individual doc: {e}") |
|
|
doc["rerank_score"] = 0 |
|
|
scored.append(doc) |
|
|
|
|
|
|
|
|
scored = sorted(scored, key=lambda x: x.get("rerank_score", 0), reverse=True) |
|
|
|
|
|
|
|
|
result = [doc for doc in scored if doc.get("rerank_score", 0) >= min_score] |
|
|
|
|
|
|
|
|
self._set_cached_result(cache_key, scored) |
|
|
|
|
|
logger.info( |
|
|
f"[RERANK] Found {len(result)} docs with score >= {min_score}. Top results: {result[:2]}...{result[-2:] if len(result) > 2 else ''}" |
|
|
) |
|
|
return result |
|
|
|