Shaukat39 commited on
Commit
5be40a0
·
verified ·
1 Parent(s): ae700e1

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +85 -64
agent.py CHANGED
@@ -2,51 +2,79 @@
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
- from langgraph.prebuilt import ToolNode
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_groq import ChatGroq
9
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
- from langchain_community.document_loaders import WikipediaLoader
12
- from langchain_community.document_loaders import ArxivLoader
13
  from langchain_community.vectorstores import SupabaseVectorStore
14
- from langchain_community.embeddings import HuggingFaceEmbeddings
15
  from langchain_core.messages import SystemMessage, HumanMessage
16
  from langchain_core.tools import tool
17
  from langchain.tools.retriever import create_retriever_tool
18
- from supabase.client import Client, create_client
 
19
 
20
  load_dotenv()
21
 
22
  # ------------------ Arithmetic Tools ------------------
 
23
  @tool
24
  def multiply(a: int, b: int) -> str:
25
- """Multiply two integers and return the product as a string."""
 
 
 
 
 
 
 
 
 
26
  return str(a * b)
27
 
28
 
29
  @tool
30
  def add(a: int, b: int) -> str:
31
- """Add two integers and return the sum as a string."""
 
 
 
 
 
 
 
 
 
32
  return str(a + b)
33
 
 
34
  @tool
35
  def subtract(a: int, b: int) -> str:
36
- """Subtract two integers and return the difference as a string."""
 
 
 
 
 
 
 
 
 
37
  return str(a - b)
38
 
 
39
  @tool
40
  def divide(a: int, b: int) -> str:
41
  """
42
- Divide two integers and return the result as a string.
43
 
44
  Args:
45
- a: The numerator (dividend).
46
- b: The denominator (divisor). Must not be zero.
47
 
48
  Returns:
49
- A string representation of the division result, or an error message if b is zero.
50
  """
51
  if b == 0:
52
  return "Error: Cannot divide by zero."
@@ -59,26 +87,27 @@ def modulus(a: int, b: int) -> str:
59
  Compute the modulus (remainder) of two integers and return the result as a string.
60
 
61
  Args:
62
- a: The numerator.
63
- b: The denominator.
64
 
65
  Returns:
66
- A string representation of the remainder when a is divided by b.
67
  """
68
  return str(a % b)
69
 
70
 
71
  # ------------------ Retrieval Tools ------------------
 
72
  @tool
73
  def wiki_search(query: str) -> str:
74
  """
75
- Search Wikipedia for a given query and return the content of up to two matching articles.
76
 
77
  Args:
78
- query: A string query to search on Wikipedia.
79
 
80
  Returns:
81
- A string containing the content from up to two relevant Wikipedia articles, separated by dividers.
82
  """
83
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
84
  return "\n\n---\n\n".join(doc.page_content for doc in docs)
@@ -90,10 +119,10 @@ def web_search(query: str) -> str:
90
  Perform a web search using Tavily and return content from the top three results.
91
 
92
  Args:
93
- query: A string query representing the web search topic.
94
 
95
  Returns:
96
- A string of up to three relevant result contents, separated by dividers.
97
  """
98
  docs = TavilySearchResults(max_results=3).invoke(query)
99
  return "\n\n---\n\n".join(doc.page_content for doc in docs)
@@ -102,37 +131,35 @@ def web_search(query: str) -> str:
102
  @tool
103
  def arvix_search(query: str) -> str:
104
  """
105
- Search arXiv for academic papers matching the query and return excerpts from up to three results.
106
 
107
  Args:
108
- query: A string query to search on arXiv.
109
 
110
  Returns:
111
- A string containing excerpts (first 1000 characters) from up to three relevant papers.
112
  """
113
  docs = ArxivLoader(query=query, load_max_docs=3).load()
114
  return "\n\n---\n\n".join(doc.page_content[:1000] for doc in docs)
115
 
116
 
117
- # ------------------ System Setup ------------------
 
118
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
119
  system_prompt = f.read().strip()
