| from typing import List, Dict |
| from .config import get_settings |
| from .gemini_client import GeminiClient |
| from loguru import logger |
| import asyncio |
|
|
| class Reranker: |
| def __init__(self): |
| settings = get_settings() |
| self.provider = getattr(settings, 'rerank_provider', settings.llm_provider) |
| self.model = getattr(settings, 'rerank_model', settings.llm_model) |
| if self.provider == 'gemini': |
| self.client = GeminiClient() |
| |
| |
| |
| |
| else: |
| raise NotImplementedError(f"Rerank provider {self.provider} not supported yet.") |
|
|
| async def rerank(self, query: str, docs: List[Dict], top_k: int = 5) -> List[Dict]: |
| """ |
| Rerank docs theo độ liên quan với query, trả về top_k docs. |
| """ |
| logger.info(f"[RERANK] Start rerank for query: {query} | docs: {len(docs)} | top_k: {top_k}") |
| scored = [] |
| for doc in docs: |
| content = (doc.get('tieude', '') or '') + ' ' + (doc.get('noidung', '') or '') |
| prompt = ( |
| f"Đoạn luật: {content}\n" |
| f"Câu hỏi: {query}\n" |
| "Hãy đánh giá mức độ liên quan giữa đoạn luật và câu hỏi trên thang điểm 0-10. " |
| "Chỉ trả về một số duy nhất." |
| ) |
| try: |
| if self.provider == 'gemini': |
| loop = asyncio.get_event_loop() |
| logger.info(f"[RERANK] Sending prompt to Gemini: {prompt}") |
| score = await loop.run_in_executor(None, self.client.generate_text, prompt) |
| logger.info(f"[RERANK] Got score from Gemini: {score}") |
| else: |
| raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in rerank method.") |
| score = float(str(score).strip().split()[0]) |
| except Exception as e: |
| logger.error(f"[RERANK] Lỗi khi tính score: {e} | doc: {doc}") |
| score = 0 |
| doc['rerank_score'] = score |
| scored.append(doc) |
| scored = sorted(scored, key=lambda x: x['rerank_score'], reverse=True) |
| logger.info(f"[RERANK] Top reranked docs: {scored[:top_k]}") |
| return scored[:top_k] |