| from typing import List, Dict |
| from .config import get_settings |
| from .gemini_client import GeminiClient |
| from loguru import logger |
| import asyncio |
| import hashlib |
| import time |
| |
| |
|
|
| class Reranker: |
| def __init__(self): |
| settings = get_settings() |
| self.provider = getattr(settings, 'rerank_provider', settings.llm_provider) |
| self.model = getattr(settings, 'rerank_model', settings.llm_model) |
| if self.provider == 'gemini': |
| self.client = GeminiClient() |
| |
| |
| |
| |
| else: |
| raise NotImplementedError(f"Rerank provider {self.provider} not supported yet.") |
| |
| self._rerank_cache = {} |
| self._cache_ttl = 3600 |
| self._max_cache_size = 200 |
| self._cache_timestamps = {} |
| |
| self.max_docs_to_rerank = settings.max_docs_to_rerank |
|
|
| def _get_cache_key(self, query: str, docs: List[Dict]) -> str: |
| """Tạo cache key từ query và docs.""" |
| |
| query_normalized = query.lower().strip() |
| doc_ids = [str(doc.get('id', '')) for doc in docs[:15]] |
| cache_content = query_normalized + "|".join(sorted(doc_ids)) |
| return hashlib.md5(cache_content.encode()).hexdigest() |
|
|
| def _clean_cache(self): |
| """Dọn dẹp cache cũ và quản lý memory.""" |
| current_time = time.time() |
| |
| |
| expired_keys = [ |
| key for key, timestamp in self._cache_timestamps.items() |
| if current_time - timestamp > self._cache_ttl |
| ] |
| |
| for key in expired_keys: |
| del self._rerank_cache[key] |
| del self._cache_timestamps[key] |
| |
| |
| if len(self._rerank_cache) > self._max_cache_size: |
| sorted_keys = sorted( |
| self._cache_timestamps.keys(), |
| key=lambda k: self._cache_timestamps[k] |
| ) |
| |
| |
| keys_to_remove = sorted_keys[:len(sorted_keys) // 5] |
| for key in keys_to_remove: |
| del self._rerank_cache[key] |
| del self._cache_timestamps[key] |
| |
| logger.info(f"[RERANK] Cleaned cache: removed {len(keys_to_remove)} old entries") |
|
|
| def _get_cached_result(self, cache_key: str, top_k: int) -> List[Dict]: |
| """Lấy kết quả từ cache nếu có và còn hợp lệ.""" |
| if cache_key in self._rerank_cache: |
| current_time = time.time() |
| if current_time - self._cache_timestamps.get(cache_key, 0) <= self._cache_ttl: |
| cached_result = self._rerank_cache[cache_key][:top_k] |
| logger.info(f"[RERANK] Cache hit for query, returning {len(cached_result)} cached results") |
| return cached_result |
| else: |
| |
| del self._rerank_cache[cache_key] |
| del self._cache_timestamps[cache_key] |
| |
| return [] |
|
|
| def _set_cached_result(self, cache_key: str, scored_docs: List[Dict]): |
| """Lưu kết quả vào cache.""" |
| self._rerank_cache[cache_key] = scored_docs |
| self._cache_timestamps[cache_key] = time.time() |
| |
| |
| if len(self._rerank_cache) > self._max_cache_size: |
| self._clean_cache() |
|
|
| async def _batch_score_docs(self, query: str, docs: List[Dict]) -> List[Dict]: |
| """ |
| Score nhiều documents cùng lúc bằng một prompt duy nhất. |
| Không cắt bớt nội dung luật. |
| """ |
| if not docs: |
| return [] |
| |
| |
| docs_content = [] |
| for i, doc in enumerate(docs): |
| |
| |
| |
| content = (doc.get('fullcontent') or '').strip() |
| docs_content.append(f"{i+1}. {content}") |
| |
| batch_prompt = ( |
| f"Đánh giá mức độ liên quan giữa câu hỏi và các đoạn luật sau:\n\n" |
| f"Câu hỏi: {query}\n\n" |
| f"Các đoạn luật:\n" + "\n".join(docs_content) + "\n\n" |
| f"Trả về điểm số từ 0-10 cho từng đoạn, phân cách bằng dấu phẩy.\n" |
| f"Ví dụ: 8,5,7,3,9" |
| ) |
| |
| try: |
| if self.provider == 'gemini': |
| loop = asyncio.get_event_loop() |
| logger.info(f"[RERANK] Sending batch prompt to Gemini for {len(docs)} docs") |
| response = await loop.run_in_executor(None, self.client.generate_text, batch_prompt) |
| logger.info(f"[RERANK] Got batch scores from Gemini: {response}") |
| |
| |
| scores_text = str(response).strip() |
| scores = [] |
| |
| |
| if ',' in scores_text: |
| score_parts = scores_text.split(',') |
| elif ' ' in scores_text: |
| score_parts = scores_text.split() |
| else: |
| score_parts = scores_text.replace('.', ',').split(',') |
| |
| for score_str in score_parts: |
| try: |
| clean_score = ''.join(c for c in score_str.strip() if c.isdigit() or c == '.') |
| if clean_score: |
| score = float(clean_score) |
| score = max(0, min(10, score)) |
| scores.append(score) |
| else: |
| scores.append(0) |
| except (ValueError, TypeError): |
| scores.append(0) |
| |
| while len(scores) < len(docs): |
| scores.append(0) |
| |
| for i, doc in enumerate(docs): |
| doc['rerank_score'] = scores[i] |
| |
| logger.info(f"[RERANK] Successfully scored {len(docs)} docs with scores: {scores}") |
| return docs |
| |
| else: |
| raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in batch method.") |
| |
| except Exception as e: |
| logger.error(f"[RERANK] Lỗi khi batch score: {e}") |
| for doc in docs: |
| doc['rerank_score'] = 0 |
| return docs |
|
|
| async def _score_doc(self, query: str, doc: Dict) -> Dict: |
| """ |
| Score một document với query. |
| Không cắt bớt nội dung luật. |
| """ |
| tieude = (doc.get('tieude') or '').strip() |
| noidung = (doc.get('noidung') or '').strip() |
| content = f"{tieude} {noidung}".strip() |
| prompt = ( |
| f"Đánh giá mức độ liên quan:\n" |
| f"Luật: {content}\n" |
| f"Hỏi: {query}\n" |
| f"Điểm (0-10):" |
| ) |
| try: |
| if self.provider == 'gemini': |
| loop = asyncio.get_event_loop() |
| logger.info(f"[RERANK] Sending individual prompt to Gemini") |
| score_response = await loop.run_in_executor(None, self.client.generate_text, prompt) |
| logger.info(f"[RERANK] Got individual score from Gemini: {score_response}") |
| score_text = str(score_response).strip() |
| try: |
| clean_score = ''.join(c for c in score_text if c.isdigit() or c == '.') |
| if clean_score: |
| score = float(clean_score) |
| score = max(0, min(10, score)) |
| else: |
| score = 0 |
| except (ValueError, TypeError): |
| score = 0 |
| doc['rerank_score'] = score |
| return doc |
| else: |
| raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in rerank method.") |
| except Exception as e: |
| logger.error(f"[RERANK] Lỗi khi tính score: {e} | doc: {doc}") |
| doc['rerank_score'] = 0 |
| return doc |
|
|
| async def rerank(self, query: str, docs: List[Dict], top_k: int = 5) -> List[Dict]: |
| """ |
| Rerank docs theo độ liên quan với query, trả về top_k docs. |
| Sử dụng batch processing và caching để tối ưu hiệu suất. |
| """ |
| logger.info(f"[RERANK] Start rerank for query: {query} | docs: {len(docs)} | top_k: {top_k}") |
| |
| if not docs: |
| return [] |
| |
| |
| cache_key = self._get_cache_key(query, docs) |
| cached_result = self._get_cached_result(cache_key, top_k) |
| |
| if cached_result: |
| return cached_result |
| |
| |
| max_docs_to_rerank = self.max_docs_to_rerank |
| docs_to_rerank = docs[:max_docs_to_rerank] |
| logger.info(f"[RERANK] Will rerank {len(docs_to_rerank)} docs (limited to top {max_docs_to_rerank})") |
| |
| |
| try: |
| scored = await self._batch_score_docs(query, docs_to_rerank) |
| logger.info(f"[RERANK] Batch processing completed, scored {len(scored)} docs") |
| except Exception as e: |
| logger.error(f"[RERANK] Batch processing failed, falling back to individual scoring: {e}") |
| |
| scored = [] |
| for doc in docs_to_rerank: |
| try: |
| scored_doc = await self._score_doc(query, doc) |
| scored.append(scored_doc) |
| except Exception as e: |
| logger.error(f"[RERANK] Error scoring individual doc: {e}") |
| doc['rerank_score'] = 0 |
| scored.append(doc) |
| |
| |
| scored = sorted(scored, key=lambda x: x['rerank_score'], reverse=True) |
| result = scored[:top_k] |
| |
| |
| self._set_cached_result(cache_key, scored) |
| |
| logger.info(f"[RERANK] Top reranked docs: {result[:2]}...{result[-2:]}") |
| return result |