120
- sys_msg = SystemMessage(content=system_prompt)
121
-
122
 
123
- # Load environment variables
124
- url = os.environ["SUPABASE_URL"]
125
- key = os.environ["SUPABASE_SERVICE_KEY"]
126
  client = create_client(url, key)
127
 
128
- # Create embedding model
129
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
130
 
131
- # Sample documents to insert
132
- docs = [
133
- {"content": "Newton's First Law states that an object in motion stays in motion unless acted upon."},
134
- {"content": "LangChain enables developers to build context-aware agents using LLMs and tools."},
135
- {"content": "Supabase is an open-source alternative to Firebase built on PostgreSQL."}
136
  ]
137
  vector_store = SupabaseVectorStore(
138
  client=client,
@@ -140,10 +167,8 @@ vector_store = SupabaseVectorStore(
140
  table_name="documents",
141
  query_name="match_documents_langchain"
142
  )
143
- texts = [doc["content"] for doc in docs]
144
- vector_store.add_texts(texts)
145
-
146
- print("✅ Documents successfully embedded and stored.")
147
 
148
  retriever_tool = create_retriever_tool(
149
  retriever=vector_store.as_retriever(),
@@ -154,16 +179,21 @@ retriever_tool = create_retriever_tool(
154
  tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
155
 
156
  # ------------------ Build Agent Graph ------------------
 
 
 
 
 
157
  def build_graph(provider: str = "groq"):
158
  if provider == "google":
159
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
160
  elif provider == "groq":
161
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
162
  elif provider == "huggingface":
163
  llm = ChatHuggingFace(
164
  llm=HuggingFaceEndpoint(
165
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
166
- temperature=0
167
  )
168
  )
169
  else:
@@ -172,48 +202,38 @@ def build_graph(provider: str = "groq"):
172
  llm_with_tools = llm.bind_tools(tools)
173
 
174
  def retriever(state: MessagesState):
175
- similar = vector_store.similarity_search(state["messages"][0].content)
 
 
176
  examples = [
177
  HumanMessage(content=f"Similar QA:\n{doc.page_content}")
178
- for doc in similar[:2]
179
  ]
180
- return {"messages": [sys_msg] + state["messages"] + examples}
181
-
182
 
183
  def assistant(state: MessagesState):
184
  try:
185
- # Always prepend the system message
186
- system_msg = SystemMessage(content=system_prompt.strip())
187
- messages = [system_msg] + state["messages"]
188
-
189
  result = llm_with_tools.invoke(messages)
190
- print("🤖 Raw LLM result:", repr(result))
191
-
192
  raw_output = result.content.strip()
 
193
 
194
- # Extract FINAL ANSWER using regex
195
- import re
196
  match = re.search(r"FINAL ANSWER:\s*(.+)", raw_output, re.IGNORECASE)
197
  if match:
198
- final_answer = match.group(1).strip()
199
- final_output = f"FINAL ANSWER: {final_answer}"
200
  else:
201
- print("⚠️ 'FINAL ANSWER:' prefix not found. Using fallback.")
202
- final_output = "FINAL ANSWER: Unable to determine answer"
203
 
204
  return {"messages": [HumanMessage(content=final_output)]}
205
-
206
  except Exception as e:
207
  print(f"🔥 Error in assistant node: {e}")
208
  return {"messages": [HumanMessage(content=f"FINAL ANSWER: AGENT ERROR: {e}")]}
209
 
210
-
211
-
212
-
213
  builder = StateGraph(MessagesState)
214
  builder.add_node("retriever", retriever)
215
  builder.add_node("assistant", assistant)
216
- builder.add_node("tools", ToolNode(tools))
217
  builder.add_edge(START, "retriever")
218
  builder.add_edge("retriever", "assistant")
219
  builder.add_conditional_edges("assistant", tools_condition)
@@ -228,3 +248,4 @@ if __name__ == "__main__":
228
  messages = [HumanMessage(content=question)]
229
  result = graph.invoke({"messages": messages})
230
  print(result["messages"][-1].content)
 
 
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition, ToolNode
 
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
  from langchain_groq import ChatGroq
8
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
9
  from langchain_community.tools.tavily_search import TavilySearchResults
10
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
 
11
  from langchain_community.vectorstores import SupabaseVectorStore
 
12
  from langchain_core.messages import SystemMessage, HumanMessage
13
  from langchain_core.tools import tool
14
  from langchain.tools.retriever import create_retriever_tool
15
+ from supabase.client import create_client
16
+ import re
17
 
18
  load_dotenv()
19
 
20
  # ------------------ Arithmetic Tools ------------------
21
+
22
  @tool
23
  def multiply(a: int, b: int) -> str:
24
+ """
25
+ Multiply two integers and return the result as a string.
26
+
27
+ Args:
28
+ a (int): The first integer.
29
+ b (int): The second integer.
30
+
31
+ Returns:
32
+ str: The product of a and b, as a string.
33
+ """
34
  return str(a * b)
35
 
36
 
37
  @tool
38
  def add(a: int, b: int) -> str:
39
+ """
40
+ Add two integers and return the result as a string.
41
+
42
+ Args:
43
+ a (int): The first integer.
44
+ b (int): The second integer.
45
+
46
+ Returns:
47
+ str: The sum of a and b, as a string.
48
+ """
49
  return str(a + b)
50
 
51
+
52
  @tool
53
  def subtract(a: int, b: int) -> str:
54
+ """
55
+ Subtract one integer from another and return the result as a string.
56
+
57
+ Args:
58
+ a (int): The minuend.
59
+ b (int): The subtrahend.
60
+
61
+ Returns:
62
+ str: The difference (a - b), as a string.
63
+ """
64
  return str(a - b)
65
 
66
+
67
  @tool
68
  def divide(a: int, b: int) -> str:
69
  """
70
+ Divide one integer by another and return the result as a string.
71
 
72
  Args:
73
+ a (int): The numerator.
74
+ b (int): The denominator. Must not be zero.
75
 
76
  Returns:
77
+ str: The result of the division (a / b), as a string. Returns an error message if b is zero.
78
  """
79
  if b == 0:
80
  return "Error: Cannot divide by zero."
 
87
  Compute the modulus (remainder) of two integers and return the result as a string.
88
 
89
  Args:
90
+ a (int): The numerator.
91
+ b (int): The denominator.
92
 
93
  Returns:
94
+ str: The remainder when a is divided by b, as a string.
95
  """
96
  return str(a % b)
97
 
98
 
99
  # ------------------ Retrieval Tools ------------------
100
+
101
  @tool
102
  def wiki_search(query: str) -> str:
103
  """
104
+ Search Wikipedia for a given query and return text from up to two matching articles.
105
 
106
  Args:
107
+ query (str): A string query to search on Wikipedia.
108
 
109
  Returns:
110
+ str: Combined content from up to two relevant articles, separated by dividers.
111
  """
112
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
113
  return "\n\n---\n\n".join(doc.page_content for doc in docs)
 
119
  Perform a web search using Tavily and return content from the top three results.
120
 
121
  Args:
122
+ query (str): A string representing the web search topic.
123
 
124
  Returns:
125
+ str: Combined content from up to three top results, separated by dividers.
126
  """
127
  docs = TavilySearchResults(max_results=3).invoke(query)
128
  return "\n\n---\n\n".join(doc.page_content for doc in docs)
 
131
  @tool
132
  def arvix_search(query: str) -> str:
133
  """
134
+ Search arXiv for academic papers related to the query and return excerpts.
135
 
136
  Args:
137
+ query (str): The search query string.
138
 
139
  Returns:
140
+ str: Excerpts (up to 1000 characters each) from up to three relevant arXiv papers, separated by dividers.
141
  """
142
  docs = ArxivLoader(query=query, load_max_docs=3).load()
143
  return "\n\n---\n\n".join(doc.page_content[:1000] for doc in docs)
144
 
145
 
146
+
147
+ # ------------------ System Prompt ------------------
148
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
149
  system_prompt = f.read().strip()
 
 
150
 
151
+ # ------------------ Supabase Setup ------------------
152
+ url = os.environ["SUPABASE_URL"].strip()
153
+ key = os.environ["SUPABASE_SERVICE_KEY"].strip()
154
  client = create_client(url, key)
155
 
 
156
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
157
 
158
+ # Embed improved QA docs
159
+ qa_examples = [
160
+ {"content": "Q: What is the capital of Vietnam?\nA: FINAL ANSWER: Hanoi"},
161
+ {"content": "Q: Alphabetize: lettuce, broccoli, basil\nA: FINAL ANSWER: basil,broccoli,lettuce"},
162
+ {"content": "Q: What is 42 multiplied by 8?\nA: FINAL ANSWER: three hundred thirty six"},
163
  ]
164
  vector_store = SupabaseVectorStore(
165
  client=client,
 
167
  table_name="documents",
168
  query_name="match_documents_langchain"
169
  )
170
+ vector_store.add_texts([doc["content"] for doc in qa_examples])
171
+ print("✅ QA documents embedded into Supabase.")
 
 
172
 
173
  retriever_tool = create_retriever_tool(
174
  retriever=vector_store.as_retriever(),
 
179
  tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
180
 
181
  # ------------------ Build Agent Graph ------------------
182
+ class VerboseToolNode(ToolNode):
183
+ def invoke(self, state):
184
+ print("🔧 ToolNode evaluating:", [m.content for m in state["messages"]])
185
+ return super().invoke(state)
186
+
187
  def build_graph(provider: str = "groq"):
188
  if provider == "google":
189
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.3)
190
  elif provider == "groq":
191
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0.3)
192
  elif provider == "huggingface":
193
  llm = ChatHuggingFace(
194
  llm=HuggingFaceEndpoint(
195
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
196
+ temperature=0.3
197
  )
198
  )
199
  else:
 
202
  llm_with_tools = llm.bind_tools(tools)
203
 
204
  def retriever(state: MessagesState):
205
+ query = state["messages"][0].content
206
+ similar = vector_store.similarity_search_with_score(query)
207
+ threshold = 0.7
208
  examples = [
209
  HumanMessage(content=f"Similar QA:\n{doc.page_content}")
210
+ for doc, score in similar if score >= threshold
211
  ]
212
+ return {"messages": state["messages"] + examples}
 
213
 
214
  def assistant(state: MessagesState):
215
  try:
216
+ messages = [SystemMessage(content=system_prompt.strip())] + state["messages"]
 
 
 
217
  result = llm_with_tools.invoke(messages)
 
 
218
  raw_output = result.content.strip()
219
+ print("🤖 Raw LLM output:", repr(raw_output))
220
 
 
 
221
  match = re.search(r"FINAL ANSWER:\s*(.+)", raw_output, re.IGNORECASE)
222
  if match:
223
+ final_output = f"FINAL ANSWER: {match.group(1).strip()}"
 
224
  else:
225
+ print("⚠️ 'FINAL ANSWER:' not found. Raw content will be used as fallback.")
226
+ final_output = f"FINAL ANSWER: {raw_output or 'Unable to determine answer'}"
227
 
228
  return {"messages": [HumanMessage(content=final_output)]}
 
229
  except Exception as e:
230
  print(f"🔥 Error in assistant node: {e}")
231
  return {"messages": [HumanMessage(content=f"FINAL ANSWER: AGENT ERROR: {e}")]}
232
 
 
 
 
233
  builder = StateGraph(MessagesState)
234
  builder.add_node("retriever", retriever)
235
  builder.add_node("assistant", assistant)
236
+ builder.add_node("tools", VerboseToolNode(tools))
237
  builder.add_edge(START, "retriever")
238
  builder.add_edge("retriever", "assistant")
239
  builder.add_conditional_edges("assistant", tools_condition)
 
248
  messages = [HumanMessage(content=question)]
249
  result = graph.invoke({"messages": messages})
250
  print(result["messages"][-1].content)
251
+