book-rec-with-LLMs / src /core /rag /reranker.py
ymlin105's picture
chore: remove obsolete files and update project structure
6ad997d
"""
Reranker: Cross-Encoder (torch/ONNX) or ColBERT (optional).
Backend selectable via RERANKER_BACKEND env: cross_encoder | onnx | colbert.
ONNX ~2x faster than torch; ColBERT requires llama-index-postprocessor-colbert-rerank.
"""
from typing import List, Any
from src.config import RERANKER_BACKEND
from src.utils import setup_logger
logger = setup_logger(__name__)
DEFAULT_RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
def _load_cross_encoder(backend: str):
"""Load CrossEncoder with torch or ONNX backend. Falls back to torch if ONNX fails."""
from sentence_transformers import CrossEncoder
import torch
device = "mps" if torch.backends.mps.is_available() else "cpu"
be = "onnx" if backend == "onnx" else "torch"
try:
logger.info(
"Loading Reranker (%s) backend=%s on %s...",
DEFAULT_RERANKER_MODEL,
be,
device,
)
model = CrossEncoder(DEFAULT_RERANKER_MODEL, device=device, backend=be)
logger.info("Reranker model loaded.")
return model
except Exception as e:
if be == "onnx":
logger.warning(
"ONNX backend failed (pip install onnxruntime?), falling back to torch: %s",
e,
)
return CrossEncoder(
DEFAULT_RERANKER_MODEL, device=device, backend="torch"
)
raise
def _load_colbert():
"""Load ColBERT reranker via llama-index (optional dep)."""
try:
from llama_index.postprocessor.colbert_rerank import ColbertRerank
return ColbertRerank(
model_name="colbert-ir/colbertv2.0",
top_n=10,
)
except ImportError as e:
logger.warning(
"ColBERT not available (pip install llama-index-postprocessor-colbert-rerank): %s",
e,
)
return None
def _get_text(doc: Any) -> str:
if hasattr(doc, "page_content"):
return doc.page_content
return doc.get("description") or doc.get("page_content") or str(doc)
def _set_score(doc: Any, score: float) -> None:
if hasattr(doc, "metadata"):
doc.metadata["relevance_score"] = score
else:
doc["score"] = score
def _get_score(doc: Any) -> float:
if hasattr(doc, "metadata"):
return doc.metadata.get("relevance_score", 0)
return doc.get("score", 0)
class RerankerService:
"""
Singleton reranker: Cross-Encoder (torch/ONNX) or ColBERT.
"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(RerankerService, cls).__new__(cls)
cls._instance.model = None
cls._instance._backend = None
return cls._instance
def __init__(self):
if self.model is None:
self._load_model()
def _load_model(self):
backend = (RERANKER_BACKEND or "").lower()
if backend == "colbert":
self.model = _load_colbert()
self._backend = "colbert" if self.model else "cross_encoder"
if self._backend == "cross_encoder":
self.model = _load_cross_encoder("torch")
else:
self._backend = "onnx" if backend == "onnx" else "cross_encoder"
self.model = _load_cross_encoder(self._backend)
def rerank(self, query: str, docs: List[Any], top_k: int = 5) -> List[Any]:
"""
Rerank documents by relevance to query.
docs: List of dicts or LangChain Document with description/page_content.
"""
if not self.model or not docs:
return docs[:top_k]
if self._backend == "colbert":
return self._rerank_colbert(query, docs, top_k)
return self._rerank_cross_encoder(query, docs, top_k)
def _rerank_cross_encoder(
self, query: str, docs: List[Any], top_k: int
) -> List[Any]:
pairs = [[query, _get_text(d)] for d in docs]
scores = self.model.predict(pairs)
for i, doc in enumerate(docs):
_set_score(doc, float(scores[i]))
docs.sort(key=_get_score, reverse=True)
return docs[:top_k]
def _rerank_colbert(
self, query: str, docs: List[Any], top_k: int
) -> List[Any]:
from llama_index.schema import NodeWithScore, TextNode
# Keep ref to original doc for metadata (isbn, etc.)
nodes = []
for d in docs:
node = TextNode(text=_get_text(d), metadata={"__original": d})
nodes.append(NodeWithScore(node=node, score=0.0))
reranked = self.model.postprocess_nodes(nodes, query_str=query)
result = []
for nws in reranked[:top_k]:
orig = getattr(nws.node, "metadata", {}).get("__original")
if orig is not None:
_set_score(orig, float(nws.score or 0))
result.append(orig)
else:
from langchain_core.documents import Document
doc = Document(
page_content=nws.node.text,
metadata={"relevance_score": float(nws.score or 0)},
)
result.append(doc)
return result
reranker = RerankerService()
__all__ = ["RerankerService", "reranker"]