Spaces:
Running
Running
File size: 1,957 Bytes
6c58cf4 dfa6a46 6c58cf4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | from typing import List
from langchain.schema import Document
from flashrank.Ranker import Ranker, RerankRequest
from project.utils.config_loader import load_config
from project.logger.logging import get_logger
logger = get_logger(__name__)
class DocumentReranker:
def __init__(self, config_path: str = None):
self.config = load_config(config_path)
reranker_config = self.config.get('reranker', {})
model_name = reranker_config.get('model_name', 'rank-T5-flan')
cache_dir = reranker_config.get('cache_dir')
self.top_k = reranker_config.get('top_k', 3)
if cache_dir:
self.ranker = Ranker(model_name=model_name, cache_dir=cache_dir)
else:
self.ranker = Ranker(model_name=model_name)
logger.info(f"FlashRank reranker initialized with model: {model_name}")
def rerank(
self,
query: str,
documents: List[Document],
top_k: int = None
) -> List[Document]:
if top_k is None:
top_k = self.top_k
if not documents:
logger.warning("No documents to rerank")
return []
passages = [
{
"id": i,
"text": doc.page_content,
"meta": doc.metadata
}
for i, doc in enumerate(documents)
]
rerank_request = RerankRequest(query=query, passages=passages)
results = self.ranker.rerank(rerank_request)
reranked_docs = []
for result in results[:top_k]:
doc_idx = result["id"]
original_doc = documents[doc_idx]
original_doc.metadata["rerank_score"] = result["score"]
reranked_docs.append(original_doc)
logger.info(f"Reranked {len(documents)} documents, returning top {len(reranked_docs)}")
return reranked_docs
|