File size: 5,405 Bytes
524e3cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""GAIA Unit 4 agent: tool-calling loop via Hugging Face Inference API."""

from __future__ import annotations

import os
from typing import Any, Optional

from huggingface_hub import InferenceClient

from answer_normalize import normalize_answer
from inference_client_factory import inference_client_kwargs
from tools.registry import TOOL_DEFINITIONS, deterministic_attempt, dispatch_tool

SYSTEM_PROMPT = """You solve GAIA benchmark questions for the Hugging Face Agents Course.

Hard rules:
- Call tools as needed (search, Wikipedia, fetch URL, Python, audio, image, Excel).
- Your final assistant message must contain ONLY the answer text required by the question — no labels like "FINAL ANSWER", no markdown fences, no extra sentences.
- Match the question's format exactly (comma-separated, alphabetical order, IOC codes, algebraic notation, two-decimal USD, first name only, etc.).
- When a local attachment path is given, use the appropriate tool with that exact path.
- For English Wikipedia tasks, use wikipedia_* tools; cross-check with web_search if needed.
- For YouTube URLs in the question, try youtube_transcript first.
"""


class GaiaAgent:
    def __init__(
        self,
        *,
        hf_token: Optional[str] = None,
        text_model: Optional[str] = None,
        max_iterations: int = 14,
    ):
        self.hf_token = (
            hf_token
            or os.environ.get("HF_TOKEN")
            or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
        )
        self.text_model = text_model or os.environ.get(
            "GAIA_TEXT_MODEL", "Qwen/Qwen2.5-7B-Instruct"
        )
        self.max_iterations = max_iterations
        self._client: Optional[InferenceClient] = None

    def _get_client(self) -> InferenceClient:
        if self._client is None:
            if not self.hf_token:
                raise RuntimeError(
                    "HF_TOKEN or HUGGINGFACEHUB_API_TOKEN is required for GaiaAgent."
                )
            kw = inference_client_kwargs(self.hf_token)
            self._client = InferenceClient(**kw)
        return self._client

    def __call__(
        self,
        question: str,
        attachment_path: Optional[str] = None,
        task_id: Optional[str] = None,
    ) -> str:
        det = deterministic_attempt(question, attachment_path)
        if det is not None:
            return normalize_answer(det)

        if not self.hf_token:
            return normalize_answer(
                "Error: missing HF_TOKEN; cannot run LLM tools for this question."
            )

        user_text = _build_user_payload(question, attachment_path, task_id)
        messages: list[dict[str, Any]] = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_text},
        ]

        client = self._get_client()
        last_text = ""

        for _ in range(self.max_iterations):
            try:
                completion = client.chat_completion(
                    messages=messages,
                    model=self.text_model,
                    tools=TOOL_DEFINITIONS,
                    tool_choice="auto",
                    max_tokens=1024,
                    temperature=0.15,
                )
            except Exception as e:
                last_text = f"Inference error: {e}"
                break

            choice = completion.choices[0]
            msg = choice.message
            last_text = (msg.content or "").strip()

            if msg.tool_calls:
                messages.append(
                    {
                        "role": "assistant",
                        "content": msg.content if msg.content else None,
                        "tool_calls": [
                            {
                                "id": tc.id,
                                "type": "function",
                                "function": {
                                    "name": tc.function.name,
                                    "arguments": tc.function.arguments,
                                },
                            }
                            for tc in msg.tool_calls
                        ],
                    }
                )
                for tc in msg.tool_calls:
                    name = tc.function.name
                    args = tc.function.arguments or "{}"
                    result = dispatch_tool(name, args, hf_token=self.hf_token)
                    messages.append(
                        {
                            "role": "tool",
                            "tool_call_id": tc.id,
                            "content": result[:24_000],
                        }
                    )
                continue

            if last_text:
                break

            if choice.finish_reason == "length":
                last_text = "Error: model hit max length without an answer."
                break

        return normalize_answer(last_text or "Error: empty response.")


def _build_user_payload(
    question: str,
    attachment_path: Optional[str],
    task_id: Optional[str],
) -> str:
    parts = []
    if task_id:
        parts.append(f"task_id: {task_id}")
    parts.append(f"Question:\n{question.strip()}")
    if attachment_path:
        parts.append(f"\nAttachment path (use with tools): {attachment_path}")
    else:
        parts.append("\nNo attachment.")
    return "\n".join(parts)