rag_template / src /reranking.py
Guilherme Favaron
Major update: Add hybrid search, reranking, multiple LLMs, and UI improvements
1b447de
"""
Sistema de reranking com cross-encoder
"""
from typing import List, Dict, Any, Optional
from sentence_transformers import CrossEncoder
from .config import RERANKER_MODEL_ID
class Reranker:
"""Reranker usando cross-encoder para melhor precisão"""
def __init__(self, model_id: str = RERANKER_MODEL_ID):
"""
Inicializa reranker
Args:
model_id: ID do modelo cross-encoder
"""
self.model_id = model_id
self.model: Optional[CrossEncoder] = None
def load_model(self) -> CrossEncoder:
"""Carrega cross-encoder (lazy loading)"""
if self.model is None:
self.model = CrossEncoder(self.model_id)
return self.model
def rerank(
self,
query: str,
documents: List[Dict[str, Any]],
top_k: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
Reordena documentos usando cross-encoder
Args:
query: Query do usuário
documents: Lista de documentos com 'content' e 'score'
top_k: Retornar apenas top K (None = todos)
Returns:
Documentos reordenados com 'rerank_score'
"""
if not documents:
return []
model = self.load_model()
# Prepara pares (query, doc)
pairs = [(query, doc['content']) for doc in documents]
# Calcula scores do cross-encoder
scores = model.predict(pairs)
# Adiciona rerank_score e preserva original_score
for doc, score in zip(documents, scores):
doc['rerank_score'] = float(score)
doc['original_score'] = doc.get('score', 0.0)
# Reordena por rerank_score
reranked = sorted(documents, key=lambda x: x['rerank_score'], reverse=True)
if top_k:
reranked = reranked[:top_k]
return reranked
def get_rerank_comparison(
self,
original_docs: List[Dict[str, Any]],
reranked_docs: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Gera dados de comparação antes/depois do reranking
Args:
original_docs: Documentos com ordem original
reranked_docs: Documentos após reranking
Returns:
Lista de dicionários com comparação
"""
comparison = []
# Cria mapa de IDs para posições originais
original_positions = {doc['id']: i+1 for i, doc in enumerate(original_docs)}
for new_rank, doc in enumerate(reranked_docs, 1):
original_rank = original_positions.get(doc['id'], -1)
position_change = original_rank - new_rank if original_rank != -1 else 0
comparison.append({
'new_rank': new_rank,
'original_rank': original_rank,
'original_score': doc.get('original_score', 0.0),
'rerank_score': doc.get('rerank_score', 0.0),
'position_change': position_change,
'content_preview': doc['content'][:100] + "..."
})
return comparison
def is_available(self) -> bool:
"""Verifica se reranker está disponível"""
try:
self.load_model()
return True
except Exception:
return False
def get_model_info(self) -> Dict[str, Any]:
"""Retorna informações do modelo"""
return {
"model_id": self.model_id,
"available": self.is_available(),
"type": "cross-encoder"
}