Shaukat39 commited on
Commit
dca7345
·
verified ·
1 Parent(s): c3e7c42

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +402 -389
agent.py CHANGED
@@ -1,479 +1,492 @@
1
- # """LangGraph Agent"""
2
- # import os
3
- # from dotenv import load_dotenv
4
- # from langgraph.graph import START, StateGraph, MessagesState
5
- # from langgraph.prebuilt import tools_condition, ToolNode
6
- # from langchain_google_genai import ChatGoogleGenerativeAI
7
- # from langchain_groq import ChatGroq
8
- # from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
9
- # from langchain_community.tools.tavily_search import TavilySearchResults
10
- # from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
11
- # from langchain_community.vectorstores import SupabaseVectorStore
12
- # from langchain_core.messages import SystemMessage, HumanMessage
13
- # from langchain_core.tools import tool
14
- # from langchain.tools.retriever import create_retriever_tool
15
- # from supabase.client import create_client
16
- # import re
17
- # import traceback
18
 
19
- # load_dotenv()
20
 
21
- # # ------------------ Arithmetic Tools ------------------
22
 
23
- # @tool
24
- # def multiply(a: int, b: int) -> str:
25
- # """
26
- # Multiply two integers and return the result as a string.
27
 
28
- # Args:
29
- # a (int): The first integer.
30
- # b (int): The second integer.
31
 
32
- # Returns:
33
- # str: The product of a and b, as a string.
34
- # """
35
- # return str(a * b)
36
 
37
 
38
- # @tool
39
- # def add(a: int, b: int) -> str:
40
- # """
41
- # Add two integers and return the result as a string.
42
 
43
- # Args:
44
- # a (int): The first integer.
45
- # b (int): The second integer.
46
 
47
- # Returns:
48
- # str: The sum of a and b, as a string.
49
- # """
50
- # return str(a + b)
51
 
52
 
53
- # @tool
54
- # def subtract(a: int, b: int) -> str:
55
- # """
56
- # Subtract one integer from another and return the result as a string.
57
 
58
- # Args:
59
- # a (int): The minuend.
60
- # b (int): The subtrahend.
61
 
62
- # Returns:
63
- # str: The difference (a - b), as a string.
64
- # """
65
- # return str(a - b)
66
 
67
 
68
- # @tool
69
- # def divide(a: int, b: int) -> str:
70
- # """
71
- # Divide one integer by another and return the result as a string.
72
 
73
- # Args:
74
- # a (int): The numerator.
75
- # b (int): The denominator. Must not be zero.
76
 
77
- # Returns:
78
- # str: The result of the division (a / b), as a string. Returns an error message if b is zero.
79
- # """
80
- # if b == 0:
81
- # return "Error: Cannot divide by zero."
82
- # return str(a / b)
83
 
84
 
85
- # @tool
86
- # def modulus(a: int, b: int) -> str:
87
- # """
88
- # Compute the modulus (remainder) of two integers and return the result as a string.
89
 
90
- # Args:
91
- # a (int): The numerator.
92
- # b (int): The denominator.
93
 
94
- # Returns:
95
- # str: The remainder when a is divided by b, as a string.
96
- # """
97
- # return str(a % b)
98
 
99
 
100
- # # ------------------ Retrieval Tools ------------------
101
 
102
- # @tool
103
- # def wiki_search(query: str) -> str:
104
- # """
105
- # Search Wikipedia for a given query and return text from up to two matching articles.
106
 
107
- # Args:
108
- # query (str): A string query to search on Wikipedia.
109
 
110
- # Returns:
111
- # str: Combined content from up to two relevant articles, separated by dividers.
112
- # """
113
- # docs = WikipediaLoader(query=query, load_max_docs=2).load()
114
- # return "\n\n---\n\n".join(doc.page_content for doc in docs)
115
 
116
 
117
- # @tool
118
- # def web_search(query: str) -> str:
119
- # """
120
- # Perform a web search using Tavily and return content from the top three results.
121
 
122
- # Args:
123
- # query (str): A string representing the web search topic.
124
 
