Spaces:
Sleeping
Sleeping
| """Cross-encoder reranker for document retrieval.""" | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import logging | |
| from langchain.schema import Document | |
| logger = logging.getLogger(__name__) | |
| # Lazy import to avoid loading model at import time | |
| _cross_encoder = None | |
| _cross_encoder_model_name = None | |
| def _get_cross_encoder(model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): | |
| """Lazy load the cross-encoder model. | |
| Args: | |
| model_name: HuggingFace model identifier | |
| Returns: | |
| CrossEncoder instance | |
| """ | |
| global _cross_encoder, _cross_encoder_model_name | |
| if _cross_encoder is None or _cross_encoder_model_name != model_name: | |
| try: | |
| from sentence_transformers import CrossEncoder | |
| logger.info(f"Loading cross-encoder model: {model_name}") | |
| _cross_encoder = CrossEncoder(model_name, max_length=512) | |
| _cross_encoder_model_name = model_name | |
| except ImportError: | |
| logger.warning( | |
| "sentence-transformers not installed. " | |
| "Run: pip install sentence-transformers" | |
| ) | |
| return None | |
| except Exception as e: | |
| logger.warning(f"Failed to load cross-encoder: {e}") | |
| return None | |
| return _cross_encoder | |
| class FastCrossEncoderReranker: | |
| """Cross-encoder reranker using sentence-transformers. | |
| Runs locally and is faster than LLM-based reranking. | |
| """ | |
| MODEL_OPTIONS = { | |
| "fast": "cross-encoder/ms-marco-MiniLM-L-6-v2", | |
| "balanced": "cross-encoder/ms-marco-MiniLM-L-12-v2", | |
| "tiny": "cross-encoder/ms-marco-TinyBERT-L-2-v2", | |
| } | |
| def __init__( | |
| self, | |
| model_name: str = "fast", | |
| max_length: int = 512, | |
| batch_size: int = 16, | |
| ) -> None: | |
| """Initialize cross-encoder reranker. | |
| Args: | |
| model_name: One of "fast", "balanced", "tiny", or a HuggingFace model ID | |
| max_length: Maximum sequence length for encoding | |
| batch_size: Batch size for scoring (higher = faster but more memory) | |
| """ | |
| # Resolve model name alias | |
| self.model_name = self.MODEL_OPTIONS.get(model_name, model_name) | |
| self.max_length = max_length | |
| self.batch_size = batch_size | |
| self._model = None | |
| def _ensure_model(self) -> bool: | |
| """Ensure model is loaded. | |
| Returns: | |
| True if model is available, False otherwise | |
| """ | |
| if self._model is None: | |
| self._model = _get_cross_encoder(self.model_name) | |
| return self._model is not None | |
| def rerank( | |
| self, | |
| query: str, | |
| documents: List[Document], | |
| top_k: int = 6, | |
| ) -> List[Document]: | |
| """Rerank documents by relevance to query. | |
| Args: | |
| query: User query | |
| documents: Documents to rerank | |
| top_k: Number of top documents to return | |
| Returns: | |
| Reranked documents (most relevant first) | |
| """ | |
| if not documents: | |
| return [] | |
| if len(documents) <= 1: | |
| return documents | |
| if not self._ensure_model(): | |
| logger.warning("Cross-encoder not available, returning original order") | |
| return documents[:top_k] | |
| try: | |
| # Prepare query-document pairs | |
| pairs = [ | |
| (query, self._get_text(doc)[:self.max_length]) | |
| for doc in documents | |
| ] | |
| # Score all pairs (batched for efficiency) | |
| scores = self._model.predict( | |
| pairs, | |
| batch_size=self.batch_size, | |
| show_progress_bar=False, | |
| ) | |
| # Sort by score descending | |
| scored_docs = sorted( | |
| zip(documents, scores), | |
| key=lambda x: x[1], | |
| reverse=True, | |
| ) | |
| return [doc for doc, _ in scored_docs[:top_k]] | |
| except Exception as e: | |
| logger.warning(f"Reranking failed: {e}, returning original order") | |
| return documents[:top_k] | |
| def rerank_with_scores( | |
| self, | |
| query: str, | |
| documents: List[Document], | |
| top_k: int = 6, | |
| ) -> List[Tuple[Document, float]]: | |
| """Rerank documents and return with scores. | |
| Args: | |
| query: User query | |
| documents: Documents to rerank | |
| top_k: Number of top documents to return | |
| Returns: | |
| List of (document, score) tuples, sorted by score descending | |
| """ | |
| if not documents: | |
| return [] | |
| if len(documents) <= 1: | |
| return [(doc, 1.0) for doc in documents] | |
| if not self._ensure_model(): | |
| return [(doc, 1.0 - i * 0.1) for i, doc in enumerate(documents[:top_k])] | |
| try: | |
| pairs = [ | |
| (query, self._get_text(doc)[:self.max_length]) | |
| for doc in documents | |
| ] | |
| scores = self._model.predict( | |
| pairs, | |
| batch_size=self.batch_size, | |
| show_progress_bar=False, | |
| ) | |
| scored_docs = sorted( | |
| zip(documents, scores), | |
| key=lambda x: x[1], | |
| reverse=True, | |
| ) | |
| return scored_docs[:top_k] | |
| except Exception as e: | |
| logger.warning(f"Reranking failed: {e}") | |
| return [(doc, 1.0 - i * 0.1) for i, doc in enumerate(documents[:top_k])] | |
| def _get_text(self, doc: Document) -> str: | |
| """Extract text content from document. | |
| Args: | |
| doc: LangChain Document | |
| Returns: | |
| Text content | |
| """ | |
| if hasattr(doc, 'page_content'): | |
| return doc.page_content | |
| return str(doc) | |
| class NoOpReranker: | |
| """No-op reranker that returns documents in original order. | |
| Use this as a fallback when cross-encoder is not available. | |
| """ | |
| def rerank( | |
| self, | |
| query: str, | |
| documents: List[Document], | |
| top_k: int = 6, | |
| ) -> List[Document]: | |
| """Return documents without reranking.""" | |
| return documents[:top_k] | |
| def rerank_with_scores( | |
| self, | |
| query: str, | |
| documents: List[Document], | |
| top_k: int = 6, | |
| ) -> List[Tuple[Document, float]]: | |
| """Return documents with dummy scores.""" | |
| return [(doc, 1.0 - i * 0.05) for i, doc in enumerate(documents[:top_k])] | |
| def get_reranker( | |
| model_name: str = "fast", | |
| fallback_to_noop: bool = True, | |
| ) -> FastCrossEncoderReranker: | |
| """Factory function to get a reranker instance. | |
| Args: | |
| model_name: Model name or alias | |
| fallback_to_noop: If True, return NoOpReranker when cross-encoder fails | |
| Returns: | |
| Reranker instance | |
| """ | |
| try: | |
| reranker = FastCrossEncoderReranker(model_name) | |
| # Test model loading | |
| if reranker._ensure_model(): | |
| return reranker | |
| except Exception as e: | |
| logger.warning(f"Failed to create cross-encoder reranker: {e}") | |
| if fallback_to_noop: | |
| logger.info("Using no-op reranker as fallback") | |
| return NoOpReranker() | |
| raise RuntimeError("Cross-encoder reranker not available") | |