File size: 4,191 Bytes
da9db52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
UAE Knowledge System - Backend Services
Handles knowledge base and retriever initialization
"""
import sys
from pathlib import Path

# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))

from ir.retriever import EntityRetriever, RetrievalOutput
from ir.knowledge_base import KnowledgeBase

# ============================================================
# Global State
# ============================================================
_retriever = None
_knowledge_base = None

# Paths relative to project root
PROJECT_ROOT = Path(__file__).parent.parent
INDEX_CACHE_PATH = PROJECT_ROOT / "ir" / "cache" / "dense_index"


def get_knowledge_base() -> KnowledgeBase:
    """Lazy load knowledge base"""
    global _knowledge_base
    if _knowledge_base is None:
        print("Loading knowledge base...")
        _knowledge_base = KnowledgeBase(debug=False)
    return _knowledge_base


def get_retriever():
    """Get the dense retriever (cached)"""
    global _retriever
    if _retriever is not None:
        return _retriever

    from ir.retrievers.dense import DenseRetriever

    print("Loading dense retriever...")
    retriever = DenseRetriever(model_name="bge-m3", debug=False)
    kb = get_knowledge_base()

    # Try to load cached index
    if INDEX_CACHE_PATH.exists():
        print(f"Loading cached index from {INDEX_CACHE_PATH}...")
        if retriever.load_index(str(INDEX_CACHE_PATH)):
            print("Cached index loaded!")
        else:
            print("Cache load failed, building index...")
            retriever.build_index_from_knowledge_base(kb)
            retriever.save_index(str(INDEX_CACHE_PATH))
    else:
        print("Building dense index (this may take a while)...")
        retriever.build_index_from_knowledge_base(kb)
        INDEX_CACHE_PATH.parent.mkdir(parents=True, exist_ok=True)
        retriever.save_index(str(INDEX_CACHE_PATH))
        print("Index built and cached!")

    _retriever = retriever
    return retriever


def search_knowledge_base(query: str, top_k: int = 5):
    """
    Search the knowledge base and return formatted results
    """
    retriever = get_retriever()
    kb = get_knowledge_base()

    # Perform search
    results = retriever.search(query, top_k=top_k)

    # Format results
    formatted_results = []
    for metadata, score in results:
        entity_id = metadata.get("entity_id", "")
        entity_name = metadata.get("entity_name", "Unknown")

        # Get full entity data from KB
        raw_data = kb.get_raw_entity(entity_id) if entity_id else None

        result = {
            "entity_id": entity_id,
            "entity_name": entity_name,
            "score": score,
            "chunk_type": metadata.get("chunk_type", ""),
            "subcategory": "",
            "emirate": "",
            "is_royal": False,
            "summary": "",
            "must_answer": []
        }

        if raw_data:
            facts_data = raw_data.get('facts', {})
            metadata_kb = raw_data.get('metadata', {})

            result["subcategory"] = raw_data.get('subcategory', '')
            result["emirate"] = metadata_kb.get('emirate', '')
            result["is_royal"] = metadata_kb.get('is_royal', False)
            result["summary"] = facts_data.get('summary_paragraph', '')

            # Extract must-answer facts
            must_answer = facts_data.get('must_answer', [])
            result["must_answer"] = [
                fact.get('fact', fact) if isinstance(fact, dict) else str(fact)
                for fact in must_answer[:5]
            ]

            # Include full entity data for detailed view
            result["full_entity"] = raw_data

        formatted_results.append(result)

    return formatted_results


def get_stats():
    """Get knowledge base statistics"""
    try:
        kb = get_knowledge_base()
        entities = len(kb.entities)
        # Fixed: 8 knowledge categories as defined in the system
        return {
            "entities": entities,
            "categories": 8,
            "version": "2.3.0"
        }
    except Exception as e:
        return {"entities": 0, "categories": 8, "error": str(e)}