125
- # Returns:
126
- # str: Combined content from up to three top results, separated by dividers.
127
- # """
128
- # docs = TavilySearchResults(max_results=3).invoke(query)
129
- # return "\n\n---\n\n".join(doc.page_content for doc in docs)
130
 
131
 
132
- # @tool
133
- # def arvix_search(query: str) -> str:
134
- # """
135
- # Search arXiv for academic papers related to the query and return excerpts.
136
 
137
- # Args:
138
- # query (str): The search query string.
139
 
140
- # Returns:
141
- # str: Excerpts (up to 1000 characters each) from up to three relevant arXiv papers, separated by dividers.
142
- # """
143
- # docs = ArxivLoader(query=query, load_max_docs=3).load()
144
- # return "\n\n---\n\n".join(doc.page_content[:1000] for doc in docs)
145
 
146
 
147
 
148
- # # ------------------ System Prompt ------------------
149
- # with open("system_prompt.txt", "r", encoding="utf-8") as f:
150
- # system_prompt = f.read().strip()
151
 
152
- # # ------------------ Supabase Setup ------------------
153
- # url = os.environ["SUPABASE_URL"].strip()
154
- # key = os.environ["SUPABASE_SERVICE_KEY"].strip()
155
- # client = create_client(url, key)
156
 
157
- # embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
158
 
159
- # # Embed improved QA docs
160
- # qa_examples = [
161
- # {"content": "Q: What is the capital of Vietnam?\nA: FINAL ANSWER: Hanoi"},
162
- # {"content": "Q: Alphabetize: lettuce, broccoli, basil\nA: FINAL ANSWER: basil,broccoli,lettuce"},
163
- # {"content": "Q: What is 42 multiplied by 8?\nA: FINAL ANSWER: three hundred thirty six"},
164
- # ]
165
- # vector_store = SupabaseVectorStore(
166
- # client=client,
167
- # embedding=embeddings,
168
- # table_name="documents",
169
- # query_name="match_documents_langchain"
170
- # )
171
- # vector_store.add_texts([doc["content"] for doc in qa_examples])
172
- # print("✅ QA documents embedded into Supabase.")
173
 
174
- # retriever_tool = create_retriever_tool(
175
- # retriever=vector_store.as_retriever(),
176
- # name="Question Search",
177
- # description="Retrieve similar questions from vector DB."
178
- # )
179
 
180
- # tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
181
 
182
- # # ------------------ Build Agent Graph ------------------
183
- # class VerboseToolNode(ToolNode):
184
- # def invoke(self, state):
185
- # print("🔧 ToolNode evaluating:", [m.content for m in state["messages"]])
186
- # return super().invoke(state)
187
 
188
- # def build_graph(provider: str = "groq"):
189
- # if provider == "google":
190
- # llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.3)
191
- # elif provider == "groq":
192
- # llm = ChatGroq(model="qwen-qwq-32b", temperature=0.3)
193
- # elif provider == "huggingface":
194
- # llm = ChatHuggingFace(
195
- # llm=HuggingFaceEndpoint(
196
- # url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
197
- # temperature=0.3
198
- # )
199
- # )
200
- # else:
201
- # raise ValueError("Invalid provider.")
202
 
203
- # llm_with_tools = llm.bind_tools(tools)
204
 
205
- # def retriever(state: MessagesState):
206
- # query = state["messages"][0].content
207
- # similar = vector_store.similarity_search_with_score(query)
208
- # threshold = 0.7
209
- # examples = [
210
- # HumanMessage(content=f"Similar QA:\n{doc.page_content}")
211
- # for doc, score in similar if score >= threshold
212
- # ]
213
- # return {"messages": state["messages"] + examples}
214
 
215
 
216
 
217
- # def assistant(state: MessagesState):
218
- # try:
219
- # messages = [SystemMessage(content=system_prompt.strip())] + state["messages"]
220
- # result = llm_with_tools.invoke(messages)
221
 
222
- # # Handle different return types gracefully
223
- # if hasattr(result, "content"):
224
- # raw_output = result.content.strip()
225
- # elif isinstance(result, dict) and "content" in result:
226
- # raw_output = result["content"].strip()
227
- # else:
228
- # raise ValueError(f"Unexpected result format: {repr(result)}")
229
 
