Sangyog10 commited on
Commit
273204e
·
1 Parent(s): 7ce4837

set up rag pipeline for chatbot

Browse files
Files changed (4) hide show
  1. README.md +3 -1
  2. app.py +3 -0
  3. features/rag_chatbot/rag_pipeline.py +40 -151
  4. requirements.txt +8 -0
README.md CHANGED
@@ -131,7 +131,9 @@ AI-Checker/
131
  2. **Run the API**
132
 
133
  ```bash
134
- uvicorn app:app --reload
 
 
135
  ```
136
 
137
  3. **Build Docker (optional)**
 
131
  2. **Run the API**
132
 
133
  ```bash
134
+ chroma run --path ./chroma_database ## to run chromadb locally
135
+ uvicorn app:app --reload --port 8001 ## fastapi (run after chromadb)
136
+
137
  ```
138
 
139
  3. **Build Docker (optional)**
app.py CHANGED
@@ -11,6 +11,7 @@ from features.nepali_text_classifier.routes import (
11
  )
12
  from features.image_classifier.routes import router as image_classifier_router
13
  from features.image_edit_detector.routes import router as image_edit_detector_router
 
14
  from fastapi.staticfiles import StaticFiles
15
 
16
  from config import ACCESS_RATE
@@ -41,6 +42,8 @@ app.include_router(text_classifier_router, prefix="/text")
41
  app.include_router(nepali_text_classifier_router, prefix="/NP")
42
  app.include_router(image_classifier_router, prefix="/AI-image")
43
  app.include_router(image_edit_detector_router, prefix="/detect")
 
 
44
 
45
 
46
  @app.get("/")
 
11
  )
12
  from features.image_classifier.routes import router as image_classifier_router
13
  from features.image_edit_detector.routes import router as image_edit_detector_router
14
+ from features.rag_chatbot.routes import router as rag_router
15
  from fastapi.staticfiles import StaticFiles
16
 
17
  from config import ACCESS_RATE
 
42
  app.include_router(nepali_text_classifier_router, prefix="/NP")
43
  app.include_router(image_classifier_router, prefix="/AI-image")
44
  app.include_router(image_edit_detector_router, prefix="/detect")
45
+ app.include_router(rag_router, prefix="/rag")
46
+
47
 
48
 
49
  @app.get("/")
features/rag_chatbot/rag_pipeline.py CHANGED
@@ -3,99 +3,38 @@ import chromadb
3
  from dotenv import load_dotenv
4
  from langchain_core.documents import Document
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain_community.embeddings import HuggingFaceEmbeddings
7
- from langchain_community.llms import OpenAI
8
  from langchain.chains.question_answering import load_qa_chain
9
  from langchain_community.vectorstores import Chroma
10
  from langchain.chains import LLMChain
11
  from langchain.prompts import PromptTemplate
12
- from langchain.chat_models import ChatOpenAI
13
-
14
 
15
  load_dotenv()
16
 
17
- # ChromaDB configuration
18
- CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost") # change in env in production when hosted
19
  COLLECTION_NAME = "company_docs_collection"
20
 
21
- # LLM Provider Configuration
22
- LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower()
23
- LLM_API_KEY = os.getenv("LLM_API_KEY")
24
- LLM_MODEL = os.getenv("LLM_MODEL", "gpt-3.5-turbo")
25
- LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0"))
26
- LLM_MAX_TOKENS = int(os.getenv("LLM_MAX_TOKENS", "2048"))
27
-
28
- # Provider-specific configurations
29
- PROVIDER_CONFIGS = {
30
- "openai": {
31
- "api_base": "https://api.openai.com/v1",
32
- "default_model": "gpt-3.5-turbo"
33
- },
34
- "groq": {
35
- "api_base": "https://api.groq.com/openai/v1",
36
- "default_model": "llama-3.3-70b-versatile"
37
- },
38
- "openrouter": {
39
- "api_base": "https://openrouter.ai/api/v1",
40
- "default_model": "mistralai/mistral-small-3.2-24b-instruct:free"
41
- }
42
- }
43
-
44
  vector_store = None
45
  company_qa_chain = None
46
  query_router_chain = None
47
  cybersecurity_chain = None
