Chatopus / app /reranker.py
VietCat's picture
add quota manager
8b81b1d
raw
history blame
2.51 kB
from typing import List, Dict
from .config import get_settings
from .gemini_client import GeminiClient
from loguru import logger
import asyncio
class Reranker:
def __init__(self):
settings = get_settings()
self.provider = getattr(settings, 'rerank_provider', settings.llm_provider)
self.model = getattr(settings, 'rerank_model', settings.llm_model)
if self.provider == 'gemini':
self.client = GeminiClient()
# elif self.provider == 'openai':
# self.client = OpenAIClient(settings.openai_api_key, model=self.model)
# elif self.provider == 'cohere':
# self.client = CohereClient(settings.cohere_api_key, model=self.model)
else:
raise NotImplementedError(f"Rerank provider {self.provider} not supported yet.")
async def rerank(self, query: str, docs: List[Dict], top_k: int = 5) -> List[Dict]:
"""
Rerank docs theo độ liên quan với query, trả về top_k docs.
"""
logger.info(f"[RERANK] Start rerank for query: {query} | docs: {len(docs)} | top_k: {top_k}")
scored = []
for doc in docs:
content = (doc.get('tieude', '') or '') + ' ' + (doc.get('noidung', '') or '')
prompt = (
f"Đoạn luật: {content}\n"
f"Câu hỏi: {query}\n"
"Hãy đánh giá mức độ liên quan giữa đoạn luật và câu hỏi trên thang điểm 0-10. "
"Chỉ trả về một số duy nhất."
)
try:
if self.provider == 'gemini':
loop = asyncio.get_event_loop()
logger.info(f"[RERANK] Sending prompt to Gemini: {prompt}")
score = await loop.run_in_executor(None, self.client.generate_text, prompt)
logger.info(f"[RERANK] Got score from Gemini: {score}")
else:
raise NotImplementedError(f"Rerank provider {self.provider} not supported yet in rerank method.")
score = float(str(score).strip().split()[0])
except Exception as e:
logger.error(f"[RERANK] Lỗi khi tính score: {e} | doc: {doc}")
score = 0
doc['rerank_score'] = score
scored.append(doc)
scored = sorted(scored, key=lambda x: x['rerank_score'], reverse=True)
logger.info(f"[RERANK] Top reranked docs: {scored[:top_k]}")
return scored[:top_k]