230
- # print("🤖 Raw LLM output:", repr(raw_output))
231
 
232
- # match = re.search(r"FINAL ANSWER:\s*(.+)", raw_output, re.IGNORECASE)
233
- # if match:
234
- # final_output = f"FINAL ANSWER: {match.group(1).strip()}"
235
- # else:
236
- # print("⚠️ 'FINAL ANSWER:' not found. Raw content will be used as fallback.")
237
- # final_output = f"FINAL ANSWER: {raw_output or 'Unable to determine answer'}"
238
 
239
- # return {"messages": [HumanMessage(content=final_output)]}
240
 
241
- # except Exception as e:
242
- # print(f"🔥 Exception: {e}")
243
- # traceback.print_exc()
244
- # return {"messages": [HumanMessage(content=f"FINAL ANSWER: AGENT ERROR: {type(e).__name__}: {e}")]}
245
 
246
 
247
- # builder = StateGraph(MessagesState)
248
- # builder.add_node("retriever", retriever)
249
- # builder.add_node("assistant", assistant)
250
- # builder.add_node("tools", VerboseToolNode(tools))
251
- # builder.add_edge(START, "retriever")
252
- # builder.add_edge("retriever", "assistant")
253
- # builder.add_conditional_edges("assistant", tools_condition)
254
- # builder.add_edge("tools", "assistant")
255
 
256
- # return builder.compile()
257
 
258
- # # ------------------ Local Test Harness ------------------
259
- # if __name__ == "__main__":
260
- # graph = build_graph(provider="groq")
261
- # question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
262
- # messages = [HumanMessage(content=question)]
263
- # result = graph.invoke({"messages": messages})
264
- # print(result["messages"][-1].content)
265
 
266
- """LangGraph Agent"""
267
- import os
268
- from dotenv import load_dotenv
269
- from langgraph.graph import START, StateGraph, MessagesState
270
- from langgraph.prebuilt import tools_condition
271
- from langgraph.prebuilt import ToolNode
272
- from langchain_google_genai import ChatGoogleGenerativeAI
273
- from langchain_groq import ChatGroq
274
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
275
- from langchain_community.tools.tavily_search import TavilySearchResults
276
- from langchain_community.document_loaders import WikipediaLoader
277
- from langchain_community.document_loaders import ArxivLoader
278
- from langchain_community.vectorstores import SupabaseVectorStore
279
- from langchain_core.messages import SystemMessage, HumanMessage
280
- from langchain_core.tools import tool
281
- from langchain.tools.retriever import create_retriever_tool
282
- from supabase.client import Client, create_client
283
 
284
- load_dotenv()
285
 
286
- @tool
287
- def multiply(a: int, b: int) -> int:
288
- """Multiply two numbers.
289
- Args:
290
- a: first int
291
- b: second int
292
- """
293
- return a * b
294
 
295
- @tool
296
- def add(a: int, b: int) -> int:
297
- """Add two numbers.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
- Args:
300
- a: first int
301
- b: second int
302
- """
303
- return a + b
304
 
305
- @tool
306
- def subtract(a: int, b: int) -> int:
307
- """Subtract two numbers.
308
 
309
- Args:
310
- a: first int
311
- b: second int
312
- """
313
- return a - b
314
 
315
- @tool
316
- def divide(a: int, b: int) -> int:
317
- """Divide two numbers.
318
 
319
- Args:
320
- a: first int
321
- b: second int
322
- """
323
- if b == 0:
324
- raise ValueError("Cannot divide by zero.")
325
- return a / b
326
 
327
- @tool
328
- def modulus(a: int, b: int) -> int:
329
- """Get the modulus of two numbers.
330
 
331
- Args:
332
- a: first int
333
- b: second int
334
- """
335
- return a % b
336
 
337
- @tool
338
- def wiki_search(query: str) -> str:
339
- """Search Wikipedia for a query and return maximum 2 results.
340
 
341
- Args:
342
- query: The search query."""
343
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
344
- formatted_search_docs = "\n\n---\n\n".join(
345
- [
346
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
347
- for doc in search_docs
348
- ])
349
- return {"wiki_results": formatted_search_docs}
350
 
