Shaukat39 commited on
Commit
0e1911a
·
verified ·
1 Parent(s): 80daa14

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +391 -177
agent.py CHANGED
@@ -1,265 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """LangGraph Agent"""
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition, ToolNode
 
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
  from langchain_groq import ChatGroq
8
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
9
  from langchain_community.tools.tavily_search import TavilySearchResults
10
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
 
11
  from langchain_community.vectorstores import SupabaseVectorStore
12
  from langchain_core.messages import SystemMessage, HumanMessage
13
  from langchain_core.tools import tool
14
  from langchain.tools.retriever import create_retriever_tool
15
- from supabase.client import create_client
16
- import re
17
- import traceback
18
 
19
  load_dotenv()
20
 
21
- # ------------------ Arithmetic Tools ------------------
22
-
23
  @tool
24
- def multiply(a: int, b: int) -> str:
25
- """
26
- Multiply two integers and return the result as a string.
27
-
28
  Args:
29
- a (int): The first integer.
30
- b (int): The second integer.
31
-
32
- Returns:
33
- str: The product of a and b, as a string.
34
  """
35
- return str(a * b)
36
-
37
 
38
  @tool
39
- def add(a: int, b: int) -> str:
40
- """
41
- Add two integers and return the result as a string.
42
-
43
  Args:
44
- a (int): The first integer.
45
- b (int): The second integer.
46
-
47
- Returns:
48
- str: The sum of a and b, as a string.
49
  """
50
- return str(a + b)
51
-
52
 
53
  @tool
54
- def subtract(a: int, b: int) -> str:
55
- """
56
- Subtract one integer from another and return the result as a string.
57
-
58
  Args:
59
- a (int): The minuend.
60
- b (int): The subtrahend.
61
-
62
- Returns:
63
- str: The difference (a - b), as a string.
64
  """
65
- return str(a - b)
66
-
67
 
68
  @tool
69
- def divide(a: int, b: int) -> str:
70
- """
71
- Divide one integer by another and return the result as a string.
72
-
73
  Args:
74
- a (int): The numerator.
75
- b (int): The denominator. Must not be zero.
76
-
77
- Returns:
78
- str: The result of the division (a / b), as a string. Returns an error message if b is zero.
79
  """
80
  if b == 0:
81
- return "Error: Cannot divide by zero."
82
- return str(a / b)
83
-
84
 
85
  @tool
86
- def modulus(a: int, b: int) -> str:
87
- """
88
- Compute the modulus (remainder) of two integers and return the result as a string.
89
-
90
  Args:
91
- a (int): The numerator.
92
- b (int): The denominator.
93
-
94
- Returns:
95
- str: The remainder when a is divided by b, as a string.
96
  """
97
- return str(a % b)
98
-
99
-
100
- # ------------------ Retrieval Tools ------------------
101
 
102
  @tool
103
  def wiki_search(query: str) -> str:
104
- """
105
- Search Wikipedia for a given query and return text from up to two matching articles.
106
-
107
  Args:
108
- query (str): A string query to search on Wikipedia.
109
-
110
- Returns:
111
- str: Combined content from up to two relevant articles, separated by dividers.
112
- """
113
- docs = WikipediaLoader(query=query, load_max_docs=2).load()
114
- return "\n\n---\n\n".join(doc.page_content for doc in docs)
115
-
116
 
117
  @tool
118
  def web_search(query: str) -> str:
119
- """
120
- Perform a web search using Tavily and return content from the top three results.
121
-
122
  Args:
123
- query (str): A string representing the web search topic.
124
-
125
- Returns:
126
- str: Combined content from up to three top results, separated by dividers.
127
- """
128
- docs = TavilySearchResults(max_results=3).invoke(query)
129
- return "\n\n---\n\n".join(doc.page_content for doc in docs)
130
-
131
 
132
  @tool
133
  def arvix_search(query: str) -> str:
134
- """
135
- Search arXiv for academic papers related to the query and return excerpts.
136
-
137
  Args:
138
- query (str): The search query string.
 
 
 
 
 
 
 
139
 
140
- Returns:
141
- str: Excerpts (up to 1000 characters each) from up to three relevant arXiv papers, separated by dividers.
142
- """
143
- docs = ArxivLoader(query=query, load_max_docs=3).load()
144
- return "\n\n---\n\n".join(doc.page_content[:1000] for doc in docs)
145
 
