File size: 3,761 Bytes
18eddf0
 
f9a89db
18eddf0
 
 
f9a89db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18eddf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f0aae0
18eddf0
 
 
9f0aae0
18eddf0
 
 
 
 
 
 
 
9f0aae0
18eddf0
f9a89db
18eddf0
 
f9a89db
18eddf0
 
 
 
 
 
 
 
 
f9a89db
 
85c5d51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18eddf0
 
85c5d51
 
18eddf0
85c5d51
18eddf0
 
 
 
f9a89db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""Tool factory for the agent."""

import os

from exa_py import Exa
from smolagents import Tool
from scripts.text_inspector_tool import TextInspectorTool
from scripts.text_web_browser import (
    ArchiveSearchTool,
    FinderTool,
    FindNextTool,
    PageDownTool,
    PageUpTool,
    SimpleTextBrowser,
    VisitTool,
)
from scripts.visual_qa import visualizer

from config import BROWSER_CONFIG, TEXT_LIMIT


class ExaSearchTool(Tool):
    name = "web_search"
    description = (
        "Perform a web search and return relevant results with key excerpts. "
        "Use a natural language query. Optionally filter by year."
    )
    inputs = {
        "query": {"type": "string", "description": "The natural language search query."},
        "filter_year": {
            "type": "string",
            "description": "[Optional]: Filter results to a specific year, e.g. '2024'.",
            "nullable": True,
        },
    }
    output_type = "string"

    def __init__(self, api_key: str):
        super().__init__()
        self.exa = Exa(api_key=api_key)

    def forward(self, query: str, filter_year: str = None) -> str:
        kwargs = {
            "num_results": 10,
            "contents": {"highlights": True},
        }
        if filter_year:
            kwargs["start_published_date"] = f"{filter_year}-01-01"
            kwargs["end_published_date"] = f"{filter_year}-12-31"

        results = self.exa.search(query, **kwargs)

        if not results.results:
            return f"No results found for '{query}'."

        lines = [f"Exa search for '{query}' returned {len(results.results)} results:\n"]
        for i, r in enumerate(results.results, 1):
            date = f"\nDate published: {r.published_date}" if r.published_date else ""
            highlights = ""
            if r.highlights:
                highlights = "\n" + "\n".join(f"  • {h}" for h in r.highlights)
            lines.append(f"{i}. [{r.title}]({r.url}){date}{highlights}")

        return "\n\n".join(lines)


class UnavailableWebSearchTool(Tool):
    name = "web_search"
    description = (
        "Web search is currently unavailable because EXA_API_KEY is not set. "
        "This tool returns a helpful message instead of raising an error."
    )
    inputs = {
        "query": {"type": "string", "description": "The natural language search query."},
        "filter_year": {
            "type": "string",
            "description": "[Optional]: Filter results to a specific year, e.g. '2024'.",
            "nullable": True,
        },
    }
    output_type = "string"

    def forward(self, query: str, filter_year: str = None) -> str:
        year_note = f" for {filter_year}" if filter_year else ""
        return (
            f"Web search{year_note} is unavailable because EXA_API_KEY is not set. "
            "Set EXA_API_KEY in your environment or Space secrets to enable search."
        )


def _create_web_search_tool() -> Tool:
    api_key = os.getenv("EXA_API_KEY")
    if not api_key:
        print(
            "Warning: EXA_API_KEY not set. Web search will be disabled until it is configured."
        )
        return UnavailableWebSearchTool()
    return ExaSearchTool(api_key=api_key)


def build_tools(model):
    web_search = _create_web_search_tool()
    browser = SimpleTextBrowser(**BROWSER_CONFIG)
    ti_tool = TextInspectorTool(model, TEXT_LIMIT)

    web_tools = [
        web_search,
        VisitTool(browser),
        PageUpTool(browser),
        PageDownTool(browser),
        FinderTool(browser),
        FindNextTool(browser),
        ArchiveSearchTool(browser),
        ti_tool,
    ]

    document_inspection_tool = TextInspectorTool(model, TEXT_LIMIT)

    return web_tools, document_inspection_tool, visualizer