FD900 commited on
Commit
ccbae19
·
verified ·
1 Parent(s): de1b9f0

Update tools/get_attachments_tool.py

Browse files
Files changed (1) hide show
  1. tools/get_attachments_tool.py +70 -0
tools/get_attachments_tool.py CHANGED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import Tool
2
+ import requests
3
+ import base64
4
+ import tempfile
5
+ from urllib.parse import urljoin
6
+
7
+
8
+ class GetAttachmentTool(Tool):
9
+ name = "get_attachment"
10
+ description = """
11
+ Downloads an attachment linked to the current GAIA task.
12
+ Supported formats for response are: URL, DATA_URL, LOCAL_FILE_PATH, or TEXT.
13
+ """
14
+
15
+ inputs = {
16
+ "fmt": {
17
+ "type": "string",
18
+ "description": "Choose response format: URL, DATA_URL, LOCAL_FILE_PATH, or TEXT.",
19
+ "nullable": True,
20
+ "default": "URL",
21
+ }
22
+ }
23
+ output_type = "string"
24
+
25
+ def __init__(self, evaluation_api: str | None = None, task_id: str | None = None, **kwargs):
26
+ # Default GAIA evaluation API endpoint
27
+ self.evaluation_api = evaluation_api or "https://agents-course-unit4-scoring.hf.space/"
28
+ self.task_id = task_id
29
+ super().__init__(**kwargs)
30
+
31
+ def attachment_for(self, task_id: str | None):
32
+ """Set the current GAIA task ID."""
33
+ self.task_id = task_id
34
+
35
+ def forward(self, fmt: str = "URL") -> str:
36
+ fmt = fmt.upper()
37
+ assert fmt in {"URL", "DATA_URL", "LOCAL_FILE_PATH", "TEXT"}
38
+
39
+ if not self.task_id:
40
+ return "No task_id provided to fetch the attachment."
41
+
42
+ file_url = urljoin(self.evaluation_api, f"files/{self.task_id}")
43
+
44
+ if fmt == "URL":
45
+ return file_url
46
+
47
+ # Download the file
48
+ resp = requests.get(file_url)
49
+ if 400 <= resp.status_code < 500:
50
+ raise ValueError(f"Failed to retrieve attachment: {resp.status_code} {resp.reason}")
51
+
52
+ resp.raise_for_status()
53
+ content_type = resp.headers.get("content-type", "text/plain")
54
+
55
+ if fmt == "TEXT":
56
+ if content_type.startswith("text/"):
57
+ return resp.text
58
+ else:
59
+ raise ValueError(f"Cannot extract text from content-type: {content_type}")
60
+
61
+ if fmt == "DATA_URL":
62
+ b64 = base64.b64encode(resp.content).decode("utf-8")
63
+ return f"data:{content_type};base64,{b64}"
64
+
65
+ if fmt == "LOCAL_FILE_PATH":
66
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
67
+ tmp_file.write(resp.content)
68
+ return tmp_file.name
69
+
70
+ raise ValueError(f"Unsupported format: {fmt}")