146
 
147
-
148
- # ------------------ System Prompt ------------------
149
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
150
- system_prompt = f.read().strip()
151
-
152
- # ------------------ Supabase Setup ------------------
153
- url = os.environ["SUPABASE_URL"].strip()
154
- key = os.environ["SUPABASE_SERVICE_KEY"].strip()
155
- client = create_client(url, key)
156
 
157
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
 
158
 
159
- # Embed improved QA docs
160
- qa_examples = [
161
- {"content": "Q: What is the capital of Vietnam?\nA: FINAL ANSWER: Hanoi"},
162
- {"content": "Q: Alphabetize: lettuce, broccoli, basil\nA: FINAL ANSWER: basil,broccoli,lettuce"},
163
- {"content": "Q: What is 42 multiplied by 8?\nA: FINAL ANSWER: three hundred thirty six"},
164
- ]
165
  vector_store = SupabaseVectorStore(
166
- client=client,
167
- embedding=embeddings,
168
  table_name="documents",
169
- query_name="match_documents_langchain"
170
  )
171
- vector_store.add_texts([doc["content"] for doc in qa_examples])
172
- print("✅ QA documents embedded into Supabase.")
173
-
174
- retriever_tool = create_retriever_tool(
175
  retriever=vector_store.as_retriever(),
176
  name="Question Search",
177
- description="Retrieve similar questions from vector DB."
178
  )
179
 
180
- tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
181
 
182
- # ------------------ Build Agent Graph ------------------
183
- class VerboseToolNode(ToolNode):
184
- def invoke(self, state):
185
- print("🔧 ToolNode evaluating:", [m.content for m in state["messages"]])
186
- return super().invoke(state)
187
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  def build_graph(provider: str = "groq"):
 
 
189
  if provider == "google":
190
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.3)
 
191
  elif provider == "groq":
192
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0.3)
 
193
  elif provider == "huggingface":
 
194
  llm = ChatHuggingFace(
195
  llm=HuggingFaceEndpoint(
196
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
197
- temperature=0.3
198
- )
199
  )
200
  else:
201
- raise ValueError("Invalid provider.")
202
-
203
  llm_with_tools = llm.bind_tools(tools)
204
 
205
- def retriever(state: MessagesState):
206
- query = state["messages"][0].content
207
- similar = vector_store.similarity_search_with_score(query)
208
- threshold = 0.7
209
- examples = [
210
- HumanMessage(content=f"Similar QA:\n{doc.page_content}")
211
- for doc, score in similar if score >= threshold
212
- ]
213
- return {"messages": state["messages"] + examples}
214
-
215
-
216
-
217
  def assistant(state: MessagesState):
218
- try:
219
- messages = [SystemMessage(content=system_prompt.strip())] + state["messages"]
220
- result = llm_with_tools.invoke(messages)
221
-
222
- # Handle different return types gracefully
223
- if hasattr(result, "content"):
224
- raw_output = result.content.strip()
225
- elif isinstance(result, dict) and "content" in result:
226
- raw_output = result["content"].strip()
227
- else:
228
- raise ValueError(f"Unexpected result format: {repr(result)}")
229
-
230
- print("🤖 Raw LLM output:", repr(raw_output))
231
-
232
- match = re.search(r"FINAL ANSWER:\s*(.+)", raw_output, re.IGNORECASE)
233
- if match:
234
- final_output = f"FINAL ANSWER: {match.group(1).strip()}"
235
- else:
236
- print("⚠️ 'FINAL ANSWER:' not found. Raw content will be used as fallback.")
237
- final_output = f"FINAL ANSWER: {raw_output or 'Unable to determine answer'}"
238
-
239
- return {"messages": [HumanMessage(content=final_output)]}
240
-
241
- except Exception as e:
242
- print(f"🔥 Exception: {e}")
243
- traceback.print_exc()
244
- return {"messages": [HumanMessage(content=f"FINAL ANSWER: AGENT ERROR: {type(e).__name__}: {e}")]}
245
-
246
 
247
  builder = StateGraph(MessagesState)
248
  builder.add_node("retriever", retriever)
249
  builder.add_node("assistant", assistant)
250
- builder.add_node("tools", VerboseToolNode(tools))
251
  builder.add_edge(START, "retriever")
252
  builder.add_edge("retriever", "assistant")
253
- builder.add_conditional_edges("assistant", tools_condition)
 
 
 
