""" Reranking module for improving retrieval relevance. Rerankers use more sophisticated models to re-score and re-order initial retrieval results, significantly improving relevance at the cost of additional computation. """ import os from typing import List, Dict, Any, Optional from langchain.schema import Document # Optional imports (graceful degradation if not available) try: import cohere COHERE_AVAILABLE = True except ImportError: COHERE_AVAILABLE = False print("⚠️ Cohere not installed. Install with: pip install cohere") try: from sentence_transformers import CrossEncoder CROSS_ENCODER_AVAILABLE = True except ImportError: CROSS_ENCODER_AVAILABLE = False print("⚠️ sentence-transformers not installed. Install with: pip install sentence-transformers") class Reranker: """ Document reranker using Cohere API or local cross-encoder models. Reranking is a two-stage retrieval process: 1. Initial retrieval (BM25 + semantic) gets ~20-50 candidates 2. Reranker scores each candidate against the query for final ranking """ def __init__( self, use_local: bool = False, cohere_api_key: Optional[str] = None, cohere_model: str = "rerank-english-v3.0", local_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" ): """ Initialize reranker. Args: use_local: Use local cross-encoder instead of Cohere API cohere_api_key: Cohere API key (if using Cohere) cohere_model: Cohere rerank model name local_model: Local cross-encoder model name """ self.use_local = use_local self.cohere_client = None self.cross_encoder = None self.cohere_model = cohere_model if use_local: self._init_local_reranker(local_model) else: self._init_cohere_reranker(cohere_api_key) def _init_cohere_reranker(self, api_key: Optional[str]) -> None: """Initialize Cohere reranker.""" if not COHERE_AVAILABLE: print("❌ Cohere not available. Falling back to local reranker.") self.use_local = True self._init_local_reranker() return api_key = api_key or os.getenv("COHERE_API_KEY") if not api_key: print("❌ COHERE_API_KEY not set. Falling back to local reranker.") self.use_local = True self._init_local_reranker() return try: self.cohere_client = cohere.Client(api_key) print(f"✅ Cohere reranker initialized (model: {self.cohere_model})") except Exception as e: print(f"❌ Failed to initialize Cohere: {e}") print("Falling back to local reranker.") self.use_local = True self._init_local_reranker() def _init_local_reranker(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2") -> None: """Initialize local cross-encoder reranker.""" if not CROSS_ENCODER_AVAILABLE: print("❌ sentence-transformers not available. Reranking disabled.") return try: self.cross_encoder = CrossEncoder(model_name) print(f"✅ Local cross-encoder initialized (model: {model_name})") except Exception as e: print(f"❌ Failed to initialize cross-encoder: {e}") def rerank( self, query: str, documents: List[Document], top_k: int = 5 ) -> List[Dict[str, Any]]: """ Rerank documents based on relevance to query. Args: query: Search query documents: List of Document objects to rerank top_k: Number of top results to return Returns: Reranked list of documents with scores """ if not documents: return [] # Use appropriate reranker if self.use_local and self.cross_encoder: return self._rerank_local(query, documents, top_k) elif self.cohere_client: return self._rerank_cohere(query, documents, top_k) else: # No reranker available, return original order print("⚠️ No reranker available. Returning original order.") return [ { 'document': doc, 'score': doc.metadata.get('relevance_score', 0.5), 'rank': i + 1 } for i, doc in enumerate(documents[:top_k]) ] def _rerank_cohere( self, query: str, documents: List[Document], top_k: int ) -> List[Dict[str, Any]]: """Rerank using Cohere API.""" try: # Prepare documents for Cohere doc_texts = [doc.page_content for doc in documents] # Call Cohere rerank API results = self.cohere_client.rerank( model=self.cohere_model, query=query, documents=doc_texts, top_n=top_k ) # Build reranked results reranked = [] for result in results.results: reranked.append({ 'document': documents[result.index], 'score': result.relevance_score, 'rank': len(reranked) + 1 }) print(f"✅ Cohere reranked {len(documents)} → {len(reranked)} documents") return reranked except Exception as e: print(f"❌ Cohere reranking failed: {e}") # Fallback to original order return [ { 'document': doc, 'score': doc.metadata.get('relevance_score', 0.5), 'rank': i + 1 } for i, doc in enumerate(documents[:top_k]) ] def _rerank_local( self, query: str, documents: List[Document], top_k: int ) -> List[Dict[str, Any]]: """Rerank using local cross-encoder.""" try: # Prepare query-document pairs pairs = [[query, doc.page_content] for doc in documents] # Get relevance scores scores = self.cross_encoder.predict(pairs) # Sort by score scored_docs = list(zip(documents, scores)) scored_docs.sort(key=lambda x: x[1], reverse=True) # Build results reranked = [] for doc, score in scored_docs[:top_k]: reranked.append({ 'document': doc, 'score': float(score), 'rank': len(reranked) + 1 }) print(f"✅ Local reranked {len(documents)} → {len(reranked)} documents") return reranked except Exception as e: print(f"❌ Local reranking failed: {e}") # Fallback to original order return [ { 'document': doc, 'score': doc.metadata.get('relevance_score', 0.5), 'rank': i + 1 } for i, doc in enumerate(documents[:top_k]) ]