Spaces:
Running
Running
| from typing import List | |
| from sentence_transformers import CrossEncoder | |
| from langchain_core.documents import Document | |
| class Reranker: | |
| """Cross-encoder reranker for retrieved docs.""" | |
| def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2") -> None: | |
| self.model = CrossEncoder(model_name) | |
| def rerank(self, query: str, docs: List[Document], top_k: int = 5) -> List[Document]: | |
| if not docs: | |
| return [] | |
| pairs = [[query, d.page_content] for d in docs] | |
| scores = self.model.predict(pairs) | |
| scored = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True) | |
| return [d for d, _ in scored[:top_k]] | |