vandenn commited on
Commit
655d939
·
1 Parent(s): 170a428

Separate out the file parsing and add image handling

Browse files
Files changed (2) hide show
  1. src/agent.py +78 -6
  2. src/tools.py +0 -62
src/agent.py CHANGED
@@ -1,7 +1,12 @@
 
1
  import random
2
  import time
3
- from typing import Any
 
4
 
 
 
 
5
  from smolagents import (
6
  DuckDuckGoSearchTool,
7
  LiteLLMModel,
@@ -12,10 +17,15 @@ from smolagents import (
12
  from smolagents.agents import FinalAnswerStep
13
 
14
  from src.settings import settings
15
- from src.tools import DownloadAndParseFileTool, FinalAnswerTool
16
  from src.utils import BaseAgent, InputTokenRateLimiter
17
 
18
 
 
 
 
 
 
19
  class GaiaAgent(BaseAgent):
20
  def __init__(self):
21
  self.model = LiteLLMModel(
@@ -25,7 +35,6 @@ class GaiaAgent(BaseAgent):
25
  tools=[
26
  DuckDuckGoSearchTool(max_results=3),
27
  VisitWebpageTool(max_output_length=20000),
28
- DownloadAndParseFileTool(),
29
  PythonInterpreterTool(),
30
  FinalAnswerTool(),
31
  # TODO: Image interpretation, MP3 interpretation
@@ -43,18 +52,24 @@ class GaiaAgent(BaseAgent):
43
  final_answer = None
44
  retry_count = 0
45
 
 
46
  input = f"""
47
- Answer the following QUESTION as concisely as possible. A necessary FILE may be provided as part of the context of the QUESTION.
 
48
  Make the shortest possible execution plan to answer this QUESTION.
49
 
50
  QUESTION: {question}
51
  FILE NAME: {file_name if file_name else "N/A"}
52
- FILE URL: {file_url if file_url else "N/A"}
53
  """
 
 
 
 
 
54
 
55
  while True:
56
  try:
57
- for step in self.agent.run(input, stream=True):
58
  self.token_rate_limiter.maybe_wait(self.expected_tokens_per_step)
59
  token_usage_info = getattr(step, "token_usage", None)
60
  tokens_used = 0
@@ -86,6 +101,63 @@ class GaiaAgent(BaseAgent):
86
 
87
  return final_answer
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  if __name__ == "__main__":
91
  agent = GaiaAgent()
 
1
+ import mimetypes
2
  import random
3
  import time
4
+ from io import BytesIO
5
+ from typing import Any, TypedDict
6
 
7
+ import pandas as pd
8
+ import requests
9
+ from PIL import Image
10
  from smolagents import (
11
  DuckDuckGoSearchTool,
12
  LiteLLMModel,
 
17
  from smolagents.agents import FinalAnswerStep
18
 
19
  from src.settings import settings
20
+ from src.tools import FinalAnswerTool
21
  from src.utils import BaseAgent, InputTokenRateLimiter
22
 
23
 
24
+ class ParsedFile(TypedDict):
25
+ text: str
26
+ image: Image
27
+
28
+
29
  class GaiaAgent(BaseAgent):
30
  def __init__(self):
31
  self.model = LiteLLMModel(
 
35
  tools=[
36
  DuckDuckGoSearchTool(max_results=3),
37
  VisitWebpageTool(max_output_length=20000),
 
38
  PythonInterpreterTool(),
39
  FinalAnswerTool(),
40
  # TODO: Image interpretation, MP3 interpretation
 
52
  final_answer = None
53
  retry_count = 0
54
 
55
+ parsed_file = self._parse_file(file_name, file_url)
56
  input = f"""
57
+ Answer the following QUESTION as concisely as possible.
58
+ If available, a FILE NAME and the actual FILE is attached for your reference.
59
  Make the shortest possible execution plan to answer this QUESTION.
60
 
61
  QUESTION: {question}
62
  FILE NAME: {file_name if file_name else "N/A"}
 
63
  """
64
+ if parsed_file["text"]:
65
+ input = input + f"\nFILE CONTENT: {parsed_file['text']}"
66
+ input_images = None
67
+ if parsed_file["image"]:
68
+ input_images = [parsed_file["image"]]
69
 
70
  while True:
71
  try:
72
+ for step in self.agent.run(input, images=input_images, stream=True):
73
  self.token_rate_limiter.maybe_wait(self.expected_tokens_per_step)
74
  token_usage_info = getattr(step, "token_usage", None)
75
  tokens_used = 0
 
101
 
102
  return final_answer
103
 
104
+ def _parse_file(self, file_name: str, file_url: str) -> ParsedFile:
105
+ result = ParsedFile(text=None, image=None)
106
+ if not file_name or not file_url:
107
+ return result
108
+
109
+ try:
110
+ response = requests.get(file_url)
111
+ response.raise_for_status()
112
+ except Exception as e:
113
+ print(f"Failed to download file: {e}")
114
+ return result
115
+
116
+ # Try to handle the 'no file' JSON case
117
+ try:
118
+ file_data = response.json()
119
+ if (
120
+ "detail" in file_data
121
+ and "No file path associated" in file_data["detail"]
122
+ ):
123
+ print(f"No file found for {file_name} at {file_url}")
124
+ return result
125
+ except Exception:
126
+ pass # Not JSON, so it's probably the file content
127
+
128
+ file_type, _ = mimetypes.guess_type(file_name)
129
+ if file_type and file_type.startswith("text"):
130
+ try:
131
+ result["text"] = response.content.decode("utf-8")
132
+ return result
133
+ except Exception:
134
+ return "Failed to decode text file as utf-8."
135
+ elif file_name.endswith(".py"):
136
+ try:
137
+ result["text"] = response.content.decode("utf-8")
138
+ return result
139
+ except Exception:
140
+ return "Failed to decode Python file as utf-8."
141
+ elif file_name.endswith(".xlsx"):
142
+ try:
143
+ df = pd.read_excel(BytesIO(response.content))
144
+ result["text"] = df.to_string()
145
+ return result
146
+ except Exception as e:
147
+ return f"Failed to parse Excel file: {e}"
148
+ elif file_type and file_type.startswith("image"):
149
+ try:
150
+ image = Image.open(BytesIO(response.content))
151
+ result["image"] = image
152
+ return result
153
+ except Exception as e:
154
+ return f"Failed to decode image file: {e}"
155
+ else:
156
+ print(
157
+ f"[{file_name} is a binary file of type {file_type or 'unknown'} and cannot be parsed as text.]"
158
+ )
159
+ return result
160
+
161
 
162
  if __name__ == "__main__":
163
  agent = GaiaAgent()
src/tools.py CHANGED
@@ -1,8 +1,3 @@
1
- import mimetypes
2
- from io import BytesIO
3
-
4
- import pandas as pd
5
- import requests
6
  from smolagents import LiteLLMModel
7
  from smolagents.tools import Tool
8
 
@@ -62,60 +57,3 @@ class FinalAnswerTool(Tool):
62
  tokens_used = token_usage_info.input_tokens
63
  self.token_rate_limiter.add_tokens(tokens_used)
64
  return response.content
65
-
66
-
67
- class DownloadAndParseFileTool(Tool):
68
- name = "download_and_parse_file"
69
- description = "Downloads a file from a given URL and parses it based on the file name. Returns the file content as text if possible, or nothing if image, etc."
70
- inputs = {
71
- "file_name": {
72
- "type": "string",
73
- "description": "The name of the file (used to determine type).",
74
- },
75
- "file_url": {
76
- "type": "string",
77
- "description": "The URL of the file to download.",
78
- },
79
- }
80
- output_type = "string"
81
-
82
- def __init__(self):
83
- self.is_initialized = True
84
-
85
- def forward(self, file_name: str, file_url: str) -> str:
86
- try:
87
- response = requests.get(file_url)
88
- response.raise_for_status()
89
- except Exception as e:
90
- return f"Failed to download file: {e}"
91
-
92
- # Try to handle the 'no file' JSON case
93
- try:
94
- file_data = response.json()
95
- if (
96
- "detail" in file_data
97
- and "No file path associated" in file_data["detail"]
98
- ):
99
- return f"No file found for {file_name} at {file_url}"
100
- except Exception:
101
- pass # Not JSON, so it's probably the file content
102
-
103
- file_type, _ = mimetypes.guess_type(file_name)
104
- if file_type and file_type.startswith("text"):
105
- try:
106
- return response.content.decode("utf-8")
107
- except Exception:
108
- return "Failed to decode text file as utf-8."
109
- elif file_name.endswith(".py"):
110
- try:
111
- return response.content.decode("utf-8")
112
- except Exception:
113
- return "Failed to decode Python file as utf-8."
114
- elif file_name.endswith(".xlsx"):
115
- try:
116
- df = pd.read_excel(BytesIO(response.content))
117
- return df.to_string()
118
- except Exception as e:
119
- return f"Failed to parse Excel file: {e}"
120
- else:
121
- return f"[{file_name} is a binary file of type {file_type or 'unknown'} and cannot be parsed as text.]"
 
 
 
 
 
 
1
  from smolagents import LiteLLMModel
2
  from smolagents.tools import Tool
3
 
 
57
  tokens_used = token_usage_info.input_tokens
58
  self.token_rate_limiter.add_tokens(tokens_used)
59
  return response.content