Spaces:
Sleeping
Sleeping
| """ | |
| MEXAR - Cross-Encoder Reranking Module | |
| Improves retrieval precision by reranking candidates with a cross-encoder model. | |
| """ | |
| import logging | |
| from typing import List, Tuple, Any | |
| logger = logging.getLogger(__name__) | |
| # Lazy load to avoid slow import on startup | |
| _reranker_model = None | |
| def _get_reranker(): | |
| """Lazy load the cross-encoder model.""" | |
| global _reranker_model | |
| if _reranker_model is None: | |
| try: | |
| from sentence_transformers import CrossEncoder | |
| _reranker_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
| logger.info("Cross-encoder reranker loaded successfully") | |
| except ImportError: | |
| logger.warning("sentence-transformers not installed. Install with: pip install sentence-transformers") | |
| _reranker_model = False | |
| except Exception as e: | |
| logger.warning(f"Failed to load cross-encoder: {e}") | |
| _reranker_model = False | |
| return _reranker_model | |
| class Reranker: | |
| """ | |
| Cross-encoder reranking for improved retrieval precision. | |
| Cross-encoders are more accurate than bi-encoders because they | |
| process query and document together, capturing fine-grained interactions. | |
| """ | |
| def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): | |
| """ | |
| Initialize reranker. | |
| Args: | |
| model_name: HuggingFace model name for cross-encoder | |
| """ | |
| self.model_name = model_name | |
| self._model = None | |
| def model(self): | |
| """Lazy load model on first use.""" | |
| if self._model is None: | |
| self._model = _get_reranker() | |
| return self._model | |
| def rerank( | |
| self, | |
| query: str, | |
| chunks: List[Any], | |
| top_k: int = 5 | |
| ) -> List[Tuple[Any, float]]: | |
| """ | |
| Rerank chunks using cross-encoder. | |
| Args: | |
| query: User's query | |
| chunks: List of DocumentChunk objects | |
| top_k: Number of top results to return | |
| Returns: | |
| List of (chunk, score) tuples, sorted by relevance | |
| """ | |
| if not chunks: | |
| return [] | |
| if not self.model: | |
| # Fallback: return chunks with placeholder scores | |
| logger.warning("Reranker not available, using original order") | |
| return [(chunk, 0.5) for chunk in chunks[:top_k]] | |
| try: | |
| # Create query-document pairs | |
| # Truncate content to avoid memory issues | |
| pairs = [[query, self._get_content(chunk)[:512]] for chunk in chunks] | |
| # Get cross-encoder scores | |
| scores = self.model.predict(pairs) | |
| # Combine chunks with scores | |
| chunk_scores = list(zip(chunks, scores)) | |
| # Sort by score descending | |
| ranked = sorted(chunk_scores, key=lambda x: x[1], reverse=True) | |
| logger.info(f"Reranked {len(chunks)} chunks, returning top {top_k}") | |
| return ranked[:top_k] | |
| except Exception as e: | |
| logger.error(f"Reranking failed: {e}") | |
| return [(chunk, 0.5) for chunk in chunks[:top_k]] | |
| def _get_content(self, chunk) -> str: | |
| """Extract content from chunk object.""" | |
| if hasattr(chunk, 'content'): | |
| return chunk.content | |
| elif isinstance(chunk, dict): | |
| return chunk.get('content', '') | |
| else: | |
| return str(chunk) | |
| def create_reranker() -> Reranker: | |
| """Factory function to create Reranker.""" | |
| return Reranker() | |