Asish Karthikeya Gogineni
Refactor: Code Structure Update & UI Redesign
a3bdcf1
import logging
from typing import List
from langchain_core.documents import Document
from sentence_transformers import CrossEncoder
logger = logging.getLogger(__name__)
class Reranker:
"""
Uses a Cross-Encoder to re-rank documents retrieved by the vector store.
This significantly improves precision by scoring the query against each document directly.
"""
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
logger.info(f"Loading Reranker model: {model_name}")
self.model = CrossEncoder(model_name)
def rerank(self, query: str, documents: List[Document], top_k: int = 5) -> List[Document]:
if not documents:
return []
# Prepare pairs for scoring: [[query, doc_text], ...]
pairs = [[query, doc.page_content] for doc in documents]
# Predict scores
scores = self.model.predict(pairs)
# Attach scores to docs and sort
scored_docs = []
for i, doc in enumerate(documents):
# We can store the score in metadata if needed
doc.metadata["rerank_score"] = float(scores[i])
scored_docs.append((doc, scores[i]))
# Sort by score descending
scored_docs.sort(key=lambda x: x[1], reverse=True)
# Return top_k
top_docs = [doc for doc, score in scored_docs[:top_k]]
return top_docs