junaid17 commited on
Commit
ec4731a
·
verified ·
1 Parent(s): 7a65abf

Update RAG.py

Browse files
Files changed (1) hide show
  1. RAG.py +309 -309
RAG.py CHANGED
@@ -1,310 +1,310 @@
1
- from langgraph.graph import StateGraph, START, END
2
- from typing import TypedDict, Annotated
3
- from langchain_groq import ChatGroq
4
- from langchain_openai import OpenAIEmbeddings, ChatOpenAI
5
- from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
6
- from langgraph.graph.message import add_messages
7
- from langchain_core.tools import tool
8
- from dotenv import load_dotenv
9
- from langgraph.checkpoint.memory import MemorySaver
10
- import os
11
- from langchain_community.vectorstores import FAISS
12
- from langchain_community.tools.tavily_search import TavilySearchResults
13
-
14
- load_dotenv()
15
-
16
-
17
- #===========================================
18
- # Load FAISS DB & Reload Logic [FEATURE ADDED]
19
- #===========================================
20
-
21
- FAISS_DB_PATH = "vectorstore/db_faiss"
22
- embeddings = OpenAIEmbeddings(model='text-embedding-3-small')
23
-
24
- # Global variable for the database
25
- db = None
26
-
27
- def reload_vector_store():
28
- """
29
- Reloads the FAISS index from disk.
30
- Call this function after a new file is ingested.
31
- """
32
- global db
33
- if os.path.exists(FAISS_DB_PATH):
34
- print(f"Loading FAISS from {FAISS_DB_PATH}...")
35
- try:
36
- db = FAISS.load_local(
37
- FAISS_DB_PATH,
38
- embeddings,
39
- allow_dangerous_deserialization=True
40
- )
41
- print("Vector store loaded successfully.")
42
- except Exception as e:
43
- print(f"Error loading vector store: {e}")
44
- db = None
45
- else:
46
- print("Warning: No Vector DB found. Please run ingestion first.")
47
- db = None
48
-
49
- # Initial Load
50
- reload_vector_store()
51
-
52
-
53
- #===========================================
54
- # Class Schema
55
- #===========================================
56
-
57
- class Ragbot_State(TypedDict):
58
- query : str
59
- context : list[str]
60
- metadata : list[dict]
61
- RAG : bool
62
- web_search : bool
63
- model_name : str
64
- web_context : str
65
- response : Annotated[list[BaseMessage], add_messages]
66
-
67
- #===========================================
68
- # LLM'S
69
- #===========================================
70
-
71
-
72
- llm_kimi2 = ChatGroq(model='moonshotai/kimi-k2-instruct-0905', streaming=True, temperature=0.4)
73
- llm_gpt = ChatOpenAI(model='gpt-4.1-nano', streaming=True, temperature=0.2)
74
- llm_gpt_oss = ChatGroq(model='openai/gpt-oss-120b', streaming=True, temperature=0.3)
75
- llm_lamma4 = ChatGroq(model='meta-llama/llama-4-scout-17b-16e-instruct', streaming=True, temperature=0.5)
76
- llm_qwen3 = ChatGroq(model='qwen/qwen3-32b', streaming=True, temperature=0.5)
77
-
78
- def get_llm(model_name: str):
79
- if model_name == "kimi2":
80
- return llm_kimi2
81
- elif model_name == "gpt":
82
- return llm_gpt
83
- elif model_name == "gpt_oss":
84
- return llm_gpt_oss
85
- elif model_name == "lamma4":
86
- return llm_lamma4
87
- elif model_name == "qwen3":
88
- return llm_qwen3
89
- else:
90
- return llm_gpt # fallback if no match
91
-
92
- #===========================================
93
- # Search tool
94
- #===========================================
95
-
96
- @tool
97
- def tavily_search(query: str) -> dict:
98
- """
99
- Perform a real-time web search using Tavily.
100
- """
101
- try:
102
- search = TavilySearchResults(max_results=2)
103
- results = search.run(query)
104
- return {"query": query, "results": results}
105
- except Exception as e:
106
- return {"error": str(e)}
107
-
108
- #===========================================
109
- # fetching web context
110
- #===========================================
111
-
112
- def fetch_web_context(state: Ragbot_State):
113
- user_query = state["query"]
114
-
115
- enriched_query = f"""
116
- Fetch the latest, accurate, and up-to-date information about:
117
- {user_query}
118
-
119
- Focus on:
120
- - recent news
121
- - official announcements
122
- - verified sources
123
- - factual data
124
- """
125
-
126
- web_result = tavily_search.run(enriched_query)
127
-
128
- return {
129
- "web_context": str(web_result)
130
- }
131
-
132
- #===========================================
133
- # db search
134
- #===========================================
135
-
136
- @tool
137
- def faiss_search(query: str) -> str:
138
- """Search the FAISS vectorstore and return relevant documents."""
139
- # Check global db variable
140
- if db is None:
141
- return "No documents have been uploaded yet.", []
142
-
143
- try:
144
- results = db.similarity_search(query, k=3)
145
- context = "\n\n".join([doc.page_content for doc in results])
146
- metadata = [doc.metadata for doc in results]
147
- return context, metadata
148
- except Exception as e:
149
- return f"Error searching vector store: {str(e)}", []
150
-
151
- #===========================================
152
- # router
153
- #===========================================
154
-
155
-
156
- def router(state: Ragbot_State):
157
- if state["RAG"]:
158
- return "fetch_context"
159
-
160
- if state["web_search"]:
161
- return "fetch_web_context"
162
-
163
- return "chat"
164
-
165
- #===========================================
166
- # fetching context
167
- #===========================================
168
-
169
- def fetch_context(state: Ragbot_State):
170
- query = state["query"]
171
- context, metadata = faiss_search.invoke({"query": query})
172
- return {"context": [context], "metadata": [metadata]}
173
-
174
-
175
- #===========================================
176
- # system prompt
177
- #===========================================
178
-
179
-
180
- SYSTEM_PROMPT = SystemMessage(
181
- content="""
182
- You are an intelligent conversational assistant and retrieval-augmented AI system built by Junaid.
183
-
184
- Your role is to:
185
- - Engage naturally in conversation like a friendly, helpful chatbot.
186
- - Answer general questions using your own knowledge when no external context is provided.
187
- - When relevant context is provided, use it accurately to answer user questions.
188
- - Seamlessly switch between casual conversation and knowledge-based answering.
189
-
190
- Guidelines:
191
- - If context is provided and relevant, use it as the primary source of truth.
192
- - If context is not provided or not relevant, respond using your general knowledge.
193
- - Do not hallucinate or invent information.
194
- - If you are unsure or the information is not available, clearly state that.
195
- - Be clear, concise, and helpful in all responses.
196
- - Maintain a natural, human-like conversational tone.
197
- - Never mention internal implementation details such as embeddings, vector databases, or system architecture.
198
-
199
- You are designed to provide reliable, accurate, and engaging assistance.
200
- """
201
- )
202
-
203
- #===========================================
204
- # Chat function
205
- #===========================================
206
-
207
- def chat(state:Ragbot_State):
208
- query = state['query']
209
- context = state['context']
210
- metadata = state['metadata']
211
- web_context = state['web_context']
212
- model_name = state.get('model_name', 'gpt')
213
-
214
- history = state.get("response", [])
215
-
216
- # [CHANGED] Updated Prompt to include History so it remembers your name
217
- prompt = f"""
218
- You are an expert assistant designed to answer user questions using multiple information sources.
219
-
220
- Source Priority Rules (STRICT):
221
- 1. **Conversation History**: Check if the answer was provided in previous messages (e.g., user's name, previous topics).
222
- 2. If the provided Context contains the answer, use ONLY the Context.
223
- 3. If the Context does not contain the answer and Web Context is available, use the Web Context.
224
- 4. If neither Context nor Web Context contains the answer, use your general knowledge.
225
- 5. Do NOT invent or hallucinate facts.
226
- 6. If the answer cannot be determined, clearly say so.
227
-
228
- User Question:
229
- {query}
230
-
231
- Retrieved Context (Vector Database):
232
- {context}
233
-
234
- Metadata:
235
- {metadata}
236
-
237
- Web Context (Real-time Search):
238
- {web_context}
239
-
240
- Final Answer:
241
- """
242
-
243
- selected_llm = get_llm(model_name)
244
- messages = [SYSTEM_PROMPT] + history + [HumanMessage(content=prompt)]
245
- response = selected_llm.invoke(messages)
246
- return {
247
- 'response': [
248
- HumanMessage(content=query),
249
- response
250
- ]
251
- }
252
-
253
- #===========================================
254
- # Graph Declaration
255
- #===========================================
256
-
257
- # Keeping MemorySaver as requested (Note: RAM only, wipes on restart)
258
- memory = MemorySaver()
259
- graph = StateGraph(Ragbot_State)
260
-
261
- graph.add_node("fetch_context", fetch_context)
262
- graph.add_node("fetch_web_context", fetch_web_context)
263
- graph.add_node("chat", chat)
264
-
265
- graph.add_conditional_edges(
266
- START,
267
- router,
268
- {
269
- "fetch_context": "fetch_context",
270
- "fetch_web_context": "fetch_web_context",
271
- "chat": "chat"
272
- }
273
- )
274
-
275
- graph.add_edge("fetch_context", "chat")
276
- graph.add_edge("fetch_web_context", "chat")
277
- graph.add_edge("chat", END)
278
-
279
- app = graph.compile(checkpointer=memory)
280
-
281
-
282
- #===========================================
283
- # Helper Function
284
- #===========================================
285
-
286
- def ask_bot(query: str, use_rag: bool = False, use_web: bool = False, thread_id: str = "1"):
287
- config = {"configurable": {"thread_id": thread_id}}
288
- inputs = {
289
- "query": query,
290
- "RAG": use_rag,
291
- "web_search": use_web,
292
- "context": [],
293
- "metadata": [],
294
- "web_context": "",
295
- }
296
-
297
- result = app.invoke(inputs, config=config)
298
- last_message = result['response'][-1]
299
-
300
- return last_message.content
301
-
302
-
303
- """print("--- Conversation 1 ---")
304
- # User says hello and gives name
305
- response = ask_bot("Hi, my name is Junaid", thread_id="session_A")
306
- print(f"Bot: {response}")
307
-
308
- # User asks for name (RAG and Web are OFF)
309
- response = ask_bot("What is my name?", thread_id="session_A")
310
  print(f"Bot: {response}")"""
 