351
- @tool
352
- def web_search(query: str) -> str:
353
- """Search Tavily for a query and return maximum 3 results.
354
 
355
- Args:
356
- query: The search query."""
357
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
358
- formatted_search_docs = "\n\n---\n\n".join(
359
- [
360
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
361
- for doc in search_docs
362
- ])
363
- return {"web_results": formatted_search_docs}
364
 
365
- @tool
366
- def arvix_search(query: str) -> str:
367
- """Search Arxiv for a query and return maximum 3 result.
368
 
369
- Args:
370
- query: The search query."""
371
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
372
- formatted_search_docs = "\n\n---\n\n".join(
373
- [
374
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
375
- for doc in search_docs
376
- ])
377
- return {"arvix_results": formatted_search_docs}
378
 
379
 
380
 
381
- # load the system prompt from the file
382
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
383
- system_prompt = f.read()
384
 
385
- # System message
386
- sys_msg = SystemMessage(content=system_prompt)
387
 
388
- # build a retriever
389
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
390
- supabase: Client = create_client(
391
- os.environ.get("SUPABASE_URL"),
392
- os.environ.get("SUPABASE_SERVICE_KEY"))
393
- vector_store = SupabaseVectorStore(
394
- client=supabase,
395
- embedding= embeddings,
396
- table_name="documents",
397
- query_name="match_documents_langchain",
398
- )
399
- create_retriever_tool = create_retriever_tool(
400
- retriever=vector_store.as_retriever(),
401
- name="Question Search",
402
- description="A tool to retrieve similar questions from a vector store.",
403
- )
404
 
405
 
406
 
407
- tools = [
408
- multiply,
409
- add,
410
- subtract,
411
- divide,
412
- modulus,
413
- wiki_search,
414
- web_search,
415
- arvix_search,
416
- ]
417
 
418
- # Build graph function
419
- def build_graph(provider: str = "groq"):
420
- """Build the graph"""
421
- # Load environment variables from .env file
422
- if provider == "google":
423
- # Google Gemini
424
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
425
- elif provider == "groq":
426
- # Groq https://console.groq.com/docs/models
427
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
428
- elif provider == "huggingface":
429
- # TODO: Add huggingface endpoint
430
- llm = ChatHuggingFace(
431
- llm=HuggingFaceEndpoint(
432
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
433
- temperature=0,
434
- ),
435
- )
436
- else:
437
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
438
- # Bind tools to LLM
439
- llm_with_tools = llm.bind_tools(tools)
440
 
441
- # Node
442
- def assistant(state: MessagesState):
443
- """Assistant node"""
444
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
445
 
446
- def retriever(state: MessagesState):
447
- """Retriever node"""
448
- similar_question = vector_store.similarity_search(state["messages"][0].content)
449
- example_msg = HumanMessage(
450
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
451
- )
452
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
453
 
454
- builder = StateGraph(MessagesState)
455
- builder.add_node("retriever", retriever)
456
- builder.add_node("assistant", assistant)
457
- builder.add_node("tools", ToolNode(tools))
458
- builder.add_edge(START, "retriever")
459
- builder.add_edge("retriever", "assistant")
460
- builder.add_conditional_edges(
461
- "assistant",
462
- tools_condition,
463
- )
464
- builder.add_edge("tools", "assistant")
465
 
466
- # Compile graph
467
- return builder.compile()
468
 
469
- # test
470
- if __name__ == "__main__":
471
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
472
- # Build the graph
473
- graph = build_graph(provider="groq")
474
- # Run the graph
475
- messages = [HumanMessage(content=question)]
476
- messages = graph.invoke({"messages": messages})
477
- for m in messages["messages"]:
478
- m.pretty_print()
479
 
 
1
+ """LangGraph Agent"""
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition, ToolNode
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_groq import ChatGroq
8
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
9
+ from langchain_community.tools.tavily_search import TavilySearchResults
10
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
11
+ from langchain_community.vectorstores import SupabaseVectorStore
12
+ from langchain_core.messages import SystemMessage, HumanMessage
13
+ from langchain_core.tools import tool
14
+ from langchain.tools.retriever import create_retriever_tool
15
+ from supabase.client import create_client
16
+ import re
17
+ import traceback
18
 
