adi-123's picture
Upload 21 files
8c35759 verified
"""Cross-encoder reranker for document retrieval."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
import logging
from langchain.schema import Document
logger = logging.getLogger(__name__)
# Lazy import to avoid loading model at import time
_cross_encoder = None
_cross_encoder_model_name = None
def _get_cross_encoder(model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
"""Lazy load the cross-encoder model.
Args:
model_name: HuggingFace model identifier
Returns:
CrossEncoder instance
"""
global _cross_encoder, _cross_encoder_model_name
if _cross_encoder is None or _cross_encoder_model_name != model_name:
try:
from sentence_transformers import CrossEncoder
logger.info(f"Loading cross-encoder model: {model_name}")
_cross_encoder = CrossEncoder(model_name, max_length=512)
_cross_encoder_model_name = model_name
except ImportError:
logger.warning(
"sentence-transformers not installed. "
"Run: pip install sentence-transformers"
)
return None
except Exception as e:
logger.warning(f"Failed to load cross-encoder: {e}")
return None
return _cross_encoder
class FastCrossEncoderReranker:
"""Cross-encoder reranker using sentence-transformers.
Runs locally and is faster than LLM-based reranking.
"""
MODEL_OPTIONS = {
"fast": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"balanced": "cross-encoder/ms-marco-MiniLM-L-12-v2",
"tiny": "cross-encoder/ms-marco-TinyBERT-L-2-v2",
}
def __init__(
self,
model_name: str = "fast",
max_length: int = 512,
batch_size: int = 16,
) -> None:
"""Initialize cross-encoder reranker.
Args:
model_name: One of "fast", "balanced", "tiny", or a HuggingFace model ID
max_length: Maximum sequence length for encoding
batch_size: Batch size for scoring (higher = faster but more memory)
"""
# Resolve model name alias
self.model_name = self.MODEL_OPTIONS.get(model_name, model_name)
self.max_length = max_length
self.batch_size = batch_size
self._model = None
def _ensure_model(self) -> bool:
"""Ensure model is loaded.
Returns:
True if model is available, False otherwise
"""
if self._model is None:
self._model = _get_cross_encoder(self.model_name)
return self._model is not None
def rerank(
self,
query: str,
documents: List[Document],
top_k: int = 6,
) -> List[Document]:
"""Rerank documents by relevance to query.
Args:
query: User query
documents: Documents to rerank
top_k: Number of top documents to return
Returns:
Reranked documents (most relevant first)
"""
if not documents:
return []
if len(documents) <= 1:
return documents
if not self._ensure_model():
logger.warning("Cross-encoder not available, returning original order")
return documents[:top_k]
try:
# Prepare query-document pairs
pairs = [
(query, self._get_text(doc)[:self.max_length])
for doc in documents
]
# Score all pairs (batched for efficiency)
scores = self._model.predict(
pairs,
batch_size=self.batch_size,
show_progress_bar=False,
)
# Sort by score descending
scored_docs = sorted(
zip(documents, scores),
key=lambda x: x[1],
reverse=True,
)
return [doc for doc, _ in scored_docs[:top_k]]
except Exception as e:
logger.warning(f"Reranking failed: {e}, returning original order")
return documents[:top_k]
def rerank_with_scores(
self,
query: str,
documents: List[Document],
top_k: int = 6,
) -> List[Tuple[Document, float]]:
"""Rerank documents and return with scores.
Args:
query: User query
documents: Documents to rerank
top_k: Number of top documents to return
Returns:
List of (document, score) tuples, sorted by score descending
"""
if not documents:
return []
if len(documents) <= 1:
return [(doc, 1.0) for doc in documents]
if not self._ensure_model():
return [(doc, 1.0 - i * 0.1) for i, doc in enumerate(documents[:top_k])]
try:
pairs = [
(query, self._get_text(doc)[:self.max_length])
for doc in documents
]
scores = self._model.predict(
pairs,
batch_size=self.batch_size,
show_progress_bar=False,
)
scored_docs = sorted(
zip(documents, scores),
key=lambda x: x[1],
reverse=True,
)
return scored_docs[:top_k]
except Exception as e:
logger.warning(f"Reranking failed: {e}")
return [(doc, 1.0 - i * 0.1) for i, doc in enumerate(documents[:top_k])]
def _get_text(self, doc: Document) -> str:
"""Extract text content from document.
Args:
doc: LangChain Document
Returns:
Text content
"""
if hasattr(doc, 'page_content'):
return doc.page_content
return str(doc)
class NoOpReranker:
"""No-op reranker that returns documents in original order.
Use this as a fallback when cross-encoder is not available.
"""
def rerank(
self,
query: str,
documents: List[Document],
top_k: int = 6,
) -> List[Document]:
"""Return documents without reranking."""
return documents[:top_k]
def rerank_with_scores(
self,
query: str,
documents: List[Document],
top_k: int = 6,
) -> List[Tuple[Document, float]]:
"""Return documents with dummy scores."""
return [(doc, 1.0 - i * 0.05) for i, doc in enumerate(documents[:top_k])]
def get_reranker(
model_name: str = "fast",
fallback_to_noop: bool = True,
) -> FastCrossEncoderReranker:
"""Factory function to get a reranker instance.
Args:
model_name: Model name or alias
fallback_to_noop: If True, return NoOpReranker when cross-encoder fails
Returns:
Reranker instance
"""
try:
reranker = FastCrossEncoderReranker(model_name)
# Test model loading
if reranker._ensure_model():
return reranker
except Exception as e:
logger.warning(f"Failed to create cross-encoder reranker: {e}")
if fallback_to_noop:
logger.info("Using no-op reranker as fallback")
return NoOpReranker()
raise RuntimeError("Cross-encoder reranker not available")