1
+ from langgraph.graph import StateGraph, START, END
2
+ from typing import TypedDict, Annotated
3
+ from langchain_groq import ChatGroq
4
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
5
+ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
6
+ from langgraph.graph.message import add_messages
7
+ from langchain_core.tools import tool
8
+ from dotenv import load_dotenv
9
+ from langgraph.checkpoint.memory import MemorySaver
10
+ import os
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain_community.tools.tavily_search import TavilySearchResults
13
+
14
+ load_dotenv()
15
+
16
+
17
+ #===========================================
18
+ # Load FAISS DB & Reload Logic [FEATURE ADDED]
19
+ #===========================================
20
+
21
+ FAISS_DB_PATH = "vectorstore/db_faiss"
22
+ embeddings = OpenAIEmbeddings(model='text-embedding-3-small')
23
+
24
+ # Global variable for the database
25
+ db = None
26
+
27
+ def reload_vector_store():
28
+ """
29
+ Reloads the FAISS index from disk.
30
+ Call this function after a new file is ingested.
31
+ """
32
+ global db
33
+ if os.path.exists(FAISS_DB_PATH):
34
+ print(f"Loading FAISS from {FAISS_DB_PATH}...")
35
+ try:
36
+ db = FAISS.load_local(
37
+ FAISS_DB_PATH,
38
+ embeddings,
39
+ allow_dangerous_deserialization=True
40
+ )
41
+ print("Vector store loaded successfully.")
42
+ except Exception as e:
43
+ print(f"Error loading vector store: {e}")
44
+ db = None
45
+ else:
46
+ print("Warning: No Vector DB found. Please run ingestion first.")
47
+ db = None
48
+
49
+ # Initial Load
50
+ reload_vector_store()
51
+
52
+
53
+ #===========================================
54
+ # Class Schema
55
+ #===========================================
56
+
57
+ class Ragbot_State(TypedDict):
58
+ query : str
59
+ context : list[str]
60
+ metadata : list[dict]
61
+ RAG : bool
62
+ web_search : bool
63
+ model_name : str
64
+ web_context : str
65
+ response : Annotated[list[BaseMessage], add_messages]
66
+
67
+ #===========================================
68
+ # LLM'S
69
+ #===========================================
70
+
71
+
72
+ llm_kimi2 = ChatGroq(model='moonshotai/kimi-k2-instruct-0905', streaming=True, temperature=0.4)
73
+ llm_gpt = ChatOpenAI(model='gpt-4.1-nano', streaming=True, temperature=0.2)
74
+ llm_gpt_oss = ChatGroq(model='openai/gpt-oss-120b', streaming=True, temperature=0.3)
75
+ llm_lamma4 = ChatGroq(model='meta-llama/llama-4-scout-17b-16e-instruct', streaming=True, temperature=0.5)
76
+ llm_qwen3 = ChatGroq(model='qwen/qwen3-32b', streaming=True, temperature=0.5)
77
+
78
+ def get_llm(model_name: str):
79
+ if model_name == "kimi2":
80
+ return llm_kimi2
81
+ elif model_name == "gpt":
82
+ return llm_gpt
83
+ elif model_name == "gpt_oss":
84
+ return llm_gpt_oss
85
+ elif model_name == "lamma4":
86
+ return llm_lamma4
87
+ elif model_name == "qwen3":
88
+ return llm_qwen3
89
+ else:
90
+ return llm_gpt # fallback if no match
91
+
92
+ #===========================================
93
+ # Search tool
94
+ #===========================================
95
+
96
+ @tool
97
+ def tavily_search(query: str) -> dict:
98
+ """
99
+ Perform a real-time web search using Tavily.
100
+ """
101
+ try:
102
+ search = TavilySearchResults(max_results=2)
103
+ results = search.run(query)
104
+ return {"query": query, "results": results}
105
+ except Exception as e:
106
+ return {"error": str(e)}
107
+
108
+ #===========================================
109
+ # fetching web context
110
+ #===========================================
111
+
112
+ def fetch_web_context(state: Ragbot_State):
113
+ user_query = state["query"]
114
+
115
+ enriched_query = f"""
116
+ Fetch the latest, accurate, and up-to-date information about:
117
+ {user_query}
118
+
119
+ Focus on:
120
+ - recent news
121
+ - official announcements
122
+ - verified sources
123
+ - factual data
124
+ """
125
+
126
+ web_result = tavily_search.run(enriched_query)
127
+
128
+ return {
129
+ "web_context": str(web_result)
130
+ }
131
+
132
+ #===========================================
133
+ # db search
134
+ #===========================================
135
+
136
+ @tool
137
+ def faiss_search(query: str) -> str:
138
+ """Search the FAISS vectorstore and return relevant documents."""
139
+ # Check global db variable
140
+ if db is None:
141
+ return "No documents have been uploaded yet.", []
142
+
143
+ try:
144
+ results = db.similarity_search(query, k=3)
145
+ context = "\n\n".join([doc.page_content for doc in results])
146
+ metadata = [doc.metadata for doc in results]
147
+ return context, metadata
148
+ except Exception as e:
149
+ return f"Error searching vector store: {str(e)}", []
150
+
151
+ #===========================================
152
+ # router
153
+ #===========================================
154
+
155
+
156
+ def router(state: Ragbot_State):
157
+ if state["RAG"]:
158
+ return "fetch_context"
159
+
160
+ if state["web_search"]:
161
+ return "fetch_web_context"
162
+
163
+ return "chat"
164
+
165
+ #===========================================
166
+ # fetching context
167
+ #===========================================
168
+
169
+ def fetch_context(state: Ragbot_State):
170
+ query = state["query"]
171
+ context, metadata = faiss_search.invoke({"query": query})
172
+ return {"context": [context], "metadata": [metadata]}
173
+
174
+
175
+ #===========================================
176
+ # system prompt
177
+ #===========================================
178
+
179
+
180
+ SYSTEM_PROMPT = SystemMessage(
181
+ content="""
182
+ You are an intelligent conversational assistant and retrieval-augmented AI system built by Junaid.
183
+
184
+ Your role is to:
185
+ - Engage naturally in conversation like a friendly, helpful chatbot.
186
+ - Answer general questions using your own knowledge when no external context is provided.
187
+ - When relevant context is provided, use it accurately to answer user questions.
188
+ - Seamlessly switch between casual conversation and knowledge-based answering.
189
+
190
+ Guidelines:
191
+ - If context is provided and relevant, use it as the primary source of truth.
192
+ - If context is not provided or not relevant, respond using your general knowledge.
193
+ - Do not hallucinate or invent information.
194
+ - If you are unsure or the information is not available, clearly state that.
195
+ - Be clear, concise, and helpful in all responses.
196
+ - Maintain a natural, human-like conversational tone.
197
+ - Never mention internal implementation details such as embeddings, vector databases, or system architecture.
198
+
199
+ You are designed to provide reliable, accurate, and engaging assistance.
200
+ """
201
+ )
202
+
203
+ #===========================================
204
+ # Chat function
205
+ #===========================================
206
+
207
+ async def chat(state:Ragbot_State):
208
+ query = state['query']
209
+ context = state['context']
210
+ metadata = state['metadata']
211
+ web_context = state['web_context']
212
+ model_name = state.get('model_name', 'gpt')
213
+
214
+ history = state.get("response", [])
215
+
216
+ # [CHANGED] Updated Prompt to include History so it remembers your name
217
+ prompt = f"""
218
+ You are an expert assistant designed to answer user questions using multiple information sources.
219
+
220
+ Source Priority Rules (STRICT):
221
+ 1. **Conversation History**: Check if the answer was provided in previous messages (e.g., user's name, previous topics).
222
+ 2. If the provided Context contains the answer, use ONLY the Context.
223
+ 3. If the Context does not contain the answer and Web Context is available, use the Web Context.
224
+ 4. If neither Context nor Web Context contains the answer, use your general knowledge.
225
+ 5. Do NOT invent or hallucinate facts.
226
+ 6. If the answer cannot be determined, clearly say so.
227
+
228
+ User Question:
229
+ {query}
230
+
231
+ Retrieved Context (Vector Database):
232
+ {context}
233
+
234
+ Metadata:
235
+ {metadata}
236
+
237
+ Web Context (Real-time Search):
238
+ {web_context}
239
+
240
+ Final Answer:
241
+ """
242
+
243
+ selected_llm = await get_llm(model_name)
244
+ messages = [SYSTEM_PROMPT] + history + [HumanMessage(content=prompt)]
245
+ response = selected_llm.invoke(messages)
246
+ return {
247
+ 'response': [
248
+ HumanMessage(content=query),
249
+ response
250
+ ]
251
+ }
252
+
253
+ #===========================================
254
+ # Graph Declaration
255
+ #===========================================
256
+
257
+ # Keeping MemorySaver as requested (Note: RAM only, wipes on restart)
258
+ memory = MemorySaver()
259
+ graph = StateGraph(Ragbot_State)
260
+
261
+ graph.add_node("fetch_context", fetch_context)
262
+ graph.add_node("fetch_web_context", fetch_web_context)
263
+ graph.add_node("chat", chat)
264
+
265
+ graph.add_conditional_edges(
266
+ START,
267
+ router,
268
+ {
269
+ "fetch_context": "fetch_context",
270
+ "fetch_web_context": "fetch_web_context",
271
+ "chat": "chat"
272
+ }
273
+ )
274
+
275
+ graph.add_edge("fetch_context", "chat")
276
+ graph.add_edge("fetch_web_context", "chat")
277
+ graph.add_edge("chat", END)
278
+
279
+ app = graph.compile(checkpointer=memory)
280
+
281
+
282
+ #===========================================
283
+ # Helper Function
284
+ #===========================================
285
+
286
+ def ask_bot(query: str, use_rag: bool = False, use_web: bool = False, thread_id: str = "1"):
287
+ config = {"configurable": {"thread_id": thread_id}}
288
+ inputs = {
289
+ "query": query,
290
+ "RAG": use_rag,
291
+ "web_search": use_web,
292
+ "context": [],
293
+ "metadata": [],
294
+ "web_context": "",
295
+ }
296
+
297
+ result = app.invoke(inputs, config=config)
298
+ last_message = result['response'][-1]
299
+
300
+ return last_message.content
301
+
302
+
303
+ """print("--- Conversation 1 ---")
304
+ # User says hello and gives name
305
+ response = ask_bot("Hi, my name is Junaid", thread_id="session_A")
306
+ print(f"Bot: {response}")
307
+
308
+ # User asks for name (RAG and Web are OFF)
309
+ response = ask_bot("What is my name?", thread_id="session_A")
310
  print(f"Bot: {response}")"""