Hamza4100 commited on
Commit
d1646a9
·
verified ·
1 Parent(s): 3f2f81d

Update rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +794 -783
rag_engine.py CHANGED
@@ -1,784 +1,795 @@
1
- """
2
- RAG Engine Module
3
- =================
4
- Handles all RAG pipeline operations:
5
- - PDF text extraction
6
- - Text chunking with overlap
7
- - Embedding generation using SentenceTransformers
8
- - FAISS vector storage and retrieval
9
- - Metadata and document registry management
10
- - Persistence of embeddings and metadata
11
- """
12
-
13
- import os
14
- import json
15
- import hashlib
16
- from datetime import datetime
17
- from typing import List, Dict, Tuple, Optional
18
- import numpy as np
19
- import faiss
20
- from sentence_transformers import SentenceTransformer
21
- import PyPDF2
22
- import google.generativeai as genai
23
- from PIL import Image
24
- import io
25
-
26
- # OCR imports (optional)
27
- try:
28
- import pytesseract
29
-
30
- OCR_AVAILABLE = True
31
- except ImportError:
32
- OCR_AVAILABLE = False
33
- print("Warning: pytesseract not installed. OCR functionality will be disabled.")
34
-
35
- # ============================================
36
- # CONFIGURATION
37
- # ============================================
38
-
39
- # Chunking parameters
40
- DEFAULT_CHUNK_SIZE = 200 # words per chunk
41
- DEFAULT_OVERLAP_SIZE = 50 # overlapping words
42
-
43
- # Retrieval parameters
44
- DEFAULT_TOP_K = 5 # number of chunks to retrieve
45
-
46
- # Embedding model
47
- EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
48
- EMBEDDING_DIMENSION = 384
49
-
50
-
51
-
52
- class RAGEngine:
53
- """
54
- Main RAG Engine class that handles:
55
- - Document processing and embedding
56
- - FAISS index management
57
- - Query processing and answer generation
58
- - Persistence of all data
59
- """
60
-
61
- def __init__(self, gemini_api_key: str, storage_dir: Optional[str] = None):
62
- """
63
- Initialize the RAG Engine.
64
-
65
- Args:
66
- gemini_api_key: API key for Google Gemini
67
- storage_dir: Optional custom storage directory for per-user isolation
68
- """
69
- # Set storage paths
70
- if storage_dir is None:
71
- storage_dir = os.path.join(os.path.dirname(__file__), "storage")
72
-
73
- self.storage_dir = storage_dir
74
- self.faiss_index_path = os.path.join(storage_dir, "faiss.index")
75
- self.metadata_path = os.path.join(storage_dir, "metadata.json")
76
- self.documents_path = os.path.join(storage_dir, "documents.json")
77
-
78
- # Ensure storage directory exists
79
- os.makedirs(storage_dir, exist_ok=True)
80
-
81
- # Initialize embedding model
82
- print("Loading embedding model...")
83
- self.embed_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
84
-
85
- # Initialize Gemini
86
- genai.configure(api_key=gemini_api_key)
87
- self.gemini_model = genai.GenerativeModel("gemini-2.5-flash")
88
-
89
- # Initialize or load FAISS index
90
- self.index: Optional[faiss.IndexFlatL2] = None
91
- self.metadata: List[Dict] = [] # Stores chunk text, source, page
92
- self.documents: Dict[str, Dict] = {} # Document registry
93
-
94
- # Load existing data if available
95
- self._load_persistent_data()
96
-
97
- print(f"RAG Engine initialized. Documents: {len(self.documents)}, Chunks: {len(self.metadata)}")
98
-
99
- # ============================================
100
- # PERSISTENCE METHODS
101
- # ============================================
102
-
103
- def _load_persistent_data(self):
104
- """Load FAISS index, metadata, and document registry from disk."""
105
-
106
- # Load document registry
107
- if os.path.exists(self.documents_path):
108
- with open(self.documents_path, "r", encoding="utf-8") as f:
109
- self.documents = json.load(f)
110
- print(f"Loaded {len(self.documents)} documents from registry")
111
-
112
- # Load metadata
113
- if os.path.exists(self.metadata_path):
114
- with open(self.metadata_path, "r", encoding="utf-8") as f:
115
- self.metadata = json.load(f)
116
- print(f"Loaded {len(self.metadata)} chunks metadata")
117
-
118
- # Load FAISS index
119
- if os.path.exists(self.faiss_index_path) and len(self.metadata) > 0:
120
- self.index = faiss.read_index(self.faiss_index_path)
121
- print(f"Loaded FAISS index with {self.index.ntotal} vectors")
122
- else:
123
- # Create new empty index
124
- self.index = faiss.IndexFlatL2(EMBEDDING_DIMENSION)
125
- print("Created new FAISS index")
126
-
127
- def _save_persistent_data(self):
128
- """Save FAISS index, metadata, and document registry to disk."""
129
-
130
- # Save document registry
131
- with open(self.documents_path, "w", encoding="utf-8") as f:
132
- json.dump(self.documents, f, indent=2, ensure_ascii=False)
133
-
134
- # Save metadata
135
- with open(self.metadata_path, "w", encoding="utf-8") as f:
136
- json.dump(self.metadata, f, indent=2, ensure_ascii=False)
137
-
138
- # Save FAISS index
139
- if self.index is not None and self.index.ntotal > 0:
140
- faiss.write_index(self.index, self.faiss_index_path)
141
-
142
- print("Persistent data saved successfully")
143
-
144
- # ============================================
145
- # DOCUMENT PROCESSING METHODS
146
- # ============================================
147
-
148
- @staticmethod
149
- def compute_file_hash(file_content: bytes) -> str:
150
- """
151
- Compute SHA-256 hash of file content.
152
-
153
- Args:
154
- file_content: Raw bytes of the file
155
-
156
- Returns:
157
- Hexadecimal hash string
158
- """
159
- return hashlib.sha256(file_content).hexdigest()
160
-
161
- @staticmethod
162
- def chunk_text_with_overlap(text: str, chunk_size: int = DEFAULT_CHUNK_SIZE,
163
- overlap_size: int = DEFAULT_OVERLAP_SIZE) -> List[str]:
164
- """
165
- Split text into overlapping chunks.
166
-
167
- Args:
168
- text: Input text to chunk
169
- chunk_size: Number of words per chunk
170
- overlap_size: Number of overlapping words between chunks
171
-
172
- Returns:
173
- List of text chunks
174
- """
175
- words = text.split()
176
- chunks = []
177
- start = 0
178
-
179
- while start < len(words):
180
- end = start + chunk_size
181
- chunk = " ".join(words[start:end])
182
- if chunk.strip(): # Only add non-empty chunks
183
- chunks.append(chunk)
184
- start += chunk_size - overlap_size
185
-
186
- return chunks
187
-
188
- @staticmethod
189
- def extract_text_from_image(image: Image.Image) -> str:
190
- """
191
- Extract text from an image using OCR.
192
-
193
- Args:
194
- image: PIL Image object
195
-
196
- Returns:
197
- Extracted text string
198
- """
199
- if not OCR_AVAILABLE:
200
- return ""
201
-
202
- try:
203
- # Convert to RGB if needed
204
- if image.mode != 'RGB':
205
- image = image.convert('RGB')
206
-
207
- # Run OCR
208
- text = pytesseract.image_to_string(image, lang='eng')
209
- return text.strip()
210
- except Exception as e:
211
- print(f"OCR error: {e}")
212
- return ""
213
-
214
- def extract_text_from_pdf(self, pdf_content: bytes) -> List[Dict]:
215
- """
216
- Extract text from PDF page by page, including OCR for images.
217
-
218
- Args:
219
- pdf_content: Raw bytes of PDF file
220
-
221
- Returns:
222
- List of dicts with page_num, text, and ocr_text
223
- """
224
- pages = []
225
-
226
- try:
227
- reader = PyPDF2.PdfReader(io.BytesIO(pdf_content))
228
- for page_num, page in enumerate(reader.pages):
229
- # Extract regular text
230
- text = page.extract_text()
231
- ocr_text = ""
232
-
233
- # Extract images and apply OCR
234
- if OCR_AVAILABLE:
235
- try:
236
- # Get images from page
237
- if '/XObject' in page['/Resources']:
238
- xObject = page['/Resources']['/XObject'].get_object()
239
-
240
- for obj in xObject:
241
- if xObject[obj]['/Subtype'] == '/Image':
242
- try:
243
- # Extract image data
244
- size = (xObject[obj]['/Width'], xObject[obj]['/Height'])
245
- data = xObject[obj].get_data()
246
-
247
- # Try to create image
248
- if xObject[obj]['/ColorSpace'] == '/DeviceRGB':
249
- mode = "RGB"
250
- elif xObject[obj]['/ColorSpace'] == '/DeviceGray':
251
- mode = "L"
252
- else:
253
- mode = "RGB" # Default
254
-
255
- try:
256
- image = Image.frombytes(mode, size, data)
257
- # Apply OCR
258
- img_text = self.extract_text_from_image(image)
259
- if img_text:
260
- ocr_text += img_text + "\n"
261
- except Exception as img_error:
262
- # Try with PIL's open if frombytes fails
263
- try:
264
- image = Image.open(io.BytesIO(data))
265
- img_text = self.extract_text_from_image(image)
266
- if img_text:
267
- ocr_text += img_text + "\n"
268
- except:
269
- pass
270
- except Exception as e:
271
- # Skip this image if extraction fails
272
- continue
273
- except Exception as e:
274
- print(f"Error extracting images from page {page_num + 1}: {e}")
275
-
276
- # Combine regular text and OCR text
277
- combined_text = ""
278
- if text and text.strip():
279
- combined_text += text.strip()
280
- if ocr_text.strip():
281
- if combined_text:
282
- combined_text += "\n\n[Text from images:]\n" + ocr_text.strip()
283
- else:
284
- combined_text = ocr_text.strip()
285
-
286
- if combined_text:
287
- pages.append({
288
- "page_num": page_num + 1,
289
- "text": combined_text,
290
- "has_ocr": bool(ocr_text.strip())
291
- })
292
- except Exception as e:
293
- print(f"Error extracting PDF text: {e}")
294
- raise
295
-
296
- return pages
297
-
298
- def process_pdf(self, filename: str, file_content: bytes,
299
- chunk_size: int = DEFAULT_CHUNK_SIZE,
300
- overlap_size: int = DEFAULT_OVERLAP_SIZE) -> List[Dict]:
301
- """
302
- Process a PDF: extract text (including OCR), chunk it, and prepare metadata.
303
-
304
- Args:
305
- filename: Original filename
306
- file_content: Raw bytes of PDF
307
- chunk_size: Words per chunk
308
- overlap_size: Overlap between chunks
309
-
310
- Returns:
311
- List of chunk metadata dicts
312
- """
313
- # Extract pages
314
- pages = self.extract_text_from_pdf(file_content)
315
-
316
- # Chunk each page
317
- chunks_metadata = []
318
- for page_info in pages:
319
- page_chunks = self.chunk_text_with_overlap(
320
- page_info["text"],
321
- chunk_size,
322
- overlap_size
323
- )
324
- for chunk_text in page_chunks:
325
- chunks_metadata.append({
326
- "text": chunk_text,
327
- "source": filename,
328
- "page": page_info["page_num"],
329
- "has_ocr": page_info.get("has_ocr", False)
330
- })
331
-
332
- return chunks_metadata
333
-
334
- # ============================================
335
- # DUPLICATE DETECTION METHODS
336
- # ============================================
337
-
338
- def check_duplicate(self, file_hash: str) -> Optional[Dict]:
339
- """
340
- Check if a document with the same hash already exists.
341
-
342
- Args:
343
- file_hash: SHA-256 hash of the file
344
-
345
- Returns:
346
- Document info if duplicate found, None otherwise
347
- """
348
- for doc_id, doc_info in self.documents.items():
349
- if doc_info.get("hash") == file_hash:
350
- return {"doc_id": doc_id, **doc_info}
351
- return None
352
-
353
- def get_document_by_filename(self, filename: str) -> Optional[Dict]:
354
- """
355
- Get document info by filename.
356
-
357
- Args:
358
- filename: Original filename
359
-
360
- Returns:
361
- Document info if found, None otherwise
362
- """
363
- for doc_id, doc_info in self.documents.items():
364
- if doc_info.get("filename") == filename:
365
- return {"doc_id": doc_id, **doc_info}
366
- return None
367
-
368
- # ============================================
369
- # EMBEDDING AND INDEXING METHODS
370
- # ============================================
371
-
372
- def generate_embeddings(self, texts: List[str]) -> np.ndarray:
373
- """
374
- Generate embeddings for a list of texts.
375
-
376
- Args:
377
- texts: List of text strings
378
-
379
- Returns:
380
- Numpy array of embeddings
381
- """
382
- embeddings = self.embed_model.encode(texts)
383
- return np.array(embeddings).astype("float32")
384
-
385
- def add_to_index(self, chunks_metadata: List[Dict]) -> int:
386
- """
387
- Add new chunks to FAISS index and metadata.
388
-
389
- Args:
390
- chunks_metadata: List of chunk dicts with text, source, page
391
-
392
- Returns:
393
- Number of chunks added
394
- """
395
- if not chunks_metadata:
396
- return 0
397
-
398
- # Extract texts for embedding
399
- texts = [c["text"] for c in chunks_metadata]
400
-
401
- # Generate embeddings
402
- embeddings = self.generate_embeddings(texts)
403
-
404
- # Add to FAISS index
405
- self.index.add(embeddings)
406
-
407
- # Add to metadata
408
- self.metadata.extend(chunks_metadata)
409
-
410
- return len(chunks_metadata)
411
-
412
- def remove_document_from_index(self, filename: str):
413
- """
414
- Remove all chunks of a document from the index.
415
- Note: FAISS IndexFlatL2 doesn't support removal, so we rebuild.
416
-
417
- Args:
418
- filename: Filename of document to remove
419
- """
420
- # Filter out chunks from this document
421
- remaining_metadata = [
422
- m for m in self.metadata if m["source"] != filename
423
- ]
424
-
425
- if len(remaining_metadata) == len(self.metadata):
426
- return # Nothing to remove
427
-
428
- # Rebuild index with remaining chunks
429
- self.metadata = remaining_metadata
430
-
431
- if self.metadata:
432
- texts = [m["text"] for m in self.metadata]
433
- embeddings = self.generate_embeddings(texts)
434
- self.index = faiss.IndexFlatL2(EMBEDDING_DIMENSION)
435
- self.index.add(embeddings)
436
- else:
437
- self.index = faiss.IndexFlatL2(EMBEDDING_DIMENSION)
438
-
439
- print(f"Removed document '{filename}' from index")
440
-
441
- # ============================================
442
- # DOCUMENT UPLOAD METHODS
443
- # ============================================
444
-
445
- def upload_document(self, filename: str, file_content: bytes,
446
- action: str = "auto") -> Dict:
447
- """
448
- Upload and process a document.
449
-
450
- Args:
451
- filename: Original filename
452
- file_content: Raw bytes of PDF
453
- action: "auto", "use_existing", "replace", or "cancel"
454
-
455
- Returns:
456
- Result dict with status and info
457
- """
458
- # Compute hash
459
- file_hash = self.compute_file_hash(file_content)
460
-
461
- # Check for duplicate
462
- existing_doc = self.check_duplicate(file_hash)
463
-
464
- if existing_doc:
465
- if action == "auto":
466
- # Return duplicate warning
467
- return {
468
- "status": "duplicate",
469
- "filename": filename,
470
- "existing_filename": existing_doc["filename"],
471
- "hash": file_hash,
472
- "message": f"Document already exists as '{existing_doc['filename']}'",
473
- "options": ["use_existing", "replace", "cancel"]
474
- }
475
- elif action == "use_existing":
476
- return {
477
- "status": "success",
478
- "filename": existing_doc["filename"],
479
- "message": "Using existing document embeddings",
480
- "chunks": 0,
481
- "reused": True
482
- }
483
- elif action == "cancel":
484
- return {
485
- "status": "cancelled",
486
- "filename": filename,
487
- "message": "Upload cancelled"
488
- }
489
- elif action == "replace":
490
- # Remove old document and continue with upload
491
- self.remove_document_from_index(existing_doc["filename"])
492
- del self.documents[existing_doc["doc_id"]]
493
-
494
- # Process new document
495
- try:
496
- chunks_metadata = self.process_pdf(filename, file_content)
497
-
498
- if not chunks_metadata:
499
- return {
500
- "status": "error",
501
- "filename": filename,
502
- "message": "No text could be extracted from PDF"
503
- }
504
-
505
- # Add to index
506
- num_chunks = self.add_to_index(chunks_metadata)
507
-
508
- # Register document
509
- doc_id = f"doc_{len(self.documents) + 1}_{int(datetime.now().timestamp())}"
510
- self.documents[doc_id] = {
511
- "filename": filename,
512
- "hash": file_hash,
513
- "upload_timestamp": datetime.now().isoformat(),
514
- "num_chunks": num_chunks,
515
- "num_pages": max(c["page"] for c in chunks_metadata)
516
- }
517
-
518
- # Persist changes
519
- self._save_persistent_data()
520
-
521
- return {
522
- "status": "success",
523
- "filename": filename,
524
- "message": f"Document processed successfully",
525
- "chunks": num_chunks,
526
- "pages": self.documents[doc_id]["num_pages"]
527
- }
528
-
529
- except Exception as e:
530
- return {
531
- "status": "error",
532
- "filename": filename,
533
- "message": f"Error processing document: {str(e)}"
534
- }
535
-
536
- # ============================================
537
- # QUERY AND RETRIEVAL METHODS
538
- # ============================================
539
-
540
- def retrieve_relevant_chunks(self, query: str, top_k: int = DEFAULT_TOP_K) -> List[Dict]:
541
- """
542
- Retrieve most relevant chunks for a query.
543
-
544
- Args:
545
- query: User's question
546
- top_k: Number of chunks to retrieve
547
-
548
- Returns:
549
- List of relevant chunks with metadata
550
- """
551
- if self.index is None or self.index.ntotal == 0:
552
- return []
553
-
554
- # Limit top_k to available chunks
555
- top_k = min(top_k, self.index.ntotal)
556
-
557
- # Embed query
558
- query_embedding = self.embed_model.encode([query]).astype("float32")
559
-
560
- # Search FAISS
561
- distances, indices = self.index.search(query_embedding, k=top_k)
562
-
563
- # Gather results
564
- results = []
565
- for i, idx in enumerate(indices[0]):
566
- if idx < len(self.metadata):
567
- results.append({
568
- **self.metadata[idx],
569
- "distance": float(distances[0][i]),
570
- "relevance_rank": i + 1
571
- })
572
-
573
- return results
574
-
575
- def generate_answer(self, query: str, context_chunks: List[Dict]) -> str:
576
- """
577
- Generate answer using Gemini with retrieved context.
578
-
579
- Args:
580
- query: User's question
581
- context_chunks: Retrieved relevant chunks
582
-
583
- Returns:
584
- Generated answer string
585
- """
586
- if not context_chunks:
587
- return "I don't have enough information to answer this question. Please upload relevant documents first."
588
-
589
- # Build context string
590
- context_parts = []
591
- for chunk in context_chunks:
592
- context_parts.append(
593
- f"[Source: {chunk['source']}, Page {chunk['page']}]\n{chunk['text']}"
594
- )
595
- context = "\n\n".join(context_parts)
596
-
597
- # Create prompt
598
- prompt = f"""You are a helpful assistant that answers questions based ONLY on the provided context.
599
- Do NOT make up information that is not in the context.
600
- If the context doesn't contain enough information to answer, say so clearly.
601
- You may summarize, combine, or rephrase information from the context to make your answer clear and helpful.
602
-
603
- CONTEXT:
604
- {context}
605
-
606
- QUESTION:
607
- {query}
608
-
609
- ANSWER:"""
610
-
611
- try:
612
- response = self.gemini_model.generate_content(prompt)
613
- return response.text
614
- except Exception as e:
615
- return f"Error generating answer: {str(e)}"
616
-
617
- def verify_sources(self, query: str, answer: str, context_chunks: List[Dict]) -> List[int]:
618
- """
619
- Verify which chunks actually support the generated answer.
620
-
621
- Args:
622
- query: User's question
623
- answer: Generated answer
624
- context_chunks: All retrieved chunks
625
-
626
- Returns:
627
- List of indices of chunks that support the answer
628
- """
629
- if not context_chunks:
630
- return []
631
-
632
- # Build context with numbered chunks
633
- context_parts = []
634
- for i, chunk in enumerate(context_chunks):
635
- context_parts.append(
636
- f"[{i}] Source: {chunk['source']}, Page {chunk['page']}\n{chunk['text']}"
637
- )
638
- context = "\n\n".join(context_parts)
639
-
640
- # Create verification prompt
641
- prompt = f"""You are a citation verification assistant. Given a question, an answer, and numbered source chunks, identify which chunks were actually used to generate the answer.
642
-
643
- Return ONLY a comma-separated list of chunk numbers that directly support the answer (e.g., "0,2,3").
644
- If no chunks support the answer, return "NONE".
645
- Do not include explanations or any other text.
646
-
647
- QUESTION:
648
- {query}
649
-
650
- ANSWER:
651
- {answer}
652
-
653
- NUMBERED CHUNKS:
654
- {context}
655
-
656
- CHUNK NUMBERS THAT SUPPORT THE ANSWER:"""
657
-
658
- try:
659
- response = self.gemini_model.generate_content(prompt)
660
- result = response.text.strip()
661
-
662
- # Parse the response
663
- if result.upper() == "NONE":
664
- return []
665
-
666
- # Extract numbers
667
- used_indices = []
668
- for part in result.split(","):
669
- try:
670
- idx = int(part.strip())
671
- if 0 <= idx < len(context_chunks):
672
- used_indices.append(idx)
673
- except ValueError:
674
- continue
675
-
676
- return used_indices
677
- except Exception as e:
678
- print(f"Error verifying sources: {e}")
679
- # Fallback: return all chunks if verification fails
680
- return list(range(len(context_chunks)))
681
-
682
- def ask(self, query: str, top_k: int = DEFAULT_TOP_K) -> Dict:
683
- """
684
- Main query method: retrieve context, generate answer, and filter sources.
685
-
686
- Args:
687
- query: User's question
688
- top_k: Number of chunks to retrieve
689
-
690
- Returns:
691
- Dict with answer and verified sources
692
- """
693
- # Retrieve relevant chunks
694
- relevant_chunks = self.retrieve_relevant_chunks(query, top_k)
695
-
696
- # Generate answer
697
- answer = self.generate_answer(query, relevant_chunks)
698
-
699
- # Verify which chunks actually support the answer
700
- used_indices = self.verify_sources(query, answer, relevant_chunks)
701
-
702
- # Filter sources to only those that support the answer
703
- sources = []
704
- seen = set()
705
- for idx in used_indices:
706
- if idx < len(relevant_chunks):
707
- chunk = relevant_chunks[idx]
708
- source_key = f"{chunk['source']}_{chunk['page']}"
709
- if source_key not in seen:
710
- sources.append({
711
- "file": chunk["source"],
712
- "page": chunk["page"]
713
- })
714
- seen.add(source_key)
715
-
716
- return {
717
- "answer": answer,
718
- "sources": sources,
719
- "num_chunks_used": len(sources),
720
- "num_chunks_retrieved": len(relevant_chunks)
721
- }
722
-
723
- # ============================================
724
- # DOCUMENT MANAGEMENT METHODS
725
- # ============================================
726
-
727
- def get_all_documents(self) -> List[Dict]:
728
- """
729
- Get list of all uploaded documents.
730
-
731
- Returns:
732
- List of document info dicts
733
- """
734
- return [
735
- {"doc_id": doc_id, **info}
736
- for doc_id, info in self.documents.items()
737
- ]
738
-
739
- def delete_document(self, doc_id: str) -> Dict:
740
- """
741
- Delete a document and its embeddings.
742
-
743
- Args:
744
- doc_id: Document ID to delete
745
-
746
- Returns:
747
- Result dict
748
- """
749
- if doc_id not in self.documents:
750
- return {
751
- "status": "error",
752
- "message": f"Document {doc_id} not found"
753
- }
754
-
755
- filename = self.documents[doc_id]["filename"]
756
-
757
- # Remove from index
758
- self.remove_document_from_index(filename)
759
-
760
- # Remove from registry
761
- del self.documents[doc_id]
762
-
763
- # Persist changes
764
- self._save_persistent_data()
765
-
766
- return {
767
- "status": "success",
768
- "message": f"Document '{filename}' deleted successfully"
769
- }
770
-
771
- def get_stats(self) -> Dict:
772
- """
773
- Get system statistics.
774
-
775
- Returns:
776
- Dict with stats
777
- """
778
- return {
779
- "total_documents": len(self.documents),
780
- "total_chunks": len(self.metadata),
781
- "index_size": self.index.ntotal if self.index else 0,
782
- "embedding_model": EMBEDDING_MODEL_NAME,
783
- "embedding_dimension": EMBEDDING_DIMENSION
 
 
 
 
 
 
 
 
 
 
 
784
  }
 
