Gabandino commited on
Commit
e552e13
·
verified ·
1 Parent(s): 15c7bda

Fix quotes tpyo in get_attachments_tool for final submission

Browse files
Files changed (1) hide show
  1. tools/get_attachments_tool.py +75 -75
tools/get_attachments_tool.py CHANGED
@@ -1,76 +1,76 @@
1
- from smolagents import Tool
2
- import requests
3
- from urllib.parse import urljoin
4
- import base64
5
- import tempfile
6
-
7
- class GetAttachmentTool(Tool):
8
- name = "get_attachment"
9
- description = """
10
- Format to retrieve attachment. Options are: URL (preferred), DATA_URL, LOCAL_FILE_PATH, TEXT. URL returns the URL of the file, DATA_URL returns a base64 encoded data URL, LOCAL_FILE_PATH returns a local file path to the downloaded file, and TEXT returns the content of the file as text.
11
- """
12
- inputs = {
13
- "fmt": {
14
- "type": "string",
15
- "description": """Format to retrieve attachment. Options are: URL, DATA_URL, LOCAL_FILE_PATH (preferred for current testing environment), TEXT. URL returns the URL of the file, DATA_URL returns a base64 encoded data URL, LOCAL_FILE_PATH returns a local file path to the downloaded file, and TEXT returns the content of the file as text.""",
16
- "nullable": True,
17
- "default": "URL",
18
- }
19
- }
20
- output_type = "string"
21
-
22
- def __init__(
23
- self,
24
- agent_evaluation_api: str | None = None,
25
- task_id: str | None = None,
26
- **kwargs,
27
- ):
28
- # Default to Hugging Face GAIA testing space
29
- self.agent_evaluation_api = (
30
- agent_evaluation_api
31
- if agent_evaluation_api is not None
32
- else "https://agents-course-unit4-scoring.hf.space/"
33
- )
34
- self.task_id = task_id
35
- super().__init__(**kwargs)
36
-
37
- def attachment_for(self, task_id: str| None):
38
- self.task_id = task_id
39
-
40
- def forward(self, fmt: str = "URL") -> str:
41
- # Ensure the format is uppercase for comparison
42
- fmt = fmt.upper()
43
- assert fmt in ["URL", "DATA_URL", "LOCAL_FILE_PATH", "TEXT"]
44
-
45
- if not self.task_id:
46
- return "No task_id provided to retrieve attachment."
47
-
48
- file_url = urljoin(self.agent_evaluation_api, f"files/{self.task_id}")
49
- if fmt == "URL":
50
- return file_url
51
-
52
- response = requests.get(
53
- file_url,
54
- headers={
55
- "Content-Type": "application/json",
56
- "Accept": "application/json",
57
- },
58
- )
59
- if 400 <= response.status_code < 500:
60
- raise ValueError(f"Error fetching file: {response.status_code} {response.reason}")
61
-
62
- response.raise_for_status()
63
- mime = response.headers.get("content-type", "text/plain")
64
- if fmt == "TEXT":
65
- if mime.startswith("text/"):
66
- return response.text
67
- else:
68
- raise ValueError(f"Content of file type {mime} cannot be retrieved as TEXT")
69
- elif fmt == "DATA_URL":
70
- return f"data:{mime};base64,{base64.b64encode(response.content).decode("utf-8")}"
71
- elif fmt == "LOCAL_FILE_PATH":
72
- with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
73
- tmp_file.write(response.content)
74
- return tmp_file.name
75
- else:
76
  raise ValueError(f"Unsupported format: {fmt}. Supported formats are URL, DATA_URL, LOCAL_FILEPATH, and TEXT.")
 
1
+ from smolagents import Tool
2
+ import requests
3
+ from urllib.parse import urljoin
4
+ import base64
5
+ import tempfile
6
+
7
+ class GetAttachmentTool(Tool):
8
+ name = "get_attachment"
9
+ description = """
10
+ Format to retrieve attachment. Options are: URL (preferred), DATA_URL, LOCAL_FILE_PATH, TEXT. URL returns the URL of the file, DATA_URL returns a base64 encoded data URL, LOCAL_FILE_PATH returns a local file path to the downloaded file, and TEXT returns the content of the file as text.
11
+ """
12
+ inputs = {
13
+ "fmt": {
14
+ "type": "string",
15
+ "description": """Format to retrieve attachment. Options are: URL, DATA_URL, LOCAL_FILE_PATH (preferred for current testing environment), TEXT. URL returns the URL of the file, DATA_URL returns a base64 encoded data URL, LOCAL_FILE_PATH returns a local file path to the downloaded file, and TEXT returns the content of the file as text.""",
16
+ "nullable": True,
17
+ "default": "URL",
18
+ }
19
+ }
20
+ output_type = "string"
21
+
22
+ def __init__(
23
+ self,
24
+ agent_evaluation_api: str | None = None,
25
+ task_id: str | None = None,
26
+ **kwargs,
27
+ ):
28
+ # Default to Hugging Face GAIA testing space
29
+ self.agent_evaluation_api = (
30
+ agent_evaluation_api
31
+ if agent_evaluation_api is not None
32
+ else "https://agents-course-unit4-scoring.hf.space/"
33
+ )
34
+ self.task_id = task_id
35
+ super().__init__(**kwargs)
36
+
37
+ def attachment_for(self, task_id: str| None):
38
+ self.task_id = task_id
39
+
40
+ def forward(self, fmt: str = "URL") -> str:
41
+ # Ensure the format is uppercase for comparison
42
+ fmt = fmt.upper()
43
+ assert fmt in ["URL", "DATA_URL", "LOCAL_FILE_PATH", "TEXT"]
44
+
45
+ if not self.task_id:
46
+ return "No task_id provided to retrieve attachment."
47
+
48
+ file_url = urljoin(self.agent_evaluation_api, f"files/{self.task_id}")
49
+ if fmt == "URL":
50
+ return file_url
51
+
52
+ response = requests.get(
53
+ file_url,
54
+ headers={
55
+ "Content-Type": "application/json",
56
+ "Accept": "application/json",
57
+ },
58
+ )
59
+ if 400 <= response.status_code < 500:
60
+ raise ValueError(f"Error fetching file: {response.status_code} {response.reason}")
61
+
62
+ response.raise_for_status()
63
+ mime = response.headers.get("content-type", "text/plain")
64
+ if fmt == "TEXT":
65
+ if mime.startswith("text/"):
66
+ return response.text
67
+ else:
68
+ raise ValueError(f"Content of file type {mime} cannot be retrieved as TEXT")
69
+ elif fmt == "DATA_URL":
70
+ return f"data:{mime};base64,{base64.b64encode(response.content).decode('utf-8')}"
71
+ elif fmt == "LOCAL_FILE_PATH":
72
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
73
+ tmp_file.write(response.content)
74
+ return tmp_file.name
75
+ else:
76
  raise ValueError(f"Unsupported format: {fmt}. Supported formats are URL, DATA_URL, LOCAL_FILEPATH, and TEXT.")