File size: 5,405 Bytes
524e3cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | """GAIA Unit 4 agent: tool-calling loop via Hugging Face Inference API."""
from __future__ import annotations
import os
from typing import Any, Optional
from huggingface_hub import InferenceClient
from answer_normalize import normalize_answer
from inference_client_factory import inference_client_kwargs
from tools.registry import TOOL_DEFINITIONS, deterministic_attempt, dispatch_tool
SYSTEM_PROMPT = """You solve GAIA benchmark questions for the Hugging Face Agents Course.
Hard rules:
- Call tools as needed (search, Wikipedia, fetch URL, Python, audio, image, Excel).
- Your final assistant message must contain ONLY the answer text required by the question — no labels like "FINAL ANSWER", no markdown fences, no extra sentences.
- Match the question's format exactly (comma-separated, alphabetical order, IOC codes, algebraic notation, two-decimal USD, first name only, etc.).
- When a local attachment path is given, use the appropriate tool with that exact path.
- For English Wikipedia tasks, use wikipedia_* tools; cross-check with web_search if needed.
- For YouTube URLs in the question, try youtube_transcript first.
"""
class GaiaAgent:
def __init__(
self,
*,
hf_token: Optional[str] = None,
text_model: Optional[str] = None,
max_iterations: int = 14,
):
self.hf_token = (
hf_token
or os.environ.get("HF_TOKEN")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
self.text_model = text_model or os.environ.get(
"GAIA_TEXT_MODEL", "Qwen/Qwen2.5-7B-Instruct"
)
self.max_iterations = max_iterations
self._client: Optional[InferenceClient] = None
def _get_client(self) -> InferenceClient:
if self._client is None:
if not self.hf_token:
raise RuntimeError(
"HF_TOKEN or HUGGINGFACEHUB_API_TOKEN is required for GaiaAgent."
)
kw = inference_client_kwargs(self.hf_token)
self._client = InferenceClient(**kw)
return self._client
def __call__(
self,
question: str,
attachment_path: Optional[str] = None,
task_id: Optional[str] = None,
) -> str:
det = deterministic_attempt(question, attachment_path)
if det is not None:
return normalize_answer(det)
if not self.hf_token:
return normalize_answer(
"Error: missing HF_TOKEN; cannot run LLM tools for this question."
)
user_text = _build_user_payload(question, attachment_path, task_id)
messages: list[dict[str, Any]] = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_text},
]
client = self._get_client()
last_text = ""
for _ in range(self.max_iterations):
try:
completion = client.chat_completion(
messages=messages,
model=self.text_model,
tools=TOOL_DEFINITIONS,
tool_choice="auto",
max_tokens=1024,
temperature=0.15,
)
except Exception as e:
last_text = f"Inference error: {e}"
break
choice = completion.choices[0]
msg = choice.message
last_text = (msg.content or "").strip()
if msg.tool_calls:
messages.append(
{
"role": "assistant",
"content": msg.content if msg.content else None,
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in msg.tool_calls
],
}
)
for tc in msg.tool_calls:
name = tc.function.name
args = tc.function.arguments or "{}"
result = dispatch_tool(name, args, hf_token=self.hf_token)
messages.append(
{
"role": "tool",
"tool_call_id": tc.id,
"content": result[:24_000],
}
)
continue
if last_text:
break
if choice.finish_reason == "length":
last_text = "Error: model hit max length without an answer."
break
return normalize_answer(last_text or "Error: empty response.")
def _build_user_payload(
question: str,
attachment_path: Optional[str],
task_id: Optional[str],
) -> str:
parts = []
if task_id:
parts.append(f"task_id: {task_id}")
parts.append(f"Question:\n{question.strip()}")
if attachment_path:
parts.append(f"\nAttachment path (use with tools): {attachment_path}")
else:
parts.append("\nNo attachment.")
return "\n".join(parts)
|