File size: 3,496 Bytes
f9ad313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ca635c
f9ad313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""

RAG Engine - Orchestrates the retrieval-augmented generation pipeline.



Handles:

- Automatic indexing of text columns from the database

- Semantic retrieval using FAISS

- Context building for the LLM

"""

import logging
from typing import List, Dict, Any, Optional, Tuple

from .document_processor import Document, get_document_processor
from .vector_store import VectorStore, get_vector_store
from .embeddings import get_embedding_provider

logger = logging.getLogger(__name__)


class RAGEngine:
    """Main RAG engine for semantic retrieval from database text."""
    
    def __init__(self, vector_store: Optional[VectorStore] = None):
        self.vector_store = vector_store or get_vector_store()
        self.doc_processor = get_document_processor()
        self.indexed_tables: Dict[str, bool] = {}
    
    def index_table(

        self,

        table_name: str,

        rows: List[Dict[str, Any]],

        text_columns: List[str],

        primary_key_column: Optional[str] = None

    ) -> int:
        """

        Index text data from a table.

        

        Returns:

            Number of documents indexed

        """
        documents = list(self.doc_processor.process_rows(
            rows, table_name, text_columns, primary_key_column
        ))
        
        if documents:
            self.vector_store.add_documents(documents)
            self.indexed_tables[table_name] = True
            logger.info(f"Indexed {len(documents)} documents from {table_name}")
        
        return len(documents)
    
    def search(

        self, 

        query: str, 

        top_k: int = 5,

        table_filter: Optional[List[str]] = None

    ) -> List[Tuple[Document, float]]:
        """

        Search for relevant documents.

        

        Args:

            query: Search query

            top_k: Number of results

            table_filter: Optional list of tables to search in

            

        Returns:

            List of (document, score) tuples

        """
        results = self.vector_store.search(query, top_k=top_k * 2)
        
        if table_filter is not None:
            results = [
                (doc, score) for doc, score in results 
                if doc.table_name in table_filter
            ]
        
        return results[:top_k]
    
    def get_context(

        self, 

        query: str, 

        top_k: int = 5,

        table_filter: Optional[List[str]] = None

    ) -> str:
        """

        Get formatted context for LLM from search results.

        """
        results = self.search(query, top_k, table_filter)
        
        if not results:
            return "No relevant information found in the database."
        
        context_parts = []
        for doc, score in results:
            context_parts.append(doc.to_context_string())
        
        return "\n\n---\n\n".join(context_parts)
    
    def clear_index(self):
        """Clear the entire index."""
        self.vector_store.clear()
        self.indexed_tables = {}
    
    def save(self):
        """Save the index to disk."""
        self.vector_store.save()
    
    @property
    def document_count(self) -> int:
        return len(self.vector_store)


_rag_engine: Optional[RAGEngine] = None


def get_rag_engine() -> RAGEngine:
    global _rag_engine
    if _rag_engine is None:
        _rag_engine = RAGEngine()
    return _rag_engine