19
+ load_dotenv()
20
 
21
+ # ------------------ Arithmetic Tools ------------------
22
 
23
+ @tool
24
+ def multiply(a: int, b: int) -> str:
25
+ """
26
+ Multiply two integers and return the result as a string.
27
 
28
+ Args:
29
+ a (int): The first integer.
30
+ b (int): The second integer.
31
 
32
+ Returns:
33
+ str: The product of a and b, as a string.
34
+ """
35
+ return str(a * b)
36
 
37
 
38
+ @tool
39
+ def add(a: int, b: int) -> str:
40
+ """
41
+ Add two integers and return the result as a string.
42
 
43
+ Args:
44
+ a (int): The first integer.
45
+ b (int): The second integer.
46
 
47
+ Returns:
48
+ str: The sum of a and b, as a string.
49
+ """
50
+ return str(a + b)
51
 
52
 
53
+ @tool
54
+ def subtract(a: int, b: int) -> str:
55
+ """
56
+ Subtract one integer from another and return the result as a string.
57
 
58
+ Args:
59
+ a (int): The minuend.
60
+ b (int): The subtrahend.
61
 
62
+ Returns:
63
+ str: The difference (a - b), as a string.
64
+ """
65
+ return str(a - b)
66
 
67
 
68
+ @tool
69
+ def divide(a: int, b: int) -> str:
70
+ """
71
+ Divide one integer by another and return the result as a string.
72
 
73
+ Args:
74
+ a (int): The numerator.
75
+ b (int): The denominator. Must not be zero.
76
 
77
+ Returns:
78
+ str: The result of the division (a / b), as a string. Returns an error message if b is zero.
79
+ """
80
+ if b == 0:
81
+ return "Error: Cannot divide by zero."
82
+ return str(a / b)
83
 
84
 
85
+ @tool
86
+ def modulus(a: int, b: int) -> str:
87
+ """
88
+ Compute the modulus (remainder) of two integers and return the result as a string.
89
 
90
+ Args:
91
+ a (int): The numerator.
92
+ b (int): The denominator.
93
 
94
+ Returns:
95
+ str: The remainder when a is divided by b, as a string.
96
+ """
97
+ return str(a % b)
98
 
99
 
100
+ # ------------------ Retrieval Tools ------------------
101
 
102
+ @tool
103
+ def wiki_search(query: str) -> str:
104
+ """
105
+ Search Wikipedia for a given query and return text from up to two matching articles.
106
 
107
+ Args:
108
+ query (str): A string query to search on Wikipedia.
109
 
110
+ Returns:
111
+ str: Combined content from up to two relevant articles, separated by dividers.
112
+ """
113
+ docs = WikipediaLoader(query=query, load_max_docs=2).load()
114
+ return "\n\n---\n\n".join(doc.page_content for doc in docs)
115
 
116
 
117
+ @tool
118
+ def web_search(query: str) -> str:
119
+ """
120
+ Perform a web search using Tavily and return content from the top three results.
121
 
122
+ Args:
123
+ query (str): A string representing the web search topic.
124
 
125
+ Returns:
126
+ str: Combined content from up to three top results, separated by dividers.
127
+ """
128
+ docs = TavilySearchResults(max_results=3).invoke(query)
129
+ return "\n\n---\n\n".join(doc.page_content for doc in docs)
130
 
131
 
132
+ @tool
133
+ def arvix_search(query: str) -> str:
134
+ """
135
+ Search arXiv for academic papers related to the query and return excerpts.
136
 
137
+ Args:
138
+ query (str): The search query string.
139
 
140
+ Returns:
141
+ str: Excerpts (up to 1000 characters each) from up to three relevant arXiv papers, separated by dividers.
142
+ """
143
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
144
+ return "\n\n---\n\n".join(doc.page_content[:1000] for doc in docs)
145
 
146
 
147
 
148
+ # ------------------ System Prompt ------------------
149
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
150
+ system_prompt = f.read().strip()
151
 
