Spaces:
Sleeping
Sleeping
File size: 13,907 Bytes
f11f984 ac299d5 772f123 ac299d5 f11f984 772f123 ac299d5 f11f984 ac299d5 a46999a 772f123 a46999a 9428cf6 a46999a ac299d5 772f123 2bf50d9 772f123 2bf50d9 772f123 2bf50d9 772f123 2bf50d9 772f123 2bf50d9 772f123 2bf50d9 772f123 2bf50d9 772f123 ac299d5 772f123 ac299d5 f11f984 ac299d5 f11f984 ac299d5 f11f984 ac299d5 f11f984 ac299d5 772f123 2bf50d9 772f123 2bf50d9 772f123 ac299d5 aead1d1 ac299d5 f11f984 9428cf6 ac299d5 2bf50d9 772f123 ac299d5 2bf50d9 ac299d5 772f123 9428cf6 772f123 9428cf6 2bf50d9 9428cf6 ac299d5 772f123 ac299d5 f11f984 ac299d5 f11f984 772f123 ac299d5 f11f984 ac299d5 f11f984 ac299d5 f11f984 ac299d5 772f123 ac299d5 772f123 ac299d5 f11f984 ac299d5 9428cf6 ac299d5 772f123 ac299d5 772f123 2bf50d9 772f123 2bf50d9 772f123 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 | """GAIA Unit 4 agent: tool-calling loop via Groq, OpenAI, or Hugging Face Inference."""
from __future__ import annotations
import os
import time
from pathlib import Path
from typing import Any, Optional
from answer_normalize import normalize_answer
from inference_client_factory import inference_client_kwargs
from llm_backends import (
chat_complete_openai,
detect_llm_backend,
groq_chat_model,
hf_chat_model,
make_openai_sdk_client,
openai_chat_model,
)
from tools.media_tools import transcribe_audio
from tools.registry import TOOL_DEFINITIONS, deterministic_attempt, dispatch_tool
try:
from huggingface_hub import InferenceClient
except ImportError:
InferenceClient = None # type: ignore
SYSTEM_PROMPT = """You solve GAIA benchmark questions for the Hugging Face Agents Course.
Hard rules:
- Call tools as needed (search, Wikipedia, fetch URL, Python, audio, image, Excel).
- Your final assistant message must contain ONLY the answer text required by the question — no labels like "FINAL ANSWER", no markdown fences, no extra sentences, no preamble.
- Never type fake tool calls such as <web_search>...</function>; the platform invokes tools for you. If you need search, emit a real tool call via the API, not XML-like text in the reply.
- When the user message includes an attachment path: for audio, a transcript may already be inlined — use it. For images (png/jpg), call analyze_image with that exact file_path. For .xlsx/.py use the appropriate tools with that path.
- Match the question's format exactly: comma-separated lists alphabetized when asked; numbers without commas/thousands separators and without $ or % unless the question asks; short strings without leading articles (a/the); city names spelled out as requested; algebraic chess notation when asked.
- If the question asks for a number (how many, highest number, etc.), reply with digits only — no words, no "Based on the video", no trailing period.
- If the question asks what someone said in a video, reply with the spoken line only (include punctuation as in the source), not "Character says …" and not the question text repeated.
- For English Wikipedia tasks, use wikipedia_* tools; for promotion dates, Featured Article logs, or table rows, use wikipedia_wikitext on the relevant page and read the wikitext.
- For YouTube URLs, use youtube_transcript first; if it fails, use web_search with the video title or URL before stopping.
- Never write meta-commentary in the final message (no "I cannot", "unfortunately", "the provided summary does not"). Keep calling tools until you have the fact.
- Never paste tool traces in the final message (no lines like wikipedia_search: or fetch_url:).
- Do not invent facts when tools return empty or ambiguous results.
"""
def _tool_char_cap(backend: str, *, shrink_pass: int = 0) -> int:
if backend == "groq":
# Free-tier Groq often rejects ~6k TPM per request; keep tool payloads small.
base = int(os.environ.get("GAIA_GROQ_MAX_TOOL_CHARS", "1400"))
elif backend == "openai":
base = int(os.environ.get("GAIA_OPENAI_MAX_TOOL_CHARS", "12000"))
else:
base = int(os.environ.get("GAIA_MAX_TOOL_CHARS", "24000"))
if shrink_pass > 0:
base = max(280, base // (2**shrink_pass))
return base
def _groq_context_budget() -> int:
return int(os.environ.get("GAIA_GROQ_CONTEXT_CHARS", "12000"))
def _maybe_retryable_llm_error(exc: Exception) -> bool:
es = str(exc).lower()
return (
"413" in es
or "429" in es
or "rate_limit" in es
or "tokens per minute" in es
or "tpm" in es
or "too many tokens" in es
)
def _truncate_tool_messages(
messages: list[dict[str, Any]],
backend: str,
*,
shrink_pass: int = 0,
) -> None:
cap = _tool_char_cap(backend, shrink_pass=shrink_pass)
for m in messages:
if m.get("role") != "tool":
continue
c = m.get("content")
if isinstance(c, str) and len(c) > cap:
m["content"] = c[:cap] + "\n[truncated]"
def _groq_message_chars(m: dict[str, Any]) -> int:
n = len(str(m.get("content") or ""))
tc = m.get("tool_calls")
if tc:
n += len(str(tc))
return n
def _drop_oldest_tool_round(messages: list[dict[str, Any]]) -> bool:
"""Remove the earliest assistant+tool_calls block and its tool replies."""
i = 2
while i < len(messages):
if messages[i].get("role") == "assistant" and messages[i].get("tool_calls"):
del messages[i]
while i < len(messages) and messages[i].get("role") == "tool":
del messages[i]
return True
i += 1
return False
def _enforce_context_budget(messages: list[dict[str, Any]], backend: str) -> None:
if backend != "groq":
return
budget = _groq_context_budget()
for _ in range(40):
total = sum(_groq_message_chars(m) for m in messages)
if total <= budget:
return
if _drop_oldest_tool_round(messages):
continue
trimmed = False
for m in messages[2:]:
if m.get("role") != "tool":
continue
c = m.get("content")
if isinstance(c, str) and len(c) > 400:
m["content"] = c[: max(400, len(c) * 2 // 3)] + "\n[truncated]"
trimmed = True
break
if not trimmed:
break
class GaiaAgent:
def __init__(
self,
*,
hf_token: Optional[str] = None,
text_model: Optional[str] = None,
max_iterations: int = 12,
):
self.hf_token = (
hf_token
or os.environ.get("HF_TOKEN")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
self.backend = detect_llm_backend()
if self.backend == "groq":
self.text_model = text_model or groq_chat_model()
self._oa_client, _ = make_openai_sdk_client("groq")
self._hf_client = None
elif self.backend == "openai":
self.text_model = text_model or openai_chat_model()
self._oa_client, _ = make_openai_sdk_client("openai")
self._hf_client = None
else:
self.text_model = text_model or hf_chat_model()
self._oa_client = None
self._hf_client: Optional[InferenceClient] = None
self.max_iterations = max_iterations
def _get_hf_client(self) -> InferenceClient:
if InferenceClient is None:
raise RuntimeError("huggingface_hub is not installed.")
if self._hf_client is None:
if not self.hf_token:
raise RuntimeError(
"HF_TOKEN or HUGGINGFACEHUB_API_TOKEN is required when using "
"Hugging Face Inference (no GROQ_API_KEY / OPENAI_API_KEY set)."
)
kw = inference_client_kwargs(self.hf_token)
self._hf_client = InferenceClient(**kw)
return self._hf_client
def _chat_round(
self,
messages: list[dict[str, Any]],
*,
shrink_pass: int = 0,
) -> Any:
_truncate_tool_messages(messages, self.backend, shrink_pass=shrink_pass)
_enforce_context_budget(messages, self.backend)
if self.backend in ("groq", "openai"):
assert self._oa_client is not None
mt = (
int(os.environ.get("GAIA_GROQ_MAX_TOKENS", "384"))
if self.backend == "groq"
else int(os.environ.get("GAIA_OPENAI_MAX_TOKENS", "768"))
)
return chat_complete_openai(
self._oa_client,
model=self.text_model,
messages=messages,
tools=TOOL_DEFINITIONS,
max_tokens=mt,
temperature=0.15,
)
client = self._get_hf_client()
return client.chat_completion(
messages=messages,
model=self.text_model,
tools=TOOL_DEFINITIONS,
tool_choice="auto",
max_tokens=1024,
temperature=0.15,
)
def __call__(
self,
question: str,
attachment_path: Optional[str] = None,
task_id: Optional[str] = None,
) -> str:
det = deterministic_attempt(question, attachment_path, task_id=task_id)
if det is not None:
return normalize_answer(det)
if self.backend == "hf" and not self.hf_token:
return normalize_answer("", context_question=question)
user_text = _build_user_payload(question, attachment_path, task_id)
user_text += _maybe_inline_audio_transcript(
attachment_path, self.hf_token, backend=self.backend
)
messages: list[dict[str, Any]] = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_text},
]
last_text = ""
# Extra delays so Groq free-tier TPM / oversized-request errors can retry after shrink.
retry_delays = (2.0, 4.0, 8.0, 14.0, 22.0)
for _ in range(self.max_iterations):
completion = None
shrink_pass = 0
for attempt in range(len(retry_delays) + 1):
try:
completion = self._chat_round(messages, shrink_pass=shrink_pass)
break
except Exception as e:
es = str(e)
if "402" in es or "Payment Required" in es or "depleted" in es.lower():
# Do not submit error prose as an answer (exact-match grading).
return normalize_answer("", context_question=question)
if attempt < len(retry_delays) and _maybe_retryable_llm_error(e):
shrink_pass = attempt + 1
time.sleep(retry_delays[attempt])
continue
if "402" in str(e) or "payment required" in str(e).lower():
return normalize_answer("", context_question=question)
if _maybe_retryable_llm_error(e):
return normalize_answer("", context_question=question)
return normalize_answer(
f"Inference error: {e}", context_question=question
)
msg = completion.choices[0].message
last_text = (msg.content or "").strip()
tool_calls = getattr(msg, "tool_calls", None)
if tool_calls:
cap = _tool_char_cap(self.backend, shrink_pass=0)
messages.append(
{
"role": "assistant",
"content": msg.content if msg.content else None,
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments or "{}",
},
}
for tc in tool_calls
],
}
)
for tc in tool_calls:
name = tc.function.name
args = tc.function.arguments or "{}"
result = dispatch_tool(name, args, hf_token=self.hf_token)
if isinstance(result, str) and len(result) > cap:
result = result[:cap] + "\n[truncated]"
messages.append(
{
"role": "tool",
"tool_call_id": tc.id,
"content": result,
}
)
continue
if last_text:
break
fr = getattr(completion.choices[0], "finish_reason", None)
if fr == "length":
last_text = "Error: model hit max length without an answer."
break
return normalize_answer(last_text or "", context_question=question)
def _build_user_payload(
question: str,
attachment_path: Optional[str],
task_id: Optional[str],
) -> str:
parts = []
if task_id:
parts.append(f"task_id: {task_id}")
parts.append(f"Question:\n{question.strip()}")
if attachment_path:
p = Path(attachment_path)
parts.append(
f"\nAttachment path (pass this exact string to tools): {attachment_path}"
)
if p.is_file():
parts.append(f"Attachment exists on disk: yes ({p.name})")
else:
parts.append("Attachment exists on disk: NO — report that you cannot read it.")
else:
parts.append("\nNo attachment.")
return "\n".join(parts)
def _maybe_inline_audio_transcript(
attachment_path: Optional[str],
hf_token: Optional[str],
*,
backend: str = "hf",
) -> str:
if not attachment_path:
return ""
p = Path(attachment_path)
if not p.is_file():
return ""
ext = p.suffix.lower()
if ext not in (".mp3", ".wav", ".m4a", ".ogg", ".flac", ".webm"):
return ""
tx = transcribe_audio(str(p), hf_token=hf_token)
if not tx or tx.lower().startswith(("error", "asr error")):
return f"\n\n[Automatic transcription failed: {tx[:500]}]\n"
cap = int(os.environ.get("GAIA_AUTO_TRANSCRIPT_CHARS", "8000"))
if backend == "groq":
cap = min(
cap,
int(os.environ.get("GAIA_GROQ_AUTO_TRANSCRIPT_CHARS", "3600")),
)
return f"\n\n[Audio transcript — use for your answer]\n{tx[:cap]}\n"
|