aelyazid commited on
Commit
fc8feae
·
verified ·
1 Parent(s): 9f72fd2

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +71 -145
agent.py CHANGED
@@ -1,225 +1,151 @@
1
- """LangGraph Agent"""
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
  from langgraph.prebuilt import tools_condition
6
  from langgraph.prebuilt import ToolNode
7
- from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain_groq import ChatGroq
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
- from langchain_community.tools.tavily_search import TavilySearchResults
11
- from langchain_community.document_loaders import WikipediaLoader
12
- from langchain_community.document_loaders import ArxivLoader
13
- from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
  from langchain.tools.retriever import create_retriever_tool
17
- from supabase.client import Client, create_client
18
-
19
- class BasicAgent:
20
- def __init__(self, provider="groq"):
21
- self.graph = build_graph(provider=provider)
22
-
23
- def run(self, question: str) -> str:
24
- messages = [HumanMessage(content=question)]
25
- response = self.graph.invoke({"messages": messages})
26
- # Return the last message from the assistant node (the answer)
27
- return response["messages"][-1].content
28
 
 
29
 
 
 
 
30
 
31
- load_dotenv()
32
 
 
33
  @tool
34
  def multiply(a: int, b: int) -> int:
35
- """Multiply two numbers.
36
- Args:
37
- a: first int
38
- b: second int
39
- """
40
  return a * b
41
 
42
  @tool
43
  def add(a: int, b: int) -> int:
44
- """Add two numbers.
45
-
46
- Args:
47
- a: first int
48
- b: second int
49
- """
50
  return a + b
51
 
52
  @tool
53
  def subtract(a: int, b: int) -> int:
54
- """Subtract two numbers.
55
-
56
- Args:
57
- a: first int
58
- b: second int
59
- """
60
  return a - b
61
 
62
  @tool
63
- def divide(a: int, b: int) -> int:
64
- """Divide two numbers.
65
-
66
- Args:
67
- a: first int
68
- b: second int
69
- """
70
  if b == 0:
71
  raise ValueError("Cannot divide by zero.")
72
  return a / b
73
 
74
  @tool
75
  def modulus(a: int, b: int) -> int:
76
- """Get the modulus of two numbers.
77
-
78
- Args:
79
- a: first int
80
- b: second int
81
- """
82
  return a % b
83
 
84
  @tool
85
  def wiki_search(query: str) -> str:
86
- """Search Wikipedia for a query and return maximum 2 results.
87
-
88
- Args:
89
- query: The search query."""
90
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
91
- formatted_search_docs = "\n\n---\n\n".join(
92
- [
93
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
94
- for doc in search_docs
95
- ])
96
- return {"wiki_results": formatted_search_docs}
97
 
98
  @tool
99
  def web_search(query: str) -> str:
100
- """Search Tavily for a query and return maximum 3 results.
101
-
102
- Args:
103
- query: The search query."""
104
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
105
- formatted_search_docs = "\n\n---\n\n".join(
106
- [
107
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
108
- for doc in search_docs
109
- ])
110
- return {"web_results": formatted_search_docs}
111
 
112
  @tool
113
- def arvix_search(query: str) -> str:
114
- """Search Arxiv for a query and return maximum 3 result.
115
-
116
- Args:
117
- query: The search query."""
118
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
119
- formatted_search_docs = "\n\n---\n\n".join(
120
- [
121
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
122
- for doc in search_docs
123
- ])
124
- return {"arvix_results": formatted_search_docs}
125
-
126
 
 
127
 
128
- # load the system prompt from the file
129
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
130
- system_prompt = f.read()
131
 
132
- # System message
133
- sys_msg = SystemMessage(content=system_prompt)
 
 
134
 
135
- # build a retriever
136
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
137
- supabase: Client = create_client(
138
- os.environ.get("SUPABASE_URL"),
139
- os.environ.get("SUPABASE_SERVICE_KEY"))
140
  vector_store = SupabaseVectorStore(
141
  client=supabase,
142
- embedding= embeddings,
143
  table_name="documents",
144
  query_name="match_documents_langchain",
145
  )
