hawkdev commited on
Commit
2bf50d9
·
1 Parent(s): 9035818

Groq: smaller requests and retries for free-tier 413/TPM

Browse files

- Lower default tool/context caps; count tool_calls in context budget
- Drop oldest tool rounds when over Groq char budget
- Shorter inlined audio for groq (GAIA_GROQ_AUTO_TRANSCRIPT_CHARS)
- GAIA_GROQ_MAX_TOKENS default 384; more retry sleeps with stronger shrink
- Return empty after exhausted rate/size errors; strip any Inference error: answers

Made-with: Cursor

Files changed (3) hide show
  1. README.md +1 -1
  2. agent.py +52 -11
  3. answer_normalize.py +4 -1
README.md CHANGED
@@ -37,7 +37,7 @@ This folder is a **drop-in replacement** for the course Space
37
  - `GAIA_VISION_MODEL` — HF-only default `meta-llama/Llama-3.2-11B-Vision-Instruct`
38
  - `GAIA_API_URL` — default `https://agents-course-unit4-scoring.hf.space`
39
  - `GAIA_USE_CACHE` — `1` (default) or `0` to disable `gaia_answers_cache.json` (set **`0`** once after changing the agent so old wrong answers are not resubmitted).
40
- - **Groq free-tier TPM (413 “request too large”)**: the agent truncates tool outputs and total context. Tune with `GAIA_GROQ_MAX_TOOL_CHARS` (default `3200`), `GAIA_GROQ_CONTEXT_CHARS` (default `26000`), and `GAIA_AUTO_TRANSCRIPT_CHARS` (default `12000` for inlined MP3 transcripts).
41
 
42
  Keep the Space **public** so `agent_code` (`…/tree/main`) verifies for the leaderboard.
43
 
 
37
  - `GAIA_VISION_MODEL` — HF-only default `meta-llama/Llama-3.2-11B-Vision-Instruct`
38
  - `GAIA_API_URL` — default `https://agents-course-unit4-scoring.hf.space`
39
  - `GAIA_USE_CACHE` — `1` (default) or `0` to disable `gaia_answers_cache.json` (set **`0`** once after changing the agent so old wrong answers are not resubmitted).
40
+ - **Groq free-tier TPM / 413 “request too large”**: defaults are conservative (`GAIA_GROQ_MAX_TOOL_CHARS` `1400`, `GAIA_GROQ_CONTEXT_CHARS` `12000`, `GAIA_GROQ_MAX_TOKENS` `384`, `GAIA_AUTO_TRANSCRIPT_CHARS` `8000`, `GAIA_GROQ_AUTO_TRANSCRIPT_CHARS` `3600` for inlined MP3 text). Increase only if you have higher Groq limits. After changing the agent, set `GAIA_USE_CACHE=0` once so cached **Inference error** strings are not resubmitted.
41
 
42
  Keep the Space **public** so `agent_code` (`…/tree/main`) verifies for the leaderboard.
43
 
agent.py CHANGED
@@ -45,18 +45,19 @@ Hard rules:
45
 
46
  def _tool_char_cap(backend: str, *, shrink_pass: int = 0) -> int:
47
  if backend == "groq":
48
- base = int(os.environ.get("GAIA_GROQ_MAX_TOOL_CHARS", "3200"))
 
49
  elif backend == "openai":
50
  base = int(os.environ.get("GAIA_OPENAI_MAX_TOOL_CHARS", "12000"))
51
  else:
52
  base = int(os.environ.get("GAIA_MAX_TOOL_CHARS", "24000"))
53
  if shrink_pass > 0:
