rakesh-dvg commited on
Commit
fef27d1
·
verified ·
1 Parent(s): 9e52214

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +17 -93
agent.py CHANGED
@@ -16,134 +16,75 @@ from langchain_core.tools import tool
16
  from langchain.tools.retriever import create_retriever_tool
17
  from supabase.client import Client, create_client
18
 
19
- from supabase import create_client
20
-
21
- supabase = None
22
-
23
- def init_supabase_client(url, key):
24
- global supabase
25
- supabase = create_client(url, key)
26
-
27
- # Now you can use `supabase` in your agent code safely.
28
-
29
-
30
-
31
- #end
32
-
33
-
34
  load_dotenv()
35
 
36
  @tool
37
  def multiply(a: int, b: int) -> int:
38
- """Multiply two numbers.
39
-
40
- Args:
41
- a: first int
42
- b: second int
43
- """
44
  return a * b
45
 
46
  @tool
47
  def add(a: int, b: int) -> int:
48
- """Add two numbers.
49
-
50
- Args:
51
- a: first int
52
- b: second int
53
- """
54
  return a + b
55
 
56
  @tool
57
  def subtract(a: int, b: int) -> int:
58
- """Subtract two numbers.
59
-
60
- Args:
61
- a: first int
62
- b: second int
63
- """
64
  return a - b
65
 
66
  @tool
67
- def divide(a: int, b: int) -> int:
68
- """Divide two numbers.
69
-
70
- Args:
71
- a: first int
72
- b: second int
73
- """
74
  if b == 0:
75
  raise ValueError("Cannot divide by zero.")
76
  return a / b
77
 
78
  @tool
79
  def modulus(a: int, b: int) -> int:
80
- """Get the modulus of two numbers.
81
-
82
- Args:
83
- a: first int
84
- b: second int
85
- """
86
  return a % b
87
 
88
  @tool
89
  def wiki_search(query: str) -> str:
90
- """Search Wikipedia for a query and return maximum 2 results.
91
-
92
- Args:
93
- query: The search query."""
94
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
95
- formatted_search_docs = "\n\n---\n\n".join(
96
  [
97
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
98
  for doc in search_docs
99
  ])
100
- return {"wiki_results": formatted_search_docs}
101
 
102
  @tool
103
  def web_search(query: str) -> str:
104
- """Search Tavily for a query and return maximum 3 results.
105
-
106
- Args:
107
- query: The search query."""
108
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
109
- formatted_search_docs = "\n\n---\n\n".join(
110
  [
111
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
112
  for doc in search_docs
113
  ])
114
- return {"web_results": formatted_search_docs}
115
 
116
  @tool
117
  def arvix_search(query: str) -> str:
118
- """Search Arxiv for a query and return maximum 3 result.
119
-
120
- Args:
121
- query: The search query."""
122
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
123
- formatted_search_docs = "\n\n---\n\n".join(
124
  [
125
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
126
  for doc in search_docs
127
  ])
128
- return {"arvix_results": formatted_search_docs}
129
-
130
 
131
 
132
- # load the system prompt from the file
133
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
134
  system_prompt = f.read()
135
 
136
- # System message
137
  sys_msg = SystemMessage(content=system_prompt)
138
 
139
- # build a retriever
140
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
141
  supabase: Client = create_client(
142
  os.environ.get("SUPABASE_URL"),
143
  os.environ.get("SUPABASE_SERVICE_KEY"))
144
  vector_store = SupabaseVectorStore(
145
  client=supabase,
146
- embedding= embeddings,
147
  table_name="documents",
148
  query_name="match_documents_langchain",
149
  )
@@ -153,8 +94,6 @@ create_retriever_tool = create_retriever_tool(
153
  description="A tool to retrieve similar questions from a vector store.",
154
  )
155
 
