"""GAIA Unit 4 agent: tool-calling loop via Groq, OpenAI, or Hugging Face Inference.""" from __future__ import annotations import os import time from pathlib import Path from typing import Any, Optional from answer_normalize import normalize_answer from inference_client_factory import inference_client_kwargs from llm_backends import ( chat_complete_openai, detect_llm_backend, groq_chat_model, hf_chat_model, make_openai_sdk_client, openai_chat_model, ) from tools.media_tools import transcribe_audio from tools.registry import TOOL_DEFINITIONS, deterministic_attempt, dispatch_tool try: from huggingface_hub import InferenceClient except ImportError: InferenceClient = None # type: ignore SYSTEM_PROMPT = """You solve GAIA benchmark questions for the Hugging Face Agents Course. Hard rules: - Call tools as needed (search, Wikipedia, fetch URL, Python, audio, image, Excel). - Your final assistant message must contain ONLY the answer text required by the question — no labels like "FINAL ANSWER", no markdown fences, no extra sentences, no preamble. - Never type fake tool calls such as ...; the platform invokes tools for you. If you need search, emit a real tool call via the API, not XML-like text in the reply. - When the user message includes an attachment path: for audio, a transcript may already be inlined — use it. For images (png/jpg), call analyze_image with that exact file_path. For .xlsx/.py use the appropriate tools with that path. - Match the question's format exactly: comma-separated lists alphabetized when asked; numbers without commas/thousands separators and without $ or % unless the question asks; short strings without leading articles (a/the); city names spelled out as requested; algebraic chess notation when asked. - If the question asks for a number (how many, highest number, etc.), reply with digits only — no words, no "Based on the video", no trailing period. - If the question asks what someone said in a video, reply with the spoken line only (include punctuation as in the source), not "Character says …" and not the question text repeated. - For English Wikipedia tasks, use wikipedia_* tools; for promotion dates, Featured Article logs, or table rows, use wikipedia_wikitext on the relevant page and read the wikitext. - For YouTube URLs, use youtube_transcript first; if it fails, use web_search with the video title or URL before stopping. - Never write meta-commentary in the final message (no "I cannot", "unfortunately", "the provided summary does not"). Keep calling tools until you have the fact. - Never paste tool traces in the final message (no lines like wikipedia_search: or fetch_url:). - Do not invent facts when tools return empty or ambiguous results. """ def _tool_char_cap(backend: str, *, shrink_pass: int = 0) -> int: if backend == "groq": # Free-tier Groq often rejects ~6k TPM per request; keep tool payloads small. base = int(os.environ.get("GAIA_GROQ_MAX_TOOL_CHARS", "1400")) elif backend == "openai": base = int(os.environ.get("GAIA_OPENAI_MAX_TOOL_CHARS", "12000")) else: base = int(os.environ.get("GAIA_MAX_TOOL_CHARS", "24000")) if shrink_pass > 0: base = max(280, base // (2**shrink_pass)) return base def _groq_context_budget() -> int: return int(os.environ.get("GAIA_GROQ_CONTEXT_CHARS", "12000")) def _maybe_retryable_llm_error(exc: Exception) -> bool: es = str(exc).lower() return ( "413" in es or "429" in es or "rate_limit" in es or "tokens per minute" in es or "tpm" in es or "too many tokens" in es ) def _truncate_tool_messages( messages: list[dict[str, Any]], backend: str, *, shrink_pass: int = 0, ) -> None: cap = _tool_char_cap(backend, shrink_pass=shrink_pass) for m in messages: if m.get("role") != "tool": continue c = m.get("content") if isinstance(c, str) and len(c) > cap: m["content"] = c[:cap] + "\n[truncated]" def _groq_message_chars(m: dict[str, Any]) -> int: n = len(str(m.get("content") or "")) tc = m.get("tool_calls") if tc: n += len(str(tc)) return n def _drop_oldest_tool_round(messages: list[dict[str, Any]]) -> bool: """Remove the earliest assistant+tool_calls block and its tool replies.""" i = 2 while i < len(messages): if messages[i].get("role") == "assistant" and messages[i].get("tool_calls"): del messages[i] while i < len(messages) and messages[i].get("role") == "tool": del messages[i] return True i += 1 return False def _enforce_context_budget(messages: list[dict[str, Any]], backend: str) -> None: if backend != "groq": return budget = _groq_context_budget() for _ in range(40): total = sum(_groq_message_chars(m) for m in messages) if total <= budget: return if _drop_oldest_tool_round(messages): continue trimmed = False for m in messages[2:]: if m.get("role") != "tool": continue c = m.get("content") if isinstance(c, str) and len(c) > 400: m["content"] = c[: max(400, len(c) * 2 // 3)] + "\n[truncated]" trimmed = True break if not trimmed: break class GaiaAgent: def __init__( self, *, hf_token: Optional[str] = None, text_model: Optional[str] = None, max_iterations: int = 12, ): self.hf_token = ( hf_token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") ) self.backend = detect_llm_backend() if self.backend == "groq": self.text_model = text_model or groq_chat_model() self._oa_client, _ = make_openai_sdk_client("groq") self._hf_client = None elif self.backend == "openai": self.text_model = text_model or openai_chat_model() self._oa_client, _ = make_openai_sdk_client("openai") self._hf_client = None else: self.text_model = text_model or hf_chat_model() self._oa_client = None self._hf_client: Optional[InferenceClient] = None self.max_iterations = max_iterations def _get_hf_client(self) -> InferenceClient: if InferenceClient is None: raise RuntimeError("huggingface_hub is not installed.") if self._hf_client is None: if not self.hf_token: raise RuntimeError( "HF_TOKEN or HUGGINGFACEHUB_API_TOKEN is required when using " "Hugging Face Inference (no GROQ_API_KEY / OPENAI_API_KEY set)." ) kw = inference_client_kwargs(self.hf_token) self._hf_client = InferenceClient(**kw) return self._hf_client def _chat_round( self, messages: list[dict[str, Any]], *, shrink_pass: int = 0, ) -> Any: _truncate_tool_messages(messages, self.backend, shrink_pass=shrink_pass) _enforce_context_budget(messages, self.backend) if self.backend in ("groq", "openai"): assert self._oa_client is not None mt = ( int(os.environ.get("GAIA_GROQ_MAX_TOKENS", "384")) if self.backend == "groq" else int(os.environ.get("GAIA_OPENAI_MAX_TOKENS", "768")) ) return chat_complete_openai( self._oa_client, model=self.text_model, messages=messages, tools=TOOL_DEFINITIONS, max_tokens=mt, temperature=0.15, ) client = self._get_hf_client() return client.chat_completion( messages=messages, model=self.text_model, tools=TOOL_DEFINITIONS, tool_choice="auto", max_tokens=1024, temperature=0.15, ) def __call__( self, question: str, attachment_path: Optional[str] = None, task_id: Optional[str] = None, ) -> str: det = deterministic_attempt(question, attachment_path, task_id=task_id) if det is not None: return normalize_answer(det) if self.backend == "hf" and not self.hf_token: return normalize_answer("", context_question=question) user_text = _build_user_payload(question, attachment_path, task_id) user_text += _maybe_inline_audio_transcript( attachment_path, self.hf_token, backend=self.backend ) messages: list[dict[str, Any]] = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_text}, ] last_text = "" # Extra delays so Groq free-tier TPM / oversized-request errors can retry after shrink. retry_delays = (2.0, 4.0, 8.0, 14.0, 22.0) for _ in range(self.max_iterations): completion = None shrink_pass = 0 for attempt in range(len(retry_delays) + 1): try: completion = self._chat_round(messages, shrink_pass=shrink_pass) break except Exception as e: es = str(e) if "402" in es or "Payment Required" in es or "depleted" in es.lower(): # Do not submit error prose as an answer (exact-match grading). return normalize_answer("", context_question=question) if attempt < len(retry_delays) and _maybe_retryable_llm_error(e): shrink_pass = attempt + 1 time.sleep(retry_delays[attempt]) continue if "402" in str(e) or "payment required" in str(e).lower(): return normalize_answer("", context_question=question) if _maybe_retryable_llm_error(e): return normalize_answer("", context_question=question) return normalize_answer( f"Inference error: {e}", context_question=question ) msg = completion.choices[0].message last_text = (msg.content or "").strip() tool_calls = getattr(msg, "tool_calls", None) if tool_calls: cap = _tool_char_cap(self.backend, shrink_pass=0) messages.append( { "role": "assistant", "content": msg.content if msg.content else None, "tool_calls": [ { "id": tc.id, "type": "function", "function": { "name": tc.function.name, "arguments": tc.function.arguments or "{}", }, } for tc in tool_calls ], } ) for tc in tool_calls: name = tc.function.name args = tc.function.arguments or "{}" result = dispatch_tool(name, args, hf_token=self.hf_token) if isinstance(result, str) and len(result) > cap: result = result[:cap] + "\n[truncated]" messages.append( { "role": "tool", "tool_call_id": tc.id, "content": result, } ) continue if last_text: break fr = getattr(completion.choices[0], "finish_reason", None) if fr == "length": last_text = "Error: model hit max length without an answer." break return normalize_answer(last_text or "", context_question=question) def _build_user_payload( question: str, attachment_path: Optional[str], task_id: Optional[str], ) -> str: parts = [] if task_id: parts.append(f"task_id: {task_id}") parts.append(f"Question:\n{question.strip()}") if attachment_path: p = Path(attachment_path) parts.append( f"\nAttachment path (pass this exact string to tools): {attachment_path}" ) if p.is_file(): parts.append(f"Attachment exists on disk: yes ({p.name})") else: parts.append("Attachment exists on disk: NO — report that you cannot read it.") else: parts.append("\nNo attachment.") return "\n".join(parts) def _maybe_inline_audio_transcript( attachment_path: Optional[str], hf_token: Optional[str], *, backend: str = "hf", ) -> str: if not attachment_path: return "" p = Path(attachment_path) if not p.is_file(): return "" ext = p.suffix.lower() if ext not in (".mp3", ".wav", ".m4a", ".ogg", ".flac", ".webm"): return "" tx = transcribe_audio(str(p), hf_token=hf_token) if not tx or tx.lower().startswith(("error", "asr error")): return f"\n\n[Automatic transcription failed: {tx[:500]}]\n" cap = int(os.environ.get("GAIA_AUTO_TRANSCRIPT_CHARS", "8000")) if backend == "groq": cap = min( cap, int(os.environ.get("GAIA_GROQ_AUTO_TRANSCRIPT_CHARS", "3600")), ) return f"\n\n[Audio transcript — use for your answer]\n{tx[:cap]}\n"