Spaces:
Sleeping
Sleeping
| from tools.base_tool import BaseTool | |
| import requests | |
| import base64 | |
| import tempfile | |
| from urllib.parse import urljoin | |
| class GetAttachmentTool(BaseTool): | |
| name = "get_attachment" | |
| description = """ | |
| Downloads an attachment linked to the current GAIA task. | |
| Supported formats for response are: URL, DATA_URL, LOCAL_FILE_PATH, or TEXT. | |
| """ | |
| inputs = { | |
| "fmt": { | |
| "type": "string", | |
| "description": "Choose response format: URL, DATA_URL, LOCAL_FILE_PATH, or TEXT.", | |
| "nullable": True, | |
| "default": "URL", | |
| } | |
| } | |
| output_type = "string" | |
| def __init__(self, evaluation_api: str | None = None, task_id: str | None = None, **kwargs): | |
| # Default GAIA evaluation API endpoint | |
| self.evaluation_api = evaluation_api or "https://agents-course-unit4-scoring.hf.space/" | |
| self.task_id = task_id | |
| super().__init__(**kwargs) | |
| def attachment_for(self, task_id: str | None): | |
| """Set the current GAIA task ID.""" | |
| self.task_id = task_id | |
| def forward(self, fmt: str = "URL") -> str: | |
| fmt = fmt.upper() | |
| assert fmt in {"URL", "DATA_URL", "LOCAL_FILE_PATH", "TEXT"} | |
| if not self.task_id: | |
| return "No task_id provided to fetch the attachment." | |
| file_url = urljoin(self.evaluation_api, f"files/{self.task_id}") | |
| if fmt == "URL": | |
| return file_url | |
| # Download the file | |
| resp = requests.get(file_url) | |
| if 400 <= resp.status_code < 500: | |
| raise ValueError(f"Failed to retrieve attachment: {resp.status_code} {resp.reason}") | |
| resp.raise_for_status() | |
| content_type = resp.headers.get("content-type", "text/plain") | |
| if fmt == "TEXT": | |
| if content_type.startswith("text/"): | |
| return resp.text | |
| else: | |
| raise ValueError(f"Cannot extract text from content-type: {content_type}") | |
| if fmt == "DATA_URL": | |
| b64 = base64.b64encode(resp.content).decode("utf-8") | |
| return f"data:{content_type};base64,{b64}" | |
| if fmt == "LOCAL_FILE_PATH": | |
| with tempfile.NamedTemporaryFile(delete=False) as tmp_file: | |
| tmp_file.write(resp.content) | |
| return tmp_file.name | |
| raise ValueError(f"Unsupported format: {fmt}") |