|
|
""" |
|
|
Entity IR Models |
|
|
Data structures for entity retrieval system |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Entity: |
|
|
""" |
|
|
Represents an entity from the knowledge base. |
|
|
Designed to work with both wiki_summary.json and dhow_summary.json formats. |
|
|
""" |
|
|
name: str |
|
|
id: str = "" |
|
|
variants: List[str] = field(default_factory=list) |
|
|
source: str = "" |
|
|
|
|
|
|
|
|
title: str = "" |
|
|
url: str = "" |
|
|
raw_text: str = "" |
|
|
summary: str = "" |
|
|
|
|
|
|
|
|
primary_position: str = "" |
|
|
primary_organization: str = "" |
|
|
family_name: str = "" |
|
|
city: str = "" |
|
|
country: str = "" |
|
|
|
|
|
|
|
|
facts: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
def __post_init__(self): |
|
|
"""Ensure name is always in variants""" |
|
|
if self.name and self.name not in self.variants: |
|
|
self.variants.insert(0, self.name) |
|
|
|
|
|
def get_searchable_text(self) -> str: |
|
|
""" |
|
|
Returns combined text for BM25 indexing. |
|
|
Includes name variants + position/organization for disambiguation. |
|
|
""" |
|
|
parts = list(self.variants) |
|
|
if self.primary_position: |
|
|
parts.append(self.primary_position) |
|
|
if self.primary_organization: |
|
|
parts.append(self.primary_organization) |
|
|
if self.family_name: |
|
|
parts.append(self.family_name) |
|
|
return " ".join(parts) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RetrievalResult: |
|
|
""" |
|
|
Result from entity retrieval with scoring information. |
|
|
""" |
|
|
entity: Entity |
|
|
score: float |
|
|
match_type: str |
|
|
|
|
|
|
|
|
matched_variant: Optional[str] = None |
|
|
normalized_query: Optional[str] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RetrievalConfig: |
|
|
""" |
|
|
Configuration for retrieval system. |
|
|
""" |
|
|
|
|
|
bm25_k1: float = 1.8 |
|
|
bm25_b: float = 0.4 |
|
|
|
|
|
|
|
|
top_k: int = 5 |
|
|
alias_boost: float = 10.0 |
|
|
exact_match_boost: float = 5.0 |
|
|
|
|
|
|
|
|
min_score_threshold: float = 0.1 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BenchmarkResult: |
|
|
""" |
|
|
Results from running benchmark evaluation. |
|
|
""" |
|
|
total_queries: int |
|
|
precision_at_1: float |
|
|
recall_at_5: float |
|
|
avg_latency_ms: float |
|
|
p95_latency_ms: float |
|
|
|
|
|
|
|
|
alias_hits: int = 0 |
|
|
bm25_hits: int = 0 |
|
|
misses: int = 0 |
|
|
|
|
|
def __str__(self) -> str: |
|
|
return f""" |
|
|
Benchmark Results: |
|
|
Total Queries: {self.total_queries} |
|
|
Precision@1: {self.precision_at_1:.2%} |
|
|
Recall@5: {self.recall_at_5:.2%} |
|
|
Avg Latency: {self.avg_latency_ms:.2f}ms |
|
|
P95 Latency: {self.p95_latency_ms:.2f}ms |
|
|
|
|
|
Breakdown: |
|
|
Alias Hits: {self.alias_hits} |
|
|
BM25 Hits: {self.bm25_hits} |
|
|
Misses: {self.misses} |
|
|
""".strip() |
|
|
|
|
|
|