|
|
""" |
|
|
Knowledge Base Loader |
|
|
Loads structured data from uae_knowledge_build/data/uae_knowledge_safety/ |
|
|
|
|
|
Supports the schema with: |
|
|
- entities.json (entities with aliases, facts, sensitive topics) |
|
|
- alias_index.json (pre-built alias mappings) |
|
|
- sensitive_topics.json (sensitive topic entries) |
|
|
- category_metadata.json (categories with must-answer topics) |
|
|
|
|
|
This module is designed to be extensible for future knowledge bases |
|
|
beyond UAE (e.g., Saudi Arabia, Qatar, etc.) |
|
|
""" |
|
|
|
|
|
import json |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
from ir.models import Entity |
|
|
|
|
|
|
|
|
class KnowledgeBase: |
|
|
""" |
|
|
Loads and manages knowledge base data. |
|
|
|
|
|
Currently optimized for UAE knowledge base format, but designed |
|
|
to be extended for other regional knowledge bases. |
|
|
""" |
|
|
|
|
|
|
|
|
DEFAULT_DATA_DIR = "uae_knowledge_build/data/unified_KB" |
|
|
|
|
|
def __init__(self, data_dir: Optional[str] = None, debug: bool = False): |
|
|
""" |
|
|
Initialize knowledge base. |
|
|
|
|
|
Args: |
|
|
data_dir: Path to data directory containing JSON files. |
|
|
Defaults to uae_knowledge_build/data/uae_knowledge_safety/ |
|
|
debug: Enable debug output |
|
|
""" |
|
|
self.debug = debug |
|
|
|
|
|
|
|
|
if data_dir: |
|
|
self.data_dir = Path(data_dir) |
|
|
else: |
|
|
|
|
|
self.data_dir = Path(__file__).parent.parent / self.DEFAULT_DATA_DIR |
|
|
|
|
|
|
|
|
self._entities_raw: List[dict] = [] |
|
|
self._alias_index_raw: List[dict] = [] |
|
|
self._sensitive_topics_raw: List[dict] = [] |
|
|
self._category_metadata_raw: List[dict] = [] |
|
|
|
|
|
|
|
|
self.entities: List[Entity] = [] |
|
|
|
|
|
|
|
|
self._id_to_entity: Dict[str, Entity] = {} |
|
|
self._id_to_raw: Dict[str, dict] = {} |
|
|
self._alias_lookup: Dict[str, str] = {} |
|
|
self._trigger_patterns: Dict[str, dict] = {} |
|
|
|
|
|
|
|
|
self._load_all() |
|
|
|
|
|
def _load_all(self) -> None: |
|
|
"""Load all knowledge base files""" |
|
|
|
|
|
self._entities_raw = self._load_json("entities.json") |
|
|
self._alias_index_raw = self._load_json("alias_index.json") |
|
|
self._sensitive_topics_raw = self._load_json("sensitive_topics.json") |
|
|
self._category_metadata_raw = self._load_json("category_metadata.json") |
|
|
|
|
|
|
|
|
self._convert_entities() |
|
|
|
|
|
|
|
|
self._build_alias_index() |
|
|
self._build_trigger_index() |
|
|
|
|
|
if self.debug: |
|
|
print(f"✅ KnowledgeBase loaded from {self.data_dir}") |
|
|
print(f" Entities: {len(self.entities)}") |
|
|
print(f" Aliases: {len(self._alias_lookup)}") |
|
|
print(f" Sensitive patterns: {len(self._trigger_patterns)}") |
|
|
|
|
|
def _load_json(self, filename: str) -> list: |
|
|
"""Load a JSON file from data directory""" |
|
|
filepath = self.data_dir / filename |
|
|
if not filepath.exists(): |
|
|
if self.debug: |
|
|
print(f"⚠️ File not found: {filepath}") |
|
|
return [] |
|
|
|
|
|
try: |
|
|
with open(filepath, "r", encoding="utf-8") as f: |
|
|
return json.load(f) |
|
|
except Exception as e: |
|
|
print(f"❌ Error loading {filepath}: {e}") |
|
|
return [] |
|
|
|
|
|
def _convert_entities(self) -> None: |
|
|
"""Convert raw entity data to Entity objects""" |
|
|
for raw in self._entities_raw: |
|
|
entity = self._raw_to_entity(raw) |
|
|
if entity: |
|
|
self.entities.append(entity) |
|
|
self._id_to_entity[raw["id"]] = entity |
|
|
self._id_to_raw[raw["id"]] = raw |
|
|
|
|
|
def _raw_to_entity(self, raw: dict) -> Optional[Entity]: |
|
|
"""Convert raw JSON entity to Entity object""" |
|
|
try: |
|
|
|
|
|
canonical = raw.get("canonical_name", {}) |
|
|
primary_name = canonical.get("en", raw.get("id", "")) |
|
|
|
|
|
|
|
|
variants = [primary_name] |
|
|
if canonical.get("ar"): |
|
|
variants.append(canonical["ar"]) |
|
|
|
|
|
for alias in raw.get("aliases", []): |
|
|
alias_name = alias.get("name", "") |
|
|
if alias_name and alias_name not in variants: |
|
|
variants.append(alias_name) |
|
|
|
|
|
|
|
|
facts = raw.get("facts", {}) |
|
|
summary = facts.get("summary_paragraph", "") |
|
|
|
|
|
|
|
|
metadata = raw.get("metadata", {}) |
|
|
emirate = metadata.get("emirate", "") |
|
|
if isinstance(emirate, list): |
|
|
emirate = ", ".join(emirate) if emirate else "" |
|
|
family_name = metadata.get("family_name", "") |
|
|
if isinstance(family_name, list): |
|
|
family_name = ", ".join(family_name) if family_name else "" |
|
|
|
|
|
return Entity( |
|
|
name=primary_name, |
|
|
id=raw.get("id", ""), |
|
|
variants=variants, |
|
|
source="knowledge_base", |
|
|
title=raw.get("subcategory", ""), |
|
|
url=raw.get("sources", [{}])[0].get("url", "") if raw.get("sources") else "", |
|
|
raw_text=raw.get("raw_content", {}).get("text", ""), |
|
|
summary=summary, |
|
|
primary_position=raw.get("subcategory", ""), |
|
|
primary_organization=emirate, |
|
|
family_name=family_name, |
|
|
city=emirate, |
|
|
country="UAE", |
|
|
facts=raw, |
|
|
) |
|
|
except Exception as e: |
|
|
if self.debug: |
|
|
print(f"⚠️ Error converting entity: {e}") |
|
|
return None |
|
|
|
|
|
def _build_alias_index(self) -> None: |
|
|
"""Build alias → entity_id lookup from alias_index.json""" |
|
|
for alias_entry in self._alias_index_raw: |
|
|
normalized = alias_entry.get("alias_normalized", "").lower() |
|
|
entity_id = alias_entry.get("canonical_entity_id", "") |
|
|
|
|
|
if normalized and entity_id: |
|
|
|
|
|
if normalized not in self._alias_lookup: |
|
|
self._alias_lookup[normalized] = entity_id |
|
|
|
|
|
def _build_trigger_index(self) -> None: |
|
|
"""Build trigger pattern → sensitive topic lookup""" |
|
|
for topic in self._sensitive_topics_raw: |
|
|
for pattern in topic.get("trigger_patterns", []): |
|
|
pattern_lower = pattern.lower() |
|
|
self._trigger_patterns[pattern_lower] = topic |
|
|
|
|
|
|
|
|
|
|
|
def get_entity_by_id(self, entity_id: str) -> Optional[Entity]: |
|
|
"""Get entity by ID""" |
|
|
return self._id_to_entity.get(entity_id) |
|
|
|
|
|
def get_raw_entity(self, entity_id: str) -> Optional[dict]: |
|
|
"""Get raw entity data (with full facts, sensitive topics, etc.)""" |
|
|
return self._id_to_raw.get(entity_id) |
|
|
|
|
|
def lookup_alias(self, text: str) -> Optional[str]: |
|
|
"""Look up entity ID by alias (case-insensitive)""" |
|
|
return self._alias_lookup.get(text.lower()) |
|
|
|
|
|
def get_alias_entries(self) -> List[dict]: |
|
|
"""Get all alias entries for building custom indexes""" |
|
|
return self._alias_index_raw |
|
|
|
|
|
def check_sensitive_triggers(self, text: str) -> List[dict]: |
|
|
""" |
|
|
Check if text contains any sensitive trigger patterns. |
|
|
|
|
|
Returns list of matching sensitive topics. |
|
|
""" |
|
|
text_lower = text.lower() |
|
|
matches = [] |
|
|
|
|
|
for pattern, topic in self._trigger_patterns.items(): |
|
|
if pattern in text_lower: |
|
|
if topic not in matches: |
|
|
matches.append(topic) |
|
|
|
|
|
return matches |
|
|
|
|
|
def get_category_metadata(self, category_id: int) -> Optional[dict]: |
|
|
"""Get metadata for a category (1-8)""" |
|
|
for cat in self._category_metadata_raw: |
|
|
if cat.get("category_id") == category_id: |
|
|
return cat |
|
|
return None |
|
|
|
|
|
def get_must_answer_topics(self, category_id: int) -> List[str]: |
|
|
"""Get must-answer topics for a category""" |
|
|
cat = self.get_category_metadata(category_id) |
|
|
if cat: |
|
|
return cat.get("must_answer_topics", []) |
|
|
return [] |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.entities) |
|
|
|
|
|
def __iter__(self): |
|
|
return iter(self.entities) |
|
|
|
|
|
def get_statistics(self) -> dict: |
|
|
"""Get knowledge base statistics""" |
|
|
|
|
|
type_counts = {} |
|
|
for raw in self._entities_raw: |
|
|
etype = raw.get("entity_type", "unknown") |
|
|
type_counts[etype] = type_counts.get(etype, 0) + 1 |
|
|
|
|
|
|
|
|
cat_counts = {} |
|
|
for raw in self._entities_raw: |
|
|
cat = raw.get("category", 0) |
|
|
cat_counts[cat] = cat_counts.get(cat, 0) + 1 |
|
|
|
|
|
return { |
|
|
"total_entities": len(self.entities), |
|
|
"total_aliases": len(self._alias_lookup), |
|
|
"sensitive_topics": len(self._sensitive_topics_raw), |
|
|
"trigger_patterns": len(self._trigger_patterns), |
|
|
"by_entity_type": type_counts, |
|
|
"by_category": cat_counts, |
|
|
"data_dir": str(self.data_dir), |
|
|
} |
|
|
|
|
|
|