File size: 10,652 Bytes
24c1b63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import os
import json
import tempfile
from typing import List, Dict, Any, Optional
from pathlib import Path

# LangChain imports for RAG
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.schema import Document

# Google Gemini imports
from google import genai

class RAGSystem:
    """
    Complete RAG (Retrieval-Augmented Generation) system using Google Gemini
    Handles document ingestion, chunking, embedding, and question answering
    """
    
    def __init__(self, persist_directory: str = "./chroma_db"):
        """Initialize the RAG system with Google Gemini and ChromaDB"""
        self.persist_directory = persist_directory
        self.gemini_api_key = None
        
        # Initialize components (lazy loading)
        self.embeddings = None
        self.llm = None
        self.vectorstore = None
        self.retriever = None
        self.qa_chain = None
        
        # Text splitter for document chunking
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200,
            length_function=len,
            separators=["\n\n", "\n", " ", ""]
        )
        
        # Track ingested documents
        self.ingested_documents = []
    
    def _initialize_components(self):
        """Lazy initialization of Gemini components"""
        if self.llm is None:
            self.gemini_api_key = os.getenv('GEMINI_API_KEY')
            if not self.gemini_api_key:
                raise ValueError("GEMINI_API_KEY environment variable must be set")
            
            # Initialize Google Gemini LLM
            self.llm = ChatGoogleGenerativeAI(
                model="gemini-2.5-flash",
                temperature=0.1,
                max_tokens=2048,
                google_api_key=self.gemini_api_key
            )
            
            # Initialize Google embeddings
            self.embeddings = GoogleGenerativeAIEmbeddings(
                model="models/text-embedding-004",
                google_api_key=self.gemini_api_key
            )
            
            # Initialize or load existing vector store
            self._initialize_vectorstore()
    
    def _initialize_vectorstore(self):
        """Initialize ChromaDB vector store"""
        try:
            # Try to load existing vectorstore
            if os.path.exists(self.persist_directory):
                self.vectorstore = Chroma(
                    persist_directory=self.persist_directory,
                    embedding_function=self.embeddings
                )
            else:
                # Create new empty vectorstore
                self.vectorstore = Chroma(
                    persist_directory=self.persist_directory,
                    embedding_function=self.embeddings
                )
            
            # Set up retriever
            self.retriever = self.vectorstore.as_retriever(
                search_type="similarity",
                search_kwargs={"k": 5}  # Retrieve top 5 most similar chunks
            )
            
        except Exception as e:
            raise Exception(f"Failed to initialize vector store: {str(e)}")
    
    def ingest_document(self, text_content: str, metadata: Dict[str, Any]) -> Dict[str, Any]:
        """
        Ingest a document into the RAG system
        
        Args:
            text_content: The full text content of the document
            metadata: Document metadata (filename, type, etc.)
            
        Returns:
            Dict with ingestion results
        """
        try:
            # Initialize components if needed
            self._initialize_components()
            
            # Create document object
            document = Document(
                page_content=text_content,
                metadata=metadata
            )
            
            # Split document into chunks
            chunks = self.text_splitter.split_documents([document])
            
            # Add chunk numbers to metadata
            for i, chunk in enumerate(chunks):
                chunk.metadata.update({
                    'chunk_id': i,
                    'total_chunks': len(chunks)
                })
            
            # Add chunks to vector store
            self.vectorstore.add_documents(chunks)
            
            # Persist the changes
            self.vectorstore.persist()
            
            # Track ingested document
            doc_info = {
                'filename': metadata.get('filename', 'Unknown'),
                'document_type': metadata.get('document_type', 'Unknown'),
                'chunks_created': len(chunks),
                'ingestion_timestamp': metadata.get('ingestion_timestamp', 'Unknown')
            }
            
            self.ingested_documents.append(doc_info)
            
            return {
                'status': 'success',
                'chunks_created': len(chunks),
                'document_info': doc_info
            }
            
        except Exception as e:
            return {
                'status': 'error',
                'error': str(e)
            }
    
    def query(self, question: str, return_source_docs: bool = True) -> Dict[str, Any]:
        """
        Query the RAG system with a question
        
        Args:
            question: User's question
            return_source_docs: Whether to return source documents
            
        Returns:
            Dict with answer and source information
        """
        try:
            # Initialize components if needed
            self._initialize_components()
            
            if not self.vectorstore:
                return {
                    'status': 'error',
                    'error': 'No documents have been ingested yet. Please upload and process some PDFs first.'
                }
            
            # Create RAG chain if not exists
            if not self.qa_chain:
                self._setup_qa_chain()
            
            # Execute query
            result = self.qa_chain.invoke({
                "query": question,
                "return_source_documents": return_source_docs
            })
            
            # Format response
            response = {
                'status': 'success',
                'answer': result.get('result', ''),
                'question': question
            }
            
            # Add source documents if requested
            if return_source_docs and 'source_documents' in result:
                response['sources'] = []
                for doc in result['source_documents']:
                    response['sources'].append({
                        'content': doc.page_content[:200] + '...',  # Preview
                        'metadata': doc.metadata
                    })
            
            return response
            
        except Exception as e:
            return {
                'status': 'error',
                'error': f"Query failed: {str(e)}"
            }
    
    def _setup_qa_chain(self):
        """Set up the question-answering chain with custom prompt"""
        
        # Custom prompt template for better responses
        prompt_template = """
        You are an AI assistant that answers questions based on the provided document context. 
        Use the following context to answer the question accurately and comprehensively.
        
        If the answer cannot be found in the context, say "I don't have enough information in the provided documents to answer this question."
        
        Context:
        {context}
        
        Question: {question}
        
        Answer:"""
        
        prompt = PromptTemplate(
            template=prompt_template,
            input_variables=["context", "question"]
        )
        
        # Create RetrievalQA chain
        self.qa_chain = RetrievalQA.from_llm(
            llm=self.llm,
            retriever=self.retriever,
            prompt=prompt,
            return_source_documents=True
        )
    
    def get_document_list(self) -> List[Dict[str, Any]]:
        """Get list of ingested documents"""
        return self.ingested_documents.copy()
    
    def get_vector_store_stats(self) -> Dict[str, Any]:
        """Get statistics about the vector store"""
        try:
            self._initialize_components()
            
            if not self.vectorstore:
                return {'total_chunks': 0, 'status': 'empty'}
            
            # Get collection info
            collection = self.vectorstore._collection
            stats = {
                'total_chunks': collection.count(),
                'total_documents': len(self.ingested_documents),
                'status': 'active'
            }
            
            return stats
            
        except Exception as e:
            return {
                'status': 'error',
                'error': str(e)
            }
    
    def clear_knowledge_base(self) -> Dict[str, Any]:
        """Clear all documents from the knowledge base"""
        try:
            # Delete vector store directory
            import shutil
            if os.path.exists(self.persist_directory):
                shutil.rmtree(self.persist_directory)
            
            # Reset components
            self.vectorstore = None
            self.qa_chain = None
            self.ingested_documents = []
            
            return {'status': 'success', 'message': 'Knowledge base cleared successfully'}
            
        except Exception as e:
            return {'status': 'error', 'error': str(e)}
    
    def search_similar_chunks(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
        """Search for similar document chunks"""
        try:
            self._initialize_components()
            
            if not self.vectorstore:
                return []
            
            # Perform similarity search
            docs = self.vectorstore.similarity_search(query, k=k)
            
            results = []
            for doc in docs:
                results.append({
                    'content': doc.page_content,
                    'metadata': doc.metadata,
                    'preview': doc.page_content[:150] + '...'
                })
            
            return results
            
        except Exception as e:
            print(f"Search error: {e}")
            return []