File size: 11,930 Bytes
43efcb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
"""
Main RAG (Retrieval-Augmented Generation) engine implementation.
"""

import os
import logging
from typing import List, Dict, Any, Optional, Tuple, Union
import numpy as np

# Configure logging
logger = logging.getLogger(__name__)


class RAGEngine:
    """Retrieval-Augmented Generation (RAG) engine for question answering."""
    
    def __init__(
        self,
        embedder,
        vector_db,
        llm=None,
        top_k: int = 5,
        search_type: str = "hybrid",
        prompt_template: Optional[str] = None
    ):
        """
        Initialize the RAG engine.
        
        Args:
            embedder: Embedding model
            vector_db: Vector database for document storage and retrieval
            llm: Language model for text generation (optional)
            top_k: Number of documents to retrieve
            search_type: Type of search ('semantic', 'keyword', 'hybrid')
            prompt_template: Optional custom prompt template
        """
        self.embedder = embedder
        self.vector_db = vector_db
        self.llm = llm
        self.top_k = top_k
        self.search_type = search_type
        
        # Set default prompt template if none provided
        if prompt_template is None:
            from ..config import DEFAULT_PROMPT_TEMPLATE
            self.prompt_template = DEFAULT_PROMPT_TEMPLATE
        else:
            self.prompt_template = prompt_template
    
    def add_documents(
        self,
        texts: List[str],
        metadata: Optional[List[Dict[str, Any]]] = None,
        batch_size: int = 32
    ) -> List[str]:
        """
        Add documents to the database.
        
        Args:
            texts: List of text chunks
            metadata: Optional list of metadata dictionaries for each text
            batch_size: Batch size for embedding generation
            
        Returns:
            List of document IDs
        """
        from ..storage.vector_db import Document
        
        # Handle metadata
        if metadata is None:
            metadata = [{} for _ in texts]
        elif len(metadata) != len(texts):
            raise ValueError(f"Length mismatch: got {len(texts)} texts but {len(metadata)} metadata entries")
        
        # Generate embeddings in batches
        doc_ids = []
        
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]
            batch_metadata = metadata[i:i+batch_size]
            
            # Generate embeddings
            logger.info(f"Generating embeddings for batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}")
            batch_embeddings = self.embedder.embed(batch_texts)
            
            # Create document objects
            documents = []
            for text, meta, embedding in zip(batch_texts, batch_metadata, batch_embeddings):
                doc = Document(text=text, metadata=meta, embedding=embedding)
                documents.append(doc)
            
            # Add to database
            batch_ids = self.vector_db.add_documents(documents)
            doc_ids.extend(batch_ids)
        
        logger.info(f"Added {len(doc_ids)} documents to database")
        return doc_ids
    
    def search(
        self,
        query: str,
        top_k: Optional[int] = None,
        search_type: Optional[str] = None,
        filter_dict: Optional[Dict[str, Any]] = None
    ) -> List[Dict[str, Any]]:
        """
        Search for relevant documents.
        
        Args:
            query: Query string
            top_k: Number of results to return (defaults to self.top_k)
            search_type: Type of search (defaults to self.search_type)
            filter_dict: Dictionary of metadata filters
            
        Returns:
            List of document dictionaries
        """
        if top_k is None:
            top_k = self.top_k
            
        if search_type is None:
            search_type = self.search_type
        
        # Create filter function if filter_dict is provided
        filter_func = None
        if filter_dict:
            def filter_func(doc):
                for key, value in filter_dict.items():
                    # Handle nested keys (e.g., "metadata.source")
                    if "." in key:
                        parts = key.split(".")
                        current = doc.metadata
                        for part in parts[:-1]:
                            if part not in current:
                                return False
                            current = current[part]
                        if parts[-1] not in current or current[parts[-1]] != value:
                            return False
                    elif key not in doc.metadata or doc.metadata[key] != value:
                        return False
                return True
        
        # Generate query embedding
        query_embedding = self.embedder.embed(query)
        
        # Perform search
        results = self.vector_db.search(query_embedding, top_k, filter_func)
        
        # Convert results to dictionaries
        return [
            {
                "id": doc.id,
                "text": doc.text,
                "metadata": doc.metadata,
                "score": score
            }
            for doc, score in results
        ]
    
    def generate_response(
        self,
        query: str,
        top_k: Optional[int] = None,
        search_type: Optional[str] = None,
        filter_dict: Optional[Dict[str, Any]] = None,
        max_tokens: int = 512
    ) -> Dict[str, Any]:
        """
        Generate a response to a query using RAG.
        
        Args:
            query: Query string
            top_k: Number of documents to retrieve
            search_type: Type of search
            filter_dict: Optional filter for document retrieval
            max_tokens: Maximum number of tokens in the response
            
        Returns:
            Dictionary with query, response, and retrieved documents
        """
        # Retrieve relevant documents
        retrieved_docs = self.search(query, top_k, search_type, filter_dict)
        
        # If no documents were found, return a default message
        if not retrieved_docs:
            return {
                "query": query,
                "response": "I couldn't find any relevant information to answer your question.",
                "retrieved_documents": [],
                "search_type": search_type or self.search_type
            }
        
        # Format context from retrieved documents
        context = self._format_context(retrieved_docs)
        
        # Format prompt with context and query
        prompt = self.prompt_template.format(context=context, query=query)
        
        # Generate response using LLM
        if self.llm is None:
            logger.warning("No LLM provided, returning only retrieved documents")
            response = "No language model available to generate a response. Here's what I found in the documents."
        else:
            response = self._generate_llm_response(prompt, max_tokens)
        
        # Return the results
        return {
            "query": query,
            "response": response,
            "retrieved_documents": retrieved_docs,
            "search_type": search_type or self.search_type
        }
    
    def _format_context(self, documents: List[Dict[str, Any]]) -> str:
        """
        Format retrieved documents into context for the prompt.
        
        Args:
            documents: List of retrieved documents
            
        Returns:
            Formatted context string
        """
        context_parts = []
        
        for i, doc in enumerate(documents):
            # Extract relevant fields
            text = doc["text"]
            metadata = doc["metadata"]
            source = metadata.get("source", "Unknown")
            
            # Format the document
            doc_text = f"Document {i+1}: [Source: {source}]\n{text}\n"
            context_parts.append(doc_text)
        
        return "\n".join(context_parts)
    
    def _generate_llm_response(self, prompt: str, max_tokens: int) -> str:
        """
        Generate a response using the LLM.
        
        Args:
            prompt: The formatted prompt
            max_tokens: Maximum number of tokens in the response
            
        Returns:
            Generated response
        """
        if hasattr(self.llm, "generate_openai_response"):
            # OpenAI-compatible LLM
            return self.llm.generate_openai_response(prompt, max_tokens)
        elif hasattr(self.llm, "generate_huggingface_response"):
            # HuggingFace-compatible LLM
            return self.llm.generate_huggingface_response(prompt, max_tokens)
        else:
            # Default implementation
            try:
                return self.llm.generate_response(prompt, max_tokens)
            except Exception as e:
                logger.error(f"Error generating response: {e}")
                return "I encountered an error while generating a response."
    
    def update_prompt_template(self, new_template: str) -> None:
        """
        Update the prompt template.
        
        Args:
            new_template: New prompt template
        """
        self.prompt_template = new_template
        logger.info("Updated prompt template")
    
    def count_documents(self) -> int:
        """
        Get the number of documents in the database.
        
        Returns:
            Number of documents
        """
        return self.vector_db.count_documents()
    
    def clear_documents(self) -> None:
        """Clear all documents from the database."""
        self.vector_db.clear()
        logger.info("Cleared all documents from database")


