vn6295337's picture
Initial commit: RAG Document Assistant with Zero-Storage Privacy
f866820
"""
Reranking module for improving retrieval precision.
Uses cross-encoder models to reorder initial retrieval results
based on query-document relevance.
"""
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
# Model cache for lazy loading
_reranker_model = None
_reranker_model_name = None
@dataclass
class RerankResult:
"""Result from reranking operation."""
chunks: List[Dict[str, Any]]
model_used: str
reranked: bool
def _get_cross_encoder(model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
"""
Lazy load and cache cross-encoder model.
Args:
model_name: HuggingFace model name for cross-encoder
Returns:
CrossEncoder model instance
"""
global _reranker_model, _reranker_model_name
if _reranker_model is None or _reranker_model_name != model_name:
try:
from sentence_transformers import CrossEncoder
_reranker_model = CrossEncoder(model_name)
_reranker_model_name = model_name
except ImportError:
raise ImportError(
"sentence-transformers not installed. "
"Install with: pip install sentence-transformers"
)
except Exception as e:
raise RuntimeError(f"Failed to load cross-encoder model: {e}")
return _reranker_model
def rerank_chunks(
query: str,
chunks: List[Dict[str, Any]],
top_k: int = 5,
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
) -> RerankResult:
"""
Rerank chunks using a cross-encoder model.
Cross-encoders process query and document together, enabling
more nuanced relevance scoring than bi-encoders.
Args:
query: User query
chunks: List of chunks to rerank
top_k: Number of top results to return
model_name: Cross-encoder model to use
Returns:
RerankResult with reranked chunks
"""
if not chunks:
return RerankResult(chunks=[], model_used="none", reranked=False)
if len(chunks) <= 1:
return RerankResult(chunks=chunks, model_used="none", reranked=False)
try:
model = _get_cross_encoder(model_name)
# Prepare query-document pairs
pairs = []
for chunk in chunks:
text = chunk.get("text", "")
if not text and "metadata" in chunk:
text = chunk["metadata"].get("text", "")
pairs.append((query, text))
# Get relevance scores
scores = model.predict(pairs)
# Combine chunks with scores and sort
scored_chunks = list(zip(chunks, scores))
scored_chunks.sort(key=lambda x: x[1], reverse=True)
# Build result with rerank scores
results = []
for chunk, score in scored_chunks[:top_k]:
result_chunk = chunk.copy()
result_chunk["rerank_score"] = float(score)
results.append(result_chunk)
return RerankResult(
chunks=results,
model_used=model_name,
reranked=True
)
except Exception as e:
# Fallback: return original chunks without reranking
return RerankResult(
chunks=chunks[:top_k],
model_used=f"fallback (error: {str(e)[:50]})",
reranked=False
)
def rerank_with_llm(
query: str,
chunks: List[Dict[str, Any]],
top_k: int = 5
) -> RerankResult:
"""
Rerank chunks using LLM-based scoring (fallback method).
More expensive but works without additional models.
Args:
query: User query
chunks: List of chunks to rerank
top_k: Number of top results to return
Returns:
RerankResult with reranked chunks
"""
if not chunks:
return RerankResult(chunks=[], model_used="none", reranked=False)
if len(chunks) <= 1:
return RerankResult(chunks=chunks, model_used="none", reranked=False)
try:
from src.llm_providers import call_llm
# Build scoring prompt
chunk_texts = []
for i, chunk in enumerate(chunks):
text = chunk.get("text", "")[:500] # Truncate for prompt size
chunk_texts.append(f"[{i}] {text}")
prompt = f"""Rate the relevance of each document to the query on a scale of 0-10.
Return ONLY a comma-separated list of scores in order (e.g., "8,3,7,5").
Query: {query}
Documents:
{chr(10).join(chunk_texts)}
Scores:"""
response = call_llm(prompt=prompt, temperature=0.0, max_tokens=100)
scores_text = response.get("text", "").strip()
# Parse scores
try:
scores = [float(s.strip()) for s in scores_text.split(",")]
if len(scores) != len(chunks):
raise ValueError("Score count mismatch")
except (ValueError, AttributeError):
# Fallback to original order
return RerankResult(
chunks=chunks[:top_k],
model_used="llm_parse_failed",
reranked=False
)
# Sort by scores
scored_chunks = list(zip(chunks, scores))
scored_chunks.sort(key=lambda x: x[1], reverse=True)
results = []
for chunk, score in scored_chunks[:top_k]:
result_chunk = chunk.copy()
result_chunk["rerank_score"] = float(score)
results.append(result_chunk)
return RerankResult(
chunks=results,
model_used="llm",
reranked=True
)
except Exception as e:
return RerankResult(
chunks=chunks[:top_k],
model_used=f"llm_fallback (error: {str(e)[:50]})",
reranked=False
)