Abdenour Chaoui commited on
Commit
ead3819
Β·
1 Parent(s): 81917a3

add agent

Browse files
Files changed (1) hide show
  1. agent.py +359 -0
agent.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_openai import ChatOpenAI
2
+ from langchain_core.messages import HumanMessage, SystemMessage
3
+ from langchain.tools import tool
4
+
5
+ from langchain_community.document_loaders import WikipediaLoader,ArxivLoader
6
+
7
+ from tavily import TavilyClient
8
+
9
+ from openai import OpenAI
10
+ import base64
11
+ import re
12
+ import os
13
+
14
+
15
+ from typing import TypedDict, Annotated, Literal
16
+
17
+ from langchain_core.messages import (
18
+ AnyMessage, HumanMessage, AIMessage, ToolMessage, SystemMessage
19
+ )
20
+
21
+ from langgraph.graph.message import add_messages
22
+ from langgraph.graph import StateGraph, END
23
+
24
+
25
+
26
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
27
+ TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")
28
+
29
+
30
+ tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
31
+ openai_client = OpenAI(api_key=OPENAI_API_KEY)
32
+
33
+ MAX_STEPS = 15
34
+
35
+ @tool
36
+ def search_wikipedia(query: str, max_docs: int = 3) -> str:
37
+ """Search Wikipedia for general knowledge and return summarized content.
38
+
39
+ Args:
40
+ query: Topic to search (e.g., 'Artificial Intelligence', 'France history')
41
+ max_docs: Maximum number of Wikipedia pages to retrieve
42
+ """
43
+ loader = WikipediaLoader(query=query, load_max_docs=max_docs)
44
+ docs = loader.load()
45
+ return "\n\n".join(doc.page_content[:3000] for doc in docs)
46
+
47
+
48
+ @tool
49
+ def search_arxiv(query: str, max_docs: int = 3) -> str:
50
+ """Search arXiv for scientific papers and return summaries.
51
+
52
+ Args:
53
+ query: Research topic or keywords (e.g., 'transformer attention')
54
+ max_docs: Maximum number of papers to retrieve
55
+ """
56
+ loader = ArxivLoader(query=query, load_max_docs=max_docs)
57
+ docs = loader.load()
58
+ return "\n\n".join(doc.page_content[:3000] for doc in docs)
59
+
60
+
61
+ @tool
62
+ def search_web(query: str, max_results: int = 5) -> str:
63
+ """Search the web for up-to-date information.
64
+
65
+ Args:
66
+ query: Search query (e.g., 'latest OpenAI model 2025')
67
+ max_results: Number of results to return
68
+ """
69
+ response = tavily_client.search(query=query, max_results=max_results)
70
+ results = [f"{r['title']}\n{r['content']}" for r in response["results"]]
71
+ return "\n\n".join(results)
72
+
73
+
74
+
75
+ @tool
76
+ def transcribe_audio(file_path: str) -> str:
77
+ """Transcribe an audio file (mp3, wav) into text.
78
+
79
+ Args:
80
+ file_path: Path to the audio file on disk
81
+ """
82
+ with open(file_path, "rb") as f:
83
+ transcript = openai_client.audio.transcriptions.create(
84
+ model="whisper-1",
85
+ file=f,
86
+ )
87
+ return transcript.text
88
+
89
+
90
+ @tool
91
+ def read_image(file_path: str) -> str:
92
+ """Read an image file and return a description via GPT-4o vision.
93
+
94
+ Args:
95
+ file_path: Path to the image file on disk
96
+ """
97
+ with open(file_path, "rb") as f:
98
+ b64 = base64.b64encode(f.read()).decode("utf-8")
99
+ ext = file_path.rsplit(".", 1)[-1].lower()
100
+ mime = {"jpg": "image/jpeg", "jpeg": "image/jpeg",
101
+ "png": "image/png", "gif": "image/gif",
102
+ "webp": "image/webp"}.get(ext, "image/png")
103
+ response = openai_client.chat.completions.create(
104
+ model="gpt-4o",
105
+ messages=[
106
+ {
107
+ "role": "user",
108
+ "content": [
109
+ {"type": "image_url",
110
+ "image_url": {"url": f"data:{mime};base64,{b64}"}},
111
+ {"type": "text",
112
+ "text": "Describe this image in detail. Extract any text, data, or key information visible."},
113
+ ],
114
+ }
115
+ ],
116
+ max_tokens=1024,
117
+ )
118
+ return response.choices[0].message.content
119
+
120
+ @tool
121
+ def read_file(file_path: str) -> str:
122
+ """Read a file and return its contents."""
123
+ with open(file_path, "r", encoding="utf-8") as f:
124
+ return f.read()
125
+
126
+ @tool
127
+ def python_repl(code: str) -> str:
128
+ """Execute Python code and return stdout + the value of the last expression.
129
+ Useful for arithmetic, data manipulation, and logic tasks.
130
+
131
+ Args:
132
+ code: Valid Python code string
133
+ """
134
+ import io, sys, traceback
135
+ stdout_capture = io.StringIO()
136
+ local_vars: dict = {}
137
+ try:
138
+ sys.stdout = stdout_capture
139
+ exec(code, {}, local_vars) # run all lines
140
+ # try to eval last line as expression
141
+ lines = [l for l in code.strip().splitlines() if l.strip()]
142
+ last_val = ""
143
+ if lines:
144
+ try:
145
+ last_val = repr(eval(lines[-1], {}, local_vars))
146
+ except Exception:
147
+ pass
148
+ except Exception:
149
+ return traceback.format_exc()
150
+ finally:
151
+ sys.stdout = sys.__stdout__
152
+ out = stdout_capture.getvalue()
153
+ return "\n".join(filter(None, [out, last_val])) or "Code executed successfully (no output)."
154
+
155
+
156
+
157
+
158
+ TOOLS = [
159
+ search_wikipedia,
160
+ search_arxiv,
161
+ search_web,
162
+ transcribe_audio,
163
+ read_image,
164
+ read_file,
165
+ python_repl,
166
+ ]
167
+
168
+ TOOL_MAP = {t.name: t for t in TOOLS}
169
+
170
+
171
+ SYSTEM_PROMPT = f"""You are a highly capable AI assistant solving tasks from the GAIA benchmark.
172
+
173
+ ## Core rules (MUST follow)
174
+ 1. THINK before acting: decompose the question and plan which tool(s) you need.
175
+ 2. NEVER call the same tool with the exact same arguments twice.
176
+ If the result was insufficient, use a DIFFERENT query or a DIFFERENT tool.
177
+ 3. If search_wikipedia returns a biography page instead of a discography/list,
178
+ immediately switch to search_web with a more specific query.
179
+ 4. For calculations / counting, always use python_repl β€” never guess numbers.
180
+ 5. Once you have enough information, STOP calling tools and give the final answer.
181
+ 6. You have at most {MAX_STEPS} tool-call rounds total. Budget them wisely.
182
+
183
+ ## Tool selection guide
184
+ - General facts / biography β†’ search_wikipedia (vary query if first try fails)
185
+ - Discographies, filmographies, lists β†’ search_web (Wikipedia tool may miss these)
186
+ - Current events / live data β†’ search_web
187
+ - Scientific papers β†’ search_arxiv
188
+ - Arithmetic / logic β†’ python_repl
189
+ - Provided image file β†’ read_image
190
+ - Provided audio file β†’ transcribe_audio
191
+ - Provided text/csv/json β†’ read_file
192
+
193
+ ## Answer format
194
+ End your FINAL response with exactly:
195
+ FINAL ANSWER: <your answer>
196
+
197
+ Keep it concise β€” no units unless asked, lists comma-separated.
198
+ """
199
+
200
+
201
+ class AgentState(TypedDict):
202
+ messages: Annotated[list[AnyMessage], add_messages]
203
+ step_count: int # counts agent_node invocations
204
+
205
+
206
+ def make_llm(model: str = "gpt-5.4-mini") -> ChatOpenAI:
207
+ return ChatOpenAI(
208
+ model=model,
209
+ temperature=0,
210
+ api_key=OPENAI_API_KEY,
211
+ ).bind_tools(TOOLS)
212
+
213
+
214
+ llm_with_tools = make_llm()
215
+
216
+ _step = 0 # console display counter
217
+
218
+ CYAN = "\033[96m"
219
+ GREEN = "\033[92m"
220
+ YELLOW = "\033[93m"
221
+ RED = "\033[91m"
222
+ BOLD = "\033[1m"
223
+ RESET = "\033[0m"
224
+
225
+
226
+ def _log(label: str, text: str, color: str = RESET) -> None:
227
+ print(f"{color}{'─'*60}{RESET}")
228
+ print(f"{color}[Step {_step}] {label}{RESET}")
229
+ if text.strip():
230
+ print(f"{color}{text.strip()}{RESET}")
231
+
232
+
233
+ def agent_node(state: AgentState) -> AgentState:
234
+ global _step
235
+ _step += 1
236
+ step_count = state.get("step_count", 0) + 1
237
+
238
+ messages = state["messages"]
239
+
240
+ # Inject system prompt on first turn
241
+ if not any(isinstance(m, SystemMessage) for m in messages):
242
+ messages = [SystemMessage(content=SYSTEM_PROMPT)] + messages
243
+
244
+ # Warn model to wrap up when approaching the limit
245
+ if step_count >= MAX_STEPS - 2:
246
+ messages = list(messages) + [HumanMessage(
247
+ content=f"⚠️ You have used {step_count}/{MAX_STEPS} steps. "
248
+ "Do NOT call any more tools. Synthesise what you have and give FINAL ANSWER now."
249
+ )]
250
+
251
+ _log("πŸ€– AGENT THINKING …", "", CYAN)
252
+ response = llm_with_tools.invoke(messages)
253
+
254
+ if response.content:
255
+ _log("πŸ€– AGENT RESPONSE", str(response.content)[:600], CYAN)
256
+
257
+ if response.tool_calls:
258
+ calls_summary = "\n".join(
259
+ f" β€’ {tc['name']}({', '.join(f'{k}={repr(v)}' for k, v in tc['args'].items())})"
260
+ for tc in response.tool_calls
261
+ )
262
+ _log("πŸ”§ TOOL CALLS PLANNED", calls_summary, YELLOW)
263
+ else:
264
+ _log("βœ… AGENT FINISHED (no more tool calls)", "", GREEN)
265
+
266
+ return {"messages": [response], "step_count": step_count}
267
+
268
+
269
+
270
+ def tool_node(state: AgentState) -> AgentState:
271
+ global _step
272
+ last_msg: AIMessage = state["messages"][-1]
273
+ tool_results: list[ToolMessage] = []
274
+
275
+ for tc in last_msg.tool_calls:
276
+ _step += 1
277
+ tool_fn = TOOL_MAP.get(tc["name"])
278
+ _log(f"βš™οΈ RUNNING: {tc['name']}",
279
+ "\n".join(f" {k}: {repr(v)}" for k, v in tc["args"].items()),
280
+ YELLOW)
281
+
282
+ if tool_fn is None:
283
+ result = f"ERROR: unknown tool '{tc['name']}'"
284
+ _log("❌ TOOL ERROR", result, RED)
285
+ else:
286
+ try:
287
+ result = tool_fn.invoke(tc["args"])
288
+ preview = str(result)[:500] + ("…" if len(str(result)) > 500 else "")
289
+ _log(f"πŸ“₯ RESULT: {tc['name']}", preview, GREEN)
290
+ except Exception as exc:
291
+ result = f"ERROR calling {tc['name']}: {exc}"
292
+ _log(f"❌ TOOL ERROR: {tc['name']}", result, RED)
293
+
294
+ tool_results.append(
295
+ ToolMessage(content=str(result), tool_call_id=tc["id"])
296
+ )
297
+
298
+ return {"messages": tool_results}
299
+
300
+
301
+
302
+ def should_continue(state: AgentState) -> Literal["tools", "end"]:
303
+ step_count = state.get("step_count", 0)
304
+
305
+ if step_count >= MAX_STEPS:
306
+ print(f"{RED}{'─'*60}")
307
+ print(f"β›” MAX_STEPS ({MAX_STEPS}) reached β€” forcing end.{RESET}")
308
+ return "end"
309
+
310
+ last = state["messages"][-1]
311
+ if isinstance(last, AIMessage) and last.tool_calls:
312
+ return "tools"
313
+
314
+ return "end"
315
+
316
+
317
+ def build_graph() -> StateGraph:
318
+ g = StateGraph(AgentState)
319
+ g.add_node("agent", agent_node)
320
+ g.add_node("tools", tool_node)
321
+ g.set_entry_point("agent")
322
+ g.add_conditional_edges("agent", should_continue, {"tools": "tools", "end": END})
323
+ g.add_edge("tools", "agent") # always return to agent after tool use
324
+ return g.compile()
325
+
326
+
327
+ graph = build_graph()
328
+
329
+
330
+ def run_agent(question: str, file_path: str | None = None) -> str:
331
+ """Run the agent on a GAIA question and return the extracted final answer."""
332
+ global _step
333
+ _step = 0
334
+
335
+ print(f"\n{BOLD}{'═'*60}{RESET}")
336
+ print(f"{BOLD}❓ QUESTION: {question}{RESET}")
337
+ if file_path:
338
+ print(f"{BOLD}πŸ“Ž FILE: {file_path}{RESET}")
339
+ print(f"{BOLD}{'═'*60}{RESET}\n")
340
+
341
+ content = question
342
+ if file_path:
343
+ content += f"\n\n[Attached file available at: {file_path}]"
344
+
345
+ result = graph.invoke({
346
+ "messages": [HumanMessage(content=content)],
347
+ "step_count": 0,
348
+ })
349
+
350
+ last_msg = result["messages"][-1]
351
+ text = last_msg.content if isinstance(last_msg, AIMessage) else str(last_msg)
352
+
353
+ match = re.search(r"FINAL ANSWER:\s*(.+)", text, re.IGNORECASE | re.DOTALL)
354
+ answer = match.group(1).strip() if match else text.strip()
355
+
356
+ print(f"\n{BOLD}{GREEN}{'═'*60}{RESET}")
357
+ print(f"{BOLD}{GREEN}🏁 FINAL ANSWER: {answer}{RESET}")
358
+ print(f"{BOLD}{GREEN}{'═'*60}{RESET}\n")
359
+ return answer