254
  builder.add_edge("tools", "assistant")
255
 
 
256
  return builder.compile()
257
 
258
- # ------------------ Local Test Harness ------------------
259
  if __name__ == "__main__":
260
- graph = build_graph(provider="groq")
261
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
 
 
 
262
  messages = [HumanMessage(content=question)]
263
- result = graph.invoke({"messages": messages})
264
- print(result["messages"][-1].content)
 
265
 
 
1
+ # """LangGraph Agent"""
2
+ # import os
3
+ # from dotenv import load_dotenv
4
+ # from langgraph.graph import START, StateGraph, MessagesState
5
+ # from langgraph.prebuilt import tools_condition, ToolNode
6
+ # from langchain_google_genai import ChatGoogleGenerativeAI
7
+ # from langchain_groq import ChatGroq
8
+ # from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
9
+ # from langchain_community.tools.tavily_search import TavilySearchResults
10
+ # from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
11
+ # from langchain_community.vectorstores import SupabaseVectorStore
12
+ # from langchain_core.messages import SystemMessage, HumanMessage
13
+ # from langchain_core.tools import tool
14
+ # from langchain.tools.retriever import create_retriever_tool
15
+ # from supabase.client import create_client
16
+ # import re
17
+ # import traceback
18
+
19
+ # load_dotenv()
20
+
21
+ # # ------------------ Arithmetic Tools ------------------
22
+
23
+ # @tool
24
+ # def multiply(a: int, b: int) -> str:
25
+ # """
26
+ # Multiply two integers and return the result as a string.
27
+
28
+ # Args:
29
+ # a (int): The first integer.
30
+ # b (int): The second integer.
31
+
32
+ # Returns:
33
+ # str: The product of a and b, as a string.
34
+ # """
35
+ # return str(a * b)
36
+
37
+
38
+ # @tool
39
+ # def add(a: int, b: int) -> str:
40
+ # """
41
+ # Add two integers and return the result as a string.
42
+
43
+ # Args:
44
+ # a (int): The first integer.
45
+ # b (int): The second integer.
46
+
47
+ # Returns:
48
+ # str: The sum of a and b, as a string.
49
+ # """
50
+ # return str(a + b)
51
+
52
+
53
+ # @tool
54
+ # def subtract(a: int, b: int) -> str:
55
+ # """
56
+ # Subtract one integer from another and return the result as a string.
57
+
58
+ # Args:
59
+ # a (int): The minuend.
60
+ # b (int): The subtrahend.
61
+
62
+ # Returns:
63
+ # str: The difference (a - b), as a string.
64
+ # """
65
+ # return str(a - b)
66
+
67
+
68
+ # @tool
69
+ # def divide(a: int, b: int) -> str:
70
+ # """
71
+ # Divide one integer by another and return the result as a string.
72
+
73
+ # Args:
74
+ # a (int): The numerator.
75
+ # b (int): The denominator. Must not be zero.
76
+
77
+ # Returns:
78
+ # str: The result of the division (a / b), as a string. Returns an error message if b is zero.
79
+ # """
80
+ # if b == 0:
81
+ # return "Error: Cannot divide by zero."
82
+ # return str(a / b)
83
+
84
+
85
+ # @tool
86
+ # def modulus(a: int, b: int) -> str:
87
+ # """
88
+ # Compute the modulus (remainder) of two integers and return the result as a string.
89
+
90
+ # Args:
91
+ # a (int): The numerator.
92
+ # b (int): The denominator.
93
+
94
+ # Returns:
95
+ # str: The remainder when a is divided by b, as a string.
96
+ # """
97
+ # return str(a % b)
98
+
99
+
100
+ # # ------------------ Retrieval Tools ------------------
101
+
102
+ # @tool
103
+ # def wiki_search(query: str) -> str:
104
+ # """
105
+ # Search Wikipedia for a given query and return text from up to two matching articles.
106
+
107
+ # Args:
108
+ # query (str): A string query to search on Wikipedia.
109
+
110
+ # Returns:
111
+ # str: Combined content from up to two relevant articles, separated by dividers.
112
+ # """
113
+ # docs = WikipediaLoader(query=query, load_max_docs=2).load()
114
+ # return "\n\n---\n\n".join(doc.page_content for doc in docs)
115
+
116
+
117
+ # @tool
118
+ # def web_search(query: str) -> str:
119
+ # """
120
+ # Perform a web search using Tavily and return content from the top three results.
121
+
122
+ # Args:
123
+ # query (str): A string representing the web search topic.
124
+
125
+ # Returns:
126
+ # str: Combined content from up to three top results, separated by dividers.
127
+ # """
128
+ # docs = TavilySearchResults(max_results=3).invoke(query)
129
+ # return "\n\n---\n\n".join(doc.page_content for doc in docs)
130
+
131
+
132
+ # @tool
133
+ # def arvix_search(query: str) -> str:
134
+ # """
135
+ # Search arXiv for academic papers related to the query and return excerpts.
136
+
137
+ # Args:
138
+ # query (str): The search query string.
139
+
140
+ # Returns:
141
+ # str: Excerpts (up to 1000 characters each) from up to three relevant arXiv papers, separated by dividers.
142
+ # """
143
+ # docs = ArxivLoader(query=query, load_max_docs=3).load()
144
+ # return "\n\n---\n\n".join(doc.page_content[:1000] for doc in docs)
145
+
146
+
147
+
148
+ # # ------------------ System Prompt ------------------
149
+ # with open("system_prompt.txt", "r", encoding="utf-8") as f:
150
+ # system_prompt = f.read().strip()
151
+
152
+ # # ------------------ Supabase Setup ------------------
153
+ # url = os.environ["SUPABASE_URL"].strip()
154
+ # key = os.environ["SUPABASE_SERVICE_KEY"].strip()
155
+ # client = create_client(url, key)
156
+
157
+ # embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
158
+
159
+ # # Embed improved QA docs
160
+ # qa_examples = [
161
+ # {"content": "Q: What is the capital of Vietnam?\nA: FINAL ANSWER: Hanoi"},
162
+ # {"content": "Q: Alphabetize: lettuce, broccoli, basil\nA: FINAL ANSWER: basil,broccoli,lettuce"},
163
+ # {"content": "Q: What is 42 multiplied by 8?\nA: FINAL ANSWER: three hundred thirty six"},
164
+ # ]
165
+ # vector_store = SupabaseVectorStore(
166
+ # client=client,
167
+ # embedding=embeddings,
168
+ # table_name="documents",
169
+ # query_name="match_documents_langchain"
170
+ # )
171
+ # vector_store.add_texts([doc["content"] for doc in qa_examples])
172
+ # print("✅ QA documents embedded into Supabase.")
173
+
174
+ # retriever_tool = create_retriever_tool(
175
+ # retriever=vector_store.as_retriever(),
176
+ # name="Question Search",
177
+ # description="Retrieve similar questions from vector DB."
178
+ # )
179
+
180
+ # tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
181
+
182
+ # # ------------------ Build Agent Graph ------------------
183
+ # class VerboseToolNode(ToolNode):
184
+ # def invoke(self, state):
185
+ # print("🔧 ToolNode evaluating:", [m.content for m in state["messages"]])
186
+ # return super().invoke(state)
187
+
188
+ # def build_graph(provider: str = "groq"):
189
+ # if provider == "google":
190
+ # llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.3)
191
+ # elif provider == "groq":
192
+ # llm = ChatGroq(model="qwen-qwq-32b", temperature=0.3)
193
+ # elif provider == "huggingface":
194
+ # llm = ChatHuggingFace(
195
+ # llm=HuggingFaceEndpoint(
196
+ # url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
197
+ # temperature=0.3
198
+ # )
199
+ # )
200
+ # else:
201
+ # raise ValueError("Invalid provider.")
202
+
203
+ # llm_with_tools = llm.bind_tools(tools)
204
+
205
+ # def retriever(state: MessagesState):
206
+ # query = state["messages"][0].content
207
+ # similar = vector_store.similarity_search_with_score(query)
208
+ # threshold = 0.7
209
+ # examples = [
210
+ # HumanMessage(content=f"Similar QA:\n{doc.page_content}")
211
+ # for doc, score in similar if score >= threshold
212
+ # ]
213
+ # return {"messages": state["messages"] + examples}
214
+
215
+
216
+
217
+ # def assistant(state: MessagesState):
218
+ # try:
219
+ # messages = [SystemMessage(content=system_prompt.strip())] + state["messages"]
220
+ # result = llm_with_tools.invoke(messages)
221
+
222
+ # # Handle different return types gracefully
223
+ # if hasattr(result, "content"):
224
+ # raw_output = result.content.strip()
225
+ # elif isinstance(result, dict) and "content" in result:
226
+ # raw_output = result["content"].strip()
227
+ # else:
228
+ # raise ValueError(f"Unexpected result format: {repr(result)}")
229
+
230
+ # print("🤖 Raw LLM output:", repr(raw_output))
231
+
232
+ # match = re.search(r"FINAL ANSWER:\s*(.+)", raw_output, re.IGNORECASE)
233
+ # if match:
234
+ # final_output = f"FINAL ANSWER: {match.group(1).strip()}"
235
+ # else:
236
+ # print("⚠️ 'FINAL ANSWER:' not found. Raw content will be used as fallback.")
237
+ # final_output = f"FINAL ANSWER: {raw_output or 'Unable to determine answer'}"
238
+
239
+ # return {"messages": [HumanMessage(content=final_output)]}
240
+
241
+ # except Exception as e:
242
+ # print(f"🔥 Exception: {e}")
243
+ # traceback.print_exc()
244
+ # return {"messages": [HumanMessage(content=f"FINAL ANSWER: AGENT ERROR: {type(e).__name__}: {e}")]}
245
+
246
+
247
+ # builder = StateGraph(MessagesState)
248
+ # builder.add_node("retriever", retriever)
249
+ # builder.add_node("assistant", assistant)
250
+ # builder.add_node("tools", VerboseToolNode(tools))
251
+ # builder.add_edge(START, "retriever")
252
+ # builder.add_edge("retriever", "assistant")
253
+ # builder.add_conditional_edges("assistant", tools_condition)
254
+ # builder.add_edge("tools", "assistant")
255
+
256
+ # return builder.compile()
257
+
258
+ # # ------------------ Local Test Harness ------------------
259
+ # if __name__ == "__main__":
260
+ # graph = build_graph(provider="groq")
261
+ # question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
262
+ # messages = [HumanMessage(content=question)]
263
+ # result = graph.invoke({"messages": messages})
264
+ # print(result["messages"][-1].content)
265
+
266
  """LangGraph Agent"""