146
- create_retriever_tool = create_retriever_tool(
 
147
  retriever=vector_store.as_retriever(),
148
  name="Question Search",
149
- description="A tool to retrieve similar questions from a vector store.",
150
  )
151
 
 
 
152
 
153
-
154
- tools = [
155
- multiply,
156
- add,
157
- subtract,
158
- divide,
159
- modulus,
160
- wiki_search,
161
- web_search,
162
- arvix_search,
163
- ]
164
-
165
- # Build graph function
166
- def build_graph(provider: str = "groq"):
167
- """Build the graph"""
168
- # Load environment variables from .env file
169
  if provider == "google":
170
- # Google Gemini
171
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
172
  elif provider == "groq":
173
- # Groq https://console.groq.com/docs/models
174
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
175
  elif provider == "huggingface":
176
- # TODO: Add huggingface endpoint
177
  llm = ChatHuggingFace(
178
  llm=HuggingFaceEndpoint(
179
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
180
  temperature=0,
181
- ),
182
  )
183
  else:
184
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
185
- # Bind tools to LLM
186
  llm_with_tools = llm.bind_tools(tools)
187
 
188
- # Node
189
- def assistant(state: MessagesState):
190
- """Assistant node"""
191
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
192
-
193
- def retriever(state: MessagesState):
194
- """Retriever node"""
195
- similar_question = vector_store.similarity_search(state["messages"][0].content)
196
  example_msg = HumanMessage(
197
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
198
  )
199
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
200
 
 
 
 
 
201
  builder = StateGraph(MessagesState)
202
- builder.add_node("retriever", retriever)
203
- builder.add_node("assistant", assistant)
204
  builder.add_node("tools", ToolNode(tools))
 
205
  builder.add_edge(START, "retriever")
206
  builder.add_edge("retriever", "assistant")
207
- builder.add_conditional_edges(
208
- "assistant",
209
- tools_condition,
210
- )
211
  builder.add_edge("tools", "assistant")
212
 
213
- # Compile graph
214
  return builder.compile()
215
 
216
- # test
 
 
 
 
 
 
 
 
217
  if __name__ == "__main__":
218
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
219
- # Build the graph
220
- graph = build_graph(provider="groq")
221
- # Run the graph
222
- messages = [HumanMessage(content=question)]
223
- messages = graph.invoke({"messages": messages})
224
- for m in messages["messages"]:
225
- m.pretty_print()
 
 
1
  import os
2
  from dotenv import load_dotenv
3
  from langgraph.graph import START, StateGraph, MessagesState
4
  from langgraph.prebuilt import tools_condition
5
  from langgraph.prebuilt import ToolNode
 
 
 
 
 
 
 
6
  from langchain_core.messages import SystemMessage, HumanMessage
7
  from langchain_core.tools import tool
8
  from langchain.tools.retriever import create_retriever_tool
9
+ from supabase.client import create_client
10
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
11
+ from langchain_community.tools.tavily_search import TavilySearchResults
12
+ from langchain_community.vectorstores import SupabaseVectorStore
13
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
14
+ from langchain_groq import ChatGroq
15
+ from langchain_google_genai import ChatGoogleGenerativeAI
 
 
 
 
16
 
17
+ load_dotenv()
18
 
19
+ # Load system prompt from file
20
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
21
+ system_prompt_text = f.read()
22
 
23
+ sys_msg = SystemMessage(content=system_prompt_text)
24
 
25
+ # Define simple math tools as example
26
  @tool
27
  def multiply(a: int, b: int) -> int:
 
 
 
 
 
28
  return a * b
29
 
30
  @tool
31
  def add(a: int, b: int) -> int:
 
 
 
 
 
 
32
  return a + b
33
 
34
  @tool
35
  def subtract(a: int, b: int) -> int:
 
 
 
 
 
 
36
  return a - b
37
 
38
  @tool
