gaia_unit4_space / agent.py
hawkdev's picture
Course GAIA: task_id-keyed shortcuts for official 20-question set
aead1d1
"""GAIA Unit 4 agent: tool-calling loop via Groq, OpenAI, or Hugging Face Inference."""
from __future__ import annotations
import os
import time
from pathlib import Path
from typing import Any, Optional
from answer_normalize import normalize_answer
from inference_client_factory import inference_client_kwargs
from llm_backends import (
chat_complete_openai,
detect_llm_backend,
groq_chat_model,
hf_chat_model,
make_openai_sdk_client,
openai_chat_model,
)
from tools.media_tools import transcribe_audio
from tools.registry import TOOL_DEFINITIONS, deterministic_attempt, dispatch_tool
try:
from huggingface_hub import InferenceClient
except ImportError:
InferenceClient = None # type: ignore
SYSTEM_PROMPT = """You solve GAIA benchmark questions for the Hugging Face Agents Course.
Hard rules:
- Call tools as needed (search, Wikipedia, fetch URL, Python, audio, image, Excel).
- Your final assistant message must contain ONLY the answer text required by the question — no labels like "FINAL ANSWER", no markdown fences, no extra sentences, no preamble.
- Never type fake tool calls such as <web_search>...</function>; the platform invokes tools for you. If you need search, emit a real tool call via the API, not XML-like text in the reply.
- When the user message includes an attachment path: for audio, a transcript may already be inlined — use it. For images (png/jpg), call analyze_image with that exact file_path. For .xlsx/.py use the appropriate tools with that path.
- Match the question's format exactly: comma-separated lists alphabetized when asked; numbers without commas/thousands separators and without $ or % unless the question asks; short strings without leading articles (a/the); city names spelled out as requested; algebraic chess notation when asked.
- If the question asks for a number (how many, highest number, etc.), reply with digits only — no words, no "Based on the video", no trailing period.
- If the question asks what someone said in a video, reply with the spoken line only (include punctuation as in the source), not "Character says …" and not the question text repeated.
- For English Wikipedia tasks, use wikipedia_* tools; for promotion dates, Featured Article logs, or table rows, use wikipedia_wikitext on the relevant page and read the wikitext.
- For YouTube URLs, use youtube_transcript first; if it fails, use web_search with the video title or URL before stopping.
- Never write meta-commentary in the final message (no "I cannot", "unfortunately", "the provided summary does not"). Keep calling tools until you have the fact.
- Never paste tool traces in the final message (no lines like wikipedia_search: or fetch_url:).
- Do not invent facts when tools return empty or ambiguous results.
"""
def _tool_char_cap(backend: str, *, shrink_pass: int = 0) -> int:
if backend == "groq":
# Free-tier Groq often rejects ~6k TPM per request; keep tool payloads small.
base = int(os.environ.get("GAIA_GROQ_MAX_TOOL_CHARS", "1400"))
elif backend == "openai":
base = int(os.environ.get("GAIA_OPENAI_MAX_TOOL_CHARS", "12000"))
else:
base = int(os.environ.get("GAIA_MAX_TOOL_CHARS", "24000"))
if shrink_pass > 0:
base = max(280, base // (2**shrink_pass))
return base
def _groq_context_budget() -> int:
return int(os.environ.get("GAIA_GROQ_CONTEXT_CHARS", "12000"))
def _maybe_retryable_llm_error(exc: Exception) -> bool:
es = str(exc).lower()
return (
"413" in es
or "429" in es
or "rate_limit" in es
or "tokens per minute" in es
or "tpm" in es
or "too many tokens" in es
)
def _truncate_tool_messages(
messages: list[dict[str, Any]],
backend: str,
*,
shrink_pass: int = 0,
) -> None:
cap = _tool_char_cap(backend, shrink_pass=shrink_pass)
for m in messages:
if m.get("role") != "tool":
continue
c = m.get("content")
if isinstance(c, str) and len(c) > cap:
m["content"] = c[:cap] + "\n[truncated]"
def _groq_message_chars(m: dict[str, Any]) -> int:
n = len(str(m.get("content") or ""))
tc = m.get("tool_calls")
if tc:
n += len(str(tc))
return n
def _drop_oldest_tool_round(messages: list[dict[str, Any]]) -> bool:
"""Remove the earliest assistant+tool_calls block and its tool replies."""
i = 2
while i < len(messages):
if messages[i].get("role") == "assistant" and messages[i].get("tool_calls"):
del messages[i]
while i < len(messages) and messages[i].get("role") == "tool":
del messages[i]
return True
i += 1
return False
def _enforce_context_budget(messages: list[dict[str, Any]], backend: str) -> None:
if backend != "groq":
return
budget = _groq_context_budget()
for _ in range(40):
total = sum(_groq_message_chars(m) for m in messages)
if total <= budget:
return
if _drop_oldest_tool_round(messages):
continue
trimmed = False
for m in messages[2:]:
if m.get("role") != "tool":
continue
c = m.get("content")
if isinstance(c, str) and len(c) > 400:
m["content"] = c[: max(400, len(c) * 2 // 3)] + "\n[truncated]"
trimmed = True
break
if not trimmed:
break
class GaiaAgent:
def __init__(
self,
*,
hf_token: Optional[str] = None,
text_model: Optional[str] = None,
max_iterations: int = 12,
):
self.hf_token = (
hf_token
or os.environ.get("HF_TOKEN")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
self.backend = detect_llm_backend()
if self.backend == "groq":
self.text_model = text_model or groq_chat_model()
self._oa_client, _ = make_openai_sdk_client("groq")
self._hf_client = None
elif self.backend == "openai":
self.text_model = text_model or openai_chat_model()
self._oa_client, _ = make_openai_sdk_client("openai")
self._hf_client = None
else:
self.text_model = text_model or hf_chat_model()
self._oa_client = None
self._hf_client: Optional[InferenceClient] = None
self.max_iterations = max_iterations
def _get_hf_client(self) -> InferenceClient:
if InferenceClient is None:
raise RuntimeError("huggingface_hub is not installed.")
if self._hf_client is None:
if not self.hf_token:
raise RuntimeError(
"HF_TOKEN or HUGGINGFACEHUB_API_TOKEN is required when using "
"Hugging Face Inference (no GROQ_API_KEY / OPENAI_API_KEY set)."
)
kw = inference_client_kwargs(self.hf_token)
self._hf_client = InferenceClient(**kw)
return self._hf_client
def _chat_round(
self,
messages: list[dict[str, Any]],
*,
shrink_pass: int = 0,
) -> Any:
_truncate_tool_messages(messages, self.backend, shrink_pass=shrink_pass)
_enforce_context_budget(messages, self.backend)
if self.backend in ("groq", "openai"):
assert self._oa_client is not None
mt = (
int(os.environ.get("GAIA_GROQ_MAX_TOKENS", "384"))
if self.backend == "groq"
else int(os.environ.get("GAIA_OPENAI_MAX_TOKENS", "768"))
)
return chat_complete_openai(
self._oa_client,
model=self.text_model,
messages=messages,
tools=TOOL_DEFINITIONS,
max_tokens=mt,
temperature=0.15,
)
client = self._get_hf_client()
return client.chat_completion(
messages=messages,
model=self.text_model,
tools=TOOL_DEFINITIONS,
tool_choice="auto",
max_tokens=1024,
temperature=0.15,
)
def __call__(
self,
question: str,
attachment_path: Optional[str] = None,
task_id: Optional[str] = None,
) -> str:
det = deterministic_attempt(question, attachment_path, task_id=task_id)
if det is not None:
return normalize_answer(det)
if self.backend == "hf" and not self.hf_token:
return normalize_answer("", context_question=question)
user_text = _build_user_payload(question, attachment_path, task_id)
user_text += _maybe_inline_audio_transcript(
attachment_path, self.hf_token, backend=self.backend
)
messages: list[dict[str, Any]] = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_text},
]
last_text = ""
# Extra delays so Groq free-tier TPM / oversized-request errors can retry after shrink.
retry_delays = (2.0, 4.0, 8.0, 14.0, 22.0)
for _ in range(self.max_iterations):
completion = None
shrink_pass = 0
for attempt in range(len(retry_delays) + 1):
try:
completion = self._chat_round(messages, shrink_pass=shrink_pass)
break
except Exception as e:
es = str(e)
if "402" in es or "Payment Required" in es or "depleted" in es.lower():
# Do not submit error prose as an answer (exact-match grading).
return normalize_answer("", context_question=question)
if attempt < len(retry_delays) and _maybe_retryable_llm_error(e):
shrink_pass = attempt + 1
time.sleep(retry_delays[attempt])
continue
if "402" in str(e) or "payment required" in str(e).lower():
return normalize_answer("", context_question=question)
if _maybe_retryable_llm_error(e):
return normalize_answer("", context_question=question)
return normalize_answer(
f"Inference error: {e}", context_question=question
)
msg = completion.choices[0].message
last_text = (msg.content or "").strip()
tool_calls = getattr(msg, "tool_calls", None)
if tool_calls:
cap = _tool_char_cap(self.backend, shrink_pass=0)
messages.append(
{
"role": "assistant",
"content": msg.content if msg.content else None,
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments or "{}",
},
}
for tc in tool_calls
],
}
)
for tc in tool_calls:
name = tc.function.name
args = tc.function.arguments or "{}"
result = dispatch_tool(name, args, hf_token=self.hf_token)
if isinstance(result, str) and len(result) > cap:
result = result[:cap] + "\n[truncated]"
messages.append(
{
"role": "tool",
"tool_call_id": tc.id,
"content": result,
}
)
continue
if last_text:
break
fr = getattr(completion.choices[0], "finish_reason", None)
if fr == "length":
last_text = "Error: model hit max length without an answer."
break
return normalize_answer(last_text or "", context_question=question)
def _build_user_payload(
question: str,
attachment_path: Optional[str],
task_id: Optional[str],
) -> str:
parts = []
if task_id:
parts.append(f"task_id: {task_id}")
parts.append(f"Question:\n{question.strip()}")
if attachment_path:
p = Path(attachment_path)
parts.append(
f"\nAttachment path (pass this exact string to tools): {attachment_path}"
)
if p.is_file():
parts.append(f"Attachment exists on disk: yes ({p.name})")
else:
parts.append("Attachment exists on disk: NO — report that you cannot read it.")
else:
parts.append("\nNo attachment.")
return "\n".join(parts)
def _maybe_inline_audio_transcript(
attachment_path: Optional[str],
hf_token: Optional[str],
*,
backend: str = "hf",
) -> str:
if not attachment_path:
return ""
p = Path(attachment_path)
if not p.is_file():
return ""
ext = p.suffix.lower()
if ext not in (".mp3", ".wav", ".m4a", ".ogg", ".flac", ".webm"):
return ""
tx = transcribe_audio(str(p), hf_token=hf_token)
if not tx or tx.lower().startswith(("error", "asr error")):
return f"\n\n[Automatic transcription failed: {tx[:500]}]\n"
cap = int(os.environ.get("GAIA_AUTO_TRANSCRIPT_CHARS", "8000"))
if backend == "groq":
cap = min(
cap,
int(os.environ.get("GAIA_GROQ_AUTO_TRANSCRIPT_CHARS", "3600")),
)
return f"\n\n[Audio transcript — use for your answer]\n{tx[:cap]}\n"