File size: 3,311 Bytes
5b89d45
 
 
 
 
 
 
a3bdcf1
5b89d45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""Wrapper retriever that adds reranking and multi-query support."""

import logging
from typing import List, Optional, Any
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from code_chatbot.retrieval.reranker import Reranker

# Try to import MultiQueryRetriever - may not be available in all versions
try:
    from langchain.retrievers.multi_query import MultiQueryRetriever
except ImportError:
    try:
        from langchain_community.retrievers import MultiQueryRetriever
    except ImportError:
        MultiQueryRetriever = None  # type: ignore

logger = logging.getLogger(__name__)


class RerankingRetriever(BaseRetriever):
    """Wraps a base retriever and applies reranking to results."""
    
    base_retriever: BaseRetriever
    reranker: Any
    top_k: int = 5

    class Config:
        arbitrary_types_allowed = True

    def __init__(self, base_retriever: BaseRetriever, reranker: Reranker, top_k: int = 5):
        super().__init__(base_retriever=base_retriever, reranker=reranker, top_k=top_k)
    
    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        """Retrieve documents and rerank them."""
        # Get documents from base retriever
        docs = self.base_retriever.invoke(query)
        logger.info(f"Base retriever returned {len(docs)} documents")
        
        if not docs:
            return []
        
        # Rerank
        reranked_docs = self.reranker.rerank(query, docs, top_k=self.top_k)
        logger.info(f"Reranked to {len(reranked_docs)} top documents")
        
        return reranked_docs


def build_enhanced_retriever(
    base_retriever: BaseRetriever,
    llm=None,
    use_multi_query: bool = False,
    use_reranking: bool = True,
    rerank_top_k: int = 5,
) -> BaseRetriever:
    """
    Builds an enhanced retriever with optional multi-query expansion and reranking.
    
    Args:
        base_retriever: The base retriever (e.g., from vector store)
        llm: LLM for multi-query expansion (required if use_multi_query=True)
        use_multi_query: Whether to use multi-query retriever for query expansion
        use_reranking: Whether to apply reranking
        rerank_top_k: Number of top documents to return after reranking
    """
    retriever = base_retriever
    
    # Apply multi-query expansion if requested
    if use_multi_query:
        if MultiQueryRetriever is None:
            logger.warning("MultiQueryRetriever not available, skipping multi-query expansion")
        elif not llm:
            logger.warning("Multi-query retriever requires an LLM, skipping multi-query expansion")
        else:
            retriever = MultiQueryRetriever.from_llm(
                retriever=retriever,
                llm=llm
            )
            logger.info("Applied multi-query retriever for query expansion")
    
    # Apply reranking if requested
    if use_reranking:
        reranker = Reranker()
        retriever = RerankingRetriever(
            base_retriever=retriever,
            reranker=reranker,
            top_k=rerank_top_k
        )
        logger.info("Applied reranking to retriever")
    
    return retriever