152
+ # ------------------ Supabase Setup ------------------
153
+ url = os.environ["SUPABASE_URL"].strip()
154
+ key = os.environ["SUPABASE_SERVICE_KEY"].strip()
155
+ client = create_client(url, key)
156
 
157
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
158
 
159
+ # Embed improved QA docs
160
+ qa_examples = [
161
+ {"content": "Q: What is the capital of Vietnam?\nA: FINAL ANSWER: Hanoi"},
162
+ {"content": "Q: Alphabetize: lettuce, broccoli, basil\nA: FINAL ANSWER: basil,broccoli,lettuce"},
163
+ {"content": "Q: What is 42 multiplied by 8?\nA: FINAL ANSWER: three hundred thirty six"},
164
+ ]
165
+ vector_store = SupabaseVectorStore(
166
+ client=client,
167
+ embedding=embeddings,
168
+ table_name="documents",
169
+ query_name="match_documents_langchain"
170
+ )
171
+ vector_store.add_texts([doc["content"] for doc in qa_examples])
172
+ print("✅ QA documents embedded into Supabase.")
173
 
174
+ retriever_tool = create_retriever_tool(
175
+ retriever=vector_store.as_retriever(),
176
+ name="Question Search",
177
+ description="Retrieve similar questions from vector DB."
178
+ )
179
 
180
+ tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
181
 
182
+ # ------------------ Build Agent Graph ------------------
183
+ class VerboseToolNode(ToolNode):
184
+ def invoke(self, state):
185
+ print("🔧 ToolNode evaluating:", [m.content for m in state["messages"]])
186
+ return super().invoke(state)
187
 
188
+ def build_graph(provider: str = "groq"):
189
+ if provider == "google":
190
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.3)
191
+ elif provider == "groq":
192
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0.3)
193
+ elif provider == "huggingface":
194
+ llm = ChatHuggingFace(
195
+ llm=HuggingFaceEndpoint(
196
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
197
+ temperature=0.3
198
+ )
199
+ )
200
+ else:
201
+ raise ValueError("Invalid provider.")
202
 
203
+ llm_with_tools = llm.bind_tools(tools)
204
 
205
+ def retriever(state: MessagesState):
206
+ query = state["messages"][0].content
207
+ similar = vector_store.similarity_search_with_score(query)
208
+ threshold = 0.7
209
+ examples = [
210
+ HumanMessage(content=f"Similar QA:\n{doc.page_content}")
211
+ for doc, score in similar if score >= threshold
212
+ ]
213
+ return {"messages": state["messages"] + examples}
214
 
215
 
216
 
217
+ def assistant(state: MessagesState):
218
+ try:
219
+ messages = [SystemMessage(content=system_prompt.strip())] + state["messages"]
220
+ result = llm_with_tools.invoke(messages)
221
 
222
+ # Handle different return types gracefully
223
+ if hasattr(result, "content"):
224
+ raw_output = result.content.strip()
225
+ elif isinstance(result, dict) and "content" in result:
226
+ raw_output = result["content"].strip()
227
+ else:
228
+ raise ValueError(f"Unexpected result format: {repr(result)}")
229
 
230
+ print("🤖 Raw LLM output:", repr(raw_output))
231
 
232
+ match = re.search(r"FINAL ANSWER:\s*(.+)", raw_output, re.IGNORECASE)
233
+ if match:
234
+ final_output = f"FINAL ANSWER: {match.group(1).strip()}"
235
+ else:
236
+ print("⚠️ 'FINAL ANSWER:' not found. Raw content will be used as fallback.")
237
+ final_output = f"FINAL ANSWER: {raw_output or 'Unable to determine answer'}"
238
 
239
+ return {"messages": [HumanMessage(content=final_output)]}
240
 
241
+ except Exception as e:
242
+ print(f"🔥 Exception: {e}")
243
+ traceback.print_exc()
244
+ return {"messages": [HumanMessage(content=f"FINAL ANSWER: AGENT ERROR: {type(e).__name__}: {e}")]}
245
 
246
 
