tsrrus commited on
Commit
92a5106
·
verified ·
1 Parent(s): 28c2754

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +37 -34
agent.py CHANGED
@@ -1,12 +1,19 @@
1
  """LangGraph Agent"""
 
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
  from langgraph.prebuilt import tools_condition
6
  from langgraph.prebuilt import ToolNode
7
  from langchain_google_genai import ChatGoogleGenerativeAI
 
 
8
  from langchain_groq import ChatGroq
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
 
 
 
 
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader
12
  from langchain_community.document_loaders import ArxivLoader
@@ -16,43 +23,42 @@ from langchain_core.tools import tool
16
  from langchain.tools.retriever import create_retriever_tool
17
  from supabase.client import Client, create_client
18
 
19
-
20
  load_dotenv()
21
 
 
22
  @tool
23
  def multiply(a: int, b: int) -> int:
24
  """Multiply two numbers.
25
-
26
  Args:
27
  a: first int
28
  b: second int
29
  """
30
  return a * b
31
 
 
32
  @tool
33
  def add(a: int, b: int) -> int:
34
  """Add two numbers.
35
-
36
  Args:
37
  a: first int
38
  b: second int
39
  """
40
  return a + b
41
 
 
42
  @tool
43
  def subtract(a: int, b: int) -> int:
44
  """Subtract two numbers.
45
-
46
  Args:
47
  a: first int
48
  b: second int
49
  """
50
  return a - b
51
 
 
52
  @tool
53
  def divide(a: int, b: int) -> int:
54
  """Divide two numbers.
55
-
56
  Args:
57
  a: first int
58
  b: second int
@@ -61,20 +67,20 @@ def divide(a: int, b: int) -> int:
61
  raise ValueError("Cannot divide by zero.")
62
  return a / b
63
 
 
64
  @tool
65
  def modulus(a: int, b: int) -> int:
66
  """Get the modulus of two numbers.
67
-
68
  Args:
69
  a: first int
70
  b: second int
71
  """
72
  return a % b
73
 
 
74
  @tool
75
  def wiki_search(query: str) -> str:
76
  """Search Wikipedia for a query and return maximum 2 results.
77
-
78
  Args:
79
  query: The search query."""
80
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
@@ -82,13 +88,14 @@ def wiki_search(query: str) -> str:
82
  [
83
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
84
  for doc in search_docs
85
- ])
 
86
  return {"wiki_results": formatted_search_docs}
87
 
 
88
  @tool
89
  def web_search(query: str) -> str:
90
  """Search Tavily for a query and return maximum 3 results.
91
-
92
  Args:
93
  query: The search query."""
94
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
@@ -96,13 +103,14 @@ def web_search(query: str) -> str:
96
  [
97
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
98
  for doc in search_docs
99
- ])
 
100
  return {"web_results": formatted_search_docs}
101
 
 
102
  @tool
103
  def arvix_search(query: str) -> str:
104
  """Search Arxiv for a query and return maximum 3 result.
105
-
106
  Args:
107
  query: The search query."""
108
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
@@ -110,11 +118,11 @@ def arvix_search(query: str) -> str:
110
  [
111
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
112
  for doc in search_docs
113
- ])
 
114
  return {"arvix_results": formatted_search_docs}
115
 
116
 
117
-
118
  # load the system prompt from the file
119
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
120
  system_prompt = f.read()
@@ -123,13 +131,15 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
123
  sys_msg = SystemMessage(content=system_prompt)
124
 
125
  # build a retriever
126
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
 
 
127
  supabase: Client = create_client(
128
- os.environ.get("SUPABASE_URL"),
129
- os.environ.get("SUPABASE_SERVICE_KEY"))
130
  vector_store = SupabaseVectorStore(
131
  client=supabase,
132
- embedding= embeddings,
133
  table_name="documents",
134
  query_name="match_documents_langchain",
135
  )
@@ -140,7 +150,6 @@ create_retriever_tool = create_retriever_tool(
140
  )
141
 
142
 
143
-
144
  tools = [
145
  multiply,
146
  add,
@@ -152,6 +161,7 @@ tools = [
152
  arvix_search,
153
  ]
154
 
 
155
  # Build graph function
156
  def build_graph(provider: str = "groq"):
157
  """Build the graph"""
@@ -161,9 +171,13 @@ def build_graph(provider: str = "groq"):
161
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
162
  elif provider == "groq":
163
  # Groq https://console.groq.com/docs/models
164
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
 
 
 
 
 
165
  elif provider == "huggingface":
166
- # TODO: Add huggingface endpoint
167
  llm = ChatHuggingFace(
168
  llm=HuggingFaceEndpoint(
169
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
@@ -179,7 +193,7 @@ def build_graph(provider: str = "groq"):
179
  def assistant(state: MessagesState):
180
  """Assistant node"""
181
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
182
-
183
  def retriever(state: MessagesState):
184
  """Retriever node"""
185
  similar_question = vector_store.similarity_search(state["messages"][0].content)
@@ -201,15 +215,4 @@ def build_graph(provider: str = "groq"):
201
  builder.add_edge("tools", "assistant")
202
 
203
  # Compile graph
204
- return builder.compile()
205
-
206
- # test
207
- if __name__ == "__main__":
208
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
209
- # Build the graph
210
- graph = build_graph(provider="groq")
211
- # Run the graph
212
- messages = [HumanMessage(content=question)]
213
- messages = graph.invoke({"messages": messages})
214
- for m in messages["messages"]:
215
- m.pretty_print()
 
1
  """LangGraph Agent"""
2
+
3
  import os
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
6
  from langgraph.prebuilt import tools_condition
7
  from langgraph.prebuilt import ToolNode
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
+ from langchain_openai import ChatOpenAI
10
+ from langchain.agents import initialize_agent, Tool
11
  from langchain_groq import ChatGroq
12
+ from langchain_huggingface import (
13
+ ChatHuggingFace,
14
+ HuggingFaceEndpoint,
15
+ HuggingFaceEmbeddings,
16
+ )
17
  from langchain_community.tools.tavily_search import TavilySearchResults
18
  from langchain_community.document_loaders import WikipediaLoader
19
  from langchain_community.document_loaders import ArxivLoader
 
23
  from langchain.tools.retriever import create_retriever_tool
24
  from supabase.client import Client, create_client
25
 
 
26
  load_dotenv()
27
 
28
+
29
  @tool
30
  def multiply(a: int, b: int) -> int:
31
  """Multiply two numbers.
 