267
  import os
268
  from dotenv import load_dotenv
269
  from langgraph.graph import START, StateGraph, MessagesState
270
+ from langgraph.prebuilt import tools_condition
271
+ from langgraph.prebuilt import ToolNode
272
  from langchain_google_genai import ChatGoogleGenerativeAI
273
  from langchain_groq import ChatGroq
274
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
275
  from langchain_community.tools.tavily_search import TavilySearchResults
276
+ from langchain_community.document_loaders import WikipediaLoader
277
+ from langchain_community.document_loaders import ArxivLoader
278
  from langchain_community.vectorstores import SupabaseVectorStore
279
  from langchain_core.messages import SystemMessage, HumanMessage
280
  from langchain_core.tools import tool
281
  from langchain.tools.retriever import create_retriever_tool
282
+ from supabase.client import Client, create_client
 
 
283
 
284
  load_dotenv()
285
 
 
 
286
  @tool
287
+ def multiply(a: int, b: int) -> int:
288
+ """Multiply two numbers.
 
 
289
  Args:
290
+ a: first int
291
+ b: second int
 
 
 
292
  """
293
+ return a * b
 
294
 
295
  @tool
296
+ def add(a: int, b: int) -> int:
297
+ """Add two numbers.
298
+
 
299
  Args:
300
+ a: first int
301
+ b: second int
 
 
 
302
  """
