"""GAIA Unit 4 agent: tool-calling loop via Hugging Face Inference API.""" from __future__ import annotations import os from typing import Any, Optional from huggingface_hub import InferenceClient from answer_normalize import normalize_answer from inference_client_factory import inference_client_kwargs from tools.registry import TOOL_DEFINITIONS, deterministic_attempt, dispatch_tool SYSTEM_PROMPT = """You solve GAIA benchmark questions for the Hugging Face Agents Course. Hard rules: - Call tools as needed (search, Wikipedia, fetch URL, Python, audio, image, Excel). - Your final assistant message must contain ONLY the answer text required by the question — no labels like "FINAL ANSWER", no markdown fences, no extra sentences. - Match the question's format exactly (comma-separated, alphabetical order, IOC codes, algebraic notation, two-decimal USD, first name only, etc.). - When a local attachment path is given, use the appropriate tool with that exact path. - For English Wikipedia tasks, use wikipedia_* tools; cross-check with web_search if needed. - For YouTube URLs in the question, try youtube_transcript first. """ class GaiaAgent: def __init__( self, *, hf_token: Optional[str] = None, text_model: Optional[str] = None, max_iterations: int = 14, ): self.hf_token = ( hf_token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") ) self.text_model = text_model or os.environ.get( "GAIA_TEXT_MODEL", "Qwen/Qwen2.5-7B-Instruct" ) self.max_iterations = max_iterations self._client: Optional[InferenceClient] = None def _get_client(self) -> InferenceClient: if self._client is None: if not self.hf_token: raise RuntimeError( "HF_TOKEN or HUGGINGFACEHUB_API_TOKEN is required for GaiaAgent." ) kw = inference_client_kwargs(self.hf_token) self._client = InferenceClient(**kw) return self._client def __call__( self, question: str, attachment_path: Optional[str] = None, task_id: Optional[str] = None, ) -> str: det = deterministic_attempt(question, attachment_path) if det is not None: return normalize_answer(det) if not self.hf_token: return normalize_answer( "Error: missing HF_TOKEN; cannot run LLM tools for this question." ) user_text = _build_user_payload(question, attachment_path, task_id) messages: list[dict[str, Any]] = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_text}, ] client = self._get_client() last_text = "" for _ in range(self.max_iterations): try: completion = client.chat_completion( messages=messages, model=self.text_model, tools=TOOL_DEFINITIONS, tool_choice="auto", max_tokens=1024, temperature=0.15, ) except Exception as e: last_text = f"Inference error: {e}" break choice = completion.choices[0] msg = choice.message last_text = (msg.content or "").strip() if msg.tool_calls: messages.append( { "role": "assistant", "content": msg.content if msg.content else None, "tool_calls": [ { "id": tc.id, "type": "function", "function": { "name": tc.function.name, "arguments": tc.function.arguments, }, } for tc in msg.tool_calls ], } ) for tc in msg.tool_calls: name = tc.function.name args = tc.function.arguments or "{}" result = dispatch_tool(name, args, hf_token=self.hf_token) messages.append( { "role": "tool", "tool_call_id": tc.id, "content": result[:24_000], } ) continue if last_text: break if choice.finish_reason == "length": last_text = "Error: model hit max length without an answer." break return normalize_answer(last_text or "Error: empty response.") def _build_user_payload( question: str, attachment_path: Optional[str], task_id: Optional[str], ) -> str: parts = [] if task_id: parts.append(f"task_id: {task_id}") parts.append(f"Question:\n{question.strip()}") if attachment_path: parts.append(f"\nAttachment path (use with tools): {attachment_path}") else: parts.append("\nNo attachment.") return "\n".join(parts)