File size: 3,640 Bytes
f660630
 
9eb4567
 
 
 
f660630
 
b1194ab
f660630
 
9eb4567
 
f660630
 
 
 
 
 
b1194ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f660630
 
 
9eb4567
 
f660630
 
 
 
 
 
b1194ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f660630
 
7272cb4
9705726
f660630
9705726
f660630
 
9705726
 
f660630
9705726
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import base64
import requests
try:
    from smolagents import Tool
except ImportError:
    Tool = object

DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
HF_DATASET_BASE = "https://huggingface.co/datasets/gaia-benchmark/GAIA/resolve/main/2023/test"

# Tool to download and read file attachments (text) from the scoring API
# Plain class (no Tool inheritance) — called directly, not via CodeAgent
class FileReaderTool:
    name = "file_reader"

    def __init__(self, api_url: str = DEFAULT_API_URL):
        self.api_url = api_url
        print("FileReaderTool initialized.")

    def __call__(self, task_id: str, file_name: str = "") -> str:
        # Try HF dataset URL first (actual file location), then scoring API fallback
        urls_to_try = []
        if file_name:
            urls_to_try.append(f"{HF_DATASET_BASE}/{file_name}")
        urls_to_try.append(f"{self.api_url}/files/{task_id}")

        for url in urls_to_try:
            print(f"FileReaderTool trying: {url}")
            try:
                resp = requests.get(url, timeout=30)
                if resp.status_code == 200:
                    content = resp.text
                    print(f"FileReaderTool downloaded {len(content)} chars.")
                    return content
            except requests.exceptions.RequestException:
                continue
        return f"Failed to download file for task {task_id}"


# Tool to download image attachments and return them as base64
# Plain class (no Tool inheritance) — called directly, not via CodeAgent
class ImageReaderTool:
    name = "image_reader"

    def __init__(self, api_url: str = DEFAULT_API_URL):
        self.api_url = api_url
        print("ImageReaderTool initialized.")

    def __call__(self, task_id: str, file_name: str = "") -> str:
        # Try HF dataset URL first (actual file location), then scoring API fallback
        urls_to_try = []
        if file_name:
            urls_to_try.append(f"{HF_DATASET_BASE}/{file_name}")
        urls_to_try.append(f"{self.api_url}/files/{task_id}")

        for url in urls_to_try:
            print(f"ImageReaderTool trying: {url}")
            try:
                resp = requests.get(url, timeout=30)
                if resp.status_code == 200:
                    content_type = resp.headers.get("Content-Type", "image/png")
                    image_b64 = base64.b64encode(resp.content).decode("utf-8")
                    print(f"ImageReaderTool downloaded image ({len(resp.content)} bytes, {content_type}).")
                    return f"data:{content_type};base64,{image_b64}"
            except requests.exceptions.RequestException:
                continue
        return f"Failed to download image for task {task_id}"


# Web search tool — uses ddgs directly to avoid smolagents DuckDuckGoSearchTool package check
class WebSearchTool:
    name = "web_search"

    def __init__(self):
        print("WebSearchTool initialized.")

    def __call__(self, query: str) -> str:
        print(f"WebSearchTool received query (first 50 chars): {query[:50]}...")
        try:
            from ddgs import DDGS
            with DDGS() as ddgs:
                results = list(ddgs.text(query, max_results=5))
            if not results:
                return "No results found."
            output = "\n\n".join(
                f"{r.get('title', '')}\n{r.get('href', '')}\n{r.get('body', '')}"
                for r in results
            )
        except Exception as e:
            output = f"Search error: {e}"
        print(f"WebSearchTool returning result (first 100 chars): {output[:100]}...")
        return output