Sangyog10 commited on
Commit
350fb44
·
1 Parent(s): 339f42a

Added 3 model selection with env support

Browse files
.env-example CHANGED
@@ -31,4 +31,4 @@ MY_SECRET_TOKEN="SECRET_CODE_TOKEN"
31
  # LLM_TEMPERATURE=0.1
32
 
33
  # Maximum tokens for response
34
- # LLM_MAX_TOKENS=4096
 
31
  # LLM_TEMPERATURE=0.1
32
 
33
  # Maximum tokens for response
34
+ # LLM_MAX_TOKENS=4096
features/rag_chatbot/rag_pipeline.py CHANGED
@@ -14,11 +14,32 @@ from langchain.chat_models import ChatOpenAI
14
 
15
  load_dotenv()
16
 
17
- CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
 
18
  COLLECTION_NAME = "company_docs_collection"
19
 
20
- # OpenRouter configuration
21
- OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  vector_store = None
24
  company_qa_chain = None
@@ -26,36 +47,54 @@ query_router_chain = None
26
  cybersecurity_chain = None
27
  llm = None
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def initialize_pipelines():
30
  """Initializes all required models, chains, and the vector store."""
31
  global vector_store, company_qa_chain, query_router_chain, cybersecurity_chain, llm
32
 
33
  try:
34
- # Check for required API keys
35
- if not OPENROUTER_API_KEY:
36
- raise ValueError("OPENROUTER_API_KEY environment variable is required")
37
-
38
-
39
- # Initialize LLM with OpenRouter
40
- llm = ChatOpenAI(
41
- model="meta-llama/llama-3.3-70b-instruct:free",
42
- openai_api_key=OPENROUTER_API_KEY,
43
- openai_api_base="https://openrouter.ai/api/v1",
44
- temperature=0,
45
- max_tokens=2048,
46
- )
47
 
 
48
  embeddings = HuggingFaceEmbeddings(
49
  model_name="all-MiniLM-L6-v2",
50
  model_kwargs={'device': 'cpu'},
51
- encode_kwargs={'normalize_embeddings': True} # Normalize embeddings for better similarity search
52
  )
53
 
54
-
55
  # Initialize ChromaDB client
56
  try:
57
  chroma_client = chromadb.HttpClient(host=CHROMA_HOST, port=8000)
58
- chroma_client.heartbeat() # Heartbeat check to confirm the connection
59
  except Exception as e:
60
  raise ConnectionError("Failed to connect to ChromaDB.") from e
61
 
@@ -86,8 +125,8 @@ Respond with only the category name (COMPANY, CYBERSECURITY, or OFF_TOPIC):"""
86
  prompt=router_prompt
87
  )
88
 
89
- # Custom Company QA Chain with natural prompt
90
- company_qa_template = """You are a helpful assistant for CyberAlertNepal. Answer the following question about our company using the information provided and links if only available. Give a natural, direct and polite response .
91
 
92
  Question: {question}
93
 
@@ -109,8 +148,7 @@ Answer:"""
109
  # Cybersecurity Chain
110
  cybersecurity_template = """You are a cybersecurity professional. Answer the following question truthfully and concisely.
111
  If you are not 100% sure about the answer, simply respond with: "I am not sure about the answer."
112
- Do not add extra explanations or assumptions. Do not provide false or speculative information.
113
-
114
 
115
  Question: {question}
116
 
