Spaces:
Sleeping
Sleeping
File size: 1,692 Bytes
08d20f8 6339746 d56b9a1 08d20f8 76e4e13 08d20f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
# app/services/reranker.py
from sentence_transformers import CrossEncoder
from app.utils.logger import setup_logger
logger = setup_logger(__name__)
class RerankerService:
"""
Cross-Encoder based re-ranker for improving top-k retrieval precision.
"""
def __init__(self, model_name: str = "BAAI/bge-reranker-base"):
logger.info(f"[RERANKER] Loading reranker model: {model_name}")
self.model = CrossEncoder(model_name)
def rerank(self, query: str, results: list, top_k: int = 5) -> list:
"""
Re-rank retrieved documents using CrossEncoder scores.
Args:
query: User query text
results: List of FAISS results [{"payload": {...}, "score": float}]
top_k: Return top_k reranked items
Returns:
List of reranked documents with updated scores
"""
if not results:
return []
pairs = [(query, r["payload"].get("searchable_text", "")) for r in results]
logger.info(f"[RERANKER] Scoring - {len(pairs)} query-document pairs...")
scores = self.model.predict(pairs)
# Attach rerank score to each document
for i, s in enumerate(scores):
results[i]["rerank_score"] = float(s)
# Sort by rerank_score (descending)
reranked = sorted(results, key=lambda x: x["rerank_score"], reverse=True)
logger.info(
f"[RERANKER] Top reranked scores: "
f"{[round(r['rerank_score'], 3) for r in reranked[:min(top_k, len(reranked))]]}"
)
return reranked[:top_k]
# Global instance
reranker = RerankerService()
|