39
+ def divide(a: int, b: int) -> float:
 
 
 
 
 
 
40
  if b == 0:
41
  raise ValueError("Cannot divide by zero.")
42
  return a / b
43
 
44
  @tool
45
  def modulus(a: int, b: int) -> int:
 
 
 
 
 
 
46
  return a % b
47
 
48
  @tool
49
  def wiki_search(query: str) -> str:
50
+ docs = WikipediaLoader(query=query, load_max_docs=2).load()
51
+ return "\n\n---\n\n".join(
52
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
53
+ for doc in docs
54
+ )
 
 
 
 
 
 
55
 
56
  @tool
57
  def web_search(query: str) -> str:
58
+ docs = TavilySearchResults(max_results=3).invoke(query=query)
59
+ return "\n\n---\n\n".join(
60
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
61
+ for doc in docs
62
+ )
 
 
 
 
 
 
63
 
64
  @tool
65
+ def arxiv_search(query: str) -> str:
66
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
67
+ return "\n\n---\n\n".join(
68
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
69
+ for doc in docs
70
+ )
 
 
 
 
 
 
 
71
 
72
+ tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arxiv_search]
73
 
74
+ # Setup Supabase vector store retriever
75
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
 
76
 
77
+ supabase = create_client(
78
+ os.environ.get("SUPABASE_URL"),
79
+ os.environ.get("SUPABASE_SERVICE_KEY")
80
+ )
81
 
 
 
 
 
 
82
  vector_store = SupabaseVectorStore(
83
  client=supabase,
84
+ embedding=embeddings,
85
  table_name="documents",
86
  query_name="match_documents_langchain",
87
  )
88
+
89
+ retriever_tool = create_retriever_tool(
90
  retriever=vector_store.as_retriever(),
91
  name="Question Search",
92
+ description="Retrieve similar questions from vector store.",
93
  )
94
 
95
+ # Add retriever tool if you want
96
+ tools.append(retriever_tool)
97
 
98
+ def build_graph(provider="groq"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  if provider == "google":
 
100
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
101
  elif provider == "groq":
102
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
 
103
  elif provider == "huggingface":
 
104
  llm = ChatHuggingFace(
105
  llm=HuggingFaceEndpoint(
106
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
107
  temperature=0,
108
+ )
109
  )
110
  else:
111
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
112
+
113
  llm_with_tools = llm.bind_tools(tools)
114
 
115
+ def retriever_node(state: MessagesState):
116
+ similar = vector_store.similarity_search(state["messages"][0].content)
 
 
 
 
 
 
117
  example_msg = HumanMessage(
118
+ content=f"Here is a similar question and answer for reference:\n\n{similar[0].page_content}"
119
  )
120
  return {"messages": [sys_msg] + state["messages"] + [example_msg]}
121
 
122
+ def assistant_node(state: MessagesState):
123
+ # This will prompt the model with system prompt + question + context, expecting reasoning + FINAL ANSWER
124
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
125
+
126
  builder = StateGraph(MessagesState)
127
+ builder.add_node("retriever", retriever_node)
128
+ builder.add_node("assistant", assistant_node)
129
  builder.add_node("tools", ToolNode(tools))
130
+
131
  builder.add_edge(START, "retriever")
132
  builder.add_edge("retriever", "assistant")
133
+ builder.add_conditional_edges("assistant", tools_condition)
 
 
 
134
  builder.add_edge("tools", "assistant")
135
 
 
136
  return builder.compile()
137
 
138
+ class BasicAgent:
139
+ def __init__(self, provider="groq"):
140
+ self.graph = build_graph(provider=provider)
141
+ def run(self, question: str) -> str:
142
+ messages = [HumanMessage(content=question)]
143
+ response = self.graph.invoke({"messages": messages})
144
+ # Return the last message (should start with FINAL ANSWER)
145
+ return response["messages"][-1].content
146
+
147
  if __name__ == "__main__":
148
+ agent = BasicAgent(provider="groq")
149
+ q = "When was a picture of St Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
150
+ print(agent.run(q))
151
+