dungeon29 commited on
Commit
b1943df
Β·
verified Β·
1 Parent(s): 18cf1c3

Update rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +235 -117
rag_engine.py CHANGED
@@ -1,117 +1,235 @@
1
- import os
2
- import glob
3
- from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader, JSONLoader
4
- from langchain_community.vectorstores import Chroma
5
- from langchain_huggingface import HuggingFaceEmbeddings
6
- from langchain_text_splitters import RecursiveCharacterTextSplitter
7
- from langchain_core.documents import Document
8
-
9
- class RAGEngine:
10
- def __init__(self, knowledge_base_dir="./knowledge_base", persist_directory="./chroma_db"):
11
- self.knowledge_base_dir = knowledge_base_dir
12
- self.persist_directory = persist_directory
13
-
14
- # Initialize Embeddings (using same model as before)
15
- self.embedding_fn = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
16
-
17
- # Initialize Vector Store
18
- self.vector_store = Chroma(
19
- persist_directory=self.persist_directory,
20
- embedding_function=self.embedding_fn,
21
- collection_name="phishing_knowledge"
22
- )
23
-
24
- # Build index if empty or on init
25
- if not self.vector_store.get()['ids']:
26
- self._build_index()
27
-
28
- def _build_index(self):
29
- """Load documents and build index"""
30
- print("πŸ”„ Building Knowledge Base Index...")
31
-
32
- documents = self._load_documents()
33
- if not documents:
34
- print("⚠️ No documents found to index.")
35
- return
36
-
37
- # Split documents
38
- text_splitter = RecursiveCharacterTextSplitter(
39
- chunk_size=500,
40
- chunk_overlap=50,
41
- separators=["\n\n", "\n", " ", ""]
42
- )
43
- chunks = text_splitter.split_documents(documents)
44
-
45
- if chunks:
46
- # Add to vector store
47
- self.vector_store.add_documents(chunks)
48
- self.vector_store.persist()
49
- print(f"βœ… Indexed {len(chunks)} chunks from {len(documents)} documents.")
50
- else:
51
- print("⚠️ No chunks created.")
52
-
53
- def _load_documents(self):
54
- """Load documents from directory or fallback file"""
55
- documents = []
56
-
57
- # Check for directory or fallback file
58
- target_path = self.knowledge_base_dir
59
- if not os.path.exists(target_path):
60
- if os.path.exists("knowledge_base.txt"):
61
- target_path = "knowledge_base.txt"
62
- print("⚠️ Using fallback 'knowledge_base.txt' in root.")
63
- else:
64
- print(f"❌ Knowledge base not found at {target_path}")
65
- return []
66
-
67
- try:
68
- if os.path.isfile(target_path):
69
- # Load single file
70
- if target_path.endswith(".pdf"):
71
- loader = PyPDFLoader(target_path)
72
- else:
73
- loader = TextLoader(target_path, encoding="utf-8")
74
- documents.extend(loader.load())
75
- else:
76
- # Load directory
77
- loaders = [
78
- DirectoryLoader(target_path, glob="**/*.txt", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}),
79
- DirectoryLoader(target_path, glob="**/*.md", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}),
80
- DirectoryLoader(target_path, glob="**/*.pdf", loader_cls=PyPDFLoader),
81
- ]
82
-
83
- for loader in loaders:
84
- try:
85
- docs = loader.load()
86
- documents.extend(docs)
87
- except Exception as e:
88
- print(f"⚠️ Error loading with {loader}: {e}")
89
-
90
- except Exception as e:
91
- print(f"❌ Error loading documents: {e}")
92
-
93
- return documents
94
-
95
- def refresh_knowledge_base(self):
96
- """Force rebuild of the index"""
97
- print("♻️ Refreshing Knowledge Base...")
98
- # Clear existing collection
99
- self.vector_store.delete_collection()
100
- self.vector_store = Chroma(
101
- persist_directory=self.persist_directory,
102
- embedding_function=self.embedding_fn,
103
- collection_name="phishing_knowledge"
104
- )
105
- # Rebuild
106
- self._build_index()
107
- return "βœ… Knowledge Base Refreshed!"
108
-
109
- def retrieve(self, query, n_results=3):
110
- """Retrieve relevant context"""
111
- # Search
112
- results = self.vector_store.similarity_search(query, k=n_results)
113
-
114
- # Format results
115
- if results:
116
- return [doc.page_content for doc in results]
117
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader
4
+ from langchain_qdrant import Qdrant
5
+ from langchain_huggingface import HuggingFaceEmbeddings
6
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
7
+ from langchain_core.documents import Document
8
+ from qdrant_client import QdrantClient, models
9
+ from datasets import load_dataset
10
+
11
+ class RAGEngine:
12
+ def __init__(self, knowledge_base_dir="./knowledge_base"):
13
+ self.knowledge_base_dir = knowledge_base_dir
14
+
15
+ # Initialize Embeddings
16
+ self.embedding_fn = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
17
+
18
+ # Qdrant Cloud Configuration
19
+ # Prioritize Env Vars, fallback to Hardcoded (User provided)
20
+ self.qdrant_url = os.environ.get("QDRANT_URL") or "https://abd29675-7fb9-4d95-8941-e6130b09bf7f.us-east4-0.gcp.cloud.qdrant.io"
21
+ self.qdrant_api_key = os.environ.get("QDRANT_API_KEY") or "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.L0aAAAbxRypLfBeGCtFr2xX06iveGb76NrA3BPJQiNM"
22
+ self.collection_name = "phishing_knowledge"
23
+
24
+ if not self.qdrant_url or not self.qdrant_api_key:
25
+ print("⚠️ QDRANT_URL or QDRANT_API_KEY not set. RAG will not function correctly.")
26
+ self.vector_store = None
27
+ return
28
+
29
+ print(f"☁️ Connecting to Qdrant Cloud: {self.qdrant_url}...")
30
+
31
+ # Initialize Qdrant Client
32
+ self.client = QdrantClient(
33
+ url=self.qdrant_url,
34
+ api_key=self.qdrant_api_key
35
+ )
36
+
37
+ # Initialize Vector Store Wrapper
38
+ self.vector_store = Qdrant(
39
+ client=self.client,
40
+ collection_name=self.collection_name,
41
+ embeddings=self.embedding_fn
42
+ )
43
+
44
+ # Check if collection exists/is empty and build if needed
45
+ try:
46
+ if not self.client.collection_exists(self.collection_name):
47
+ print(f"⚠️ Collection '{self.collection_name}' not found. Creating...")
48
+ self.client.create_collection(
49
+ collection_name=self.collection_name,
50
+ vectors_config=models.VectorParams(size=384, distance=models.Distance.COSINE)
51
+ )
52
+ print(f"βœ… Collection '{self.collection_name}' created!")
53
+ self._build_index()
54
+ else:
55
+ # Check if dataset is already indexed
56
+ dataset_filter = models.Filter(
57
+ must=[
58
+ models.FieldCondition(
59
+ key="metadata.source",
60
+ match=models.MatchValue(value="hf_dataset")
61
+ )
62
+ ]
63
+ )
64
+ dataset_count = self.client.count(
65
+ collection_name=self.collection_name,
66
+ count_filter=dataset_filter
67
+ ).count
68
+
69
+ print(f"βœ… Qdrant Collection '{self.collection_name}' ready with {dataset_count} vectors.")
70
+
71
+ if dataset_count == 0:
72
+ print("⚠️ Phishing dataset not found. Please run 'index_dataset_colab.ipynb' to populate.")
73
+ # self.load_from_huggingface() # Disabled to prevent timeout
74
+
75
+ except Exception as e:
76
+ print(f"⚠️ Collection check/creation failed: {e}")
77
+ # Try to build anyway, maybe wrapper handles it
78
+ self._build_index()
79
+
80
+ def _build_index(self):
81
+ """Load documents and build index"""
82
+ print("πŸ”„ Building Knowledge Base Index on Qdrant Cloud...")
83
+
84
+ documents = self._load_documents()
85
+ if not documents:
86
+ print("⚠️ No documents found to index.")
87
+ return
88
+
89
+ # Split documents
90
+ text_splitter = RecursiveCharacterTextSplitter(
91
+ chunk_size=500,
92
+ chunk_overlap=50,
93
+ separators=["\n\n", "\n", " ", ""]
94
+ )
95
+ chunks = text_splitter.split_documents(documents)
96
+
97
+ if chunks:
98
+ # Add to vector store (Qdrant handles persistence automatically)
99
+ try:
100
+ self.vector_store.add_documents(chunks)
101
+ print(f"βœ… Indexed {len(chunks)} chunks to Qdrant Cloud.")
102
+ except Exception as e:
103
+ print(f"❌ Error indexing to Qdrant: {e}")
104
+ else:
105
+ print("⚠️ No chunks created.")
106
+
107
+ def _load_documents(self):
108
+ """Load documents from directory or fallback file"""
109
+ documents = []
110
+
111
+ # Check for directory or fallback file
112
+ target_path = self.knowledge_base_dir
113
+ if not os.path.exists(target_path):
114
+ if os.path.exists("knowledge_base.txt"):
115
+ target_path = "knowledge_base.txt"
116
+ print("⚠️ Using fallback 'knowledge_base.txt' in root.")
117
+ else:
118
+ print(f"❌ Knowledge base not found at {target_path}")
119
+ return []
120
+
121
+ try:
122
+ if os.path.isfile(target_path):
123
+ # Load single file
124
+ if target_path.endswith(".pdf"):
125
+ loader = PyPDFLoader(target_path)
126
+ else:
127
+ loader = TextLoader(target_path, encoding="utf-8")
128
+ documents.extend(loader.load())
129
+ else:
130
+ # Load directory
131
+ loaders = [
132
+ DirectoryLoader(target_path, glob="**/*.txt", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}),
133
+ DirectoryLoader(target_path, glob="**/*.md", loader_cls=TextLoader, loader_kwargs={"encoding": "utf-8"}),
134
+ DirectoryLoader(target_path, glob="**/*.pdf", loader_cls=PyPDFLoader),
135
+ ]
136
+
137
+ for loader in loaders:
138
+ try:
139
+ docs = loader.load()
140
+ documents.extend(docs)
141
+ except Exception as e:
142
+ print(f"⚠️ Error loading with {loader}: {e}")
143
+
144
+ except Exception as e:
145
+ print(f"❌ Error loading documents: {e}")
146
+
147
+ return documents
148
+
149
+ def load_from_huggingface(self):
150
+ """Load and index dataset manually from Hugging Face JSON"""
151
+ dataset_url = "https://huggingface.co/datasets/ealvaradob/phishing-dataset/resolve/main/combined_reduced.json"
152
+ print(f"πŸ“₯ Downloading dataset from {dataset_url}...")
153
+
154
+ try:
155
+ import requests
156
+ import json
157
+
158
+ response = requests.get(dataset_url)
159
+ if response.status_code != 200:
160
+ print(f"❌ Failed to download dataset: {response.status_code}")
161
+ return
162
+
163
+ data = response.json()
164
+ print(f"βœ… Dataset downloaded. Processing {len(data)} rows...")
165
+
166
+ documents = []
167
+ for row in data:
168
+ # Structure: text, label
169
+ content = row.get('text', '')
170
+ label = row.get('label', -1)
171
+
172
+ if content:
173
+ doc = Document(
174
+ page_content=content,
175
+ metadata={"source": "hf_dataset", "label": label}
176
+ )
177
+ documents.append(doc)
178
+
179
+ if documents:
180
+ # Batch add to vector store
181
+ print(f"πŸ”„ Indexing {len(documents)} documents to Qdrant...")
182
+
183
+ # Use a larger chunk size for efficiency since these are likely short texts
184
+ text_splitter = RecursiveCharacterTextSplitter(
185
+ chunk_size=1000,
186
+ chunk_overlap=100
187
+ )
188
+ chunks = text_splitter.split_documents(documents)
189
+
190
+ # Add in batches to avoid hitting API limits or timeouts
191
+ batch_size = 100
192
+ total_chunks = len(chunks)
193
+
194
+ for i in range(0, total_chunks, batch_size):
195
+ batch = chunks[i:i+batch_size]
196
+ try:
197
+ self.vector_store.add_documents(batch)
198
+ print(f" - Indexed batch {i//batch_size + 1}/{(total_chunks + batch_size - 1)//batch_size}")
199
+ except Exception as e:
200
+ print(f" ⚠️ Error indexing batch {i}: {e}")
201
+
202
+ print(f"βœ… Successfully indexed {total_chunks} chunks from dataset!")
203
+ else:
204
+ print("⚠️ No valid documents found in dataset.")
205
+
206
+ except Exception as e:
207
+ print(f"❌ Error loading HF dataset: {e}")
208
+
209
+ def refresh_knowledge_base(self):
210
+ """Force rebuild of the index"""
211
+ print("♻️ Refreshing Knowledge Base...")
212
+ if self.client:
213
+ try:
214
+ self.client.delete_collection(self.collection_name)
215
+ self._build_index()
216
+ self.load_from_huggingface()
217
+ return "βœ… Knowledge Base Refreshed on Cloud!"
218
+ except Exception as e:
219
+ return f"❌ Error refreshing: {e}"
220
+ return "❌ Qdrant Client not initialized."
221
+
222
+ def retrieve(self, query, n_results=3):
223
+ """Retrieve relevant context"""
224
+ if not self.vector_store:
225
+ return []
226
+
227
+ # Search
228
+ try:
229
+ results = self.vector_store.similarity_search(query, k=n_results)
230
+ if results:
231
+ return [doc.page_content for doc in results]
232
+ except Exception as e:
233
+ print(f"⚠️ Retrieval Error: {e}")
234
+
235
+ return []