48
- llm = None
49
-
50
- def get_llm_config():
51
- """Get the appropriate LLM configuration based on the provider."""
52
- if LLM_PROVIDER not in PROVIDER_CONFIGS:
53
- raise ValueError(f"Unsupported LLM provider: {LLM_PROVIDER}. Supported: {list(PROVIDER_CONFIGS.keys())}")
54
-
55
- config = PROVIDER_CONFIGS[LLM_PROVIDER].copy()
56
-
57
- # Use provided model or fall back to default
58
- model = LLM_MODEL if LLM_MODEL != "gpt-3.5-turbo" else config["default_model"]
59
-
60
- return {
61
- "model": model,
62
- "openai_api_key": LLM_API_KEY,
63
- "openai_api_base": config["api_base"],
64
- "temperature": LLM_TEMPERATURE,
65
- "max_tokens": LLM_MAX_TOKENS,
66
- }
67
-
68
- def initialize_llm():
69
- """Initialize the LLM based on the configured provider."""
70
- if not LLM_API_KEY:
71
- raise ValueError(f"LLM_API_KEY environment variable is required for {LLM_PROVIDER}")
72
-
73
- config = get_llm_config()
74
-
75
- print(f"Initializing {LLM_PROVIDER.upper()} with model: {config['model']}")
76
-
77
- return ChatOpenAI(**config)
78
 
79
  def initialize_pipelines():
80
  """Initializes all required models, chains, and the vector store."""
81
  global vector_store, company_qa_chain, query_router_chain, cybersecurity_chain, llm
82
 
83
  try:
84
- # Initialize LLM
85
- llm = initialize_llm()
86
-
87
- # Initialize embeddings
88
- embeddings = HuggingFaceEmbeddings(
89
- model_name="all-MiniLM-L6-v2",
90
- model_kwargs={'device': 'cpu'},
91
- encode_kwargs={'normalize_embeddings': True}
92
- )
93
 
94
  # Initialize ChromaDB client
95
  try:
96
  chroma_client = chromadb.HttpClient(host=CHROMA_HOST, port=8000)
97
- chroma_client.heartbeat()
 
98
  except Exception as e:
 
 
99
  raise ConnectionError("Failed to connect to ChromaDB.") from e
100
 
101
  # Initialize vector store
@@ -106,14 +45,16 @@ def initialize_pipelines():
106
  )
107
 
108
  # Query Router Chain
109
- router_template = """You are a query classifier. Classify the following query into one of these categories:
110
- - COMPANY: Questions about our company, its products, services, or general information
111
- - CYBERSECURITY: Questions about cybersecurity, security threats, best practices, or vulnerabilities
112
- - OFF_TOPIC: Questions that don't fit the above categories
113
-
114
- Query: {query}
115
-
116
- Respond with only the category name (COMPANY, CYBERSECURITY, or OFF_TOPIC):"""
 
 
117
 
118
  router_prompt = PromptTemplate(
119
  input_variables=["query"],
@@ -125,34 +66,17 @@ Respond with only the category name (COMPANY, CYBERSECURITY, or OFF_TOPIC):"""
125
  prompt=router_prompt
126
  )
127
 
128
- # Custom Company QA Chain
129
- company_qa_template = """You are a helpful assistant for CyberAlertNepal. Answer the following question about our company using the information provided and links if only available. Give a natural, direct and polite response.
130
-
131
- Question: {question}
132
-
133
- Information:
134
- {context}
135
-
136
- Answer:"""
137
 
138
- company_qa_prompt = PromptTemplate(
139
- input_variables=["question", "context"],
140
- template=company_qa_template
141
- )
142
 
143
- company_qa_chain = LLMChain(
144
- llm=llm,
145
- prompt=company_qa_prompt
146
- )
147
 
148
- # Cybersecurity Chain
149
- cybersecurity_template = """You are a cybersecurity professional. Answer the following question truthfully and concisely.
150
- If you are not 100% sure about the answer, simply respond with: "I am not sure about the answer."
151
- Do not add extra explanations or assumptions. Do not provide false or speculative information.
152
-
153
- Question: {question}
154
-
155
- Provide a comprehensive and accurate answer about cybersecurity:"""
156
 
157
  cybersecurity_prompt = PromptTemplate(
158
  input_variables=["question"],
@@ -164,8 +88,8 @@ Provide a comprehensive and accurate answer about cybersecurity:"""
164
  prompt=cybersecurity_prompt
165
  )
166
 
167
- print(f"Successfully initialized pipelines with {LLM_PROVIDER.upper()}")
168
-
169
  except Exception as e:
170
  print(f"Error initializing pipelines: {e}")
171
  raise
@@ -188,6 +112,7 @@ def add_document_to_rag(text: str, metadata: dict):
188
  print("Document was empty after splitting, not adding to ChromaDB.")
189
  return False
190
 
 
191
  vector_store.add_documents(docs)
192
  print("Successfully added documents.")
193
  return True
@@ -208,6 +133,7 @@ def route_and_process_query(query: str):
208
  route_result = query_router_chain.run(query)
209
  route = route_result.strip().upper()
210
 
 
211
 
212
  # 2. Route to appropriate logic
213
  if "CYBERSECURITY" in route:
@@ -215,47 +141,38 @@ def route_and_process_query(query: str):
215
  return {
216
  "answer": answer,
217
  "source": "Cybersecurity Knowledge Base",
218
- "route": "CYBERSECURITY",
219
- "provider": LLM_PROVIDER.upper(),
220
- "model": get_llm_config()["model"]
221
  }