303
+ return a + b
 
304
 
305
  @tool
306
+ def subtract(a: int, b: int) -> int:
307
+ """Subtract two numbers.
308
+
 
309
  Args:
310
+ a: first int
311
+ b: second int
 
 
 
312
  """
313
+ return a - b
 
314
 
315
  @tool
316
+ def divide(a: int, b: int) -> int:
317
+ """Divide two numbers.
318
+
 
319
  Args:
320
+ a: first int
321
+ b: second int
 
 
 
322
  """
323
  if b == 0:
324
+ raise ValueError("Cannot divide by zero.")
325
+ return a / b
 
326
 
327
  @tool
328
+ def modulus(a: int, b: int) -> int:
329
+ """Get the modulus of two numbers.
330
+
 
331
  Args:
332
+ a: first int
333
+ b: second int
 
 
 
334
  """
335
+ return a % b
 
 
 
336
 
337
  @tool
338
  def wiki_search(query: str) -> str:
339
+ """Search Wikipedia for a query and return maximum 2 results.
340
+
 
341
  Args:
342
+ query: The search query."""
343
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
344
+ formatted_search_docs = "\n\n---\n\n".join(
345
+ [
346
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
347
+ for doc in search_docs
348
+ ])
349
+ return {"wiki_results": formatted_search_docs}
350
 
