|
|
""" |
|
|
Reranker Agent |
|
|
|
|
|
Cross-encoder based reranking for improved retrieval precision. |
|
|
Follows FAANG best practices for production RAG systems. |
|
|
|
|
|
Key Features: |
|
|
- LLM-based cross-encoder reranking |
|
|
- Relevance scoring with explanations |
|
|
- Diversity promotion to avoid redundancy |
|
|
- Quality filtering (removes low-quality chunks) |
|
|
- Chunk deduplication |
|
|
""" |
|
|
|
|
|
from typing import List, Optional, Dict, Any, Tuple |
|
|
from pydantic import BaseModel, Field |
|
|
from loguru import logger |
|
|
from dataclasses import dataclass |
|
|
import json |
|
|
import re |
|
|
from difflib import SequenceMatcher |
|
|
|
|
|
try: |
|
|
import httpx |
|
|
HTTPX_AVAILABLE = True |
|
|
except ImportError: |
|
|
HTTPX_AVAILABLE = False |
|
|
|
|
|
from .retriever import RetrievalResult |
|
|
|
|
|
|
|
|
class RerankerConfig(BaseModel): |
|
|
"""Configuration for reranking.""" |
|
|
|
|
|
model: str = Field(default="llama3.2:3b") |
|
|
base_url: str = Field(default="http://localhost:11434") |
|
|
temperature: float = Field(default=0.1) |
|
|
|
|
|
|
|
|
top_k: int = Field(default=5, ge=1) |
|
|
min_relevance_score: float = Field(default=0.3, ge=0.0, le=1.0) |
|
|
|
|
|
|
|
|
enable_diversity: bool = Field(default=True) |
|
|
diversity_threshold: float = Field(default=0.8, description="Max similarity between chunks") |
|
|
|
|
|
|
|
|
dedup_threshold: float = Field(default=0.9, description="Similarity threshold for dedup") |
|
|
|
|
|
|
|
|
use_llm_rerank: bool = Field(default=True) |
|
|
|
|
|
|
|
|
class RankedResult(BaseModel): |
|
|
"""A reranked result with relevance score.""" |
|
|
chunk_id: str |
|
|
document_id: str |
|
|
text: str |
|
|
original_score: float |
|
|
relevance_score: float |
|
|
final_score: float |
|
|
relevance_explanation: Optional[str] = None |
|
|
|
|
|
|
|
|
page: Optional[int] = None |
|
|
chunk_type: Optional[str] = None |
|
|
source_path: Optional[str] = None |
|
|
metadata: Dict[str, Any] = Field(default_factory=dict) |
|
|
bbox: Optional[Dict[str, float]] = None |
|
|
|
|
|
|
|
|
class RerankerAgent: |
|
|
""" |
|
|
Reranks retrieval results for improved precision. |
|
|
|
|
|
Capabilities: |
|
|
1. Cross-encoder relevance scoring |
|
|
2. Diversity-aware reranking (MMR-style) |
|
|
3. Quality filtering |
|
|
4. Chunk deduplication |
|
|
""" |
|
|
|
|
|
RERANK_PROMPT = """Score the relevance of this text passage to the given query. |
|
|
|
|
|
Query: {query} |
|
|
|
|
|
Passage: {passage} |
|
|
|
|
|
Score the relevance on a scale of 0-10 where: |
|
|
- 0-2: Completely irrelevant, no useful information |
|
|
- 3-4: Marginally relevant, tangentially related |
|
|
- 5-6: Somewhat relevant, contains some useful information |
|
|
- 7-8: Highly relevant, directly addresses the query |
|
|
- 9-10: Perfectly relevant, comprehensive answer to query |
|
|
|
|
|
Respond with ONLY a JSON object: |
|
|
{{"score": <number>, "explanation": "<brief reason>"}}""" |
|
|
|
|
|
def __init__(self, config: Optional[RerankerConfig] = None): |
|
|
""" |
|
|
Initialize Reranker Agent. |
|
|
|
|
|
Args: |
|
|
config: Reranker configuration |
|
|
""" |
|
|
self.config = config or RerankerConfig() |
|
|
logger.info(f"RerankerAgent initialized (model={self.config.model})") |
|
|
|
|
|
def rerank( |
|
|
self, |
|
|
query: str, |
|
|
results: List[RetrievalResult], |
|
|
top_k: Optional[int] = None, |
|
|
) -> List[RankedResult]: |
|
|
""" |
|
|
Rerank retrieval results by relevance to query. |
|
|
|
|
|
Args: |
|
|
query: Original search query |
|
|
results: Retrieval results to rerank |
|
|
top_k: Number of results to return |
|
|
|
|
|
Returns: |
|
|
Reranked results with relevance scores |
|
|
""" |
|
|
if not results: |
|
|
return [] |
|
|
|
|
|
top_k = top_k or self.config.top_k |
|
|
|
|
|
|
|
|
deduped = self._deduplicate(results) |
|
|
|
|
|
|
|
|
if self.config.use_llm_rerank and HTTPX_AVAILABLE: |
|
|
scored = self._llm_rerank(query, deduped) |
|
|
else: |
|
|
scored = self._heuristic_rerank(query, deduped) |
|
|
|
|
|
|
|
|
filtered = [ |
|
|
r for r in scored |
|
|
if r.relevance_score >= self.config.min_relevance_score |
|
|
] |
|
|
|
|
|
|
|
|
if self.config.enable_diversity: |
|
|
diverse = self._promote_diversity(filtered, top_k) |
|
|
else: |
|
|
diverse = sorted(filtered, key=lambda x: x.final_score, reverse=True)[:top_k] |
|
|
|
|
|
return diverse |
|
|
|
|
|
def _deduplicate(self, results: List[RetrievalResult]) -> List[RetrievalResult]: |
|
|
"""Remove near-duplicate chunks.""" |
|
|
if not results: |
|
|
return [] |
|
|
|
|
|
deduped = [results[0]] |
|
|
|
|
|
for result in results[1:]: |
|
|
is_dup = False |
|
|
for existing in deduped: |
|
|
similarity = self._text_similarity(result.text, existing.text) |
|
|
if similarity > self.config.dedup_threshold: |
|
|
is_dup = True |
|
|
break |
|
|
|
|
|
if not is_dup: |
|
|
deduped.append(result) |
|
|
|
|
|
if len(results) != len(deduped): |
|
|
logger.debug(f"Deduplication: {len(results)} -> {len(deduped)} chunks") |
|
|
|
|
|
return deduped |
|
|
|
|
|
def _text_similarity(self, text1: str, text2: str) -> float: |
|
|
"""Compute text similarity using SequenceMatcher.""" |
|
|
return SequenceMatcher(None, text1.lower(), text2.lower()).ratio() |
|
|
|
|
|
def _llm_rerank( |
|
|
self, |
|
|
query: str, |
|
|
results: List[RetrievalResult], |
|
|
) -> List[RankedResult]: |
|
|
"""Use LLM for cross-encoder style reranking.""" |
|
|
ranked = [] |
|
|
|
|
|
for result in results: |
|
|
try: |
|
|
relevance_score, explanation = self._score_passage(query, result.text) |
|
|
|
|
|
|
|
|
|
|
|
final_score = 0.3 * result.score + 0.7 * (relevance_score / 10.0) |
|
|
|
|
|
ranked.append(RankedResult( |
|
|
chunk_id=result.chunk_id, |
|
|
document_id=result.document_id, |
|
|
text=result.text, |
|
|
original_score=result.score, |
|
|
relevance_score=relevance_score / 10.0, |
|
|
final_score=final_score, |
|
|
relevance_explanation=explanation, |
|
|
page=result.page, |
|
|
chunk_type=result.chunk_type, |
|
|
source_path=result.source_path, |
|
|
metadata=result.metadata, |
|
|
bbox=result.bbox, |
|
|
)) |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to score passage: {e}") |
|
|
|
|
|
ranked.append(RankedResult( |
|
|
chunk_id=result.chunk_id, |
|
|
document_id=result.document_id, |
|
|
text=result.text, |
|
|
original_score=result.score, |
|
|
relevance_score=result.score, |
|
|
final_score=result.score, |
|
|
page=result.page, |
|
|
chunk_type=result.chunk_type, |
|
|
source_path=result.source_path, |
|
|
metadata=result.metadata, |
|
|
bbox=result.bbox, |
|
|
)) |
|
|
|
|
|
return ranked |
|
|
|
|
|
def _score_passage(self, query: str, passage: str) -> Tuple[float, str]: |
|
|
"""Score a single passage using LLM.""" |
|
|
prompt = self.RERANK_PROMPT.format( |
|
|
query=query, |
|
|
passage=passage[:1000], |
|
|
) |
|
|
|
|
|
with httpx.Client(timeout=30.0) as client: |
|
|
response = client.post( |
|
|
f"{self.config.base_url}/api/generate", |
|
|
json={ |
|
|
"model": self.config.model, |
|
|
"prompt": prompt, |
|
|
"stream": False, |
|
|
"options": { |
|
|
"temperature": self.config.temperature, |
|
|
"num_predict": 256, |
|
|
}, |
|
|
}, |
|
|
) |
|
|
response.raise_for_status() |
|
|
result = response.json() |
|
|
|
|
|
|
|
|
response_text = result.get("response", "") |
|
|
return self._parse_score_response(response_text) |
|
|
|
|
|
def _parse_score_response(self, text: str) -> Tuple[float, str]: |
|
|
"""Parse score and explanation from LLM response.""" |
|
|
try: |
|
|
|
|
|
json_match = re.search(r'\{[\s\S]*\}', text) |
|
|
if json_match: |
|
|
data = json.loads(json_match.group()) |
|
|
score = float(data.get("score", 5)) |
|
|
explanation = data.get("explanation", "") |
|
|
return min(max(score, 0), 10), explanation |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
num_match = re.search(r'\b([0-9]|10)\b', text) |
|
|
if num_match: |
|
|
return float(num_match.group()), "" |
|
|
|
|
|
|
|
|
return 5.0, "Could not parse score" |
|
|
|
|
|
def _heuristic_rerank( |
|
|
self, |
|
|
query: str, |
|
|
results: List[RetrievalResult], |
|
|
) -> List[RankedResult]: |
|
|
"""Fast heuristic-based reranking.""" |
|
|
query_terms = set(query.lower().split()) |
|
|
ranked = [] |
|
|
|
|
|
for result in results: |
|
|
|
|
|
text_lower = result.text.lower() |
|
|
|
|
|
|
|
|
text_terms = set(text_lower.split()) |
|
|
overlap = len(query_terms & text_terms) / len(query_terms) if query_terms else 0 |
|
|
|
|
|
|
|
|
phrase_bonus = 0.2 if query.lower() in text_lower else 0 |
|
|
|
|
|
|
|
|
length = len(result.text) |
|
|
length_score = min(length, 500) / 500 |
|
|
|
|
|
|
|
|
relevance = 0.5 * overlap + 0.3 * phrase_bonus + 0.2 * length_score |
|
|
final_score = 0.4 * result.score + 0.6 * relevance |
|
|
|
|
|
ranked.append(RankedResult( |
|
|
chunk_id=result.chunk_id, |
|
|
document_id=result.document_id, |
|
|
text=result.text, |
|
|
original_score=result.score, |
|
|
relevance_score=relevance, |
|
|
final_score=final_score, |
|
|
page=result.page, |
|
|
chunk_type=result.chunk_type, |
|
|
source_path=result.source_path, |
|
|
metadata=result.metadata, |
|
|
bbox=result.bbox, |
|
|
)) |
|
|
|
|
|
return ranked |
|
|
|
|
|
def _promote_diversity( |
|
|
self, |
|
|
results: List[RankedResult], |
|
|
top_k: int, |
|
|
) -> List[RankedResult]: |
|
|
""" |
|
|
Promote diversity using MMR-style selection. |
|
|
|
|
|
Maximal Marginal Relevance balances relevance with diversity. |
|
|
""" |
|
|
if not results: |
|
|
return [] |
|
|
|
|
|
|
|
|
sorted_results = sorted(results, key=lambda x: x.final_score, reverse=True) |
|
|
|
|
|
selected = [sorted_results[0]] |
|
|
remaining = sorted_results[1:] |
|
|
|
|
|
while len(selected) < top_k and remaining: |
|
|
|
|
|
best_mmr = -1 |
|
|
best_idx = 0 |
|
|
|
|
|
for i, candidate in enumerate(remaining): |
|
|
|
|
|
relevance = candidate.final_score |
|
|
|
|
|
|
|
|
max_sim = max( |
|
|
self._text_similarity(candidate.text, s.text) |
|
|
for s in selected |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
mmr = 0.7 * relevance - 0.3 * max_sim |
|
|
|
|
|
if mmr > best_mmr: |
|
|
best_mmr = mmr |
|
|
best_idx = i |
|
|
|
|
|
selected.append(remaining.pop(best_idx)) |
|
|
|
|
|
return selected |
|
|
|