Sangyog10 commited on
Commit
53001af
·
1 Parent(s): 273204e

configured with openrouter

Browse files
features/rag_chatbot/rag_pipeline.py CHANGED
@@ -3,38 +3,60 @@ import chromadb
3
  from dotenv import load_dotenv
4
  from langchain_core.documents import Document
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain_openai import OpenAIEmbeddings, OpenAI
 
7
  from langchain.chains.question_answering import load_qa_chain
8
  from langchain_community.vectorstores import Chroma
9
  from langchain.chains import LLMChain
10
  from langchain.prompts import PromptTemplate
 
 
11
 
12
  load_dotenv()
13
 
14
  CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
15
  COLLECTION_NAME = "company_docs_collection"
16
 
 
 
 
17
  vector_store = None
18
  company_qa_chain = None
19
  query_router_chain = None
20
  cybersecurity_chain = None
21
- llm = OpenAI(temperature=0)
22
 
23
  def initialize_pipelines():
24
  """Initializes all required models, chains, and the vector store."""
25
  global vector_store, company_qa_chain, query_router_chain, cybersecurity_chain, llm
26
 
27
  try:
28
- embeddings = OpenAIEmbeddings()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # Initialize ChromaDB client
31
  try:
32
  chroma_client = chromadb.HttpClient(host=CHROMA_HOST, port=8000)
33
  chroma_client.heartbeat() # Heartbeat check to confirm the connection
34
- print("Successfully connected to ChromaDB.")
35
  except Exception as e:
36
- print(f"FATAL: Could not connect to ChromaDB at {CHROMA_HOST}:8000. Please ensure the ChromaDB server is running.")
37
- print(f"Error details: {e}")
38
  raise ConnectionError("Failed to connect to ChromaDB.") from e
39
 
40
  # Initialize vector store
@@ -45,16 +67,14 @@ def initialize_pipelines():
45
  )
46
 
47
  # Query Router Chain
48
- router_template = """
49
- You are a query classifier. Classify the following query into one of these categories:
50
- - COMPANY: Questions about company policies, procedures, documents, or internal information
51
- - CYBERSECURITY: Questions about cybersecurity, security threats, best practices, or vulnerabilities
52
- - OFF_TOPIC: Questions that don't fit the above categories
53
-
54
- Query: {query}
55
-
56
- Respond with only the category name (COMPANY, CYBERSECURITY, or OFF_TOPIC):
57
- """
58
 
59
  router_prompt = PromptTemplate(
60
  input_variables=["query"],
@@ -70,13 +90,11 @@ def initialize_pipelines():
70
  company_qa_chain = load_qa_chain(llm, chain_type="stuff")
71
 
72
  # Cybersecurity Chain
73
- cybersecurity_template = """
74
- You are a cybersecurity expert. Answer the following cybersecurity question based on your knowledge:
75
-
76
- Question: {question}
77
-
78
- Provide a comprehensive and accurate answer about cybersecurity:
79
- """
80
 
81
  cybersecurity_prompt = PromptTemplate(
82
  input_variables=["question"],
@@ -88,8 +106,7 @@ def initialize_pipelines():
88
  prompt=cybersecurity_prompt
89
  )
90
 
91
- print("All pipelines initialized successfully!")
92
-
93
  except Exception as e:
94
  print(f"Error initializing pipelines: {e}")
95
  raise
@@ -112,7 +129,6 @@ def add_document_to_rag(text: str, metadata: dict):
112
  print("Document was empty after splitting, not adding to ChromaDB.")
113
  return False
114
 
115
- print(f"Adding {len(docs)} document chunks to ChromaDB...")
116
  vector_store.add_documents(docs)
117
  print("Successfully added documents.")
118
  return True
@@ -133,7 +149,6 @@ def route_and_process_query(query: str):
133
  route_result = query_router_chain.run(query)
134
  route = route_result.strip().upper()
135
 
136
- print(f"Query routed to: {route}")
137
 
138
  # 2. Route to appropriate logic
139
  if "CYBERSECURITY" in route:
@@ -147,8 +162,6 @@ def route_and_process_query(query: str):
147
  elif "COMPANY" in route:
148
  # Perform similarity search on ChromaDB
149
  docs = vector_store.similarity_search(query, k=3)
150
- print(f"Found {len(docs)} relevant documents.")
151
- print(f"Documents: {[doc.metadata.get('source', 'Unknown') for doc in docs]}")
152
 
153
  if not docs:
154
  return {
@@ -195,7 +208,8 @@ def check_system_health():
195
  "vector_store": vector_store is not None,
196
  "company_qa_chain": company_qa_chain is not None,
197
  "query_router_chain": query_router_chain is not None,
198
- "cybersecurity_chain": cybersecurity_chain is not None
 
199
  }
200
 
201
  return {
@@ -209,6 +223,20 @@ def check_system_health():
209
  "error": str(e)
210
  }
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  # Initialize pipelines on module import
213
  try:
214
  initialize_pipelines()
 
3
  from dotenv import load_dotenv
4
  from langchain_core.documents import Document
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
+ from langchain_community.llms import OpenAI
8
  from langchain.chains.question_answering import load_qa_chain
9
  from langchain_community.vectorstores import Chroma
10
  from langchain.chains import LLMChain
11
  from langchain.prompts import PromptTemplate
12
+ from langchain.chat_models import ChatOpenAI
13
+
14
 
15
  load_dotenv()
16
 
17
  CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost")
18
  COLLECTION_NAME = "company_docs_collection"
19
 
20
+ # OpenRouter configuration
21
+ OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
22
+
23
  vector_store = None
24
  company_qa_chain = None
25
  query_router_chain = None
26
  cybersecurity_chain = None
27
+ llm = None
28
 
29
  def initialize_pipelines():
30
  """Initializes all required models, chains, and the vector store."""
31
  global vector_store, company_qa_chain, query_router_chain, cybersecurity_chain, llm
32
 
33
  try:
34
+ # Check for required API keys
35
+ if not OPENROUTER_API_KEY:
36
+ raise ValueError("OPENROUTER_API_KEY environment variable is required")
37
+
38
+
39
+ # Initialize LLM with OpenRouter
40
+ llm = ChatOpenAI(
41
+ model="meta-llama/llama-3.3-70b-instruct:free",
42
+ openai_api_key=OPENROUTER_API_KEY,
43
+ openai_api_base="https://openrouter.ai/api/v1",
44
+ temperature=0,
45
+ max_tokens=2048,
46
+ )
47
+
48
+ embeddings = HuggingFaceEmbeddings(
49
+ model_name="all-MiniLM-L6-v2",
50
+ model_kwargs={'device': 'cpu'},
51
+ encode_kwargs={'normalize_embeddings': True} # Normalize embeddings for better similarity search
52
+ )
53
+
54
 
55
  # Initialize ChromaDB client
56
  try:
57
  chroma_client = chromadb.HttpClient(host=CHROMA_HOST, port=8000)
58
  chroma_client.heartbeat() # Heartbeat check to confirm the connection
 
59
  except Exception as e:
 
 
60
  raise ConnectionError("Failed to connect to ChromaDB.") from e
61
 
62
  # Initialize vector store
 
67
  )
