| | """ |
| | Entity Reranker |
| | Improves precision by re-scoring and reordering retrieval results |
| | """ |
| |
|
| | from typing import List |
| |
|
| | from .models import RetrievalResult |
| | from .normalizer import get_normalizer |
| |
|
| |
|
| | class EntityReranker: |
| | """ |
| | Reranks retrieval results to improve precision. |
| | |
| | Scoring factors: |
| | - Full name match (highest priority) |
| | - Name length ratio (prefer specific over generic) |
| | - Position/organization match |
| | - Penalize ambiguous short matches |
| | """ |
| | |
| | def __init__(self, debug: bool = False): |
| | self.debug = debug |
| | self.normalizer = get_normalizer() |
| | |
| | def rerank(self, query: str, results: List[RetrievalResult]) -> List[RetrievalResult]: |
| | """ |
| | Rerank results based on query-specific signals. |
| | |
| | Args: |
| | query: Original user query |
| | results: Retrieved results to rerank |
| | |
| | Returns: |
| | Reranked list of results |
| | """ |
| | if not results: |
| | return results |
| | |
| | |
| | query_norm = self.normalizer.normalize(query) |
| | query_stem = self.normalizer.stem(query_norm) |
| | query_parts = self.normalizer.extract_name_parts(query) |
| | |
| | |
| | scored_results = [] |
| | |
| | for result in results: |
| | score = result.score |
| | entity = result.entity |
| | |
| | |
| | entity_norm = self.normalizer.normalize(entity.name) |
| | entity_stem = self.normalizer.stem(entity_norm) |
| | |
| | |
| | if query_norm == entity_norm or query_stem == entity_stem: |
| | score *= 10.0 |
| | if self.debug: |
| | print(f" [Rerank] Exact match: {entity.name} β Γ10") |
| | |
| | |
| | elif query_norm in entity_norm or query_stem in entity_stem: |
| | |
| | ratio = len(query_norm) / len(entity_norm) |
| | bonus = 1.0 + (ratio * 3.0) |
| | score *= bonus |
| | if self.debug: |
| | print(f" [Rerank] Substring match: {entity.name} β Γ{bonus:.2f}") |
| | |
| | |
| | elif entity_norm in query_norm or entity_stem in query_stem: |
| | score *= 3.0 |
| | if self.debug: |
| | print(f" [Rerank] Entity in query: {entity.name} β Γ3") |
| | |
| | |
| | if query_parts.get("first_name"): |
| | entity_parts = self.normalizer.extract_name_parts(entity.name) |
| | if query_parts["first_name"] == entity_parts.get("first_name"): |
| | score *= 1.5 |
| | if self.debug: |
| | print(f" [Rerank] First name match: {entity.name} β Γ1.5") |
| | |
| | |
| | if len(query_norm) < 15 and len(entity_norm) > 40: |
| | |
| | penalty = 0.7 |
| | score *= penalty |
| | if self.debug: |
| | print(f" [Rerank] Short query penalty: {entity.name} β Γ{penalty}") |
| | |
| | |
| | query_lower = query.lower() |
| | if entity.primary_position and entity.primary_position.lower() in query_lower: |
| | score *= 2.0 |
| | if self.debug: |
| | print(f" [Rerank] Position match: {entity.primary_position} β Γ2") |
| | |
| | if entity.primary_organization and entity.primary_organization.lower() in query_lower: |
| | score *= 1.5 |
| | if self.debug: |
| | print(f" [Rerank] Organization match: {entity.primary_organization} β Γ1.5") |
| | |
| | |
| | scored_results.append(RetrievalResult( |
| | entity=result.entity, |
| | score=score, |
| | match_type=result.match_type, |
| | matched_variant=result.matched_variant, |
| | normalized_query=result.normalized_query, |
| | )) |
| | |
| | |
| | scored_results.sort(key=lambda r: r.score, reverse=True) |
| | |
| | return scored_results |
| | |
| | def filter_ambiguous(self, results: List[RetrievalResult], threshold: float = 0.2) -> List[RetrievalResult]: |
| | """ |
| | Filter results where top scores are too close (ambiguous). |
| | |
| | Args: |
| | results: Ranked results |
| | threshold: Relative difference threshold (0.2 = 20%) |
| | |
| | Returns: |
| | Filtered results, or original if not ambiguous |
| | """ |
| | if len(results) < 2: |
| | return results |
| | |
| | top_score = results[0].score |
| | second_score = results[1].score |
| | |
| | |
| | if top_score > 0 and (top_score - second_score) / top_score < threshold: |
| | if self.debug: |
| | print(f" [Rerank] Ambiguous results detected:") |
| | print(f" #1: {results[0].entity.name} (score: {top_score:.2f})") |
| | print(f" #2: {results[1].entity.name} (score: {second_score:.2f})") |
| | |
| | close_results = [r for r in results if r.score >= top_score * (1 - threshold * 2)] |
| | return close_results |
| | |
| | return results |
| |
|
| |
|
| | |
| | def rerank_results(query: str, results: List[RetrievalResult], debug: bool = False) -> List[RetrievalResult]: |
| | """ |
| | Convenience function to rerank results. |
| | """ |
| | reranker = EntityReranker(debug=debug) |
| | return reranker.rerank(query, results) |
| |
|
| |
|