Spaces:
Running
Running
File size: 3,311 Bytes
5b89d45 a3bdcf1 5b89d45 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | """Wrapper retriever that adds reranking and multi-query support."""
import logging
from typing import List, Optional, Any
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from code_chatbot.retrieval.reranker import Reranker
# Try to import MultiQueryRetriever - may not be available in all versions
try:
from langchain.retrievers.multi_query import MultiQueryRetriever
except ImportError:
try:
from langchain_community.retrievers import MultiQueryRetriever
except ImportError:
MultiQueryRetriever = None # type: ignore
logger = logging.getLogger(__name__)
class RerankingRetriever(BaseRetriever):
"""Wraps a base retriever and applies reranking to results."""
base_retriever: BaseRetriever
reranker: Any
top_k: int = 5
class Config:
arbitrary_types_allowed = True
def __init__(self, base_retriever: BaseRetriever, reranker: Reranker, top_k: int = 5):
super().__init__(base_retriever=base_retriever, reranker=reranker, top_k=top_k)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Retrieve documents and rerank them."""
# Get documents from base retriever
docs = self.base_retriever.invoke(query)
logger.info(f"Base retriever returned {len(docs)} documents")
if not docs:
return []
# Rerank
reranked_docs = self.reranker.rerank(query, docs, top_k=self.top_k)
logger.info(f"Reranked to {len(reranked_docs)} top documents")
return reranked_docs
def build_enhanced_retriever(
base_retriever: BaseRetriever,
llm=None,
use_multi_query: bool = False,
use_reranking: bool = True,
rerank_top_k: int = 5,
) -> BaseRetriever:
"""
Builds an enhanced retriever with optional multi-query expansion and reranking.
Args:
base_retriever: The base retriever (e.g., from vector store)
llm: LLM for multi-query expansion (required if use_multi_query=True)
use_multi_query: Whether to use multi-query retriever for query expansion
use_reranking: Whether to apply reranking
rerank_top_k: Number of top documents to return after reranking
"""
retriever = base_retriever
# Apply multi-query expansion if requested
if use_multi_query:
if MultiQueryRetriever is None:
logger.warning("MultiQueryRetriever not available, skipping multi-query expansion")
elif not llm:
logger.warning("Multi-query retriever requires an LLM, skipping multi-query expansion")
else:
retriever = MultiQueryRetriever.from_llm(
retriever=retriever,
llm=llm
)
logger.info("Applied multi-query retriever for query expansion")
# Apply reranking if requested
if use_reranking:
reranker = Reranker()
retriever = RerankingRetriever(
base_retriever=retriever,
reranker=reranker,
top_k=rerank_top_k
)
logger.info("Applied reranking to retriever")
return retriever
|