32
  Args:
33
  a: first int
34
  b: second int
35
  """
36
  return a * b
37
 
38
+
39
  @tool
40
  def add(a: int, b: int) -> int:
41
  """Add two numbers.
 
42
  Args:
43
  a: first int
44
  b: second int
45
  """
46
  return a + b
47
 
48
+
49
  @tool
50
  def subtract(a: int, b: int) -> int:
51
  """Subtract two numbers.
 
52
  Args:
53
  a: first int
54
  b: second int
55
  """
56
  return a - b
57
 
58
+
59
  @tool
60
  def divide(a: int, b: int) -> int:
61
  """Divide two numbers.
 
62
  Args:
63
  a: first int
64
  b: second int
 
67
  raise ValueError("Cannot divide by zero.")
68
  return a / b
69
 
70
+
71
  @tool
72
  def modulus(a: int, b: int) -> int:
73
  """Get the modulus of two numbers.
 
74
  Args:
75
  a: first int
76
  b: second int
77
  """
78
  return a % b
79
 
80
+
81
  @tool
82
  def wiki_search(query: str) -> str:
83
  """Search Wikipedia for a query and return maximum 2 results.
 
84
  Args:
85
  query: The search query."""
86
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
 
88
  [
89
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
90
  for doc in search_docs
91
+ ]
92
+ )
93
  return {"wiki_results": formatted_search_docs}
94
 
95
+
96
  @tool
97
  def web_search(query: str) -> str:
98
  """Search Tavily for a query and return maximum 3 results.
 
99
  Args:
100
  query: The search query."""
101
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
 
103
  [
104
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
105
  for doc in search_docs
106
+ ]
107
+ )
108
  return {"web_results": formatted_search_docs}
109
 
110
+
111
  @tool
112
  def arvix_search(query: str) -> str:
113
  """Search Arxiv for a query and return maximum 3 result.
 
114
  Args:
115
  query: The search query."""
116
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
 
118
  [
119
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
120
  for doc in search_docs
121
+ ]
122
+ )
123
  return {"arvix_results": formatted_search_docs}
124
 
125
 
 
126
  # load the system prompt from the file
127
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
128
  system_prompt = f.read()
 
131
  sys_msg = SystemMessage(content=system_prompt)
132
 
133
  # build a retriever
134
+ embeddings = HuggingFaceEmbeddings(
135
+ model_name="sentence-transformers/all-mpnet-base-v2"
136
+ ) # dim=768
137
  supabase: Client = create_client(
138
+ os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_KEY")
139
+ )
140
  vector_store = SupabaseVectorStore(
141
  client=supabase,
142
+ embedding=embeddings,
143
  table_name="documents",
144
  query_name="match_documents_langchain",
145
  )
 
150
  )
151
 
152
 
 
153
  tools = [
154
  multiply,
155
  add,
 
161
  arvix_search,
162
  ]
163
 
164
+
165
  # Build graph function
166
  def build_graph(provider: str = "groq"):
167
  """Build the graph"""
 
171
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
172
  elif provider == "groq":
173
  # Groq https://console.groq.com/docs/models
174
+ llm = ChatGroq(
175
+ model="qwen-qwq-32b", temperature=0
176
+ ) # optional : qwen-qwq-32b gemma2-9b-it
177
+ elif provider == "openai":
178
+ # OpenAI
179
+ llm = ChatOpenAI(model="gpt-4", temperature=0)
180
  elif provider == "huggingface":
 
181
  llm = ChatHuggingFace(
182
  llm=HuggingFaceEndpoint(
183
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
 
193
  def assistant(state: MessagesState):
194
  """Assistant node"""
195
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
196
+
197
  def retriever(state: MessagesState):
198
  """Retriever node"""
199
  similar_question = vector_store.similarity_search(state["messages"][0].content)
 
215
  builder.add_edge("tools", "assistant")
216
 
217
  # Compile graph
218
+ return builder.compile()