54
- base = max(600, base // (2**shrink_pass))
55
  return base
56
 
57
 
58
  def _groq_context_budget() -> int:
59
- return int(os.environ.get("GAIA_GROQ_CONTEXT_CHARS", "26000"))
60
 
61
 
62
  def _maybe_retryable_llm_error(exc: Exception) -> bool:
@@ -86,21 +87,44 @@ def _truncate_tool_messages(
86
  m["content"] = c[:cap] + "\n[truncated]"
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def _enforce_context_budget(messages: list[dict[str, Any]], backend: str) -> None:
90
  if backend != "groq":
91
  return
92
  budget = _groq_context_budget()
93
- for _ in range(24):
94
- total = sum(len(str(m.get("content") or "")) for m in messages)
95
  if total <= budget:
96
  return
 
 
97
  trimmed = False
98
  for m in messages[2:]:
99
  if m.get("role") != "tool":
100
  continue
101
  c = m.get("content")
102
- if isinstance(c, str) and len(c) > 800:
103
- m["content"] = c[: max(600, len(c) * 2 // 3)] + "\n[truncated]"
104
  trimmed = True
105
  break
106
  if not trimmed:
@@ -159,12 +183,17 @@ class GaiaAgent:
159
  _enforce_context_budget(messages, self.backend)
160
  if self.backend in ("groq", "openai"):
161
  assert self._oa_client is not None
 
 
 
 
 
162
  return chat_complete_openai(
163
  self._oa_client,
164
  model=self.text_model,
165
  messages=messages,
166
  tools=TOOL_DEFINITIONS,
167
- max_tokens=768,
168
  temperature=0.15,
169
  )
170
  client = self._get_hf_client()
@@ -191,7 +220,9 @@ class GaiaAgent:
191
  return normalize_answer("", context_question=question)
192
 
193
  user_text = _build_user_payload(question, attachment_path, task_id)
194
- user_text += _maybe_inline_audio_transcript(attachment_path, self.hf_token)
 
 
195
 
196
  messages: list[dict[str, Any]] = [
197
  {"role": "system", "content": SYSTEM_PROMPT},
@@ -199,7 +230,8 @@ class GaiaAgent:
199
  ]
200
 
201
  last_text = ""
202
- retry_delays = (2.0, 6.0, 14.0)
 
203
 
204
  for _ in range(self.max_iterations):
205
  completion = None
@@ -219,6 +251,8 @@ class GaiaAgent:
219
  continue
220
  if "402" in str(e) or "payment required" in str(e).lower():
221
  return normalize_answer("", context_question=question)
 
 
222
  return normalize_answer(
223
  f"Inference error: {e}", context_question=question
224
  )
@@ -298,6 +332,8 @@ def _build_user_payload(
298
  def _maybe_inline_audio_transcript(
299
  attachment_path: Optional[str],
300
  hf_token: Optional[str],
 
 
301
  ) -> str:
302
  if not attachment_path:
303
  return ""
@@ -310,5 +346,10 @@ def _maybe_inline_audio_transcript(
310
  tx = transcribe_audio(str(p), hf_token=hf_token)
311
  if not tx or tx.lower().startswith(("error", "asr error")):
312
  return f"\n\n[Automatic transcription failed: {tx[:500]}]\n"
313
- cap = int(os.environ.get("GAIA_AUTO_TRANSCRIPT_CHARS", "12000"))
 
 
 
 
 
314
  return f"\n\n[Audio transcript — use for your answer]\n{tx[:cap]}\n"
 
45
 
46
  def _tool_char_cap(backend: str, *, shrink_pass: int = 0) -> int:
47
  if backend == "groq":
48
+ # Free-tier Groq often rejects ~6k TPM per request; keep tool payloads small.
49
+ base = int(os.environ.get("GAIA_GROQ_MAX_TOOL_CHARS", "1400"))
50
  elif backend == "openai":
51
  base = int(os.environ.get("GAIA_OPENAI_MAX_TOOL_CHARS", "12000"))
52
  else:
53
  base = int(os.environ.get("GAIA_MAX_TOOL_CHARS", "24000"))
54
  if shrink_pass > 0:
55
+ base = max(280, base // (2**shrink_pass))
56
  return base
57
 
58
 
59
  def _groq_context_budget() -> int:
60
+ return int(os.environ.get("GAIA_GROQ_CONTEXT_CHARS", "12000"))
61
 
62
 
63
  def _maybe_retryable_llm_error(exc: Exception) -> bool:
 
87
  m["content"] = c[:cap] + "\n[truncated]"
88
 
89
 
90
+ def _groq_message_chars(m: dict[str, Any]) -> int:
91
+ n = len(str(m.get("content") or ""))
92
+ tc = m.get("tool_calls")
93
+ if tc:
94
+ n += len(str(tc))
95
+ return n
96
+
97
+
98
+ def _drop_oldest_tool_round(messages: list[dict[str, Any]]) -> bool:
99
+ """Remove the earliest assistant+tool_calls block and its tool replies."""
100
+ i = 2
101
+ while i < len(messages):
102
+ if messages[i].get("role") == "assistant" and messages[i].get("tool_calls"):
103
+ del messages[i]
104
+ while i < len(messages) and messages[i].get("role") == "tool":
105
+ del messages[i]
106
+ return True
107
+ i += 1
108
+ return False
109
+
110
+
111
  def _enforce_context_budget(messages: list[dict[str, Any]], backend: str) -> None:
112
  if backend != "groq":
113
  return
114
  budget = _groq_context_budget()
115
+ for _ in range(40):
116
+ total = sum(_groq_message_chars(m) for m in messages)
117
  if total <= budget:
118
  return
119
+ if _drop_oldest_tool_round(messages):
120
+ continue
121
  trimmed = False
122
  for m in messages[2:]:
123
  if m.get("role") != "tool":
124
  continue
125
  c = m.get("content")
126
+ if isinstance(c, str) and len(c) > 400:
127
+ m["content"] = c[: max(400, len(c) * 2 // 3)] + "\n[truncated]"
128
  trimmed = True
129
  break
130
  if not trimmed:
 
183
  _enforce_context_budget(messages, self.backend)
184
  if self.backend in ("groq", "openai"):
185
  assert self._oa_client is not None
186
+ mt = (
187
+ int(os.environ.get("GAIA_GROQ_MAX_TOKENS", "384"))
188
+ if self.backend == "groq"
189
+ else int(os.environ.get("GAIA_OPENAI_MAX_TOKENS", "768"))
190
+ )
191
  return chat_complete_openai(
192
  self._oa_client,
193
  model=self.text_model,
194
  messages=messages,
195
  tools=TOOL_DEFINITIONS,
196
+ max_tokens=mt,
197
  temperature=0.15,
198
  )
199
  client = self._get_hf_client()
 
220
  return normalize_answer("", context_question=question)
221
 
222
  user_text = _build_user_payload(question, attachment_path, task_id)
223
+ user_text += _maybe_inline_audio_transcript(
224
+ attachment_path, self.hf_token, backend=self.backend
225
+ )
226
 
227
  messages: list[dict[str, Any]] = [
228
  {"role": "system", "content": SYSTEM_PROMPT},
 
230
  ]
231
 
232
  last_text = ""
233
+ # Extra delays so Groq free-tier TPM / oversized-request errors can retry after shrink.
234
+ retry_delays = (2.0, 4.0, 8.0, 14.0, 22.0)
235
 
236
  for _ in range(self.max_iterations):
237
  completion = None
 
251
  continue
252
  if "402" in str(e) or "payment required" in str(e).lower():
253
  return normalize_answer("", context_question=question)
254
+ if _maybe_retryable_llm_error(e):
255
+ return normalize_answer("", context_question=question)
256
  return normalize_answer(
257
  f"Inference error: {e}", context_question=question
258
  )
 
332
  def _maybe_inline_audio_transcript(
333
  attachment_path: Optional[str],
334
  hf_token: Optional[str],
335
+ *,
336
+ backend: str = "hf",
337
  ) -> str:
338
  if not attachment_path:
339
  return ""
 
346
  tx = transcribe_audio(str(p), hf_token=hf_token)
347
  if not tx or tx.lower().startswith(("error", "asr error")):
348
  return f"\n\n[Automatic transcription failed: {tx[:500]}]\n"
349
+ cap = int(os.environ.get("GAIA_AUTO_TRANSCRIPT_CHARS", "8000"))
350
+ if backend == "groq":
351
+ cap = min(
352
+ cap,
353
+ int(os.environ.get("GAIA_GROQ_AUTO_TRANSCRIPT_CHARS", "3600")),
354
+ )
355
  return f"\n\n[Audio transcript — use for your answer]\n{tx[:cap]}\n"
answer_normalize.py CHANGED
@@ -73,10 +73,13 @@ def normalize_answer(
73
  if not text:
74
  return ""
75
  low = text.lower()
 
 
76
  if (
77
  "hugging face inference credits exhausted" in low
78
  or "inference credits exhausted" in low
79
- or ("inference error:" in low and "402" in text)
 
80
  ):
81
  return ""
82
  if "wikipedia_search:" in low and low.count("wikipedia_search:") >= 4:
 
73
  if not text:
74
  return ""
75
  low = text.lower()
76
+ if low.startswith("inference error:"):
77
+ return ""
78
  if (
79
  "hugging face inference credits exhausted" in low
80
  or "inference credits exhausted" in low
81
+ or "error code: 413" in low
82
+ or ("rate_limit_exceeded" in low and "413" in text)
83
  ):
84
  return ""
85
  if "wikipedia_search:" in low and low.count("wikipedia_search:") >= 4: