DrekFretson commited on
Commit
3872031
·
verified ·
1 Parent(s): 03caaac

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +53 -350
agent.py CHANGED
@@ -1,353 +1,56 @@
1
 
2
- import os
3
- import time
4
- from langchain.tools import Tool, tool
5
- from typing import Tuple, List
6
- from typing_extensions import TypedDict, Annotated, Optional
7
- from langgraph.graph.message import add_messages
8
- from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
9
- from langgraph.graph import START, StateGraph, END
10
- from langgraph.prebuilt import ToolNode, tools_condition
11
- from langchain_litellm import ChatLiteLLM
12
- from IPython.display import Image, display
13
- import asyncio
14
- from tools import (search_tool,
15
- download_tool,
16
- get_web_page,
17
- add,
18
- subtract,
19
- multiply,
20
- divide,
21
- power,
22
- square_root,
23
- get_information_from_wikipedia,
24
- get_information_from_arxiv,
25
- get_information_from_youtube,
26
- python_tool,
27
- get_information_from_json,
28
- get_information_from_audio,
29
- get_information_from_xml,
30
- get_information_from_docx,
31
- get_information_from_txt,
32
- get_information_from_pdf,
33
- get_information_from_csv,
34
- get_information_from_excel,
35
- get_information_from_pdb,
36
- get_information_from_image,
37
- get_information_from_pptx,
38
- get_all_files_from_zip,
39
- get_information_from_python)
40
-
41
- DELAY = 5
42
- TIME_SLEEP = 60/15 + DELAY
43
-
44
- GEMINI_API_KEY_1 = os.getenv("GOOGLE_API_KEY_1")
45
- GEMINI_API_KEY_2 = os.getenv("GOOGLE_API_KEY_2")
46
- GEMINI_API_KEY_3 = os.getenv("GOOGLE_API_KEY_3")
47
-
48
- chat_model_1 = ChatLiteLLM(model="gemini/gemini-2.0-flash",
49
- temperature=0,
50
- api_key=GEMINI_API_KEY_1,
51
- max_retries=10,
52
- verbose=True)
53
-
54
- chat_model_2 = ChatLiteLLM(model="gemini/gemini-2.0-flash",
55
- temperature=0,
56
- api_key=GEMINI_API_KEY_2,
57
- max_retries=10,
58
- verbose=True)
59
-
60
- chat_model_3 = ChatLiteLLM(model="gemini/gemini-2.0-flash",
61
- temperature=0,
62
- api_key=GEMINI_API_KEY_3,
63
- max_retries=10,
64
- verbose=True)
65
-
66
- class AgentState(TypedDict):
67
- messages: Annotated[list[AnyMessage], add_messages]
68
- question: Optional[str]
69
- file_path: Optional[str]
70
- task_id: Optional[str]
71
- new_messages: Optional[int]
72
- final_answer: Optional[str]
73
- attempt: Optional[int]
74
- chat_model: Optional[int]
75
-
76
- class MyAgent:
77
- def __init__(self, web_tools=None):
78
- print("MyAgent initialized.")
79
-
80
- self.chat_1 = chat_model_1
81
- self.chat_2 = chat_model_2
82
- self.chat_3 = chat_model_3
83
-
84
- self.tools = [search_tool,
85
- download_tool,
86
- get_web_page,
87
- add,
88
- subtract,
89
- multiply,
90
- divide,
91
- power,
92
- square_root,
93
- get_information_from_wikipedia,
94
- get_information_from_arxiv,
95
- get_information_from_youtube,
96
- python_tool,
97
- get_information_from_json,
98
- get_information_from_audio,
99
- get_information_from_xml,
100
- get_information_from_docx,
101
- get_information_from_txt,
102
- get_information_from_pdf,
103
- get_information_from_csv,
104
- get_information_from_excel,
105
- get_information_from_pdb,
106
- get_information_from_image,
107
- get_information_from_pptx,
108
- get_all_files_from_zip] + web_tools
109
-
110
- self.chat_with_tools_1 = self.chat_1.bind_tools(self.tools, verbose=True)
111
- self.chat_with_tools_2 = self.chat_2.bind_tools(self.tools, verbose=True)
112
- self.chat_with_tools_3 = self.chat_3.bind_tools(self.tools, verbose=True)
113
- self.chats = [self.chat_with_tools_1, self.chat_with_tools_2, self.chat_with_tools_3]
114
-
115
- self.builder = StateGraph(AgentState)
116
- self.builder.add_node("assistant", self.assistant)
117
- self.builder.add_node("tools", ToolNode(self.tools))
118
- self.builder.add_node("extract_data_from_file", self.extract_data_from_file)
119
- self.builder.add_node("postprocess", self.postprocess)
120
-
121
- self.builder.add_edge(START, "extract_data_from_file")
122
- self.builder.add_edge("extract_data_from_file", "assistant")
123
- self.builder.add_conditional_edges(
124
- "assistant",
125
- self.assistant_router,
126
- {
127
- "tools": "tools",
128
- "postprocess": "postprocess"
129
- }
130
  )
131
- self.builder.add_edge("tools", "assistant")
132
- self.builder.add_conditional_edges(
133
- "postprocess",
134
- self.answer_evaluation,
135
- {
136
- "RETRY": "assistant",
137
- "END": END
138
- }
139
- )
140
-
141
- self.agent = self.builder.compile()
142
-
143
- async def __call__(self, question: str, file_path: str, task_id: str) -> str:
144
- print("\033[1m\033[93m"+"="*150+"\033[0m")
145
- print(f"QUESTION: {question}")
146
- print(f"File: {file_path}")
147
- prompt = f"""You are a general AI assistant. You will receive a user question and extracted data from associated files.
148
- Follow this process:
149
- 1. Identify the required output type (e.g., number, string, list) and key concepts in the question.
150
- 2. Before using any tools, check if the answer can be deduced or recalled directly. If yes, answer immediately. Never guess.
151
- 3. If tools are needed:
152
- - Create a plan with:
153
- - The reasoning approach and tool sequence.
154
- - A rephrased version of the question optimized for search engines (DuckDuckGo or Google).
155
- - Search queries must:
156
- - Be keyword-focused (avoid full sentences).
157
- - Use advanced operators if helpful: `site:` for domains, `inurl:` for internal paths, `filetype:` for formats.
158
- - Avoid punctuation, commas, quotes, or special characters.
159
- - Cover multiple query angles if needed.
160
- 4. Do not run any tool until the plan is complete.
161
- 5. If a tool fails or returns no useful result:
162
- - Reformulate the query with synonyms or tighter context.
163
- - Retry or use a fallback tool.
164
- 6. Analyze tool results carefully. If multiple source links appear, use `navigate_browser` to explore and extract relevant information from each.
165
- Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
166
- YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
167
- If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
168
- If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
169
- If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""
170
-
171
- user_message = f"""Question: {question}
172
- Filepath: {file_path}"""
173
-
174
- messages = [SystemMessage(content=prompt, name="SYSTEM"),
175
- HumanMessage(content=user_message, name="USER")]
176
-
177
- response = await self.agent.ainvoke({"messages": messages,
178
- "question": question,
179
- "file_path": file_path,
180
- "task_id": task_id,
181
- "new_messages": 1,
182
- "chat_model": 0,
183
- "final_answer": "",
184
- "attempt": 0},
185
- {"recursion_limit": 100})
186
-
187
- print("\033[1m\033[93m"+"="*150+"\033[0m")
188
- return response['final_answer']
189
-
190
- async def call_chat(self, chat, state: AgentState, max_retries=5):
191
- from google.api_core.exceptions import GoogleAPICallError
192
- for i in range(max_retries):
193
- try:
194
- return await chat.ainvoke(state["messages"])
195
- except GoogleAPICallError as e:
196
- if "503" in str(e) or "UNAVAILABLE" in str(e):
197
- wait = 2 ** i
198
- print(f"[Gemini] Overloaded (attempt {i+1}), retrying in {wait:.1f}s...")
199
- await asyncio.sleep(wait)
200
- else:
201
- raise e
202
- raise RuntimeError("Gemini failed after multiple retries")
203
-
204
- async def assistant(self, state: AgentState):
205
- new_messages = state["new_messages"]
206
-
207
- for i in reversed(range(1, new_messages+1)):
208
- print("\033[1m\033[92m"+"+"*150+"\033[0m")
209
- name = state["messages"][-i].name
210
- content = state["messages"][-i].content
211
- print(f'\033[1m\033[96m{name}\033[0m: {content if len(content) < 5000 else content[:5000]}')
212
-
213
- chat = self.chats[state["chat_model"]]
214
- result = await self.call_chat(chat=chat, state=state)
215
-
216
- state["chat_model"] += 1
217
- state["chat_model"] %= len(self.chats)
218
-
219
- result.name="ASSISTANT"
220
- await asyncio.sleep(TIME_SLEEP)
221
-
222
- print("\033[1m\033[92m"+"+"*150+"\033[0m")
223
- content = result.content[:-2] if result.content[-2:] == '\n\n' else result.content
224
- print(f'\033[1m\033[96m{result.name}\033[0m: {content}')
225
- state["new_messages"] = 1
226
- state["messages"].append(result)
227
-
228
- return state
229
-
230
- def extract_data_from_file(self, state: AgentState) -> str:
231
- path = state["file_path"]
232
- new_messages = state["new_messages"]
233
- prompt = ""
234
- messages = []
235
-
236
- if path and "." in path:
237
- ext = path.strip().split(".")[-1].lower()
238
- print(f"Extension detected: {ext}")
239
-
240
- if ext == "zip":
241
- files, prompt = get_all_files_from_zip(path)
242
- name = "get_all_file_from_zip"
243
- messages.append(AIMessage(content=prompt, name=name))
244
- else:
245
- files = [path]
246
-
247
- for file_path in files:
248
- ext = file_path.strip().split(".")[-1].lower()
249
- print(f"Extension detected: {ext}")
250
-
251
- prompt = f"Information extracted from {file_path}.\n\n"
252
- match ext:
253
- case "csv":
254
- content = get_information_from_csv.invoke(file_path)
255
- name = "get_information_from_csv"
256
- case "txt":
257
- content = get_information_from_txt.invoke(file_path)
258
- name = "get_information_from_txt"
259
- case "pdf":
260
- content = get_information_from_pdf.invoke(file_path)
261
- name = "get_information_from_pdf"
262
- case "json":
263
- content = get_information_from_json.invoke(file_path)
264
- name = "get_information_from_json"
265
- case "jsonld":
266
- content = get_information_from_json.invoke(file_path)
267
- name = "get_information_from_json"
268
- case "xml":
269
- content = get_information_from_xml.invoke(file_path)
270
- name = "get_information_from_xml"
271
- case "pdb":
272
- content = get_information_from_pdb.invoke(file_path)
273
- name = "get_information_from_pdb"
274
- case "mp3":
275
- content = get_information_from_audio.invoke(file_path)
276
- name = "get_information_from_audio"
277
- case "m4a":
278
- content = get_information_from_audio.invoke(file_path)
279
- name = "get_information_from_audio"
280
- case "docx":
281
- content = get_information_from_docx.invoke(file_path)
282
- name = "get_information_from_docx"
283
- case "xlsx":
284
- content = get_information_from_excel.invoke(file_path)
285
- name = "get_information_from_excel"
286
- case "xls":
287
- content = get_information_from_excel.invoke(file_path)
288
- name = "get_information_from_excel"
289
- case "png":
290
- content = get_information_from_image.invoke({"file_path": file_path, "question": state["question"]})
291
- name = "get_information_from_image"
292
- case "jpg":
293
- content = get_information_from_image.invoke({"file_path": file_path, "question": state["question"]})
294
- name = "get_information_from_image"
295
- case "py":
296
- content = get_information_from_python.invoke(file_path)
297
- name = "get_information_from_python"
298
- case "pptx":
299
- content = get_information_from_pptx.invoke(file_path)
300
- name = "get_information_from_pptx"
301
- case _:
302
- content = "Try to use some available tool to answer the user question."
303
- name = "handle_no_file"
304
- prompt += f"{content}"
305
- messages.append(AIMessage(content=prompt, name=name))
306
- new_messages += 1
307
- else:
308
- prompt = "The question doesn't have an attached file."
309
- name = "handle_no_file"
310
-
311
- return {"messages": messages, "new_messages": new_messages}
312
-
313
- def assistant_router(self, state: AgentState) -> str:
314
- tool_decision = tools_condition(state)
315
- if tool_decision == "tools":
316
- return "tools"
317
- else:
318
- return "postprocess"
319
-
320
- def postprocess(self, state: AgentState) -> AgentState:
321
- last_msg = state["messages"][-1]
322
- content = last_msg.content
323
- index = content.find("FINAL ANSWER: ")
324
- if index != -1:
325
- content = content[index+len("FINAL ANSWER: "):].replace("\n", "")
326
- state["final_answer"] = content
327
- return state
328
- else:
329
- state["attempt"] += 1
330
- prompt = f"""You were unable to find a satisfactory answer to the user's question.
331
- Now, try again, but use a different approach. You may:
332
- - Focus on a different angle of the question,
333
- - Reformulate it using alternative terminology,
334
- - Search for related concepts,
335
- - Or use a different reasoning path.
336
- Be creative and precise. Your goal is to uncover useful information that may have been missed previously.
337
- Original question:
338
- {state["question"]}"""
339
-
340
- state["messages"].append(AIMessage(content=prompt, name="ASSISTANT"))
341
- return state
342
- def answer_evaluation(self, state: AgentState):
343
- if state["final_answer"] != "":
344
- return "END"
345
- elif state["attempt"] >= 3:
346
- state["final_answer"] = "Unable to find the answer."
347
- return "END"
348
- else:
349
- return "RETRY"
350
 
351
- def draw_graph(self):
352
- display(Image(self.agent.get_graph().draw_mermaid_png()))
353
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ from typing import Optional
3
+
4
+ from smolagents import (
5
+ CodeAgent,
6
+ DuckDuckGoSearchTool,
7
+ InferenceClientModel,
8
+ VisitWebpageTool,
9
+ )
10
+
11
+ from tools import describe_image, transcribe_mp3, extract_data
12
+
13
+
14
+ class SmolAgent:
15
+ def __init__(self):
16
+ print("SmolAgent initialized.")
17
+ # Initialize a simple CodeAgent with a DuckDuckGo search tool
18
+ self.search_tool = DuckDuckGoSearchTool()
19
+ self.visit_web_tool = VisitWebpageTool()
20
+
21
+ self.agent = CodeAgent(
22
+ name="SmolAgent",
23
+ description="An agent that can solve GAIA challenges using web search and code execution.",
24
+ tools=[
25
+ self.search_tool,
26
+ self.visit_web_tool,
27
+ describe_image,
28
+ transcribe_mp3,
29
+ extract_data,
30
+ ],
31
+ add_base_tools=True,
32
+ model=InferenceClientModel(), # or another available model
33
+ additional_authorized_imports=["requests", "json", "pandas", "numpy"],
34
+ max_steps=5,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ def run(self, question: str, file_path: Optional[str] = None) -> str:
38
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
39
+ # Use the CodeAgent to answer the question
40
+
41
+ file_prompt = ""
42
+ if file_path:
43
+ file_prompt = f"You can find the provided fiel at {file_path}"
44
+
45
+ prompt = f"""
46
+ You are a general AI assistant. I will ask you a question. And I want you to reply with just your final answer.
47
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
48
+ If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
49
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
50
+ If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
51
+ Question: {question}
52
+ {file_prompt}
53
+ """
54
+ answer = self.agent.run(prompt)
55
+ print(f"Agent returning answer: {answer}")
56
+ return answer