Pranesh64 commited on
Commit
2e0cf55
Β·
verified Β·
1 Parent(s): 0721a21

Create rag.py

Browse files
Files changed (1) hide show
  1. backend/rag.py +203 -0
backend/rag.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG engine module for embeddings and FAISS-based retrieval.
3
+ """
4
+
5
+ import os
6
+ import warnings
7
+ from typing import List, Dict, Tuple
8
+ import numpy as np
9
+ import faiss
10
+ import pickle
11
+
12
+ # Suppress PyTorch internal warnings
13
+ warnings.filterwarnings('ignore', category=UserWarning, module='torch')
14
+
15
+ from sentence_transformers import SentenceTransformer
16
+
17
+
18
+ class RAGEngine:
19
+ """In-memory RAG engine using FAISS for similarity search with persistence."""
20
+
21
+ def __init__(self, index_path: str = "faiss_index"):
22
+ """Initialize the RAG engine with embedding model."""
23
+ self.model = None
24
+ self.index = None
25
+ self.chunks = [] # Store chunk texts and metadata
26
+ self.dimension = 384 # MiniLM-L6-v2 embedding dimension
27
+ self.index_path = index_path
28
+
29
+ # Try to load existing index
30
+ self._load_index()
31
+
32
+ def _load_model(self):
33
+ """Lazy load the embedding model."""
34
+ if self.model is None:
35
+ print("πŸ”„ Loading embedding model...")
36
+ self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
37
+ print("βœ… Embedding model loaded")
38
+
39
+ def _initialize_index(self):
40
+ """Initialize FAISS index if not already created."""
41
+ if self.index is None:
42
+ self._load_model()
43
+ self.index = faiss.IndexFlatL2(self.dimension)
44
+
45
+ def _save_index(self):
46
+ """Save FAISS index and chunks to disk."""
47
+ if self.index is not None and self.index.ntotal > 0:
48
+ try:
49
+ # Create directory if it doesn't exist
50
+ os.makedirs(self.index_path, exist_ok=True)
51
+
52
+ # Save FAISS index
53
+ index_file = os.path.join(self.index_path, "index.faiss")
54
+ faiss.write_index(self.index, index_file)
55
+
56
+ # Save chunks metadata
57
+ chunks_file = os.path.join(self.index_path, "chunks.pkl")
58
+ with open(chunks_file, 'wb') as f:
59
+ pickle.dump(self.chunks, f)
60
+
61
+ print(f"πŸ’Ύ Index saved to {self.index_path}")
62
+ except Exception as e:
63
+ print(f"❌ Failed to save index: {str(e)}")
64
+
65
+ def _load_index(self):
66
+ """Load FAISS index and chunks from disk."""
67
+ if os.path.exists(self.index_path):
68
+ try:
69
+ index_file = os.path.join(self.index_path, "index.faiss")
70
+ chunks_file = os.path.join(self.index_path, "chunks.pkl")
71
+
72
+ if os.path.exists(index_file) and os.path.exists(chunks_file):
73
+ # Load FAISS index
74
+ self.index = faiss.read_index(index_file)
75
+
76
+ # Load chunks metadata
77
+ with open(chunks_file, 'rb') as f:
78
+ self.chunks = pickle.load(f)
79
+
80
+ print(f"βœ… Loaded existing index with {len(self.chunks)} chunks")
81
+ return True
82
+ except Exception as e:
83
+ print(f"⚠️ Failed to load existing index: {str(e)}")
84
+ print("πŸ”„ Will create new index...")
85
+
86
+ return False
87
+
88
+ def add_documents(self, chunks: List[Dict], save_index: bool = True):
89
+ """
90
+ Add document chunks to the index.
91
+
92
+ Args:
93
+ chunks: List of dictionaries with 'text', 'source', 'chunk_id' keys
94
+ save_index: Whether to save index to disk after adding
95
+ """
96
+ if not chunks:
97
+ return
98
+
99
+ self._initialize_index()
100
+
101
+ # Extract texts for embedding
102
+ texts = [chunk['text'] for chunk in chunks]
103
+
104
+ print(f"🧠 Generating embeddings for {len(texts)} chunks...")
105
+
106
+ # Generate embeddings
107
+ embeddings = self.model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
108
+ embeddings = embeddings.astype('float32')
109
+
110
+ # Add to FAISS index
111
+ self.index.add(embeddings)
112
+
113
+ # Store chunk metadata
114
+ self.chunks.extend(chunks)
115
+
116
+ print(f"βœ… Added {len(chunks)} chunks to index")
117
+
118
+ # Save index to disk
119
+ if save_index:
120
+ self._save_index()
121
+
122
+ def search(self, query: str, top_k: int = 5) -> List[Dict]:
123
+ """
124
+ Search for similar chunks.
125
+
126
+ Args:
127
+ query: Search query text
128
+ top_k: Number of results to return
129
+
130
+ Returns:
131
+ List of dictionaries with 'text', 'source', 'chunk_id', 'score' keys
132
+ """
133
+ if self.index is None or self.index.ntotal == 0:
134
+ return []
135
+
136
+ self._load_model()
137
+
138
+ # Generate query embedding
139
+ query_embedding = self.model.encode([query], convert_to_numpy=True).astype('float32')
140
+
141
+ # Search in FAISS
142
+ k = min(top_k, self.index.ntotal)
143
+ distances, indices = self.index.search(query_embedding, k)
144
+
145
+ # Format results
146
+ results = []
147
+ for i, idx in enumerate(indices[0]):
148
+ if idx < len(self.chunks):
149
+ chunk_data = self.chunks[idx].copy()
150
+ # Convert L2 distance to similarity score (lower distance = higher similarity)
151
+ distance = float(distances[0][i])
152
+ # Simple similarity: 1 / (1 + distance)
153
+ similarity = 1.0 / (1.0 + distance)
154
+ chunk_data['score'] = similarity
155
+ chunk_data['distance'] = distance
156
+ results.append(chunk_data)
157
+
158
+ return results
159
+
160
+ def get_chunk_count(self) -> int:
161
+ """Get total number of indexed chunks."""
162
+ if self.index is None:
163
+ return 0
164
+ return self.index.ntotal
165
+
166
+ def reset(self):
167
+ """Reset the index and clear all chunks."""
168
+ self.index = None
169
+ self.chunks = []
170
+
171
+ # Remove saved index files
172
+ if os.path.exists(self.index_path):
173
+ try:
174
+ import shutil
175
+ shutil.rmtree(self.index_path)
176
+ print("πŸ—‘οΈ Removed saved index files")
177
+ except Exception as e:
178
+ print(f"⚠️ Failed to remove index files: {str(e)}")
179
+
180
+ def rebuild_from_data(self, data_dir: str = "data"):
181
+ """
182
+ Rebuild the entire index from documents in data directory.
183
+
184
+ Args:
185
+ data_dir: Directory containing documents to index
186
+ """
187
+ from .processing import process_documents_from_directory
188
+
189
+ # Reset current index
190
+ self.reset()
191
+
192
+ # Process documents and build index
193
+ try:
194
+ chunks = process_documents_from_directory(data_dir)
195
+ if chunks:
196
+ self.add_documents(chunks)
197
+ return len(chunks)
198
+ else:
199
+ print("⚠️ No documents found to process")
200
+ return 0
201
+ except Exception as e:
202
+ print(f"❌ Failed to rebuild index: {str(e)}")
203
+ return 0