247
+ builder = StateGraph(MessagesState)
248
+ builder.add_node("retriever", retriever)
249
+ builder.add_node("assistant", assistant)
250
+ builder.add_node("tools", VerboseToolNode(tools))
251
+ builder.add_edge(START, "retriever")
252
+ builder.add_edge("retriever", "assistant")
253
+ builder.add_conditional_edges("assistant", tools_condition)
254
+ builder.add_edge("tools", "assistant")
255
 
256
+ return builder.compile()
257
 
258
+ # ------------------ Local Test Harness ------------------
259
+ if __name__ == "__main__":
260
+ graph = build_graph(provider="groq")
261
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
262
+ messages = [HumanMessage(content=question)]
263
+ result = graph.invoke({"messages": messages})
264
+ print(result["messages"][-1].content)
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
 
267
 
 
 
 
 
 
 
 
 
268
 
269
+
270
+
271
+
272
+
273
+
274
+
275
+
276
+
277
+
278
+
279
+ # """LangGraph Agent"""
280
+ # import os
281
+ # from dotenv import load_dotenv
282
+ # from langgraph.graph import START, StateGraph, MessagesState
283
+ # from langgraph.prebuilt import tools_condition
284
+ # from langgraph.prebuilt import ToolNode
285
+ # from langchain_google_genai import ChatGoogleGenerativeAI
286
+ # from langchain_groq import ChatGroq
287
+ # from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
288
+ # from langchain_community.tools.tavily_search import TavilySearchResults
289
+ # from langchain_community.document_loaders import WikipediaLoader
290
+ # from langchain_community.document_loaders import ArxivLoader
291
+ # from langchain_community.vectorstores import SupabaseVectorStore
292
+ # from langchain_core.messages import SystemMessage, HumanMessage
293
+ # from langchain_core.tools import tool
294
+ # from langchain.tools.retriever import create_retriever_tool
295
+ # from supabase.client import Client, create_client
296
+
297
+ # load_dotenv()
298
+
299
+ # @tool
300
+ # def multiply(a: int, b: int) -> int:
301
+ # """Multiply two numbers.
302
+ # Args:
303
+ # a: first int
304
+ # b: second int
305
+ # """
306
+ # return a * b
307
+
308
+ # @tool
309
+ # def add(a: int, b: int) -> int:
310
+ # """Add two numbers.
311
 
312
+ # Args:
313
+ # a: first int
314
+ # b: second int
315
+ # """
316
+ # return a + b
317
 
318
+ # @tool
319
+ # def subtract(a: int, b: int) -> int:
320
+ # """Subtract two numbers.
321
 
322
+ # Args:
323
+ # a: first int
324
+ # b: second int
325
+ # """
326
+ # return a - b
327
 
328
+ # @tool
329
+ # def divide(a: int, b: int) -> int:
330
+ # """Divide two numbers.
331
 
332
+ # Args:
333
+ # a: first int
334
+ # b: second int
335
+ # """
336
+ # if b == 0:
337
+ # raise ValueError("Cannot divide by zero.")
338
+ # return a / b
339
 
340
+ # @tool
341
+ # def modulus(a: int, b: int) -> int:
342
+ # """Get the modulus of two numbers.
343
 
344
+ # Args:
345
+ # a: first int
346
+ # b: second int
347
+ # """
348
+ # return a % b
349
 
350
+ # @tool
351
+ # def wiki_search(query: str) -> str:
352
+ # """Search Wikipedia for a query and return maximum 2 results.
353
 
354
+ # Args:
355
+ # query: The search query."""
356
+ # search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
357
+ # formatted_search_docs = "\n\n---\n\n".join(
358
+ # [
359
+ # f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
360
+ # for doc in search_docs
361
+ # ])
362
+ # return {"wiki_results": formatted_search_docs}
363
 
364
+ # @tool
365
+ # def web_search(query: str) -> str:
366
+ # """Search Tavily for a query and return maximum 3 results.
367
 
368
+ # Args:
369
+ # query: The search query."""
370
+ # search_docs = TavilySearchResults(max_results=3).invoke(query=query)
371
+ # formatted_search_docs = "\n\n---\n\n".join(
372
+ # [
373
+ # f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
374
+ # for doc in search_docs
375
+ # ])
376
+ # return {"web_results": formatted_search_docs}
377
 
