uae-kb / ir /reranker.py
Demon1212122's picture
Initial UAE Knowledge System demo
8124364
"""
Entity Reranker
Improves precision by re-scoring and reordering retrieval results
"""
from typing import List
from .models import RetrievalResult
from .normalizer import get_normalizer
class EntityReranker:
"""
Reranks retrieval results to improve precision.
Scoring factors:
- Full name match (highest priority)
- Name length ratio (prefer specific over generic)
- Position/organization match
- Penalize ambiguous short matches
"""
def __init__(self, debug: bool = False):
self.debug = debug
self.normalizer = get_normalizer()
def rerank(self, query: str, results: List[RetrievalResult]) -> List[RetrievalResult]:
"""
Rerank results based on query-specific signals.
Args:
query: Original user query
results: Retrieved results to rerank
Returns:
Reranked list of results
"""
if not results:
return results
# Normalize query for comparison
query_norm = self.normalizer.normalize(query)
query_stem = self.normalizer.stem(query_norm)
query_parts = self.normalizer.extract_name_parts(query)
# Score adjustments
scored_results = []
for result in results:
score = result.score
entity = result.entity
# Get entity's normalized form
entity_norm = self.normalizer.normalize(entity.name)
entity_stem = self.normalizer.stem(entity_norm)
# Factor 1: Exact match bonus
if query_norm == entity_norm or query_stem == entity_stem:
score *= 10.0
if self.debug:
print(f" [Rerank] Exact match: {entity.name} β†’ Γ—10")
# Factor 2: Query is substring of entity (specific query)
elif query_norm in entity_norm or query_stem in entity_stem:
# Length ratio bonus (longer query = more specific)
ratio = len(query_norm) / len(entity_norm)
bonus = 1.0 + (ratio * 3.0) # Up to 4x for long matches
score *= bonus
if self.debug:
print(f" [Rerank] Substring match: {entity.name} β†’ Γ—{bonus:.2f}")
# Factor 3: Entity is substring of query (query contains name)
elif entity_norm in query_norm or entity_stem in query_stem:
score *= 3.0
if self.debug:
print(f" [Rerank] Entity in query: {entity.name} β†’ Γ—3")
# Factor 4: First name match (important for Arabic names)
if query_parts.get("first_name"):
entity_parts = self.normalizer.extract_name_parts(entity.name)
if query_parts["first_name"] == entity_parts.get("first_name"):
score *= 1.5
if self.debug:
print(f" [Rerank] First name match: {entity.name} β†’ Γ—1.5")
# Factor 5: Penalize very short matches (likely ambiguous)
if len(query_norm) < 15 and len(entity_norm) > 40:
# Short query matching long name - might be too generic
penalty = 0.7
score *= penalty
if self.debug:
print(f" [Rerank] Short query penalty: {entity.name} β†’ Γ—{penalty}")
# Factor 6: Position/Organization match bonus
query_lower = query.lower()
if entity.primary_position and entity.primary_position.lower() in query_lower:
score *= 2.0
if self.debug:
print(f" [Rerank] Position match: {entity.primary_position} β†’ Γ—2")
if entity.primary_organization and entity.primary_organization.lower() in query_lower:
score *= 1.5
if self.debug:
print(f" [Rerank] Organization match: {entity.primary_organization} β†’ Γ—1.5")
# Create new result with adjusted score
scored_results.append(RetrievalResult(
entity=result.entity,
score=score,
match_type=result.match_type,
matched_variant=result.matched_variant,
normalized_query=result.normalized_query,
))
# Sort by adjusted score
scored_results.sort(key=lambda r: r.score, reverse=True)
return scored_results
def filter_ambiguous(self, results: List[RetrievalResult], threshold: float = 0.2) -> List[RetrievalResult]:
"""
Filter results where top scores are too close (ambiguous).
Args:
results: Ranked results
threshold: Relative difference threshold (0.2 = 20%)
Returns:
Filtered results, or original if not ambiguous
"""
if len(results) < 2:
return results
top_score = results[0].score
second_score = results[1].score
# If scores are very close, might be ambiguous
if top_score > 0 and (top_score - second_score) / top_score < threshold:
if self.debug:
print(f" [Rerank] Ambiguous results detected:")
print(f" #1: {results[0].entity.name} (score: {top_score:.2f})")
print(f" #2: {results[1].entity.name} (score: {second_score:.2f})")
# Return all close results for further disambiguation
close_results = [r for r in results if r.score >= top_score * (1 - threshold * 2)]
return close_results
return results
# Convenience function
def rerank_results(query: str, results: List[RetrievalResult], debug: bool = False) -> List[RetrievalResult]:
"""
Convenience function to rerank results.
"""
reranker = EntityReranker(debug=debug)
return reranker.rerank(query, results)