Frazer2810 commited on
Commit
d4f42f4
·
verified ·
1 Parent(s): c0f663a

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +236 -0
agent.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent with OpenAI"""
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition, ToolNode
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain_community.document_loaders import WikipediaLoader
8
+ from langchain_community.document_loaders import ArxivLoader
9
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
10
+ from langchain_core.tools import tool
11
+
12
+ load_dotenv()
13
+
14
+ # Tools definition
15
+ @tool
16
+ def multiply(a: int, b: int) -> int:
17
+ """Multiply two numbers.
18
+
19
+ Args:
20
+ a: first int
21
+ b: second int
22
+ """
23
+ return a * b
24
+
25
+ @tool
26
+ def add(a: int, b: int) -> int:
27
+ """Add two numbers.
28
+
29
+ Args:
30
+ a: first int
31
+ b: second int
32
+ """
33
+ return a + b
34
+
35
+ @tool
36
+ def subtract(a: int, b: int) -> int:
37
+ """Subtract two numbers.
38
+
39
+ Args:
40
+ a: first int
41
+ b: second int
42
+ """
43
+ return a - b
44
+
45
+ @tool
46
+ def divide(a: int, b: int) -> float:
47
+ """Divide two numbers.
48
+
49
+ Args:
50
+ a: first int
51
+ b: second int
52
+ """
53
+ if b == 0:
54
+ raise ValueError("Cannot divide by zero.")
55
+ return a / b
56
+
57
+ @tool
58
+ def modulus(a: int, b: int) -> int:
59
+ """Get the modulus of two numbers.
60
+
61
+ Args:
62
+ a: first int
63
+ b: second int
64
+ """
65
+ return a % b
66
+
67
+ @tool
68
+ def wiki_search(query: str) -> str:
69
+ """Search Wikipedia for a query and return maximum 2 results.
70
+
71
+ Args:
72
+ query: The search query.
73
+ """
74
+ try:
75
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
76
+ formatted_search_docs = "\n\n---\n\n".join(
77
+ [
78
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:2000]}\n</Document>'
79
+ for doc in search_docs
80
+ ])
81
+ return formatted_search_docs
82
+ except Exception as e:
83
+ return f"Error searching Wikipedia: {str(e)}"
84
+
85
+
86
+
87
+ @tool
88
+ def arxiv_search(query: str) -> str:
89
+ """Search Arxiv for a query and return maximum 3 results.
90
+
91
+ Args:
92
+ query: The search query.
93
+ """
94
+ try:
95
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
96
+ formatted_search_docs = "\n\n---\n\n".join(
97
+ [
98
+ f'<Document title="{doc.metadata.get("Title", "")}" authors="{doc.metadata.get("Authors", "")}"/>\n{doc.page_content[:1500]}\n</Document>'
99
+ for doc in search_docs
100
+ ])
101
+ return formatted_search_docs
102
+ except Exception as e:
103
+ return f"Error searching Arxiv: {str(e)}"
104
+
105
+ # System prompt
106
+ system_prompt = """You are a helpful AI assistant with access to various tools for calculations and information retrieval.
107
+ You can perform mathematical operations, search Wikipedia, and search academic papers on Arxiv.
108
+ Always try to provide accurate, concise, and helpful responses based on the tools available to you.
109
+ When searching for information, be thorough but concise in your final answer.
110
+ If a question requires multiple steps or tools, break it down and use the appropriate tools in sequence."""
111
+
112
+ # Tools list
113
+ tools = [
114
+ multiply,
115
+ add,
116
+ subtract,
117
+ divide,
118
+ modulus,
119
+ wiki_search,
120
+ arxiv_search,
121
+ ]
122
+
123
+ class LangGraphAgent:
124
+ """LangGraph Agent with OpenAI that can be used in HuggingFace Space evaluation"""
125
+
126
+ def __init__(self):
127
+ """Initialize the agent with OpenAI LLM and tools"""
128
+ print("Initializing LangGraphAgent...")
129
+
130
+ # Get API key from environment
131
+ self.api_key = os.environ.get("OPENAI_KEY") or os.environ.get("OPENAI_API_KEY")
132
+ if not self.api_key:
133
+ raise ValueError("OPENAI_KEY environment variable is required")
134
+
135
+ # Initialize the graph
136
+ self.graph = self._build_graph()
137
+ print("LangGraphAgent initialized successfully.")
138
+
139
+ def _build_graph(self):
140
+ """Build the LangGraph workflow"""
141
+ # Initialize OpenAI LLM
142
+ llm = ChatOpenAI(
143
+ model="gpt-4.1",
144
+ temperature=0.0,
145
+ api_key=self.api_key
146
+ )
147
+
148
+ # Bind tools to LLM
149
+ llm_with_tools = llm.bind_tools(tools)
150
+
151
+ # System message
152
+ sys_msg = SystemMessage(content=system_prompt)
153
+
154
+ # Node functions
155
+ def assistant(state: MessagesState):
156
+ """Assistant node"""
157
+ # Ensure system message is included
158
+ messages = state["messages"]
159
+ if not any(isinstance(msg, SystemMessage) for msg in messages):
160
+ messages = [sys_msg] + messages
161
+
162
+ response = llm_with_tools.invoke(messages)
163
+ return {"messages": [response]}
164
+
165
+ # Build the graph
166
+ builder = StateGraph(MessagesState)
167
+
168
+ # Add nodes
169
+ builder.add_node("assistant", assistant)
170
+ builder.add_node("tools", ToolNode(tools))
171
+
172
+ # Add edges
173
+ builder.add_edge(START, "assistant")
174
+ builder.add_conditional_edges(
175
+ "assistant",
176
+ tools_condition,
177
+ )
178
+ builder.add_edge("tools", "assistant")
179
+
180
+ # Compile and return
181
+ return builder.compile()
182
+
183
+ def __call__(self, question: str) -> str:
184
+ """
185
+ Process a question and return an answer.
186
+
187
+ Args:
188
+ question: The question to answer
189
+
190
+ Returns:
191
+ str: The answer to the question
192
+ """
193
+ print(f"Agent received question (first 100 chars): {question[:100]}...")
194
+
195
+ try:
196
+ # Create message
197
+ messages = [HumanMessage(content=question)]
198
+
199
+ # Invoke the graph
200
+ result = self.graph.invoke({"messages": messages})
201
+
202
+ # Extract the final answer
203
+ ai_messages = [msg for msg in result["messages"] if isinstance(msg, AIMessage)]
204
+
205
+ if ai_messages:
206
+ answer = ai_messages[-1].content
207
+ print(f"Agent returning answer (first 100 chars): {answer[:100]}...")
208
+ return answer
209
+ else:
210
+ return "I couldn't generate a response. Please try again."
211
+
212
+ except Exception as e:
213
+ print(f"Error processing question: {e}")
214
+ return f"Error: {str(e)}"
215
+
216
+ # For backwards compatibility and testing
217
+ BasicAgent = LangGraphAgent
218
+
219
+ if __name__ == "__main__":
220
+ # Test the agent
221
+ print("Testing LangGraphAgent...")
222
+ try:
223
+ agent = LangGraphAgent()
224
+ test_questions = [
225
+ "What is 15 * 23?",
226
+ "Search Wikipedia for information about quantum computing",
227
+ "What are the latest developments in AI according to recent papers on Arxiv?",
228
+ ]
229
+
230
+ for question in test_questions:
231
+ print(f"\nQuestion: {question}")
232
+ answer = agent(question)
233
+ print(f"Answer: {answer}")
234
+
235
+ except Exception as e:
236
+ print(f"Error during testing: {e}")