1
+ """
2
+ RAG Engine Module
3
+ =================
4
+ Handles all RAG pipeline operations:
5
+ - PDF text extraction
6
+ - Text chunking with overlap
7
+ - Embedding generation using SentenceTransformers
8
+ - FAISS vector storage and retrieval
9
+ - Metadata and document registry management
10
+ - Persistence of embeddings and metadata
11
+ """
12
+
13
+ import os
14
+ import json
15
+ import hashlib
16
+ from datetime import datetime
17
+ from typing import List, Dict, Tuple, Optional
18
+ import numpy as np
19
+ import faiss
20
+ from sentence_transformers import SentenceTransformer
21
+ import PyPDF2
22
+ import google.generativeai as genai
23
+ from PIL import Image
24
+ import io
25
+
26
+ # OCR imports (optional)
27
+ try:
28
+ import pytesseract
29
+
30
+ OCR_AVAILABLE = True
31
+ except ImportError:
32
+ OCR_AVAILABLE = False
33
+ print("Warning: pytesseract not installed. OCR functionality will be disabled.")
34
+
35
+ # ============================================
36
+ # CONFIGURATION
37
+ # ============================================
38
+
39
+ # Chunking parameters
40
+ DEFAULT_CHUNK_SIZE = 200 # words per chunk
41
+ DEFAULT_OVERLAP_SIZE = 50 # overlapping words
42
+
43
+ # Retrieval parameters
44
+ DEFAULT_TOP_K = 5 # number of chunks to retrieve
45
+
46
+ # Embedding model
47
+ EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
48
+ EMBEDDING_DIMENSION = 384
49
+
50
+
51
+
52
+ class RAGEngine:
53
+ """
54
+ Main RAG Engine class that handles:
55
+ - Document processing and embedding
56
+ - FAISS index management
57
+ - Query processing and answer generation
58
+ - Persistence of all data
59
+ """
60
+
61
+ def __init__(self, gemini_api_key: str, storage_dir: Optional[str] = None):
62
+ """
63
+ Initialize the RAG Engine.
64
+
65
+ Args:
66
+ gemini_api_key: API key for Google Gemini
67
+ storage_dir: Optional custom storage directory for per-user isolation
68
+ """
69
+ # Set storage paths
70
+ if storage_dir is None:
71
+ storage_dir = os.path.join(os.path.dirname(__file__), "storage")
72
+
73
+ self.storage_dir = storage_dir
74
+ self.faiss_index_path = os.path.join(storage_dir, "faiss.index")
75
+ self.metadata_path = os.path.join(storage_dir, "metadata.json")
76
+ self.documents_path = os.path.join(storage_dir, "documents.json")
77
+
78
+ # Ensure storage directory exists
79
+ os.makedirs(storage_dir, exist_ok=True)
80
+
81
+ # Initialize embedding model
82
+ print("Loading embedding model...")
83
+ self.embed_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
84
+
85
+ # Initialize Gemini
86
+ genai.configure(api_key=gemini_api_key)
87
+ self.gemini_model = genai.GenerativeModel("gemini-2.5-flash")
88
+
89
+ # Initialize or load FAISS index
90
+ self.index: Optional[faiss.IndexFlatL2] = None
91
+ self.metadata: List[Dict] = [] # Stores chunk text, source, page
92
+ self.documents: Dict[str, Dict] = {} # Document registry
93
+
94
+ # Load existing data if available
95
+ self._load_persistent_data()
96
+
97
+ print(f"RAG Engine initialized. Documents: {len(self.documents)}, Chunks: {len(self.metadata)}")
98
+
99
+ # ============================================
100
+ # PERSISTENCE METHODS
101
+ # ============================================
102
+
103
+ def _load_persistent_data(self):
104
+ """Load FAISS index, metadata, and document registry from disk."""
105
+
106
+ # Load document registry
107
+ if os.path.exists(self.documents_path):
108
+ with open(self.documents_path, "r", encoding="utf-8") as f:
109
+ self.documents = json.load(f)
110
+ print(f"Loaded {len(self.documents)} documents from registry")
111
+
112
+ # Load metadata
113
+ if os.path.exists(self.metadata_path):
114
+ with open(self.metadata_path, "r", encoding="utf-8") as f:
115
+ self.metadata = json.load(f)
116
+ print(f"Loaded {len(self.metadata)} chunks metadata")
117
+
118
+ # Load FAISS index
119
+ if os.path.exists(self.faiss_index_path) and len(self.metadata) > 0:
120
+ self.index = faiss.read_index(self.faiss_index_path)
121
+ print(f"Loaded FAISS index with {self.index.ntotal} vectors")
122
+ else:
123
+ # Create new empty index
124
+ self.index = faiss.IndexFlatL2(EMBEDDING_DIMENSION)
125
+ print("Created new FAISS index")
126
+
127
+ def _save_persistent_data(self):
128
+ """Save FAISS index, metadata, and document registry to disk."""
129
+
130
+ # Save document registry
131
+ with open(self.documents_path, "w", encoding="utf-8") as f:
132
+ json.dump(self.documents, f, indent=2, ensure_ascii=False)
133
+
134
+ # Save metadata
135
+ with open(self.metadata_path, "w", encoding="utf-8") as f:
136
+ json.dump(self.metadata, f, indent=2, ensure_ascii=False)
137
+
138
+ # Save FAISS index
139
+ if self.index is not None and self.index.ntotal > 0:
140
+ faiss.write_index(self.index, self.faiss_index_path)
141
+
142
+ print("Persistent data saved successfully")
143
+
144
+ # ============================================
145
+ # DOCUMENT PROCESSING METHODS
146
+ # ============================================
147
+
148
+ @staticmethod
149
+ def compute_file_hash(file_content: bytes) -> str:
150
+ """
151
+ Compute SHA-256 hash of file content.
152
+
153
+ Args:
154
+ file_content: Raw bytes of the file
155
+
156
+ Returns:
157
+ Hexadecimal hash string
158
+ """
159
+ return hashlib.sha256(file_content).hexdigest()
160
+
161
+ @staticmethod
162
+ def chunk_text_with_overlap(text: str, chunk_size: int = DEFAULT_CHUNK_SIZE,
163
+ overlap_size: int = DEFAULT_OVERLAP_SIZE) -> List[str]:
164
+ """
165
+ Split text into overlapping chunks.
166
+
167
+ Args:
168
+ text: Input text to chunk
169
+ chunk_size: Number of words per chunk
170
+ overlap_size: Number of overlapping words between chunks
171
+
172
+ Returns:
173
+ List of text chunks
174
+ """
175
+ words = text.split()
176
+ chunks = []
177
+ start = 0
178
+
179
+ while start < len(words):
180
+ end = start + chunk_size
181
+ chunk = " ".join(words[start:end])
182
+ if chunk.strip(): # Only add non-empty chunks
183
+ chunks.append(chunk)
184
+ start += chunk_size - overlap_size
185
+
186
+ return chunks
187
+
188
+ @staticmethod
189
+ def extract_text_from_image(image: Image.Image) -> str:
190
+ """
191
+ Extract text from an image using OCR.
192
+
193
+ Args:
194
+ image: PIL Image object
195
+
196
+ Returns:
197
+ Extracted text string
198
+ """
199
+ if not OCR_AVAILABLE:
200
+ return ""
201
+
202
+ try:
203
+ # Convert to RGB if needed
204
+ if image.mode != 'RGB':
205
+ image = image.convert('RGB')
206
+
207
+ # Run OCR
208
+ text = pytesseract.image_to_string(image, lang='eng')
209
+ return text.strip()
210
+ except Exception as e:
211
+ print(f"OCR error: {e}")
212
+ return ""
213
+
214
+ def extract_text_from_pdf(self, pdf_content: bytes) -> List[Dict]:
215
+ """
216
+ Extract text from PDF page by page, including OCR for images.
217
+
218
+ Args:
219
+ pdf_content: Raw bytes of PDF file
220
+
221
+ Returns:
222
+ List of dicts with page_num, text, and ocr_text
223
+ """
224
+ pages = []
225
+
226
+ try:
227
+ reader = PyPDF2.PdfReader(io.BytesIO(pdf_content))
228
+ for page_num, page in enumerate(reader.pages):
229
+ # Extract regular text
230
+ text = page.extract_text()
231
+ ocr_text = ""
232
+
233
+ # Extract images and apply OCR
234
+ if OCR_AVAILABLE:
235
+ try:
236
+ # Get images from page
237
+ if '/XObject' in page['/Resources']:
238
+ xObject = page['/Resources']['/XObject'].get_object()
239
+
240
+ for obj in xObject:
241
+ if xObject[obj]['/Subtype'] == '/Image':
242
+ try:
243
+ # Extract image data
244
+ size = (xObject[obj]['/Width'], xObject[obj]['/Height'])
245
+ data = xObject[obj].get_data()
246
+
247
+ # Try to create image
248
+ if xObject[obj]['/ColorSpace'] == '/DeviceRGB':
249
+ mode = "RGB"
250
+ elif xObject[obj]['/ColorSpace'] == '/DeviceGray':
251
+ mode = "L"
252
+ else:
253
+ mode = "RGB" # Default
254
+
255
+ try:
256
+ image = Image.frombytes(mode, size, data)
257
+ # Apply OCR
258
+ img_text = self.extract_text_from_image(image)
259
+ if img_text:
260
+ ocr_text += img_text + "\n"
261
+ except Exception as img_error:
262
+ # Try with PIL's open if frombytes fails
263
+ try:
264
+ image = Image.open(io.BytesIO(data))
265
+ img_text = self.extract_text_from_image(image)
266
+ if img_text:
267
+ ocr_text += img_text + "\n"
268
+ except:
269
+ pass
270
+ except Exception as e:
271
+ # Skip this image if extraction fails
272
+ continue
273
+ except Exception as e:
274
+ print(f"Error extracting images from page {page_num + 1}: {e}")
275
+
276
+ # Combine regular text and OCR text
277
+ combined_text = ""
278
+ if text and text.strip():
279
+ combined_text += text.strip()
280
+ if ocr_text.strip():
281
+ if combined_text:
282
+ combined_text += "\n\n[Text from images:]\n" + ocr_text.strip()
283
+ else:
284
+ combined_text = ocr_text.strip()
285
+
286
+ if combined_text:
287
+ pages.append({
288
+ "page_num": page_num + 1,
289
+ "text": combined_text,
290
+ "has_ocr": bool(ocr_text.strip())
291
+ })
292
+ except Exception as e:
293
+ print(f"Error extracting PDF text: {e}")
294
+ raise
295
+
296
+ return pages
297
+
298
+ def process_pdf(self, filename: str, file_content: bytes,
299
+ chunk_size: int = DEFAULT_CHUNK_SIZE,
300
+ overlap_size: int = DEFAULT_OVERLAP_SIZE) -> List[Dict]:
301
+ """
302
+ Process a PDF: extract text (including OCR), chunk it, and prepare metadata.
303
+
304
+ Args:
305
+ filename: Original filename
306
+ file_content: Raw bytes of PDF
307
+ chunk_size: Words per chunk
308
+ overlap_size: Overlap between chunks
309
+
310
+ Returns:
311
+ List of chunk metadata dicts
312
+ """
313
+ # Extract pages
314
+ pages = self.extract_text_from_pdf(file_content)
315
+
316
+ # Chunk each page
317
+ chunks_metadata = []
318
+ for page_info in pages:
319
+ page_chunks = self.chunk_text_with_overlap(
320
+ page_info["text"],
321
+ chunk_size,
322
+ overlap_size
323
+ )
324
+ for chunk_text in page_chunks:
325
+ chunks_metadata.append({
326
+ "text": chunk_text,
327
+ "source": filename,
328
+ "page": page_info["page_num"],
329
+ "has_ocr": page_info.get("has_ocr", False)
330
+ })
331
+
332
+ return chunks_metadata
333
+
334
+ # ============================================
335
+ # DUPLICATE DETECTION METHODS
336
+ # ============================================
337
+
338
+ def check_duplicate(self, file_hash: str) -> Optional[Dict]:
339
+ """
340
+ Check if a document with the same hash already exists.
341
+
342
+ Args:
343
+ file_hash: SHA-256 hash of the file
344
+
345
+ Returns:
346
+ Document info if duplicate found, None otherwise
347
+ """
348
+ for doc_id, doc_info in self.documents.items():
349
+ if doc_info.get("hash") == file_hash:
350
+ return {"doc_id": doc_id, **doc_info}
351
+ return None
352
+
353
+ def get_document_by_filename(self, filename: str) -> Optional[Dict]:
354
+ """
355
+ Get document info by filename.
356
+
357
+ Args:
358
+ filename: Original filename
359
+
360
+ Returns:
361
+ Document info if found, None otherwise
362
+ """
363
+ for doc_id, doc_info in self.documents.items():
364
+ if doc_info.get("filename") == filename:
365
+ return {"doc_id": doc_id, **doc_info}
366
+ return None
367
+
368
+ # ============================================
369
+ # EMBEDDING AND INDEXING METHODS
370
+ # ============================================
371
+
372
+ def generate_embeddings(self, texts: List[str]) -> np.ndarray:
373
+ """
374
+ Generate embeddings for a list of texts.
375
+
376
+ Args:
377
+ texts: List of text strings
378
+
379
+ Returns:
380
+ Numpy array of embeddings
381
+ """
382
+ embeddings = self.embed_model.encode(texts)
383
+ return np.array(embeddings).astype("float32")
384
+
385
+ def add_to_index(self, chunks_metadata: List[Dict]) -> int:
386
+ """
387
+ Add new chunks to FAISS index and metadata.
388
+
389
+ Args:
390
+ chunks_metadata: List of chunk dicts with text, source, page
391
+
392
+ Returns:
393
+ Number of chunks added
394
+ """
395
+ if not chunks_metadata:
396
+ return 0
397
+
398
+ # Extract texts for embedding
399
+ texts = [c["text"] for c in chunks_metadata]
400
+
401
+ # Generate embeddings
402
+ embeddings = self.generate_embeddings(texts)
403
+
404
+ # Add to FAISS index
405
+ self.index.add(embeddings)
406
+
407
+ # Add to metadata
408
+ self.metadata.extend(chunks_metadata)
409
+
410
+ return len(chunks_metadata)
411
+
412
+ def remove_document_from_index(self, filename: str):
413
+ """
414
+ Remove all chunks of a document from the index.
415
+ Note: FAISS IndexFlatL2 doesn't support removal, so we rebuild.
416
+
417
+ Args:
418
+ filename: Filename of document to remove
419
+ """
420
+ # Filter out chunks from this document
421
+ remaining_metadata = [
422
+ m for m in self.metadata if m["source"] != filename
423
+ ]
424
+
425
+ if len(remaining_metadata) == len(self.metadata):
426
+ return # Nothing to remove
427
+
428
+ # Rebuild index with remaining chunks
429
+ self.metadata = remaining_metadata
430
+
431
+ if self.metadata:
432
+ texts = [m["text"] for m in self.metadata]
433
+ embeddings = self.generate_embeddings(texts)
434
+ self.index = faiss.IndexFlatL2(EMBEDDING_DIMENSION)
435
+ self.index.add(embeddings)
436
+ else:
437
+ self.index = faiss.IndexFlatL2(EMBEDDING_DIMENSION)
438
+
439
+ print(f"Removed document '{filename}' from index")
440
+
441
+ # ============================================
442
+ # DOCUMENT UPLOAD METHODS
443
+ # ============================================
444
+
445
+ def upload_document(self, filename: str, file_content: bytes,
446
+ action: str = "auto") -> Dict:
447
+ """
448
+ Upload and process a document.
449
+
450
+ Args:
451
+ filename: Original filename
452
+ file_content: Raw bytes of PDF
453
+ action: "auto", "use_existing", "replace", or "cancel"
454
+
455
+ Returns:
456
+ Result dict with status and info
457
+ """
458
+ # Compute hash
459
+ file_hash = self.compute_file_hash(file_content)
460
+
461
+ # Check for duplicate
462
+ existing_doc = self.check_duplicate(file_hash)
463
+
464
+ if existing_doc:
465
+ if action == "auto":
466
+ # Return duplicate warning
467
+ return {
468
+ "status": "duplicate",
469
+ "filename": filename,
470
+ "existing_filename": existing_doc["filename"],
471
+ "hash": file_hash,
472
+ "message": f"Document already exists as '{existing_doc['filename']}'",
473
+ "options": ["use_existing", "replace", "cancel"]
474
+ }
475
+ elif action == "use_existing":
476
+ return {
477
+ "status": "success",
478
+ "filename": existing_doc["filename"],
479
+ "message": "Using existing document embeddings",
480
+ "chunks": 0,
481
+ "reused": True
482
+ }
483
+ elif action == "cancel":
484
+ return {
485
+ "status": "cancelled",
486
+ "filename": filename,
487
+ "message": "Upload cancelled"
488
+ }
489
+ elif action == "replace":
490
+ # Remove old document and continue with upload
491
+ self.remove_document_from_index(existing_doc["filename"])
492
+ del self.documents[existing_doc["doc_id"]]
493
+
494
+ # Process new document
495
+ try:
496
+ chunks_metadata = self.process_pdf(filename, file_content)
497
+
498
+ if not chunks_metadata:
499
+ return {
500
+ "status": "error",
501
+ "filename": filename,
502
+ "message": "No text could be extracted from PDF"
503
+ }
504
+
505
+ # Add to index
506
+ num_chunks = self.add_to_index(chunks_metadata)
507
+
508
+ # Register document
509
+ doc_id = f"doc_{len(self.documents) + 1}_{int(datetime.now().timestamp())}"
510
+ self.documents[doc_id] = {
511
+ "filename": filename,
512
+ "hash": file_hash,
513
+ "upload_timestamp": datetime.now().isoformat(),
514
+ "num_chunks": num_chunks,
515
+ "num_pages": max(c["page"] for c in chunks_metadata)
516
+ }
517
+
518
+ # Persist changes
519
+ self._save_persistent_data()
520
+
521
+ return {
522
+ "status": "success",
523
+ "filename": filename,
524
+ "message": f"Document processed successfully",
525
+ "chunks": num_chunks,
526
+ "pages": self.documents[doc_id]["num_pages"]
527
+ }
528
+
529
+ except Exception as e:
530
+ return {
531
+ "status": "error",
532
+ "filename": filename,
533
+ "message": f"Error processing document: {str(e)}"
534
+ }
535
+
536
+ # ============================================
537
+ # QUERY AND RETRIEVAL METHODS
538
+ # ============================================
539
+
540
+ def retrieve_relevant_chunks(self, query: str, top_k: int = DEFAULT_TOP_K, doc_id: str = None) -> List[Dict]:
541
+ """
542
+ Retrieve most relevant chunks for a query.
543
+
544
+ Args:
545
+ query: User's question
546
+ top_k: Number of chunks to retrieve
547
+
548
+ Returns:
549
+ List of relevant chunks with metadata
550
+ """
551
+ if self.index is None or self.index.ntotal == 0:
552
+ return []
553
+
554
+ # Limit top_k to available chunks
555
+ top_k = min(top_k, self.index.ntotal)
556
+
557
+ # Embed query
558
+ query_embedding = self.embed_model.encode([query]).astype("float32")
559
+
560
+ # Search FAISS (request more if scoping to a single doc might filter results)
561
+ k_search = max(top_k, min(50, self.index.ntotal))
562
+ distances, indices = self.index.search(query_embedding, k=k_search)
563
+
564
+ # If doc_id provided, determine filename to filter by
565
+ filename_filter = None
566
+ if doc_id and doc_id in self.documents:
567
+ filename_filter = self.documents[doc_id].get('filename')
568
+
569
+ # Gather results and apply optional filename filter
570
+ results = []
571
+ for i, idx in enumerate(indices[0]):
572
+ if idx < len(self.metadata):
573
+ meta = self.metadata[idx]
574
+ if filename_filter and meta.get('source') != filename_filter:
575
+ continue
576
+ results.append({
577
+ **meta,
578
+ "distance": float(distances[0][i]),
579
+ "relevance_rank": len(results) + 1
580
+ })
581
+ if len(results) >= top_k:
582
+ break
583
+
584
+ return results
585
+
586
+ def generate_answer(self, query: str, context_chunks: List[Dict]) -> str:
587
+ """
588
+ Generate answer using Gemini with retrieved context.
589
+
590
+ Args:
591
+ query: User's question
592
+ context_chunks: Retrieved relevant chunks
593
+
594
+ Returns:
595
+ Generated answer string
596
+ """
597
+ if not context_chunks:
598
+ return "I don't have enough information to answer this question. Please upload relevant documents first."
599
+
600
+ # Build context string
601
+ context_parts = []
602
+ for chunk in context_chunks:
603
+ context_parts.append(
604
+ f"[Source: {chunk['source']}, Page {chunk['page']}]\n{chunk['text']}"
605
+ )
606
+ context = "\n\n".join(context_parts)
607
+
608
+ # Create prompt
609
+ prompt = f"""You are a helpful assistant that answers questions based ONLY on the provided context.
610
+ Do NOT make up information that is not in the context.
611
+ If the context doesn't contain enough information to answer, say so clearly.
612
+ You may summarize, combine, or rephrase information from the context to make your answer clear and helpful.
613
+
614
+ CONTEXT:
615
+ {context}
616
+
617
+ QUESTION:
618
+ {query}
619
+
620
+ ANSWER:"""
621
+
622
+ try:
623
+ response = self.gemini_model.generate_content(prompt)
624
+ return response.text
625
+ except Exception as e:
626
+ return f"Error generating answer: {str(e)}"
627
+
628
+ def verify_sources(self, query: str, answer: str, context_chunks: List[Dict]) -> List[int]:
629
+ """
630
+ Verify which chunks actually support the generated answer.
631
+
632
+ Args:
633
+ query: User's question
634
+ answer: Generated answer
635
+ context_chunks: All retrieved chunks
636
+
637
+ Returns:
638
+ List of indices of chunks that support the answer
639
+ """
640
+ if not context_chunks:
641
+ return []
642
+
643
+ # Build context with numbered chunks
644
+ context_parts = []
645
+ for i, chunk in enumerate(context_chunks):
646
+ context_parts.append(
647
+ f"[{i}] Source: {chunk['source']}, Page {chunk['page']}\n{chunk['text']}"
648
+ )
649
+ context = "\n\n".join(context_parts)
650
+
651
+ # Create verification prompt
652
+ prompt = f"""You are a citation verification assistant. Given a question, an answer, and numbered source chunks, identify which chunks were actually used to generate the answer.
653
+
654
+ Return ONLY a comma-separated list of chunk numbers that directly support the answer (e.g., "0,2,3").
655
+ If no chunks support the answer, return "NONE".
656
+ Do not include explanations or any other text.
657
+
658
+ QUESTION:
659
+ {query}
660
+
661
+ ANSWER:
662
+ {answer}
663
+
664
+ NUMBERED CHUNKS:
665
+ {context}
666
+
667
+ CHUNK NUMBERS THAT SUPPORT THE ANSWER:"""
668
+
669
+ try:
670
+ response = self.gemini_model.generate_content(prompt)
671
+ result = response.text.strip()
672
+
673
+ # Parse the response
674
+ if result.upper() == "NONE":
675
+ return []
676
+
677
+ # Extract numbers
678
+ used_indices = []
679
+ for part in result.split(","):
680
+ try:
681
+ idx = int(part.strip())
682
+ if 0 <= idx < len(context_chunks):
683
+ used_indices.append(idx)
684
+ except ValueError:
685
+ continue
686
+
687
+ return used_indices
688
+ except Exception as e:
689
+ print(f"Error verifying sources: {e}")
690
+ # Fallback: return all chunks if verification fails
691
+ return list(range(len(context_chunks)))
692
+
693
+ def ask(self, query: str, top_k: int = DEFAULT_TOP_K, doc_id: str = None) -> Dict:
694
+ """
695
+ Main query method: retrieve context, generate answer, and filter sources.
696
+
697
+ Args:
698
+ query: User's question
699
+ top_k: Number of chunks to retrieve
700
+
701
+ Returns:
702
+ Dict with answer and verified sources
703
+ """
704
+ # Retrieve relevant chunks (optionally scoped to a document)
705
+ relevant_chunks = self.retrieve_relevant_chunks(query, top_k, doc_id=doc_id)
706
+
707
+ # Generate answer
708
+ answer = self.generate_answer(query, relevant_chunks)
709
+
710
+ # Verify which chunks actually support the answer
711
+ used_indices = self.verify_sources(query, answer, relevant_chunks)
712
+
713
+ # Filter sources to only those that support the answer
714
+ sources = []
715
+ seen = set()
716
+ for idx in used_indices:
717
+ if idx < len(relevant_chunks):
718
+ chunk = relevant_chunks[idx]
719
+ source_key = f"{chunk['source']}_{chunk['page']}"
720
+ if source_key not in seen:
721
+ sources.append({
722
+ "file": chunk["source"],
723
+ "page": chunk["page"]
724
+ })
725
+ seen.add(source_key)
726
+
727
+ return {
728
+ "answer": answer,
729
+ "sources": sources,
730
+ "num_chunks_used": len(sources),
731
+ "num_chunks_retrieved": len(relevant_chunks)
732
+ }
733
+
734
+ # ============================================
735
+ # DOCUMENT MANAGEMENT METHODS
736
+ # ============================================
737
+
738
+ def get_all_documents(self) -> List[Dict]:
739
+ """
740
+ Get list of all uploaded documents.
741
+
742
+ Returns:
743
+ List of document info dicts
744
+ """
745
+ return [
746
+ {"doc_id": doc_id, **info}
747
+ for doc_id, info in self.documents.items()
748
+ ]
749
+
750
+ def delete_document(self, doc_id: str) -> Dict:
751
+ """
752
+ Delete a document and its embeddings.
753
+
754
+ Args:
755
+ doc_id: Document ID to delete
756
+
757
+ Returns:
758
+ Result dict
759
+ """
760
+ if doc_id not in self.documents:
761
+ return {
762
+ "status": "error",
763
+ "message": f"Document {doc_id} not found"
764
+ }
765
+
766
+ filename = self.documents[doc_id]["filename"]
767
+
768
+ # Remove from index
769
+ self.remove_document_from_index(filename)
770
+
771
+ # Remove from registry
772
+ del self.documents[doc_id]
773
+
774
+ # Persist changes
775
+ self._save_persistent_data()
776
+
777
+ return {
778
+ "status": "success",
779
+ "message": f"Document '{filename}' deleted successfully"
780
+ }
781
+
782
+ def get_stats(self) -> Dict:
783
+ """
784
+ Get system statistics.
785
+
786
+ Returns:
787
+ Dict with stats
788
+ """
789
+ return {
790
+ "total_documents": len(self.documents),
791
+ "total_chunks": len(self.metadata),
792
+ "index_size": self.index.ntotal if self.index else 0,
793
+ "embedding_model": EMBEDDING_MODEL_NAME,
794
+ "embedding_dimension": EMBEDDING_DIMENSION
795
  }