andrewammann commited on
Commit
24c1b63
·
verified ·
1 Parent(s): 06c5826

Create rag_system

Browse files
Files changed (1) hide show
  1. rag_system +309 -0
rag_system ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import tempfile
4
+ from typing import List, Dict, Any, Optional
5
+ from pathlib import Path
6
+
7
+ # LangChain imports for RAG
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain_community.vectorstores import Chroma
10
+ from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
11
+ from langchain.chains import RetrievalQA
12
+ from langchain.prompts import PromptTemplate
13
+ from langchain.schema import Document
14
+
15
+ # Google Gemini imports
16
+ from google import genai
17
+
18
+ class RAGSystem:
19
+ """
20
+ Complete RAG (Retrieval-Augmented Generation) system using Google Gemini
21
+ Handles document ingestion, chunking, embedding, and question answering
22
+ """
23
+
24
+ def __init__(self, persist_directory: str = "./chroma_db"):
25
+ """Initialize the RAG system with Google Gemini and ChromaDB"""
26
+ self.persist_directory = persist_directory
27
+ self.gemini_api_key = None
28
+
29
+ # Initialize components (lazy loading)
30
+ self.embeddings = None
31
+ self.llm = None
32
+ self.vectorstore = None
33
+ self.retriever = None
34
+ self.qa_chain = None
35
+
36
+ # Text splitter for document chunking
37
+ self.text_splitter = RecursiveCharacterTextSplitter(
38
+ chunk_size=1000,
39
+ chunk_overlap=200,
40
+ length_function=len,
41
+ separators=["\n\n", "\n", " ", ""]
42
+ )
43
+
44
+ # Track ingested documents
45
+ self.ingested_documents = []
46
+
47
+ def _initialize_components(self):
48
+ """Lazy initialization of Gemini components"""
49
+ if self.llm is None:
50
+ self.gemini_api_key = os.getenv('GEMINI_API_KEY')
51
+ if not self.gemini_api_key:
52
+ raise ValueError("GEMINI_API_KEY environment variable must be set")
53
+
54
+ # Initialize Google Gemini LLM
55
+ self.llm = ChatGoogleGenerativeAI(
56
+ model="gemini-2.5-flash",
57
+ temperature=0.1,
58
+ max_tokens=2048,
59
+ google_api_key=self.gemini_api_key
60
+ )
61
+
62
+ # Initialize Google embeddings
63
+ self.embeddings = GoogleGenerativeAIEmbeddings(
64
+ model="models/text-embedding-004",
65
+ google_api_key=self.gemini_api_key
66
+ )
67
+
68
+ # Initialize or load existing vector store
69
+ self._initialize_vectorstore()
70
+
71
+ def _initialize_vectorstore(self):
72
+ """Initialize ChromaDB vector store"""
73
+ try:
74
+ # Try to load existing vectorstore
75
+ if os.path.exists(self.persist_directory):
76
+ self.vectorstore = Chroma(
77
+ persist_directory=self.persist_directory,
78
+ embedding_function=self.embeddings
79
+ )
80
+ else:
81
+ # Create new empty vectorstore
82
+ self.vectorstore = Chroma(
83
+ persist_directory=self.persist_directory,
84
+ embedding_function=self.embeddings
85
+ )
86
+
87
+ # Set up retriever
88
+ self.retriever = self.vectorstore.as_retriever(
89
+ search_type="similarity",
90
+ search_kwargs={"k": 5} # Retrieve top 5 most similar chunks
91
+ )
92
+
93
+ except Exception as e:
94
+ raise Exception(f"Failed to initialize vector store: {str(e)}")
95
+
96
+ def ingest_document(self, text_content: str, metadata: Dict[str, Any]) -> Dict[str, Any]:
97
+ """
98
+ Ingest a document into the RAG system
99
+
100
+ Args:
101
+ text_content: The full text content of the document
102
+ metadata: Document metadata (filename, type, etc.)
103
+
104
+ Returns:
105
+ Dict with ingestion results
106
+ """
107
+ try:
108
+ # Initialize components if needed
109
+ self._initialize_components()
110
+
111
+ # Create document object
112
+ document = Document(
113
+ page_content=text_content,
114
+ metadata=metadata
115
+ )
116
+
117
+ # Split document into chunks
118
+ chunks = self.text_splitter.split_documents([document])
119
+
120
+ # Add chunk numbers to metadata
121
+ for i, chunk in enumerate(chunks):
122
+ chunk.metadata.update({
123
+ 'chunk_id': i,
124
+ 'total_chunks': len(chunks)
125
+ })
126
+
127
+ # Add chunks to vector store
128
+ self.vectorstore.add_documents(chunks)
129
+
130
+ # Persist the changes
131
+ self.vectorstore.persist()
132
+
133
+ # Track ingested document
134
+ doc_info = {
135
+ 'filename': metadata.get('filename', 'Unknown'),
136
+ 'document_type': metadata.get('document_type', 'Unknown'),
137
+ 'chunks_created': len(chunks),
138
+ 'ingestion_timestamp': metadata.get('ingestion_timestamp', 'Unknown')
139
+ }
140
+
141
+ self.ingested_documents.append(doc_info)
142
+
143
+ return {
144
+ 'status': 'success',
145
+ 'chunks_created': len(chunks),
146
+ 'document_info': doc_info
147
+ }
148
+
149
+ except Exception as e:
150
+ return {
151
+ 'status': 'error',
152
+ 'error': str(e)
153
+ }
154
+
155
+ def query(self, question: str, return_source_docs: bool = True) -> Dict[str, Any]:
156
+ """
157
+ Query the RAG system with a question
158
+
159
+ Args:
160
+ question: User's question
161
+ return_source_docs: Whether to return source documents
162
+
163
+ Returns:
164
+ Dict with answer and source information
165
+ """
166
+ try:
167
+ # Initialize components if needed
168
+ self._initialize_components()
169
+
170
+ if not self.vectorstore:
171
+ return {
172
+ 'status': 'error',
173
+ 'error': 'No documents have been ingested yet. Please upload and process some PDFs first.'
174
+ }
175
+
176
+ # Create RAG chain if not exists
177
+ if not self.qa_chain:
178
+ self._setup_qa_chain()
179
+
180
+ # Execute query
181
+ result = self.qa_chain.invoke({
182
+ "query": question,
183
+ "return_source_documents": return_source_docs
184
+ })
185
+
186
+ # Format response
187
+ response = {
188
+ 'status': 'success',
189
+ 'answer': result.get('result', ''),
190
+ 'question': question
191
+ }
192
+
193
+ # Add source documents if requested
194
+ if return_source_docs and 'source_documents' in result:
195
+ response['sources'] = []
196
+ for doc in result['source_documents']:
197
+ response['sources'].append({
198
+ 'content': doc.page_content[:200] + '...', # Preview
199
+ 'metadata': doc.metadata
200
+ })
201
+
202
+ return response
203
+
204
+ except Exception as e:
205
+ return {
206
+ 'status': 'error',
207
+ 'error': f"Query failed: {str(e)}"
208
+ }
209
+
210
+ def _setup_qa_chain(self):
211
+ """Set up the question-answering chain with custom prompt"""
212
+
213
+ # Custom prompt template for better responses
214
+ prompt_template = """
215
+ You are an AI assistant that answers questions based on the provided document context.
216
+ Use the following context to answer the question accurately and comprehensively.
217
+
218
+ If the answer cannot be found in the context, say "I don't have enough information in the provided documents to answer this question."
219
+
220
+ Context:
221
+ {context}
222
+
223
+ Question: {question}
224
+
225
+ Answer:"""
226
+
227
+ prompt = PromptTemplate(
228
+ template=prompt_template,
229
+ input_variables=["context", "question"]
230
+ )
231
+
232
+ # Create RetrievalQA chain
233
+ self.qa_chain = RetrievalQA.from_llm(
234
+ llm=self.llm,
235
+ retriever=self.retriever,
236
+ prompt=prompt,
237
+ return_source_documents=True
238
+ )
239
+
240
+ def get_document_list(self) -> List[Dict[str, Any]]:
241
+ """Get list of ingested documents"""
242
+ return self.ingested_documents.copy()
243
+
244
+ def get_vector_store_stats(self) -> Dict[str, Any]:
245
+ """Get statistics about the vector store"""
246
+ try:
247
+ self._initialize_components()
248
+
249
+ if not self.vectorstore:
250
+ return {'total_chunks': 0, 'status': 'empty'}
251
+
252
+ # Get collection info
253
+ collection = self.vectorstore._collection
254
+ stats = {
255
+ 'total_chunks': collection.count(),
256
+ 'total_documents': len(self.ingested_documents),
257
+ 'status': 'active'
258
+ }
259
+
260
+ return stats
261
+
262
+ except Exception as e:
263
+ return {
264
+ 'status': 'error',
265
+ 'error': str(e)
266
+ }
267
+
268
+ def clear_knowledge_base(self) -> Dict[str, Any]:
269
+ """Clear all documents from the knowledge base"""
270
+ try:
271
+ # Delete vector store directory
272
+ import shutil
273
+ if os.path.exists(self.persist_directory):
274
+ shutil.rmtree(self.persist_directory)
275
+
276
+ # Reset components
277
+ self.vectorstore = None
278
+ self.qa_chain = None
279
+ self.ingested_documents = []
280
+
281
+ return {'status': 'success', 'message': 'Knowledge base cleared successfully'}
282
+
283
+ except Exception as e:
284
+ return {'status': 'error', 'error': str(e)}
285
+
286
+ def search_similar_chunks(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
287
+ """Search for similar document chunks"""
288
+ try:
289
+ self._initialize_components()
290
+
291
+ if not self.vectorstore:
292
+ return []
293
+
294
+ # Perform similarity search
295
+ docs = self.vectorstore.similarity_search(query, k=k)
296
+
297
+ results = []
298
+ for doc in docs:
299
+ results.append({
300
+ 'content': doc.page_content,
301
+ 'metadata': doc.metadata,
302
+ 'preview': doc.page_content[:150] + '...'
303
+ })
304
+
305
+ return results
306
+
307
+ except Exception as e:
308
+ print(f"Search error: {e}")
309
+ return []