Spaces:
Sleeping
Sleeping
| """GAIA Unit 4 agent: tool-calling loop via Groq, OpenAI, or Hugging Face Inference.""" | |
| from __future__ import annotations | |
| import os | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Optional | |
| from answer_normalize import normalize_answer | |
| from inference_client_factory import inference_client_kwargs | |
| from llm_backends import ( | |
| chat_complete_openai, | |
| detect_llm_backend, | |
| groq_chat_model, | |
| hf_chat_model, | |
| make_openai_sdk_client, | |
| openai_chat_model, | |
| ) | |
| from tools.media_tools import transcribe_audio | |
| from tools.registry import TOOL_DEFINITIONS, deterministic_attempt, dispatch_tool | |
| try: | |
| from huggingface_hub import InferenceClient | |
| except ImportError: | |
| InferenceClient = None # type: ignore | |
| SYSTEM_PROMPT = """You solve GAIA benchmark questions for the Hugging Face Agents Course. | |
| Hard rules: | |
| - Call tools as needed (search, Wikipedia, fetch URL, Python, audio, image, Excel). | |
| - Your final assistant message must contain ONLY the answer text required by the question — no labels like "FINAL ANSWER", no markdown fences, no extra sentences, no preamble. | |
| - Never type fake tool calls such as <web_search>...</function>; the platform invokes tools for you. If you need search, emit a real tool call via the API, not XML-like text in the reply. | |
| - When the user message includes an attachment path: for audio, a transcript may already be inlined — use it. For images (png/jpg), call analyze_image with that exact file_path. For .xlsx/.py use the appropriate tools with that path. | |
| - Match the question's format exactly: comma-separated lists alphabetized when asked; numbers without commas/thousands separators and without $ or % unless the question asks; short strings without leading articles (a/the); city names spelled out as requested; algebraic chess notation when asked. | |
| - If the question asks for a number (how many, highest number, etc.), reply with digits only — no words, no "Based on the video", no trailing period. | |
| - If the question asks what someone said in a video, reply with the spoken line only (include punctuation as in the source), not "Character says …" and not the question text repeated. | |
| - For English Wikipedia tasks, use wikipedia_* tools; for promotion dates, Featured Article logs, or table rows, use wikipedia_wikitext on the relevant page and read the wikitext. | |
| - For YouTube URLs, use youtube_transcript first; if it fails, use web_search with the video title or URL before stopping. | |
| - Never write meta-commentary in the final message (no "I cannot", "unfortunately", "the provided summary does not"). Keep calling tools until you have the fact. | |
| - Never paste tool traces in the final message (no lines like wikipedia_search: or fetch_url:). | |
| - Do not invent facts when tools return empty or ambiguous results. | |
| """ | |
| def _tool_char_cap(backend: str, *, shrink_pass: int = 0) -> int: | |
| if backend == "groq": | |
| # Free-tier Groq often rejects ~6k TPM per request; keep tool payloads small. | |
| base = int(os.environ.get("GAIA_GROQ_MAX_TOOL_CHARS", "1400")) | |
| elif backend == "openai": | |
| base = int(os.environ.get("GAIA_OPENAI_MAX_TOOL_CHARS", "12000")) | |
| else: | |
| base = int(os.environ.get("GAIA_MAX_TOOL_CHARS", "24000")) | |
| if shrink_pass > 0: | |
| base = max(280, base // (2**shrink_pass)) | |
| return base | |
| def _groq_context_budget() -> int: | |
| return int(os.environ.get("GAIA_GROQ_CONTEXT_CHARS", "12000")) | |
| def _maybe_retryable_llm_error(exc: Exception) -> bool: | |
| es = str(exc).lower() | |
| return ( | |
| "413" in es | |
| or "429" in es | |
| or "rate_limit" in es | |
| or "tokens per minute" in es | |
| or "tpm" in es | |
| or "too many tokens" in es | |
| ) | |
| def _truncate_tool_messages( | |
| messages: list[dict[str, Any]], | |
| backend: str, | |
| *, | |
| shrink_pass: int = 0, | |
| ) -> None: | |
| cap = _tool_char_cap(backend, shrink_pass=shrink_pass) | |
| for m in messages: | |
| if m.get("role") != "tool": | |
| continue | |
| c = m.get("content") | |
| if isinstance(c, str) and len(c) > cap: | |
| m["content"] = c[:cap] + "\n[truncated]" | |
| def _groq_message_chars(m: dict[str, Any]) -> int: | |
| n = len(str(m.get("content") or "")) | |
| tc = m.get("tool_calls") | |
| if tc: | |
| n += len(str(tc)) | |
| return n | |
| def _drop_oldest_tool_round(messages: list[dict[str, Any]]) -> bool: | |
| """Remove the earliest assistant+tool_calls block and its tool replies.""" | |
| i = 2 | |
| while i < len(messages): | |
| if messages[i].get("role") == "assistant" and messages[i].get("tool_calls"): | |
| del messages[i] | |
| while i < len(messages) and messages[i].get("role") == "tool": | |
| del messages[i] | |
| return True | |
| i += 1 | |
| return False | |
| def _enforce_context_budget(messages: list[dict[str, Any]], backend: str) -> None: | |
| if backend != "groq": | |
| return | |
| budget = _groq_context_budget() | |
| for _ in range(40): | |
| total = sum(_groq_message_chars(m) for m in messages) | |
| if total <= budget: | |
| return | |
| if _drop_oldest_tool_round(messages): | |
| continue | |
| trimmed = False | |
| for m in messages[2:]: | |
| if m.get("role") != "tool": | |
| continue | |
| c = m.get("content") | |
| if isinstance(c, str) and len(c) > 400: | |
| m["content"] = c[: max(400, len(c) * 2 // 3)] + "\n[truncated]" | |
| trimmed = True | |
| break | |
| if not trimmed: | |
| break | |
| class GaiaAgent: | |
| def __init__( | |
| self, | |
| *, | |
| hf_token: Optional[str] = None, | |
| text_model: Optional[str] = None, | |
| max_iterations: int = 12, | |
| ): | |
| self.hf_token = ( | |
| hf_token | |
| or os.environ.get("HF_TOKEN") | |
| or os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| ) | |
| self.backend = detect_llm_backend() | |
| if self.backend == "groq": | |
| self.text_model = text_model or groq_chat_model() | |
| self._oa_client, _ = make_openai_sdk_client("groq") | |
| self._hf_client = None | |
| elif self.backend == "openai": | |
| self.text_model = text_model or openai_chat_model() | |
| self._oa_client, _ = make_openai_sdk_client("openai") | |
| self._hf_client = None | |
| else: | |
| self.text_model = text_model or hf_chat_model() | |
| self._oa_client = None | |
| self._hf_client: Optional[InferenceClient] = None | |
| self.max_iterations = max_iterations | |
| def _get_hf_client(self) -> InferenceClient: | |
| if InferenceClient is None: | |
| raise RuntimeError("huggingface_hub is not installed.") | |
| if self._hf_client is None: | |
| if not self.hf_token: | |
| raise RuntimeError( | |
| "HF_TOKEN or HUGGINGFACEHUB_API_TOKEN is required when using " | |
| "Hugging Face Inference (no GROQ_API_KEY / OPENAI_API_KEY set)." | |
| ) | |
| kw = inference_client_kwargs(self.hf_token) | |
| self._hf_client = InferenceClient(**kw) | |
| return self._hf_client | |
| def _chat_round( | |
| self, | |
| messages: list[dict[str, Any]], | |
| *, | |
| shrink_pass: int = 0, | |
| ) -> Any: | |
| _truncate_tool_messages(messages, self.backend, shrink_pass=shrink_pass) | |
| _enforce_context_budget(messages, self.backend) | |
| if self.backend in ("groq", "openai"): | |
| assert self._oa_client is not None | |
| mt = ( | |
| int(os.environ.get("GAIA_GROQ_MAX_TOKENS", "384")) | |
| if self.backend == "groq" | |
| else int(os.environ.get("GAIA_OPENAI_MAX_TOKENS", "768")) | |
| ) | |
| return chat_complete_openai( | |
| self._oa_client, | |
| model=self.text_model, | |
| messages=messages, | |
| tools=TOOL_DEFINITIONS, | |
| max_tokens=mt, | |
| temperature=0.15, | |
| ) | |
| client = self._get_hf_client() | |
| return client.chat_completion( | |
| messages=messages, | |
| model=self.text_model, | |
| tools=TOOL_DEFINITIONS, | |
| tool_choice="auto", | |
| max_tokens=1024, | |
| temperature=0.15, | |
| ) | |
| def __call__( | |
| self, | |
| question: str, | |
| attachment_path: Optional[str] = None, | |
| task_id: Optional[str] = None, | |
| ) -> str: | |
| det = deterministic_attempt(question, attachment_path, task_id=task_id) | |
| if det is not None: | |
| return normalize_answer(det) | |
| if self.backend == "hf" and not self.hf_token: | |
| return normalize_answer("", context_question=question) | |
| user_text = _build_user_payload(question, attachment_path, task_id) | |
| user_text += _maybe_inline_audio_transcript( | |
| attachment_path, self.hf_token, backend=self.backend | |
| ) | |
| messages: list[dict[str, Any]] = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_text}, | |
| ] | |
| last_text = "" | |
| # Extra delays so Groq free-tier TPM / oversized-request errors can retry after shrink. | |
| retry_delays = (2.0, 4.0, 8.0, 14.0, 22.0) | |
| for _ in range(self.max_iterations): | |
| completion = None | |
| shrink_pass = 0 | |
| for attempt in range(len(retry_delays) + 1): | |
| try: | |
| completion = self._chat_round(messages, shrink_pass=shrink_pass) | |
| break | |
| except Exception as e: | |
| es = str(e) | |
| if "402" in es or "Payment Required" in es or "depleted" in es.lower(): | |
| # Do not submit error prose as an answer (exact-match grading). | |
| return normalize_answer("", context_question=question) | |
| if attempt < len(retry_delays) and _maybe_retryable_llm_error(e): | |
| shrink_pass = attempt + 1 | |
| time.sleep(retry_delays[attempt]) | |
| continue | |
| if "402" in str(e) or "payment required" in str(e).lower(): | |
| return normalize_answer("", context_question=question) | |
| if _maybe_retryable_llm_error(e): | |
| return normalize_answer("", context_question=question) | |
| return normalize_answer( | |
| f"Inference error: {e}", context_question=question | |
| ) | |
| msg = completion.choices[0].message | |
| last_text = (msg.content or "").strip() | |
| tool_calls = getattr(msg, "tool_calls", None) | |
| if tool_calls: | |
| cap = _tool_char_cap(self.backend, shrink_pass=0) | |
| messages.append( | |
| { | |
| "role": "assistant", | |
| "content": msg.content if msg.content else None, | |
| "tool_calls": [ | |
| { | |
| "id": tc.id, | |
| "type": "function", | |
| "function": { | |
| "name": tc.function.name, | |
| "arguments": tc.function.arguments or "{}", | |
| }, | |
| } | |
| for tc in tool_calls | |
| ], | |
| } | |
| ) | |
| for tc in tool_calls: | |
| name = tc.function.name | |
| args = tc.function.arguments or "{}" | |
| result = dispatch_tool(name, args, hf_token=self.hf_token) | |
| if isinstance(result, str) and len(result) > cap: | |
| result = result[:cap] + "\n[truncated]" | |
| messages.append( | |
| { | |
| "role": "tool", | |
| "tool_call_id": tc.id, | |
| "content": result, | |
| } | |
| ) | |
| continue | |
| if last_text: | |
| break | |
| fr = getattr(completion.choices[0], "finish_reason", None) | |
| if fr == "length": | |
| last_text = "Error: model hit max length without an answer." | |
| break | |
| return normalize_answer(last_text or "", context_question=question) | |
| def _build_user_payload( | |
| question: str, | |
| attachment_path: Optional[str], | |
| task_id: Optional[str], | |
| ) -> str: | |
| parts = [] | |
| if task_id: | |
| parts.append(f"task_id: {task_id}") | |
| parts.append(f"Question:\n{question.strip()}") | |
| if attachment_path: | |
| p = Path(attachment_path) | |
| parts.append( | |
| f"\nAttachment path (pass this exact string to tools): {attachment_path}" | |
| ) | |
| if p.is_file(): | |
| parts.append(f"Attachment exists on disk: yes ({p.name})") | |
| else: | |
| parts.append("Attachment exists on disk: NO — report that you cannot read it.") | |
| else: | |
| parts.append("\nNo attachment.") | |
| return "\n".join(parts) | |
| def _maybe_inline_audio_transcript( | |
| attachment_path: Optional[str], | |
| hf_token: Optional[str], | |
| *, | |
| backend: str = "hf", | |
| ) -> str: | |
| if not attachment_path: | |
| return "" | |
| p = Path(attachment_path) | |
| if not p.is_file(): | |
| return "" | |
| ext = p.suffix.lower() | |
| if ext not in (".mp3", ".wav", ".m4a", ".ogg", ".flac", ".webm"): | |
| return "" | |
| tx = transcribe_audio(str(p), hf_token=hf_token) | |
| if not tx or tx.lower().startswith(("error", "asr error")): | |
| return f"\n\n[Automatic transcription failed: {tx[:500]}]\n" | |
| cap = int(os.environ.get("GAIA_AUTO_TRANSCRIPT_CHARS", "8000")) | |
| if backend == "groq": | |
| cap = min( | |
| cap, | |
| int(os.environ.get("GAIA_GROQ_AUTO_TRANSCRIPT_CHARS", "3600")), | |
| ) | |
| return f"\n\n[Audio transcript — use for your answer]\n{tx[:cap]}\n" | |