jessica45 commited on
Commit
5f04d6e
·
verified ·
1 Parent(s): e52d28d

updated rag

Browse files
Files changed (5) hide show
  1. embeddings_qdrant.py +382 -0
  2. index_docs.py +101 -0
  3. main.py +196 -0
  4. rag_with_gemini.py +201 -0
  5. requirements.txt +6 -10
embeddings_qdrant.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import google.generativeai as genai
4
+ from dotenv import load_dotenv
5
+ from typing import List, Dict, Optional, Union
6
+ import json
7
+ import pickle
8
+ import uuid
9
+ from qdrant_client import QdrantClient
10
+ from qdrant_client.http import models
11
+ from qdrant_client.http.models import Distance, VectorParams, PointStruct
12
+
13
+ # Load environment variables
14
+ load_dotenv()
15
+
16
+ class EmbeddingManager:
17
+ def __init__(self, api_key: Optional[str] = None):
18
+ """Initialize the embedding manager with Gemini API."""
19
+ self.api_key = api_key or os.getenv('GEMINI_API_KEY')
20
+ if not self.api_key:
21
+ raise ValueError("GEMINI_API_KEY not found in environment variables")
22
+
23
+ genai.configure(api_key=self.api_key)
24
+ self.model_name = "models/text-embedding-004"
25
+
26
+ def generate_embedding(self, text: str) -> np.ndarray:
27
+ """Generate embedding for a single text."""
28
+ try:
29
+ result = genai.embed_content(
30
+ model=self.model_name,
31
+ content=text,
32
+ task_type="retrieval_document"
33
+ )
34
+ return np.array(result['embedding'], dtype=np.float32)
35
+ except Exception as e:
36
+ print(f"Error generating embedding: {e}")
37
+ return np.array([])
38
+
39
+ def generate_embeddings_batch(self, texts: List[str]) -> List[np.ndarray]:
40
+ """Generate embeddings for multiple texts."""
41
+ embeddings = []
42
+ for i, text in enumerate(texts):
43
+ print(f"Generating embedding {i+1}/{len(texts)}")
44
+ embedding = self.generate_embedding(text)
45
+ if embedding.size > 0:
46
+ embeddings.append(embedding)
47
+ else:
48
+ print(f"Failed to generate embedding for text {i+1}")
49
+ return embeddings
50
+
51
+ def generate_query_embedding(self, query: str) -> np.ndarray:
52
+ """Generate embedding for a query (search)."""
53
+ try:
54
+ result = genai.embed_content(
55
+ model=self.model_name,
56
+ content=query,
57
+ task_type="retrieval_query"
58
+ )
59
+ return np.array(result['embedding'], dtype=np.float32)
60
+ except Exception as e:
61
+ print(f"Error generating query embedding: {e}")
62
+ return np.array([])
63
+
64
+
65
+ class QdrantVectorStore:
66
+ def __init__(self, collection_name: Optional[str] = None, url: Optional[str] = None, api_key: Optional[str] = None):
67
+ """Initialize Qdrant vector store."""
68
+ self.collection_name = collection_name or os.getenv('QDRANT_COLLECTION_NAME', 'rag_documents')
69
+
70
+ # Get Qdrant configuration from environment
71
+ qdrant_url = url or os.getenv('QDRANT_URL')
72
+ qdrant_api_key = api_key or os.getenv('QDRANT_API_KEY')
73
+
74
+ # Initialize Qdrant client
75
+ if qdrant_url and qdrant_api_key:
76
+ # Qdrant Cloud
77
+ print(f"Connecting to Qdrant Cloud at {qdrant_url}")
78
+ self.client = QdrantClient(
79
+ url=qdrant_url,
80
+ api_key=qdrant_api_key,
81
+ )
82
+ else:
83
+ # Local Qdrant (default)
84
+ print("Using local Qdrant instance at http://localhost:6333")
85
+ self.client = QdrantClient("localhost", port=6333)
86
+
87
+ self.embedding_dim = 768 # Gemini embedding dimension
88
+
89
+ def create_collection(self, force_recreate: bool = False):
90
+ """Create or recreate the collection."""
91
+ try:
92
+ # Check if collection exists
93
+ collections = self.client.get_collections().collections
94
+ collection_exists = any(col.name == self.collection_name for col in collections)
95
+
96
+ if collection_exists and force_recreate:
97
+ print(f"Deleting existing collection: {self.collection_name}")
98
+ self.client.delete_collection(collection_name=self.collection_name)
99
+ collection_exists = False
100
+
101
+ if not collection_exists:
102
+ print(f"Creating collection: {self.collection_name}")
103
+ self.client.create_collection(
104
+ collection_name=self.collection_name,
105
+ vectors_config=VectorParams(size=self.embedding_dim, distance=Distance.COSINE),
106
+ )
107
+ print(f"✓ Collection '{self.collection_name}' created successfully")
108
+ else:
109
+ print(f"✓ Collection '{self.collection_name}' already exists")
110
+
111
+ except Exception as e:
112
+ print(f"Error creating collection: {e}")
113
+ raise
114
+
115
+ def add_documents(self, chunks: List[str], embeddings: List[np.ndarray], metadata: List[Dict] = None, session_id: Optional[str] = None):
116
+ """Add documents with their embeddings to Qdrant.
117
+
118
+ Args:
119
+ chunks: list of text chunks
120
+ embeddings: list of numpy embeddings corresponding to chunks
121
+ metadata: optional list of dicts with metadata per chunk
122
+ session_id: optional session identifier to attach to each point payload
123
+ """
124
+ if metadata is None:
125
+ metadata = [{"index": i} for i in range(len(chunks))]
126
+
127
+ if len(chunks) != len(embeddings) or len(chunks) != len(metadata):
128
+ raise ValueError("chunks, embeddings, and metadata must have the same length")
129
+
130
+ # Ensure collection exists
131
+ self.create_collection()
132
+
133
+ # Prepare points for Qdrant
134
+ points = []
135
+ for i, (chunk, embedding, meta) in enumerate(zip(chunks, embeddings, metadata)):
136
+ point_id = str(uuid.uuid4())
137
+
138
+ # Combine text and metadata for payload
139
+ payload = {
140
+ "text": chunk,
141
+ "metadata": meta
142
+ }
143
+
144
+ # Attach session info if provided
145
+ if session_id is not None:
146
+ payload["session_id"] = session_id
147
+
148
+ point = PointStruct(
149
+ id=point_id,
150
+ vector=embedding.tolist(),
151
+ payload=payload
152
+ )
153
+ points.append(point)
154
+
155
+ # Upload points to Qdrant
156
+ try:
157
+ print(f"Uploading {len(points)} documents to Qdrant...")
158
+ self.client.upsert(
159
+ collection_name=self.collection_name,
160
+ points=points
161
+ )
162
+ print(f"✓ Successfully uploaded {len(points)} documents")
163
+ except Exception as e:
164
+ print(f"Error uploading documents: {e}")
165
+ raise
166
+
167
+ def similarity_search(self, query_embedding: np.ndarray, top_k: int = 5, score_threshold: float = 0.0,
168
+ include_context: bool = False) -> List[Dict]:
169
+ """
170
+ Search for similar documents in Qdrant.
171
+
172
+ Args:
173
+ query_embedding: The query vector
174
+ top_k: Number of results to return
175
+ score_threshold: Minimum similarity score
176
+ include_context: If True, try to include surrounding chunks for context
177
+ """
178
+ try:
179
+ search_results = self.client.search(
180
+ collection_name=self.collection_name,
181
+ query_vector=query_embedding.tolist(),
182
+ limit=top_k,
183
+ score_threshold=score_threshold
184
+ )
185
+
186
+ results = []
187
+ for hit in search_results:
188
+ metadata = hit.payload['metadata']
189
+
190
+ # Basic result structure
191
+ result = {
192
+ 'id': hit.id,
193
+ 'similarity': hit.score,
194
+ 'chunk': hit.payload['text'],
195
+ 'metadata': metadata,
196
+ 'source': {
197
+ 'file_name': metadata.get('file_name', 'Unknown'),
198
+ 'file_path': metadata.get('file_path', 'Unknown'),
199
+ 'chunk_index': metadata.get('chunk_index', 0)
200
+ }
201
+ }
202
+
203
+ # Add context if requested
204
+ if include_context:
205
+ result['context'] = self._get_surrounding_context(metadata)
206
+
207
+ # Add citation format
208
+ result['citation'] = f"{metadata.get('file_name', 'Unknown')} (chunk {metadata.get('chunk_index', 0)})"
209
+
210
+ results.append(result)
211
+
212
+ return results
213
+
214
+ except Exception as e:
215
+ print(f"Error searching documents: {e}")
216
+ return []
217
+
218
+ def _get_surrounding_context(self, metadata: Dict) -> Dict:
219
+ """Get surrounding chunks for context (if available)."""
220
+ try:
221
+ file_path = metadata.get('file_path')
222
+ chunk_index = metadata.get('chunk_index', 0)
223
+
224
+ # Try to find adjacent chunks from the same file
225
+ context_filter = {
226
+ "must": [
227
+ {"key": "metadata.file_path", "match": {"value": file_path}}
228
+ ]
229
+ }
230
+
231
+ # Search for chunks from same file
232
+ context_results = self.client.search(
233
+ collection_name=self.collection_name,
234
+ query_vector=[0.0] * self.embedding_dim, # Dummy vector
235
+ query_filter=context_filter,
236
+ limit=10,
237
+ score_threshold=0.0
238
+ )
239
+
240
+ # Sort by chunk index and get surrounding chunks
241
+ file_chunks = []
242
+ for hit in context_results:
243
+ hit_metadata = hit.payload['metadata']
244
+ if hit_metadata.get('chunk_index') is not None:
245
+ file_chunks.append({
246
+ 'index': hit_metadata['chunk_index'],
247
+ 'text': hit.payload['text']
248
+ })
249
+
250
+ file_chunks.sort(key=lambda x: x['index'])
251
+
252
+ # Find current chunk and get neighbors
253
+ current_idx = None
254
+ for i, chunk in enumerate(file_chunks):
255
+ if chunk['index'] == chunk_index:
256
+ current_idx = i
257
+ break
258
+
259
+ context = {
260
+ 'previous_chunk': file_chunks[current_idx - 1]['text'] if current_idx and current_idx > 0 else None,
261
+ 'next_chunk': file_chunks[current_idx + 1]['text'] if current_idx is not None and current_idx < len(file_chunks) - 1 else None,
262
+ 'total_chunks_in_file': len(file_chunks)
263
+ }
264
+
265
+ return context
266
+
267
+ except Exception as e:
268
+ print(f"Error getting context: {e}")
269
+ return {'error': 'Could not retrieve context'}
270
+
271
+ def get_relevant_passages(self, query_embedding: np.ndarray, top_k: int = 5) -> List[str]:
272
+ """Return just the text passages for RAG prompt creation."""
273
+ results = self.similarity_search(query_embedding, top_k)
274
+ return [result['chunk'] for result in results if result['chunk']]
275
+
276
+ def enhanced_search(self, query_embedding: np.ndarray, top_k: int = 5) -> str:
277
+ """Return a formatted string with search results ready for RAG."""
278
+ results = self.similarity_search(query_embedding, top_k, include_context=True)
279
+
280
+ if not results:
281
+ return "No relevant documents found."
282
+
283
+ formatted_results = []
284
+ for i, result in enumerate(results, 1):
285
+ formatted_result = f"""
286
+ **Result {i}** (Similarity: {result['similarity']:.3f})
287
+ **Source**: {result['citation']}
288
+ **Content**: {result['chunk']}
289
+ """
290
+
291
+ # Add context if available
292
+ if 'context' in result and not result['context'].get('error'):
293
+ context = result['context']
294
+ if context.get('previous_chunk'):
295
+ formatted_result += f"\n**Previous Context**: ...{context['previous_chunk'][-100:]}"
296
+ if context.get('next_chunk'):
297
+ formatted_result += f"\n**Following Context**: {context['next_chunk'][:100]}..."
298
+
299
+ formatted_results.append(formatted_result)
300
+
301
+ return "\n" + "="*50 + "\n".join(formatted_results)
302
+
303
+ def get_collection_info(self) -> Dict:
304
+ """Get information about the collection."""
305
+ try:
306
+ info = self.client.get_collection(collection_name=self.collection_name)
307
+ return {
308
+ 'name': self.collection_name,
309
+ 'points_count': info.points_count,
310
+ 'vectors_count': info.vectors_count,
311
+ 'status': info.status
312
+ }
313
+ except Exception as e:
314
+ print(f"Error getting collection info: {e}")
315
+ return {}
316
+
317
+ def delete_collection(self):
318
+ """Delete the collection."""
319
+ try:
320
+ self.client.delete_collection(collection_name=self.collection_name)
321
+ print(f"✓ Collection '{self.collection_name}' deleted")
322
+ except Exception as e:
323
+ print(f"Error deleting collection: {e}")
324
+
325
+
326
+ if __name__ == "__main__":
327
+ # Example usage with Qdrant
328
+ print("Testing Qdrant Vector Store...")
329
+
330
+ try:
331
+ embedding_manager = EmbeddingManager()
332
+ qdrant_store = QdrantVectorStore()
333
+
334
+ # Test with sample texts
335
+ sample_texts = [
336
+ "This is a sample document about machine learning and artificial intelligence.",
337
+ "Python is a great programming language for data science and AI development.",
338
+ "Qdrant is a vector database that enables similarity search at scale."
339
+ ]
340
+
341
+ print("Generating embeddings...")
342
+ embeddings = embedding_manager.generate_embeddings_batch(sample_texts)
343
+
344
+ if embeddings:
345
+ # Create metadata
346
+ metadata = [
347
+ {"source": "sample_doc", "topic": "machine_learning", "index": 0},
348
+ {"source": "sample_doc", "topic": "programming", "index": 1},
349
+ {"source": "sample_doc", "topic": "database", "index": 2}
350
+ ]
351
+
352
+ # Add to Qdrant
353
+ qdrant_store.add_documents(sample_texts, embeddings, metadata)
354
+
355
+ # Test search - Basic
356
+ query = "What is vector database?"
357
+ query_embedding = embedding_manager.generate_query_embedding(query)
358
+
359
+ if query_embedding.size > 0:
360
+ print(f"\n🔍 BASIC SEARCH: {query}")
361
+ results = qdrant_store.similarity_search(query_embedding, top_k=2)
362
+ for result in results:
363
+ print(f"Similarity: {result['similarity']:.4f}")
364
+ print(f"Source: {result['citation']}")
365
+ print(f"Text: {result['chunk']}")
366
+ print(f"Topic: {result['metadata']['topic']}")
367
+ print("---")
368
+
369
+ # Test enhanced search
370
+ print(f"\n🚀 ENHANCED SEARCH (RAG-ready format):")
371
+ enhanced_results = qdrant_store.enhanced_search(query_embedding, top_k=2)
372
+ print(enhanced_results)
373
+
374
+ # Show collection info
375
+ info = qdrant_store.get_collection_info()
376
+ print(f"\nCollection Info: {info}")
377
+
378
+ except Exception as e:
379
+ print(f"Error in test: {e}")
380
+ print("Make sure:")
381
+ print("1. Your GEMINI_API_KEY is valid in .env file")
382
+ print("2. Qdrant is running (docker run -p 6333:6333 qdrant/qdrant) or configure Qdrant Cloud")
index_docs.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from docx import Document
3
+ try:
4
+ import fitz # PyMuPDF
5
+ except Exception:
6
+ # fall back to pymupdf module name if present
7
+ import pymupdf as fitz
8
+
9
+ def load_pdf_text(file_path: str) -> str:
10
+ try:
11
+ doc = fitz.open(file_path)
12
+ text = ""
13
+ # iterate directly over pages
14
+ for page in doc:
15
+ # use standard PyMuPDF API
16
+ try:
17
+ page_text = page.get_text()
18
+ except Exception:
19
+ # try alternate name for older versions
20
+ page_text = page.getText() if hasattr(page, 'getText') else ''
21
+ if page_text:
22
+ text += page_text + "\n"
23
+ try:
24
+ doc.close()
25
+ except Exception:
26
+ pass
27
+ return text.strip()
28
+ except Exception as e:
29
+ print(f"Error reading PDF {file_path}: {e}")
30
+ return ""
31
+
32
+
33
+ def load_docx_text(file_path: str) -> str:
34
+ try:
35
+ doc = Document(file_path)
36
+ paragraphs = [p.text for p in doc.paragraphs if p.text]
37
+ return "\n".join(paragraphs).strip()
38
+ except Exception as e:
39
+ print(f"Error reading DOCX {file_path}: {e}")
40
+ return ""
41
+
42
+
43
+ def load_txt_text(file_path: str) -> str:
44
+ try:
45
+ with open(file_path, 'r', encoding='utf-8') as f:
46
+ return f.read()
47
+ except Exception as e:
48
+ print(f"Error reading TXT {file_path}: {e}")
49
+ return ""
50
+
51
+
52
+ def extract_text_from_path(path: str) -> Optional[str]:
53
+ if path.lower().endswith('.pdf'):
54
+ return load_pdf_text(path)
55
+ if path.lower().endswith('.docx'):
56
+ return load_docx_text(path)
57
+ if path.lower().endswith('.txt'):
58
+ return load_txt_text(path)
59
+ return None
60
+
61
+
62
+ def chunk_text(text: str, chunk_size: int = 500, overlap: int = 100) -> list:
63
+ chunks = []
64
+ start = 0
65
+ text_length = len(text)
66
+ while start < text_length:
67
+ end = min(start + chunk_size, text_length)
68
+ chunk = text[start:end]
69
+ chunks.append(chunk)
70
+ start += chunk_size - overlap
71
+ return chunks
72
+
73
+
74
+ if __name__ == '__main__':
75
+ import sys
76
+
77
+ def usage():
78
+ print('Usage: python src/index_docs.py <path-to-file-or-folder> [chunk_size]')
79
+
80
+ if len(sys.argv) < 2:
81
+ usage()
82
+ sys.exit(1)
83
+
84
+ path = sys.argv[1]
85
+ chunk_size = int(sys.argv[2]) if len(sys.argv) > 2 else 500
86
+
87
+ print(f'Testing extraction for: {path}')
88
+ text = extract_text_from_path(path)
89
+ if not text:
90
+ print('No text extracted or unsupported file type.')
91
+ sys.exit(1)
92
+
93
+ print('Characters extracted:', len(text))
94
+ chunks = chunk_text(text, chunk_size=chunk_size)
95
+ print('Chunks produced:', len(chunks))
96
+ if chunks:
97
+ preview = 300
98
+ print('\n--- First chunk preview ---')
99
+ print(chunks[0][:preview])
100
+ print('\n--- Second chunk preview ---')
101
+ print(chunks[1][:preview] if len(chunks) > 1 else '<none>')
main.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import tempfile
4
+ import hashlib
5
+ from typing import List
6
+ from dotenv import load_dotenv
7
+ from rag_with_gemini import RAGSystem
8
+
9
+ # Load environment variables
10
+ load_dotenv()
11
+
12
+ # --- PAGE CONFIG ---
13
+ st.set_page_config(
14
+ page_title="RAG Document Assistant",
15
+ page_icon="🤖",
16
+ layout="wide",
17
+ initial_sidebar_state="expanded"
18
+ )
19
+
20
+ # --- SESSION STATE INIT ---
21
+ def initialize_session_state():
22
+ if 'rag_system' not in st.session_state:
23
+ st.session_state.rag_system = None
24
+ if 'documents_processed' not in st.session_state:
25
+ st.session_state.documents_processed = []
26
+ # store SHA256 hashes of processed files to avoid reprocessing the same file in a session
27
+ if 'processed_hashes' not in st.session_state:
28
+ st.session_state.processed_hashes = set()
29
+ if 'chat_history' not in st.session_state:
30
+ st.session_state.chat_history = []
31
+ if 'processing_status' not in st.session_state:
32
+ st.session_state.processing_status = ""
33
+ if 'system_initialized' not in st.session_state:
34
+ st.session_state.system_initialized = False
35
+
36
+ # --- RAG SYSTEM INIT ---
37
+ def initialize_rag_system():
38
+ if st.session_state.system_initialized:
39
+ return True
40
+ try:
41
+ gemini_api_key = os.getenv('GEMINI_API_KEY')
42
+ qdrant_url = os.getenv('QDRANT_URL')
43
+ qdrant_api_key = os.getenv('QDRANT_API_KEY')
44
+
45
+ if not gemini_api_key or not qdrant_url or not qdrant_api_key:
46
+ st.error("❌ Missing API keys in your .env file.")
47
+ return False
48
+
49
+ with st.spinner("🚀 Initializing RAG system..."):
50
+ rag_system = RAGSystem(gemini_api_key, qdrant_url, qdrant_api_key)
51
+ st.session_state.rag_system = rag_system
52
+ st.session_state.system_initialized = True
53
+ return True
54
+ except Exception as e:
55
+ st.error(f"❌ Initialization error: {e}")
56
+ return False
57
+
58
+ # --- DOCUMENT PROCESSING ---
59
+ def process_uploaded_files(uploaded_files):
60
+ if not uploaded_files or not st.session_state.rag_system:
61
+ return False
62
+ try:
63
+ temp_paths = []
64
+ to_process = []
65
+ skipped = []
66
+
67
+ # Determine which files are new by hashing contents
68
+ for uploaded_file in uploaded_files:
69
+ data = uploaded_file.getvalue()
70
+ h = hashlib.sha256(data).hexdigest()
71
+ if h in st.session_state.processed_hashes:
72
+ skipped.append(uploaded_file.name)
73
+ continue
74
+
75
+ # write temp file for processing
76
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp:
77
+ tmp.write(data)
78
+ temp_paths.append(tmp.name)
79
+ to_process.append((uploaded_file.name, h))
80
+
81
+ # If there are no new files to process, short-circuit
82
+ if not temp_paths:
83
+ st.session_state.processing_status = f"⚠️ No new files to process. Skipped: {', '.join(skipped)}" if skipped else "⚠️ No files provided."
84
+ return True
85
+
86
+ with st.spinner("📄 Processing documents..."):
87
+ success = st.session_state.rag_system.add_documents(temp_paths)
88
+
89
+ for path in temp_paths:
90
+ try:
91
+ os.unlink(path)
92
+ except:
93
+ pass
94
+
95
+ if success:
96
+ # record processed filenames and their hashes
97
+ for name, h in to_process:
98
+ st.session_state.documents_processed.append(name)
99
+ st.session_state.processed_hashes.add(h)
100
+
101
+ # if some were skipped, include that in the status
102
+ status_msg = f"✅ Processed {len(to_process)} documents!"
103
+ if skipped:
104
+ status_msg += f" Skipped {len(skipped)} duplicate(s): {', '.join(skipped)}"
105
+ st.session_state.processing_status = status_msg
106
+ return True
107
+ else:
108
+ st.session_state.processing_status = "❌ Failed to process documents."
109
+ return False
110
+ except Exception as e:
111
+ st.session_state.processing_status = f"❌ Error: {str(e)}"
112
+ return False
113
+
114
+ # --- CHAT DISPLAY ---
115
+ def display_chat_message(role: str, content: str, sources: List[str] = None):
116
+ avatar_url = (
117
+ "https://cdn-icons-png.flaticon.com/512/4712/4712035.png"
118
+ if role == "assistant"
119
+ else "https://cdn-icons-png.flaticon.com/512/1077/1077012.png"
120
+ )
121
+ with st.chat_message(role, avatar=avatar_url):
122
+ st.markdown(content)
123
+
124
+ # --- MAIN ---
125
+ def main():
126
+ initialize_session_state()
127
+ st.markdown('<h1 class="main-header">RAG Document Assistant</h1>', unsafe_allow_html=True)
128
+
129
+ if not initialize_rag_system():
130
+ st.stop()
131
+
132
+ # Sidebar
133
+ with st.sidebar:
134
+ st.markdown("### 📁 Upload Documents")
135
+ uploaded_files = st.file_uploader("Choose files", type=['pdf', 'txt', 'docx'], accept_multiple_files=True)
136
+ if uploaded_files and st.button("📤 Process Documents"):
137
+ if process_uploaded_files(uploaded_files):
138
+ st.rerun()
139
+
140
+ if st.session_state.processing_status:
141
+ msg = st.session_state.processing_status
142
+ cls = "success-message" if "✅" in msg else "error-message"
143
+ st.markdown(f'<div class="{cls}">{msg}</div>', unsafe_allow_html=True)
144
+
145
+ if st.session_state.documents_processed:
146
+ st.markdown("### ✅ Processed Files")
147
+ for doc in st.session_state.documents_processed:
148
+ st.write(f"- {doc}")
149
+
150
+ if st.button("🗑️ Clear Chat"):
151
+ st.session_state.chat_history = []
152
+ st.rerun()
153
+
154
+ if not st.session_state.chat_history and not st.session_state.documents_processed:
155
+ st.markdown("""
156
+ <div style="text-align:center; padding:3rem; color:#9ca3af;">
157
+ <h3>👋 Welcome to your RAG Assistant</h3>
158
+ <p>Upload documents in the sidebar, then ask me anything about their content.</p>
159
+ </div>
160
+ """, unsafe_allow_html=True)
161
+
162
+ for message in st.session_state.chat_history:
163
+ display_chat_message(message["role"], message["content"], message.get("sources", []))
164
+
165
+ # Chat input
166
+ if prompt := st.chat_input("💬 Ask me anything..."):
167
+ if not st.session_state.documents_processed:
168
+ st.warning("⚠️ Upload and process documents first!")
169
+ return
170
+
171
+ st.session_state.chat_history.append({"role": "user", "content": prompt})
172
+ display_chat_message("user", prompt)
173
+
174
+ with st.chat_message("assistant"):
175
+ with st.spinner("🤔 Thinking..."):
176
+ try:
177
+ result = st.session_state.rag_system.query(prompt)
178
+ st.markdown(result['answer'])
179
+ # if result['sources']:
180
+ # with st.expander("📚 Sources"):
181
+ # for i, src in enumerate(result['sources'], 1):
182
+ # st.write(f"{i}. {os.path.basename(src)}")
183
+
184
+ st.session_state.chat_history.append({
185
+ "role": "assistant",
186
+ "content": result['answer'],
187
+ "sources": result['sources']
188
+ })
189
+ except Exception as e:
190
+ error_msg = f"❌ Error: {str(e)}"
191
+ st.error(error_msg)
192
+ st.session_state.chat_history.append({"role": "assistant", "content": error_msg})
193
+
194
+
195
+ if __name__ == "__main__":
196
+ main()
rag_with_gemini.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG (Retrieval-Augmented Generation) system with Gemini
3
+ """
4
+
5
+ import google.generativeai as genai
6
+ import time
7
+ import logging
8
+ from typing import List, Dict, Any, Optional
9
+ from embeddings_qdrant import EmbeddingManager, QdrantVectorStore
10
+ from index_docs import extract_text_from_path, chunk_text
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class RAGSystem:
17
+ """Complete RAG system with Gemini AI"""
18
+
19
+ def __init__(self, gemini_api_key: str, qdrant_url: str, qdrant_api_key: str):
20
+ """Initialize RAG system with Gemini and Qdrant"""
21
+ # Configure Gemini
22
+ genai.configure(api_key=gemini_api_key)
23
+ self.model = genai.GenerativeModel("models/gemini-2.5-flash")
24
+
25
+
26
+ # Initialize components
27
+ self.embedding_manager = EmbeddingManager(gemini_api_key)
28
+
29
+ # Try Qdrant Cloud first, fallback to simple vector store
30
+ try:
31
+ self.vector_store = QdrantVectorStore(url=qdrant_url, api_key=qdrant_api_key)
32
+ self.vector_store.create_collection(force_recreate=True)
33
+ logger.info("✅ Connected to Qdrant Cloud")
34
+ self.using_qdrant = True
35
+ except Exception as e:
36
+ logger.warning(f"❌ Qdrant Cloud connection failed: {e}")
37
+ logger.info("🔄 Falling back to simple vector store")
38
+ self.vector_store.create_collection()
39
+ self.using_qdrant = False
40
+
41
+ def add_documents(self, file_paths: List[str], session_id: Optional[str] = None) -> bool:
42
+ """Add documents to the vector store"""
43
+ try:
44
+ all_chunks = []
45
+
46
+ for file_path in file_paths:
47
+ logger.info(f"Processing {file_path}")
48
+
49
+ # Extract text
50
+ text = extract_text_from_path(file_path)
51
+ if not text:
52
+ logger.warning(f"No text extracted from {file_path}")
53
+ continue
54
+
55
+ # Chunk text
56
+ chunks = chunk_text(text)
57
+
58
+ # Add metadata
59
+ for chunk in chunks:
60
+ all_chunks.append({
61
+ 'text': chunk,
62
+ 'source': file_path,
63
+ 'chunk_id': len(all_chunks)
64
+ })
65
+
66
+ if not all_chunks:
67
+ logger.error("No chunks to process")
68
+ return False
69
+
70
+ # Generate embeddings
71
+ logger.info(f"Generating embeddings for {len(all_chunks)} chunks")
72
+
73
+ embeddings = []
74
+ texts = []
75
+ metadata_list = []
76
+
77
+ for i, chunk in enumerate(all_chunks):
78
+ try:
79
+ # Generate embedding
80
+ embedding = self.embedding_manager.generate_embedding(chunk['text'])
81
+
82
+ embeddings.append(embedding)
83
+ texts.append(chunk['text'])
84
+ metadata_list.append({
85
+ 'source': chunk['source'],
86
+ 'chunk_id': chunk['chunk_id']
87
+ })
88
+
89
+ logger.info(f"Generated embedding {i+1}/{len(all_chunks)}")
90
+
91
+ # Small delay to avoid rate limits
92
+ time.sleep(0.1)
93
+
94
+ except Exception as e:
95
+ logger.error(f"Error processing chunk {i}: {e}")
96
+ continue
97
+
98
+ # Store all embeddings in vector database
99
+ if embeddings and texts:
100
+ logger.info(f"Storing {len(embeddings)} embeddings in vector database (session={session_id})")
101
+ # Forward session_id so it is stored with each point
102
+ self.vector_store.add_documents(texts, embeddings, metadata_list, session_id=session_id)
103
+
104
+ logger.info("Document processing completed successfully!")
105
+ return True
106
+
107
+ except Exception as e:
108
+ logger.error(f"Error adding documents: {e}")
109
+ return False
110
+
111
+ def make_rag_prompt(self, query: str, context_passages: List[str]) -> str:
112
+ """Create RAG prompt with the user's specified format"""
113
+ context = "\n\n".join([f"Context {i+1}: {passage}" for i, passage in enumerate(context_passages)])
114
+
115
+ prompt = f"""You are a helpful assistant. Answer the user's question based on the provided context. If the context doesn't contain enough information to answer the question, say so clearly.
116
+
117
+ Context:
118
+ {context}
119
+
120
+ Question: {query}
121
+
122
+ Answer:"""
123
+
124
+ return prompt
125
+
126
+ def generate_answer(self, prompt: str, max_retries: int = 3) -> str:
127
+ """Generate answer using Gemini with retry logic"""
128
+ for attempt in range(max_retries):
129
+ try:
130
+ response = self.model.generate_content(prompt)
131
+
132
+ if response and response.text:
133
+ return response.text.strip()
134
+ else:
135
+ logger.warning(f"Empty response on attempt {attempt + 1}")
136
+
137
+ except Exception as e:
138
+ logger.error(f"Error generating answer (attempt {attempt + 1}): {e}")
139
+
140
+ if "429" in str(e) or "quota" in str(e).lower():
141
+ if attempt < max_retries - 1:
142
+ wait_time = (2 ** attempt) * 2 # Exponential backoff
143
+ logger.info(f"Rate limit hit, waiting {wait_time} seconds...")
144
+ time.sleep(wait_time)
145
+ else:
146
+ return "I'm sorry, I'm currently experiencing high demand. Please try again in a few minutes."
147
+ elif attempt < max_retries - 1:
148
+ time.sleep(1)
149
+ else:
150
+ return f"I encountered an error while generating the answer: {str(e)}"
151
+
152
+ return "I'm sorry, I couldn't generate an answer at this time. Please try again."
153
+
154
+ def query(self, question: str, top_k: int = 3) -> Dict[str, Any]:
155
+ """Handle complete RAG query process"""
156
+ try:
157
+ logger.info(f"Processing query: {question}")
158
+
159
+ # Generate query embedding
160
+ query_embedding = self.embedding_manager.generate_embedding(question)
161
+
162
+ # Search for relevant passages
163
+ search_results = self.vector_store.similarity_search(
164
+ query_embedding=query_embedding,
165
+ top_k=top_k
166
+ )
167
+
168
+ if not search_results:
169
+ return {
170
+ 'answer': "I couldn't find relevant information to answer your question.",
171
+ 'sources': [],
172
+ 'context_used': []
173
+ }
174
+
175
+ # Extract context passages and sources
176
+ context_passages = [result.get('chunk', '') for result in search_results]
177
+ sources = [result.get('metadata', {}).get('source', 'Unknown') for result in search_results]
178
+
179
+ # Create RAG prompt
180
+ rag_prompt = self.make_rag_prompt(question, context_passages)
181
+
182
+ # Generate answer
183
+ answer = self.generate_answer(rag_prompt)
184
+
185
+ return {
186
+ 'answer': answer,
187
+ 'sources': list(set(sources)), # Remove duplicates
188
+ 'context_used': context_passages
189
+ }
190
+
191
+ except Exception as e:
192
+ logger.error(f"Error in query processing: {e}")
193
+ return {
194
+ 'answer': f"I encountered an error while processing your question: {str(e)}",
195
+ 'sources': [],
196
+ 'context_used': []
197
+ }
198
+
199
+ def handle_query(rag_system: RAGSystem, query: str) -> Dict[str, Any]:
200
+ """Handle a single query through the RAG system"""
201
+ return rag_system.query(query)
requirements.txt CHANGED
@@ -1,12 +1,8 @@
1
  google-generativeai>=0.3.0
2
- chromadb>=0.4.0
3
- pdfplumber
4
- pip<24.1
5
  python-docx>=0.8.11
6
- langchain>=0.1.0
7
- numpy>=1.21.0
8
- python-dotenv>=0.19.0
9
- streamlit>=1.18.0
10
- typing-extensions>=3.7.4
11
- tika
12
- pymupdf
 
1
  google-generativeai>=0.3.0
2
+ python-dotenv>=1.0.0
3
+ numpy>=1.24.0
4
+ pypdf>=3.0.0
5
  python-docx>=0.8.11
6
+ qdrant-client>=1.7.0
7
+ streamlit>=1.28.0
8
+ scikit-learn>=1.3.0