378
+ # @tool
379
+ # def arvix_search(query: str) -> str:
380
+ # """Search Arxiv for a query and return maximum 3 result.
381
 
382
+ # Args:
383
+ # query: The search query."""
384
+ # search_docs = ArxivLoader(query=query, load_max_docs=3).load()
385
+ # formatted_search_docs = "\n\n---\n\n".join(
386
+ # [
387
+ # f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
388
+ # for doc in search_docs
389
+ # ])
390
+ # return {"arvix_results": formatted_search_docs}
391
 
392
 
393
 
394
+ # # load the system prompt from the file
395
+ # with open("system_prompt.txt", "r", encoding="utf-8") as f:
396
+ # system_prompt = f.read()
397
 
398
+ # # System message
399
+ # sys_msg = SystemMessage(content=system_prompt)
400
 
401
+ # # build a retriever
402
+ # embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
403
+ # supabase: Client = create_client(
404
+ # os.environ.get("SUPABASE_URL"),
405
+ # os.environ.get("SUPABASE_SERVICE_KEY"))
406
+ # vector_store = SupabaseVectorStore(
407
+ # client=supabase,
408
+ # embedding= embeddings,
409
+ # table_name="documents",
410
+ # query_name="match_documents_langchain",
411
+ # )
412
+ # create_retriever_tool = create_retriever_tool(
413
+ # retriever=vector_store.as_retriever(),
414
+ # name="Question Search",
415
+ # description="A tool to retrieve similar questions from a vector store.",
416
+ # )
417
 
418
 
419
 
420
+ # tools = [
421
+ # multiply,
422
+ # add,
423
+ # subtract,
424
+ # divide,
425
+ # modulus,
426
+ # wiki_search,
427
+ # web_search,
428
+ # arvix_search,
429
+ # ]
430
 
431
+ # # Build graph function
432
+ # def build_graph(provider: str = "groq"):
433
+ # """Build the graph"""
434
+ # # Load environment variables from .env file
435
+ # if provider == "google":
436
+ # # Google Gemini
437
+ # llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
438
+ # elif provider == "groq":
439
+ # # Groq https://console.groq.com/docs/models
440
+ # llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
441
+ # elif provider == "huggingface":
442
+ # # TODO: Add huggingface endpoint
443
+ # llm = ChatHuggingFace(
444
+ # llm=HuggingFaceEndpoint(
445
+ # url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
446
+ # temperature=0,
447
+ # ),
448
+ # )
449
+ # else:
450
+ # raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
451
+ # # Bind tools to LLM
452
+ # llm_with_tools = llm.bind_tools(tools)
453
 
454
+ # # Node
455
+ # def assistant(state: MessagesState):
456
+ # """Assistant node"""
457
+ # return {"messages": [llm_with_tools.invoke(state["messages"])]}
458
 
459
+ # def retriever(state: MessagesState):
460
+ # """Retriever node"""
461
+ # similar_question = vector_store.similarity_search(state["messages"][0].content)
462
+ # example_msg = HumanMessage(
463
+ # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
464
+ # )
465
+ # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
466
 
467
+ # builder = StateGraph(MessagesState)
468
+ # builder.add_node("retriever", retriever)
469
+ # builder.add_node("assistant", assistant)
470
+ # builder.add_node("tools", ToolNode(tools))
471
+ # builder.add_edge(START, "retriever")
472
+ # builder.add_edge("retriever", "assistant")
473
+ # builder.add_conditional_edges(
474
+ # "assistant",
475
+ # tools_condition,
476
+ # )
477
+ # builder.add_edge("tools", "assistant")
478
 
479
+ # # Compile graph
480
+ # return builder.compile()
481
 
482
+ # # test
483
+ # if __name__ == "__main__":
484
+ # question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
485
+ # # Build the graph
486
+ # graph = build_graph(provider="groq")
487
+ # # Run the graph
488
+ # messages = [HumanMessage(content=question)]
489
+ # messages = graph.invoke({"messages": messages})
490
+ # for m in messages["messages"]:
491
+ # m.pretty_print()
492