Gabandino's picture
Fix quotes tpyo in get_attachments_tool for final submission
e552e13 verified
from smolagents import Tool
import requests
from urllib.parse import urljoin
import base64
import tempfile
class GetAttachmentTool(Tool):
name = "get_attachment"
description = """
Format to retrieve attachment. Options are: URL (preferred), DATA_URL, LOCAL_FILE_PATH, TEXT. URL returns the URL of the file, DATA_URL returns a base64 encoded data URL, LOCAL_FILE_PATH returns a local file path to the downloaded file, and TEXT returns the content of the file as text.
"""
inputs = {
"fmt": {
"type": "string",
"description": """Format to retrieve attachment. Options are: URL, DATA_URL, LOCAL_FILE_PATH (preferred for current testing environment), TEXT. URL returns the URL of the file, DATA_URL returns a base64 encoded data URL, LOCAL_FILE_PATH returns a local file path to the downloaded file, and TEXT returns the content of the file as text.""",
"nullable": True,
"default": "URL",
}
}
output_type = "string"
def __init__(
self,
agent_evaluation_api: str | None = None,
task_id: str | None = None,
**kwargs,
):
# Default to Hugging Face GAIA testing space
self.agent_evaluation_api = (
agent_evaluation_api
if agent_evaluation_api is not None
else "https://agents-course-unit4-scoring.hf.space/"
)
self.task_id = task_id
super().__init__(**kwargs)
def attachment_for(self, task_id: str| None):
self.task_id = task_id
def forward(self, fmt: str = "URL") -> str:
# Ensure the format is uppercase for comparison
fmt = fmt.upper()
assert fmt in ["URL", "DATA_URL", "LOCAL_FILE_PATH", "TEXT"]
if not self.task_id:
return "No task_id provided to retrieve attachment."
file_url = urljoin(self.agent_evaluation_api, f"files/{self.task_id}")
if fmt == "URL":
return file_url
response = requests.get(
file_url,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
},
)
if 400 <= response.status_code < 500:
raise ValueError(f"Error fetching file: {response.status_code} {response.reason}")
response.raise_for_status()
mime = response.headers.get("content-type", "text/plain")
if fmt == "TEXT":
if mime.startswith("text/"):
return response.text
else:
raise ValueError(f"Content of file type {mime} cannot be retrieved as TEXT")
elif fmt == "DATA_URL":
return f"data:{mime};base64,{base64.b64encode(response.content).decode('utf-8')}"
elif fmt == "LOCAL_FILE_PATH":
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
tmp_file.write(response.content)
return tmp_file.name
else:
raise ValueError(f"Unsupported format: {fmt}. Supported formats are URL, DATA_URL, LOCAL_FILEPATH, and TEXT.")