222
 
223
  elif "COMPANY" in route:
224
  # Perform similarity search on ChromaDB
225
  docs = vector_store.similarity_search(query, k=3)
 
 
226
 
227
  if not docs:
228
  return {
229
  "answer": "I could not find any relevant information to answer your question.",
230
  "source": "Company Documents",
231
- "route": "COMPANY",
232
- "provider": LLM_PROVIDER.upper(),
233
- "model": get_llm_config()["model"]
234
  }
235
 
236
- # Combine document content for context
237
- context = "\n\n".join([doc.page_content for doc in docs])
238
-
239
- # Run the custom QA chain
240
- answer = company_qa_chain.run(question=query, context=context)
241
  sources = list(set([doc.metadata.get("source", "Unknown") for doc in docs]))
242
 
243
  return {
244
  "answer": answer,
245
  "source": "Company Documents",
246
  "documents": sources,
247
- "route": "COMPANY",
248
- "provider": LLM_PROVIDER.upper(),
249
- "model": get_llm_config()["model"]
250
  }
251
 
252
  else: # OFF_TOPIC
253
  return {
254
  "answer": "I am a specialized assistant of CyberAlertNepal. I cannot answer questions outside of cybersecurity topics.",
255
  "source": "N/A",
256
- "route": "OFF_TOPIC",
257
- "provider": LLM_PROVIDER.upper(),
258
- "model": get_llm_config()["model"]
259
  }
260
 
261
  except Exception as e:
@@ -263,9 +180,6 @@ def route_and_process_query(query: str):
263
  return {
264
  "answer": "I encountered an error while processing your query. Please try again.",
265
  "source": "Error",
266
- "route": None,
267
- "documents": None,
268
- "provider": LLM_PROVIDER.upper(),
269
  "error": str(e)
270
  }
271
 
@@ -281,42 +195,17 @@ def check_system_health():
281
  "vector_store": vector_store is not None,
282
  "company_qa_chain": company_qa_chain is not None,
283
  "query_router_chain": query_router_chain is not None,
284
- "cybersecurity_chain": cybersecurity_chain is not None,
285
- "llm": llm is not None
286
  }
287
 
288
  return {
289
  "status": "healthy" if all(components.values()) else "unhealthy",
290
- "components": components,
291
- "provider": LLM_PROVIDER.upper(),
292
- "model": get_llm_config()["model"] if llm else "Not initialized"
293
  }
294
 
295
  except Exception as e:
296
  return {
297
  "status": "unhealthy",
298
- "error": str(e),
299
- "provider": LLM_PROVIDER.upper()
300
- }
301
-
302
- def test_llm_connection():
303
- """Test the LLM API connection."""
304
- try:
305
- if not llm:
306
- initialize_pipelines()
307
-
308
- # Simple test query
309
- test_response = llm("Say 'Hello, LLM is working!'")
310
- return {
311
- "success": True,
312
- "provider": LLM_PROVIDER.upper(),
313
- "model": get_llm_config()["model"],
314
- "response": str(test_response)
315
- }
316
- except Exception as e:
317
- return {
318
- "success": False,
319
- "provider": LLM_PROVIDER.upper(),
320
  "error": str(e)
321
  }
322
 
 
3
  from dotenv import load_dotenv
4
  from langchain_core.documents import Document
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_openai import OpenAIEmbeddings, OpenAI
 
7
  from langchain.chains.question_answering import load_qa_chain
8
  from langchain_community.vectorstores import Chroma
9
  from langchain.chains import LLMChain
10
  from langchain.prompts import PromptTemplate
 
 
11
 
12
  load_dotenv()
13
 
14
+ CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
 
15
  COLLECTION_NAME = "company_docs_collection"
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  vector_store = None
18
  company_qa_chain = None
19
  query_router_chain = None
20
  cybersecurity_chain = None
21
+ llm = OpenAI(temperature=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def initialize_pipelines():
24
  """Initializes all required models, chains, and the vector store."""
25
  global vector_store, company_qa_chain, query_router_chain, cybersecurity_chain, llm
26
 
27
  try:
28
+ embeddings = OpenAIEmbeddings()
 
 
 
 
 
 
 
 
29
 
30
  # Initialize ChromaDB client
31
  try:
32
  chroma_client = chromadb.HttpClient(host=CHROMA_HOST, port=8000)
33
+ chroma_client.heartbeat() # Heartbeat check to confirm the connection
34
+ print("Successfully connected to ChromaDB.")
35
  except Exception as e:
36
+ print(f"FATAL: Could not connect to ChromaDB at {CHROMA_HOST}:8000. Please ensure the ChromaDB server is running.")
37
+ print(f"Error details: {e}")
38
  raise ConnectionError("Failed to connect to ChromaDB.") from e
39
 
40
  # Initialize vector store
 
45
  )
