File size: 7,513 Bytes
7644eac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""
Reranking module for improving retrieval relevance.

Rerankers use more sophisticated models to re-score and re-order initial retrieval
results, significantly improving relevance at the cost of additional computation.
"""
import os
from typing import List, Dict, Any, Optional
from langchain.schema import Document

# Optional imports (graceful degradation if not available)
try:
    import cohere
    COHERE_AVAILABLE = True
except ImportError:
    COHERE_AVAILABLE = False
    print("⚠️  Cohere not installed. Install with: pip install cohere")

try:
    from sentence_transformers import CrossEncoder
    CROSS_ENCODER_AVAILABLE = True
except ImportError:
    CROSS_ENCODER_AVAILABLE = False
    print("⚠️  sentence-transformers not installed. Install with: pip install sentence-transformers")


class Reranker:
    """
    Document reranker using Cohere API or local cross-encoder models.
    
    Reranking is a two-stage retrieval process:
    1. Initial retrieval (BM25 + semantic) gets ~20-50 candidates
    2. Reranker scores each candidate against the query for final ranking
    """
    
    def __init__(
        self,
        use_local: bool = False,
        cohere_api_key: Optional[str] = None,
        cohere_model: str = "rerank-english-v3.0",
        local_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
    ):
        """
        Initialize reranker.
        
        Args:
            use_local: Use local cross-encoder instead of Cohere API
            cohere_api_key: Cohere API key (if using Cohere)
            cohere_model: Cohere rerank model name
            local_model: Local cross-encoder model name
        """
        self.use_local = use_local
        self.cohere_client = None
        self.cross_encoder = None
        self.cohere_model = cohere_model
        
        if use_local:
            self._init_local_reranker(local_model)
        else:
            self._init_cohere_reranker(cohere_api_key)
    
    def _init_cohere_reranker(self, api_key: Optional[str]) -> None:
        """Initialize Cohere reranker."""
        if not COHERE_AVAILABLE:
            print("❌ Cohere not available. Falling back to local reranker.")
            self.use_local = True
            self._init_local_reranker()
            return
        
        api_key = api_key or os.getenv("COHERE_API_KEY")
        if not api_key:
            print("❌ COHERE_API_KEY not set. Falling back to local reranker.")
            self.use_local = True
            self._init_local_reranker()
            return
        
        try:
            self.cohere_client = cohere.Client(api_key)
            print(f"βœ… Cohere reranker initialized (model: {self.cohere_model})")
        except Exception as e:
            print(f"❌ Failed to initialize Cohere: {e}")
            print("Falling back to local reranker.")
            self.use_local = True
            self._init_local_reranker()
    
    def _init_local_reranker(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2") -> None:
        """Initialize local cross-encoder reranker."""
        if not CROSS_ENCODER_AVAILABLE:
            print("❌ sentence-transformers not available. Reranking disabled.")
            return
        
        try:
            self.cross_encoder = CrossEncoder(model_name)
            print(f"βœ… Local cross-encoder initialized (model: {model_name})")
        except Exception as e:
            print(f"❌ Failed to initialize cross-encoder: {e}")
    
    def rerank(
        self,
        query: str,
        documents: List[Document],
        top_k: int = 5
    ) -> List[Dict[str, Any]]:
        """
        Rerank documents based on relevance to query.
        
        Args:
            query: Search query
            documents: List of Document objects to rerank
            top_k: Number of top results to return
            
        Returns:
            Reranked list of documents with scores
        """
        if not documents:
            return []
        
        # Use appropriate reranker
        if self.use_local and self.cross_encoder:
            return self._rerank_local(query, documents, top_k)
        elif self.cohere_client:
            return self._rerank_cohere(query, documents, top_k)
        else:
            # No reranker available, return original order
            print("⚠️  No reranker available. Returning original order.")
            return [
                {
                    'document': doc,
                    'score': doc.metadata.get('relevance_score', 0.5),
                    'rank': i + 1
                }
                for i, doc in enumerate(documents[:top_k])
            ]
    
    def _rerank_cohere(
        self,
        query: str,
        documents: List[Document],
        top_k: int
    ) -> List[Dict[str, Any]]:
        """Rerank using Cohere API."""
        try:
            # Prepare documents for Cohere
            doc_texts = [doc.page_content for doc in documents]
            
            # Call Cohere rerank API
            results = self.cohere_client.rerank(
                model=self.cohere_model,
                query=query,
                documents=doc_texts,
                top_n=top_k
            )
            
            # Build reranked results
            reranked = []
            for result in results.results:
                reranked.append({
                    'document': documents[result.index],
                    'score': result.relevance_score,
                    'rank': len(reranked) + 1
                })
            
            print(f"βœ… Cohere reranked {len(documents)} β†’ {len(reranked)} documents")
            return reranked
            
        except Exception as e:
            print(f"❌ Cohere reranking failed: {e}")
            # Fallback to original order
            return [
                {
                    'document': doc,
                    'score': doc.metadata.get('relevance_score', 0.5),
                    'rank': i + 1
                }
                for i, doc in enumerate(documents[:top_k])
            ]
    
    def _rerank_local(
        self,
        query: str,
        documents: List[Document],
        top_k: int
    ) -> List[Dict[str, Any]]:
        """Rerank using local cross-encoder."""
        try:
            # Prepare query-document pairs
            pairs = [[query, doc.page_content] for doc in documents]
            
            # Get relevance scores
            scores = self.cross_encoder.predict(pairs)
            
            # Sort by score
            scored_docs = list(zip(documents, scores))
            scored_docs.sort(key=lambda x: x[1], reverse=True)
            
            # Build results
            reranked = []
            for doc, score in scored_docs[:top_k]:
                reranked.append({
                    'document': doc,
                    'score': float(score),
                    'rank': len(reranked) + 1
                })
            
            print(f"βœ… Local reranked {len(documents)} β†’ {len(reranked)} documents")
            return reranked
            
        except Exception as e:
            print(f"❌ Local reranking failed: {e}")
            # Fallback to original order
            return [
                {
                    'document': doc,
                    'score': doc.metadata.get('relevance_score', 0.5),
                    'rank': i + 1
                }
                for i, doc in enumerate(documents[:top_k])
            ]