# Factory function to create the RAG engine
def create_rag_engine(
    embedder=None,
    vector_db=None,
    llm=None,
    config=None
) -> RAGEngine:
    """
    Factory function to create a RAG engine.
    
    Args:
        embedder: Embedding model (if None, created based on config)
        vector_db: Vector database (if None, created based on config)
        llm: Language model (if None, created based on config)
        config: Configuration module or dictionary
        
    Returns:
        Configured RAGEngine instance
    """
    # Load configuration if provided
    if config is None:
        from ..config import (
            TOP_K,
            SEARCH_TYPE,
            DEFAULT_PROMPT_TEMPLATE
        )
    else:
        TOP_K = config.get("TOP_K", 5)
        SEARCH_TYPE = config.get("SEARCH_TYPE", "hybrid")
        DEFAULT_PROMPT_TEMPLATE = config.get(
            "DEFAULT_PROMPT_TEMPLATE",
            """
            Answer the following question based ONLY on the provided context.
            
            Context:
            {context}
            
            Question: {query}
            
            Answer:
            """
        )
    
    # Create embedding model if not provided
    if embedder is None:
        from ..embedding.model import create_embedding_model
        embedder = create_embedding_model()
    
    # Create vector database if not provided
    if vector_db is None:
        from ..storage.vector_db import create_vector_database
        vector_db = create_vector_database(dimension=embedder.dimension)
    
    # Create language model if not provided and requested
    if llm is None:
        try:
            from ..llm.model import create_llm
            llm = create_llm()
        except (ImportError, ModuleNotFoundError):
            logger.warning("LLM module not found, proceeding without an LLM")
    
    # Create and return the RAG engine
    return RAGEngine(
        embedder=embedder,
        vector_db=vector_db,
        llm=llm,
        top_k=TOP_K,
        search_type=SEARCH_TYPE,
        prompt_template=DEFAULT_PROMPT_TEMPLATE
    )