68
 
69
  # Query Router Chain
70
+ router_template = """You are a query classifier. Classify the following query into one of these categories:
71
+ - COMPANY: Questions about company policies, procedures, documents, or internal information
72
+ - CYBERSECURITY: Questions about cybersecurity, security threats, best practices, or vulnerabilities
73
+ - OFF_TOPIC: Questions that don't fit the above categories
74
+
75
+ Query: {query}
76
+
77
+ Respond with only the category name (COMPANY, CYBERSECURITY, or OFF_TOPIC):"""
 
 
78
 
79
  router_prompt = PromptTemplate(
80
  input_variables=["query"],
 
90
  company_qa_chain = load_qa_chain(llm, chain_type="stuff")
91
 
92
  # Cybersecurity Chain
93
+ cybersecurity_template = """You are a cybersecurity expert. Answer the following cybersecurity question based on your knowledge without claiming yourself as expert:
94
+
95
+ Question: {question}
96
+
97
+ Provide a comprehensive and accurate answer about cybersecurity:"""
 
 
98
 
99
  cybersecurity_prompt = PromptTemplate(
100
  input_variables=["question"],
 
106
  prompt=cybersecurity_prompt
107
  )
108
 
109
+
 
110
  except Exception as e:
111
  print(f"Error initializing pipelines: {e}")
112
  raise
 
129
  print("Document was empty after splitting, not adding to ChromaDB.")
130
  return False
131
 
 
132
  vector_store.add_documents(docs)
133
  print("Successfully added documents.")
134
  return True
 
149
  route_result = query_router_chain.run(query)
150
  route = route_result.strip().upper()
151
 
 
152
 
153
  # 2. Route to appropriate logic
154
  if "CYBERSECURITY" in route:
 
162
  elif "COMPANY" in route:
163
  # Perform similarity search on ChromaDB
164
  docs = vector_store.similarity_search(query, k=3)
 
 
165
 
166
  if not docs:
167
  return {
 
208
  "vector_store": vector_store is not None,
209
  "company_qa_chain": company_qa_chain is not None,
210
  "query_router_chain": query_router_chain is not None,
211
+ "cybersecurity_chain": cybersecurity_chain is not None,
212
+ "llm": llm is not None
213
  }
214
 
215
  return {
 
223
  "error": str(e)
224
  }
225
 
226
+ # Test function to verify OpenRouter connection
227
+ def test_openrouter_connection():
228
+ """Test the OpenRouter API connection."""
229
+ try:
230
+ if not llm:
231
+ initialize_pipelines()
232
+
233
+ # Simple test query
234
+ test_response = llm("Say 'Hello, OpenRouter is working!'")
235
+ return True
236
+ except Exception as e:
237
+ print(f"OpenRouter connection test failed: {e}")
238
+ return False
239
+
240
  # Initialize pipelines on module import
241
  try:
242
  initialize_pipelines()
requirements.txt CHANGED
@@ -28,4 +28,6 @@ faiss-cpu
28
  PyPDF2
29
  tiktoken
30
  chromadb
31
- langchain_chroma
 
 
 
28
  PyPDF2
29
  tiktoken
30
  chromadb
31
+ langchain_chroma
32
+ sentence-transformers
33
+ tf-keras