156
-
157
-
158
  tools = [
159
  multiply,
160
  add,
@@ -166,18 +105,12 @@ tools = [
166
  arvix_search,
167
  ]
168
 
169
- # Build graph function
170
  def build_graph(provider: str = "groq"):
171
- """Build the graph"""
172
- # Load environment variables from .env file
173
  if provider == "google":
174
- # Google Gemini
175
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
176
  elif provider == "groq":
177
- # Groq https://console.groq.com/docs/models
178
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
179
  elif provider == "huggingface":
180
- # TODO: Add huggingface endpoint
181
  llm = ChatHuggingFace(
182
  llm=HuggingFaceEndpoint(
183
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
@@ -186,16 +119,13 @@ def build_graph(provider: str = "groq"):
186
  )
187
  else:
188
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
189
- # Bind tools to LLM
190
  llm_with_tools = llm.bind_tools(tools)
191
 
192
- # Node
193
  def assistant(state: MessagesState):
194
- """Assistant node"""
195
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
196
-
197
  def retriever(state: MessagesState):
198
- """Retriever node"""
199
  similar_question = vector_store.similarity_search(state["messages"][0].content)
200
  example_msg = HumanMessage(
201
  content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
@@ -208,22 +138,16 @@ def build_graph(provider: str = "groq"):
208
  builder.add_node("tools", ToolNode(tools))
209
  builder.add_edge(START, "retriever")
210
  builder.add_edge("retriever", "assistant")
211
- builder.add_conditional_edges(
212
- "assistant",
213
- tools_condition,
214
- )
215
  builder.add_edge("tools", "assistant")
216
 
217
- # Compile graph
218
  return builder.compile()
219
 
220
- # test
221
  if __name__ == "__main__":
222
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
223
- # Build the graph
224
  graph = build_graph(provider="groq")
225
- # Run the graph
226
  messages = [HumanMessage(content=question)]
227
  messages = graph.invoke({"messages": messages})
228
  for m in messages["messages"]:
229
- m.pretty_print()
 
16
  from langchain.tools.retriever import create_retriever_tool
17
  from supabase.client import Client, create_client
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  load_dotenv()
20
 
21
  @tool
22
  def multiply(a: int, b: int) -> int:
 
 
 
 
 
 
23
  return a * b
24
 
25
  @tool
26
  def add(a: int, b: int) -> int:
 
 
 
 
 
 
27
  return a + b
28
 
29
  @tool
30
  def subtract(a: int, b: int) -> int:
 
 
 
 
 
 
31
  return a - b
32
 
33
  @tool
34
+ def divide(a: int, b: int) -> float:
 
 
 
 
 
 
35
  if b == 0:
36
  raise ValueError("Cannot divide by zero.")
37
  return a / b
38
 
39
  @tool
40
  def modulus(a: int, b: int) -> int:
 
 
 
 
 
 
41
  return a % b
42
 
43
  @tool
44
  def wiki_search(query: str) -> str:
 
 
 
 
45
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
46
+ formatted = "\n\n---\n\n".join(
47
  [
48
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
49
  for doc in search_docs
50
  ])
51
+ return {"wiki_results": formatted}
52
 
53
  @tool
54
  def web_search(query: str) -> str:
 
 
 
 
55
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
56
+ formatted = "\n\n---\n\n".join(
57
  [
58
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
59
  for doc in search_docs
60
  ])
61
+ return {"web_results": formatted}
62
 
63
  @tool
64
  def arvix_search(query: str) -> str:
 
 
 
 
65
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
66
+ formatted = "\n\n---\n\n".join(
67
  [
68
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
69
  for doc in search_docs
70
  ])
71
+ return {"arvix_results": formatted}
 
72
 
73
 
74
+ # Load system prompt
75
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
76
  system_prompt = f.read()
77
 
 
78
  sys_msg = SystemMessage(content=system_prompt)
79
 
80
+ # Setup vector store retriever
81
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
82
  supabase: Client = create_client(
83
  os.environ.get("SUPABASE_URL"),
84
  os.environ.get("SUPABASE_SERVICE_KEY"))
85
  vector_store = SupabaseVectorStore(
86
  client=supabase,
87
+ embedding=embeddings,
88
  table_name="documents",
89
  query_name="match_documents_langchain",
90
  )
 
94
  description="A tool to retrieve similar questions from a vector store.",
95
  )
96
 
 
 
97
  tools = [
98
  multiply,
99
  add,
 
105
  arvix_search,
106
  ]
107
 
 
108
  def build_graph(provider: str = "groq"):
 
 
109
  if provider == "google":
 
110
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
111
  elif provider == "groq":
112
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
 
113
  elif provider == "huggingface":
 
114
  llm = ChatHuggingFace(
115
  llm=HuggingFaceEndpoint(
116
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
 
119
  )
120
  else:
121
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
122
+
123
  llm_with_tools = llm.bind_tools(tools)
124
 
 
125
  def assistant(state: MessagesState):
 
126
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
127
+
128
  def retriever(state: MessagesState):
 
129
  similar_question = vector_store.similarity_search(state["messages"][0].content)
130
  example_msg = HumanMessage(
131
  content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
 
138
  builder.add_node("tools", ToolNode(tools))
139
  builder.add_edge(START, "retriever")
140
  builder.add_edge("retriever", "assistant")
141
+ builder.add_conditional_edges("assistant", tools_condition)
 
 
 
142
  builder.add_edge("tools", "assistant")
143
 
 
144
  return builder.compile()
145
 
146
+
147
  if __name__ == "__main__":
148
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
 
149
  graph = build_graph(provider="groq")
 
150
  messages = [HumanMessage(content=question)]
151
  messages = graph.invoke({"messages": messages})
152
  for m in messages["messages"]:
153
+ m.pretty_print()