351
  @tool
352
  def web_search(query: str) -> str:
353
+ """Search Tavily for a query and return maximum 3 results.
354
+
 
355
  Args:
356
+ query: The search query."""
357
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
358
+ formatted_search_docs = "\n\n---\n\n".join(
359
+ [
360
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
361
+ for doc in search_docs
362
+ ])
363
+ return {"web_results": formatted_search_docs}
364
 
365
  @tool
366
  def arvix_search(query: str) -> str:
367
+ """Search Arxiv for a query and return maximum 3 result.
368
+
 
369
  Args:
370
+ query: The search query."""
371
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
372
+ formatted_search_docs = "\n\n---\n\n".join(
373
+ [
374
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
375
+ for doc in search_docs
376
+ ])
377
+ return {"arvix_results": formatted_search_docs}
378
 
 
 
 
 
 
379
 
380
 
381
+ # load the system prompt from the file
 
382
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
383
+ system_prompt = f.read()
 
 
 
 
 
384
 
385
+ # System message
386
+ sys_msg = SystemMessage(content=system_prompt)
387
 
388
+ # build a retriever
389
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
390
+ supabase: Client = create_client(
391
+ os.environ.get("SUPABASE_URL"),
392
+ os.environ.get("SUPABASE_SERVICE_KEY"))
 
393
  vector_store = SupabaseVectorStore(
394
+ client=supabase,
395
+ embedding= embeddings,
396
  table_name="documents",
397
+ query_name="match_documents_langchain",
398
  )
399
+ create_retriever_tool = create_retriever_tool(
 
 
 
400
  retriever=vector_store.as_retriever(),
401
  name="Question Search",
402
+ description="A tool to retrieve similar questions from a vector store.",
403
  )
404
 
 
405
 
 
 
 
 
 
406
 
407
+ tools = [
408
+ multiply,
409
+ add,
410
+ subtract,
411
+ divide,
412
+ modulus,
413
+ wiki_search,
414
+ web_search,
415
+ arvix_search,
416
+ ]
417
+
418
+ # Build graph function
419
  def build_graph(provider: str = "groq"):
420
+ """Build the graph"""
421
+ # Load environment variables from .env file
422
  if provider == "google":
423
+ # Google Gemini
424
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
425
  elif provider == "groq":
426
+ # Groq https://console.groq.com/docs/models
427
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
428
  elif provider == "huggingface":
429
+ # TODO: Add huggingface endpoint
430
  llm = ChatHuggingFace(
431
  llm=HuggingFaceEndpoint(
432
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
433
+ temperature=0,
434
+ ),
435
  )
436
  else:
437
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
438
+ # Bind tools to LLM
439
  llm_with_tools = llm.bind_tools(tools)
440
 
441
+ # Node
 
 
 
 
 
 
 
 
 
 
 
442
  def assistant(state: MessagesState):
443
+ """Assistant node"""
444
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
445
+
446
+ def retriever(state: MessagesState):
447
+ """Retriever node"""
448
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
449
+ example_msg = HumanMessage(
450
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
451
+ )
452
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
 
454
  builder = StateGraph(MessagesState)
455
  builder.add_node("retriever", retriever)
456
  builder.add_node("assistant", assistant)
457
+ builder.add_node("tools", ToolNode(tools))
458
  builder.add_edge(START, "retriever")
459
  builder.add_edge("retriever", "assistant")
460
+ builder.add_conditional_edges(
461
+ "assistant",
462
+ tools_condition,
463
+ )
464
  builder.add_edge("tools", "assistant")
465
 
466
+ # Compile graph
467
  return builder.compile()
468
 
469
+ # test
470
  if __name__ == "__main__":
 
471
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
472
+ # Build the graph
473
+ graph = build_graph(provider="groq")
474
+ # Run the graph
475
  messages = [HumanMessage(content=question)]
476
+ messages = graph.invoke({"messages": messages})
477
+ for m in messages["messages"]:
478
+ m.pretty_print()
479