RAG-Project / project /model /reranking.py
github-actions[bot]
Deploy from GitHub Actions
dfa6a46
from typing import List
from langchain.schema import Document
from flashrank.Ranker import Ranker, RerankRequest
from project.utils.config_loader import load_config
from project.logger.logging import get_logger
logger = get_logger(__name__)
class DocumentReranker:
def __init__(self, config_path: str = None):
self.config = load_config(config_path)
reranker_config = self.config.get('reranker', {})
model_name = reranker_config.get('model_name', 'rank-T5-flan')
cache_dir = reranker_config.get('cache_dir')
self.top_k = reranker_config.get('top_k', 3)
if cache_dir:
self.ranker = Ranker(model_name=model_name, cache_dir=cache_dir)
else:
self.ranker = Ranker(model_name=model_name)
logger.info(f"FlashRank reranker initialized with model: {model_name}")
def rerank(
self,
query: str,
documents: List[Document],
top_k: int = None
) -> List[Document]:
if top_k is None:
top_k = self.top_k
if not documents:
logger.warning("No documents to rerank")
return []
passages = [
{
"id": i,
"text": doc.page_content,
"meta": doc.metadata
}
for i, doc in enumerate(documents)
]
rerank_request = RerankRequest(query=query, passages=passages)
results = self.ranker.rerank(rerank_request)
reranked_docs = []
for result in results[:top_k]:
doc_idx = result["id"]
original_doc = documents[doc_idx]
original_doc.metadata["rerank_score"] = result["score"]
reranked_docs.append(original_doc)
logger.info(f"Reranked {len(documents)} documents, returning top {len(reranked_docs)}")
return reranked_docs