bwilkie commited on
Commit
139eea1
·
verified ·
1 Parent(s): a45e612

Create agent_simple.py

Browse files
Files changed (1) hide show
  1. agent_simple.py +190 -0
agent_simple.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent"""
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
10
+ from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain_community.document_loaders import WikipediaLoader
12
+ from langchain_community.document_loaders import ArxivLoader
13
+ from langchain_core.messages import SystemMessage, HumanMessage
14
+ from langchain_core.tools import tool
15
+
16
+ load_dotenv()
17
+
18
+ from langchain_core.rate_limiters import InMemoryRateLimiter
19
+
20
+ # Create a rate limiter
21
+ rate_limiter = InMemoryRateLimiter(
22
+ requests_per_second=0.1, # Once every 10 seconds
23
+ check_every_n_seconds=0.1,
24
+ max_bucket_size=10,
25
+ )
26
+
27
+
28
+ @tool
29
+ def multiply(a: int, b: int) -> int:
30
+ """Multiply two numbers.
31
+ Args:
32
+ a: first int
33
+ b: second int
34
+ """
35
+ return a * b
36
+
37
+ @tool
38
+ def add(a: int, b: int) -> int:
39
+ """Add two numbers.
40
+
41
+ Args:
42
+ a: first int
43
+ b: second int
44
+ """
45
+ return a + b
46
+
47
+ @tool
48
+ def subtract(a: int, b: int) -> int:
49
+ """Subtract two numbers.
50
+
51
+ Args:
52
+ a: first int
53
+ b: second int
54
+ """
55
+ return a - b
56
+
57
+ @tool
58
+ def divide(a: int, b: int) -> int:
59
+ """Divide two numbers.
60
+
61
+ Args:
62
+ a: first int
63
+ b: second int
64
+ """
65
+ if b == 0:
66
+ raise ValueError("Cannot divide by zero.")
67
+ return a / b
68
+
69
+ @tool
70
+ def modulus(a: int, b: int) -> int:
71
+ """Get the modulus of two numbers.
72
+
73
+ Args:
74
+ a: first int
75
+ b: second int
76
+ """
77
+ return a % b
78
+
79
+ @tool
80
+ def wiki_search(query: str) -> str:
81
+ """Search Wikipedia for a query and return maximum 2 results.
82
+
83
+ Args:
84
+ query: The search query."""
85
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
86
+ formatted_search_docs = "\n\n---\n\n".join(
87
+ [
88
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
89
+ for doc in search_docs
90
+ ])
91
+ return {"wiki_results": formatted_search_docs}
92
+
93
+ @tool
94
+ def web_search(query: str) -> str:
95
+ """Search Tavily for a query and return maximum 3 results.
96
+
97
+ Args:
98
+ query: The search query."""
99
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
100
+ formatted_search_docs = "\n\n---\n\n".join(
101
+ [
102
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
103
+ for doc in search_docs
104
+ ])
105
+ return {"web_results": formatted_search_docs}
106
+
107
+ @tool
108
+ def arvix_search(query: str) -> str:
109
+ """Search Arxiv for a query and return maximum 3 result.
110
+
111
+ Args:
112
+ query: The search query."""
113
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
114
+ formatted_search_docs = "\n\n---\n\n".join(
115
+ [
116
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
117
+ for doc in search_docs
118
+ ])
119
+ return {"arvix_results": formatted_search_docs}
120
+
121
+
122
+
123
+ # load the system prompt from the file
124
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
125
+ system_prompt = f.read()
126
+
127
+ # System message
128
+ sys_msg = SystemMessage(content=system_prompt)
129
+
130
+
131
+ tools = [
132
+ multiply,
133
+ add,
134
+ subtract,
135
+ divide,
136
+ modulus,
137
+ wiki_search,
138
+ web_search,
139
+ arvix_search,
140
+ ]
141
+
142
+ # Build graph function
143
+ def build_graph(provider: str = "groq"):
144
+ """Build the graph"""
145
+ # Load environment variables from .env file
146
+ if provider == "google":
147
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0, rate_limiter=rate_limiter)
148
+ elif provider == "groq":
149
+ llm = ChatGroq(model="qwen-qwen3-32b", temperature=0, rate_limiter=rate_limiter)
150
+ elif provider == "huggingface":
151
+ llm = ChatHuggingFace(
152
+ llm=HuggingFaceEndpoint(
153
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
154
+ temperature=0,
155
+ rate_limiter=rate_limiter,
156
+ ),
157
+ )
158
+ else:
159
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
160
+ # Bind tools to LLM
161
+ llm_with_tools = llm.bind_tools(tools)
162
+
163
+ # Node
164
+ def assistant(state: MessagesState):
165
+ """Assistant node"""
166
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
167
+
168
+ builder = StateGraph(MessagesState)
169
+ builder.add_node("assistant", assistant)
170
+ builder.add_node("tools", ToolNode(tools))
171
+ builder.add_edge(START, "assistant")
172
+ builder.add_conditional_edges(
173
+ "assistant",
174
+ tools_condition,
175
+ )
176
+ builder.add_edge("tools", "assistant")
177
+
178
+ # Compile graph
179
+ return builder.compile()
180
+
181
+ # test
182
+ if __name__ == "__main__":
183
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
184
+ # Build the graph
185
+ graph = build_graph(provider="groq")
186
+ # Run the graph
187
+ messages = [HumanMessage(content=question)]
188
+ messages = graph.invoke({"messages": messages})
189
+ for m in messages["messages"]:
190
+ m.pretty_print()