| """ |
| Rerank with cross encoder. |
| Ref: |
| https://medium.aiplanet.com/advanced-rag-cohere-re-ranker-99acc941601c |
| https://github.com/langchain-ai/langchain/issues/13076 |
| """ |
|
|
| from __future__ import annotations |
| from typing import Optional, Sequence |
| from langchain.schema import Document |
| from langchain.pydantic_v1 import Extra |
|
|
| from langchain.callbacks.manager import Callbacks |
| from langchain.retrievers.document_compressors.base import BaseDocumentCompressor |
|
|
| from sentence_transformers import CrossEncoder |
|
|
|
|
| class BgeRerank(BaseDocumentCompressor): |
| """ |
| Re-rank with CrossEncoder. |
| |
| Ref: |
| https://medium.aiplanet.com/advanced-rag-cohere-re-ranker-99acc941601c |
| https://github.com/langchain-ai/langchain/issues/13076 |
| good to read: |
| https://zhuanlan.zhihu.com/p/676008717 or its source https://teemukanstren.com/2023/12/25/llmrag-based-question-answering/ |
| """ |
|
|
| |
| |
| model_name: str = "jinaai/jina-reranker-v1-turbo-en" |
| """Model name to use for reranking.""" |
| top_n: int = 6 |
| """Number of documents to return.""" |
| model: CrossEncoder = CrossEncoder(model_name, trust_remote_code=True) |
| """CrossEncoder instance to use for reranking.""" |
|
|
| def bge_rerank(self, query, docs): |
| model_inputs = [[query, doc] for doc in docs] |
| scores = self.model.predict(model_inputs) |
| results = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) |
| return results[: self.top_n] |
|
|
| class Config: |
| """Configuration for this pydantic object.""" |
|
|
| extra = Extra.forbid |
| arbitrary_types_allowed = True |
|
|
| def compress_documents( |
| self, |
| documents: Sequence[Document], |
| query: str, |
| callbacks: Optional[Callbacks] = None, |
| ) -> Sequence[Document]: |
| """ |
| Compress documents using BAAI/bge-reranker models. |
| |
| Args: |
| documents: A sequence of documents to compress. |
| query: The query to use for compressing the documents. |
| callbacks: Callbacks to run during the compression process. |
| |
| Returns: |
| A sequence of compressed documents. |
| """ |
| if len(documents) == 0: |
| return [] |
| doc_list = list(documents) |
| _docs = [d.page_content for d in doc_list] |
| results = self.bge_rerank(query, _docs) |
| final_results = [] |
| for r in results: |
| doc = doc_list[r[0]] |
| doc.metadata["relevance_score"] = r[1] |
| final_results.append(doc) |
|
|
| return final_results |
|
|