@@ -126,6 +164,7 @@ Provide a comprehensive and accurate answer about cybersecurity:"""
126
  prompt=cybersecurity_prompt
127
  )
128
 
 
129
 
130
  except Exception as e:
131
  print(f"Error initializing pipelines: {e}")
@@ -176,7 +215,9 @@ def route_and_process_query(query: str):
176
  return {
177
  "answer": answer,
178
  "source": "Cybersecurity Knowledge Base",
179
- "route": "CYBERSECURITY"
 
 
180
  }
181
 
182
  elif "COMPANY" in route:
@@ -187,7 +228,9 @@ def route_and_process_query(query: str):
187
  return {
188
  "answer": "I could not find any relevant information to answer your question.",
189
  "source": "Company Documents",
190
- "route": "COMPANY"
 
 
191
  }
192
 
193
  # Combine document content for context
@@ -201,14 +244,18 @@ def route_and_process_query(query: str):
201
  "answer": answer,
202
  "source": "Company Documents",
203
  "documents": sources,
204
- "route": "COMPANY"
 
 
205
  }
206
 
207
  else: # OFF_TOPIC
208
  return {
209
  "answer": "I am a specialized assistant of CyberAlertNepal. I cannot answer questions outside of cybersecurity topics.",
210
  "source": "N/A",
211
- "route": "OFF_TOPIC"
 
 
212
  }
213
 
214
  except Exception as e:
@@ -216,6 +263,9 @@ def route_and_process_query(query: str):
216
  return {
217
  "answer": "I encountered an error while processing your query. Please try again.",
218
  "source": "Error",
 
 
 
219
  "error": str(e)
220
  }
221
 
@@ -237,28 +287,38 @@ def check_system_health():
237
 
238
  return {
239
  "status": "healthy" if all(components.values()) else "unhealthy",
240
- "components": components
 
 
241
  }
242
 
243
  except Exception as e:
244
  return {
245
  "status": "unhealthy",
246
- "error": str(e)
 
247
  }
248
 
249
- # Test function to verify OpenRouter connection
250
- def test_openrouter_connection():
251
- """Test the OpenRouter API connection."""
252
  try:
253
  if not llm:
254
  initialize_pipelines()
255
 
256
  # Simple test query
257
- test_response = llm("Say 'Hello, OpenRouter is working!'")
258
- return True
 
 
 
 
 
259
  except Exception as e:
260
- print(f"OpenRouter connection test failed: {e}")
261
- return False
 
 
 
262
 
263
  # Initialize pipelines on module import
264
  try:
 
14
 
15
  load_dotenv()
16
 
17
+ # ChromaDB configuration
18
+ CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost") # change in env in production when hosted
19
  COLLECTION_NAME = "company_docs_collection"
20
 
21
+ # LLM Provider Configuration
22
+ LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower()
23
+ LLM_API_KEY = os.getenv("LLM_API_KEY")
24
+ LLM_MODEL = os.getenv("LLM_MODEL", "gpt-3.5-turbo")
25
+ LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0"))
26
+ LLM_MAX_TOKENS = int(os.getenv("LLM_MAX_TOKENS", "2048"))
27
+
28
+ # Provider-specific configurations
29
+ PROVIDER_CONFIGS = {
30
+ "openai": {
31
+ "api_base": "https://api.openai.com/v1",
32
+ "default_model": "gpt-3.5-turbo"
33
+ },
34
+ "groq": {
35
+ "api_base": "https://api.groq.com/openai/v1",
36
+ "default_model": "llama-3.3-70b-versatile"
37
+ },
38
+ "openrouter": {
39
+ "api_base": "https://openrouter.ai/api/v1",
40
+ "default_model": "mistralai/mistral-small-3.2-24b-instruct:free"
41
+ }
42
+ }
43
 
44
  vector_store = None
45
  company_qa_chain = None
 
47
  cybersecurity_chain = None
48
  llm = None
49
 
50
+ def get_llm_config():
51
+ """Get the appropriate LLM configuration based on the provider."""
52
+ if LLM_PROVIDER not in PROVIDER_CONFIGS:
53
+ raise ValueError(f"Unsupported LLM provider: {LLM_PROVIDER}. Supported: {list(PROVIDER_CONFIGS.keys())}")
54
+
55
+ config = PROVIDER_CONFIGS[LLM_PROVIDER].copy()
56
+
57
+ # Use provided model or fall back to default
58
+ model = LLM_MODEL if LLM_MODEL != "gpt-3.5-turbo" else config["default_model"]
59
+
60
+ return {
61
+ "model": model,
62
+ "openai_api_key": LLM_API_KEY,
63
+ "openai_api_base": config["api_base"],
64
+ "temperature": LLM_TEMPERATURE,
65
+ "max_tokens": LLM_MAX_TOKENS,
66
+ }
67
+
68
+ def initialize_llm():
69
+ """Initialize the LLM based on the configured provider."""
70
+ if not LLM_API_KEY:
71
+ raise ValueError(f"LLM_API_KEY environment variable is required for {LLM_PROVIDER}")
72
+
73
+ config = get_llm_config()
74
+
75
+ print(f"Initializing {LLM_PROVIDER.upper()} with model: {config['model']}")
76
+
77
+ return ChatOpenAI(**config)
78
+
79
  def initialize_pipelines():
80
  """Initializes all required models, chains, and the vector store."""
81
  global vector_store, company_qa_chain, query_router_chain, cybersecurity_chain, llm
82
 
83
  try:
84
+ # Initialize LLM
85
+ llm = initialize_llm()
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ # Initialize embeddings
88
  embeddings = HuggingFaceEmbeddings(
89
  model_name="all-MiniLM-L6-v2",
90
  model_kwargs={'device': 'cpu'},
91
+ encode_kwargs={'normalize_embeddings': True}
92
  )
93
 
 
94
  # Initialize ChromaDB client
95
  try:
96
  chroma_client = chromadb.HttpClient(host=CHROMA_HOST, port=8000)
97
+ chroma_client.heartbeat()
98
  except Exception as e:
99
  raise ConnectionError("Failed to connect to ChromaDB.") from e
100
 
 
125
  prompt=router_prompt
126
  )
127
 
128
+ # Custom Company QA Chain
129
+ company_qa_template = """You are a helpful assistant for CyberAlertNepal. Answer the following question about our company using the information provided and links if only available. Give a natural, direct and polite response.
130
 
