| from smolagents import Tool |
| import requests |
| from urllib.parse import urljoin |
| import base64 |
| import tempfile |
|
|
|
|
| class GetAttachmentTool(Tool): |
| name = "get_attachment" |
| description = """Retrieves attachment for current task in specified format.""" |
| inputs = { |
| "fmt": { |
| "type": "string", |
| "description": "Format to retrieve attachment. Options are: URL (preferred), DATA_URL, LOCAL_FILE_PATH, TEXT. URL returns the URL of the file, DATA_URL returns a base64 encoded data URL, LOCAL_FILE_PATH returns a local file path to the downloaded file, and TEXT returns the content of the file as text.", |
| "nullable": True, |
| "default": "URL", |
| } |
| } |
| output_type = "string" |
|
|
| def __init__( |
| self, |
| agent_evaluation_api: str | None = None, |
| task_id: str | None = None, |
| **kwargs, |
| ): |
| self.agent_evaluation_api = ( |
| agent_evaluation_api |
| if agent_evaluation_api is not None |
| else "https://agents-course-unit4-scoring.hf.space/" |
| ) |
| self.task_id = task_id |
| super().__init__(**kwargs) |
|
|
| def attachment_for(self, task_id: str | None): |
| self.task_id = task_id |
|
|
| def forward(self, fmt: str = "URL") -> str: |
| fmt = fmt.upper() |
| assert fmt in ["URL", "DATA_URL", "LOCAL_FILE_PATH", "TEXT"] |
|
|
| if not self.task_id: |
| return "" |
|
|
| file_url = urljoin(self.agent_evaluation_api, f"files/{self.task_id}") |
| if fmt == "URL": |
| return file_url |
|
|
| response = requests.get( |
| file_url, |
| headers={ |
| "Content-Type": "application/json", |
| "Accept": "application/json", |
| }, |
| ) |
| if 400 <= response.status_code < 500: |
| return "" |
|
|
| response.raise_for_status() |
| mime = response.headers.get("content-type", "text/plain") |
| if fmt == "TEXT": |
| if mime.startswith("text/"): |
| return response.text |
| else: |
| raise ValueError( |
| f"Content of file type {mime} cannot be retrieved as TEXT." |
| ) |
| elif fmt == "DATA_URL": |
| return f"data:{mime};base64,{base64.b64encode(response.content).decode('utf-8')}" |
| elif fmt == "LOCAL_FILE_PATH": |
| with tempfile.NamedTemporaryFile(delete=False) as tmp_file: |
| tmp_file.write(response.content) |
| return tmp_file.name |
| else: |
| raise ValueError( |
| f"Unsupported format: {fmt}. Supported formats are URL, DATA_URL, LOCAL_FILE_PATH, and TEXT." |
| ) |
|
|