46
 
47
  # Query Router Chain
48
+ router_template = """
49
+ You are a query classifier. Classify the following query into one of these categories:
50
+ - COMPANY: Questions about company policies, procedures, documents, or internal information
51
+ - CYBERSECURITY: Questions about cybersecurity, security threats, best practices, or vulnerabilities
52
+ - OFF_TOPIC: Questions that don't fit the above categories
53
+
54
+ Query: {query}
55
+
56
+ Respond with only the category name (COMPANY, CYBERSECURITY, or OFF_TOPIC):
57
+ """
58
 
59
  router_prompt = PromptTemplate(
60
  input_variables=["query"],
 
66
  prompt=router_prompt
67
  )
68
 
69
+ # Company QA Chain
70
+ company_qa_chain = load_qa_chain(llm, chain_type="stuff")
 
 
 
 
 
 
 
71
 
72
+ # Cybersecurity Chain
73
+ cybersecurity_template = """
74
+ You are a cybersecurity expert. Answer the following cybersecurity question based on your knowledge:
 
75
 
76
+ Question: {question}
 
 
 
77
 
78
+ Provide a comprehensive and accurate answer about cybersecurity:
79
+ """
 
 
 
 
 
 
80
 
81
  cybersecurity_prompt = PromptTemplate(
82
  input_variables=["question"],
 
88
  prompt=cybersecurity_prompt
89
  )
90
 
91
+ print("All pipelines initialized successfully!")
92
+
93
  except Exception as e:
94
  print(f"Error initializing pipelines: {e}")
95
  raise
 
112
  print("Document was empty after splitting, not adding to ChromaDB.")
113
  return False
114
 
115
+ print(f"Adding {len(docs)} document chunks to ChromaDB...")
116
  vector_store.add_documents(docs)
117
  print("Successfully added documents.")
118
  return True
 
133
  route_result = query_router_chain.run(query)
134
  route = route_result.strip().upper()
135
 
136
+ print(f"Query routed to: {route}")
137
 
138
  # 2. Route to appropriate logic
139
  if "CYBERSECURITY" in route:
 
141
  return {
142
  "answer": answer,
143
  "source": "Cybersecurity Knowledge Base",
144
+ "route": "CYBERSECURITY"
 
 
145
  }
146
 
147
  elif "COMPANY" in route:
148
  # Perform similarity search on ChromaDB
149
  docs = vector_store.similarity_search(query, k=3)
150
+ print(f"Found {len(docs)} relevant documents.")
151
+ print(f"Documents: {[doc.metadata.get('source', 'Unknown') for doc in docs]}")
152
 
153
  if not docs:
154
  return {
155
  "answer": "I could not find any relevant information to answer your question.",
156
  "source": "Company Documents",
157
+ "route": "COMPANY"
 
 
158
  }
159
 
160
+ # Run the QA chain
161
+ answer = company_qa_chain.run(input_documents=docs, question=query)
 
 
 
162
  sources = list(set([doc.metadata.get("source", "Unknown") for doc in docs]))
163
 
164
  return {
165
  "answer": answer,
166
  "source": "Company Documents",
167
  "documents": sources,
168
+ "route": "COMPANY"
 
 
169
  }
170
 
171
  else: # OFF_TOPIC
172
  return {
173
  "answer": "I am a specialized assistant of CyberAlertNepal. I cannot answer questions outside of cybersecurity topics.",
174
  "source": "N/A",
175
+ "route": "OFF_TOPIC"
 
 
176
  }
177
 
178
  except Exception as e:
 
180
  return {
181
  "answer": "I encountered an error while processing your query. Please try again.",
182
  "source": "Error",
 
 
 
183
  "error": str(e)
184
  }
185
 
 
195
  "vector_store": vector_store is not None,
196
  "company_qa_chain": company_qa_chain is not None,
197
  "query_router_chain": query_router_chain is not None,
198
+ "cybersecurity_chain": cybersecurity_chain is not None
 
199
  }
200
 
201
  return {
202
  "status": "healthy" if all(components.values()) else "unhealthy",
203
+ "components": components
 
 
204
  }
205
 
206
  except Exception as e:
207
  return {
208
  "status": "unhealthy",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  "error": str(e)
210
  }
211
 
requirements.txt CHANGED
@@ -21,3 +21,11 @@ tools
21
  pandas
22
  requests
23
  beautifulsoup4
 
 
 
 
 
 
 
 
 
21
  pandas
22
  requests
23
  beautifulsoup4
24
+ langchain
25
+ langchain-community
26
+ langchain-openai
27
+ faiss-cpu
28
+ PyPDF2
29
+ tiktoken
30
+ chromadb
31
+ langchain_chroma