131
  Question: {question}
132
 
 
148
  # Cybersecurity Chain
149
  cybersecurity_template = """You are a cybersecurity professional. Answer the following question truthfully and concisely.
150
  If you are not 100% sure about the answer, simply respond with: "I am not sure about the answer."
151
+ Do not add extra explanations or assumptions. Do not provide false or speculative information.
 
152
 
153
  Question: {question}
154
 
 
164
  prompt=cybersecurity_prompt
165
  )
166
 
167
+ print(f"Successfully initialized pipelines with {LLM_PROVIDER.upper()}")
168
 
169
  except Exception as e:
170
  print(f"Error initializing pipelines: {e}")
 
215
  return {
216
  "answer": answer,
217
  "source": "Cybersecurity Knowledge Base",
218
+ "route": "CYBERSECURITY",
219
+ "provider": LLM_PROVIDER.upper(),
220
+ "model": get_llm_config()["model"]
221
  }
222
 
223
  elif "COMPANY" in route:
 
228
  return {
229
  "answer": "I could not find any relevant information to answer your question.",
230
  "source": "Company Documents",
231
+ "route": "COMPANY",
232
+ "provider": LLM_PROVIDER.upper(),
233
+ "model": get_llm_config()["model"]
234
  }
235
 
236
  # Combine document content for context
 
244
  "answer": answer,
245
  "source": "Company Documents",
246
  "documents": sources,
247
+ "route": "COMPANY",
248
+ "provider": LLM_PROVIDER.upper(),
249
+ "model": get_llm_config()["model"]
250
  }
251
 
252
  else: # OFF_TOPIC
253
  return {
254
  "answer": "I am a specialized assistant of CyberAlertNepal. I cannot answer questions outside of cybersecurity topics.",
255
  "source": "N/A",
256
+ "route": "OFF_TOPIC",
257
+ "provider": LLM_PROVIDER.upper(),
258
+ "model": get_llm_config()["model"]
259
  }
260
 
261
  except Exception as e:
 
263
  return {
264
  "answer": "I encountered an error while processing your query. Please try again.",
265
  "source": "Error",
266
+ "route": None,
267
+ "documents": None,
268
+ "provider": LLM_PROVIDER.upper(),
269
  "error": str(e)
270
  }
271
 
 
287
 
288
  return {
289
  "status": "healthy" if all(components.values()) else "unhealthy",
290
+ "components": components,
291
+ "provider": LLM_PROVIDER.upper(),
292
+ "model": get_llm_config()["model"] if llm else "Not initialized"
293
  }
294
 
295
  except Exception as e:
296
  return {
297
  "status": "unhealthy",
298
+ "error": str(e),
299
+ "provider": LLM_PROVIDER.upper()
300
  }
301
 
302
+ def test_llm_connection():
303
+ """Test the LLM API connection."""
 
304
  try:
305
  if not llm:
306
  initialize_pipelines()
307
 
308
  # Simple test query
309
+ test_response = llm("Say 'Hello, LLM is working!'")
310
+ return {
311
+ "success": True,
312
+ "provider": LLM_PROVIDER.upper(),
313
+ "model": get_llm_config()["model"],
314
+ "response": str(test_response)
315
+ }
316
  except Exception as e:
317
+ return {
318
+ "success": False,
319
+ "provider": LLM_PROVIDER.upper(),
320
+ "error": str(e)
321
+ }
322
 
323
  # Initialize pipelines on module import
324
  try: