Spaces:
Sleeping
Sleeping
Fix prompts and utils
Browse files- agent.py +56 -841
- llm_client.py +40 -52
- prompts.py +31 -118
- tools.py +121 -609
- utils.py +60 -292
agent.py
CHANGED
|
@@ -1,859 +1,74 @@
|
|
| 1 |
-
# from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
# from dataclasses import dataclass
|
| 4 |
-
# from typing import Optional
|
| 5 |
-
|
| 6 |
-
# from prompts import build_solver_prompt
|
| 7 |
-
# from tools import TaskFileTool
|
| 8 |
-
# from utils import extract_final_answer, normalize_final_answer
|
| 9 |
-
# from llm_client import HFLLMClient
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
# @dataclass
|
| 13 |
-
# class AgentConfig:
|
| 14 |
-
# api_base_url: str = "https://agents-course-unit4-scoring.hf.space"
|
| 15 |
-
# max_context_chars: int = 12000
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
# class SubmissionAgent:
|
| 19 |
-
# """
|
| 20 |
-
# V1 agent for the Hugging Face Agents Course Unit 4 final project.
|
| 21 |
-
|
| 22 |
-
# Goals:
|
| 23 |
-
# - Accept a benchmark question and optional task_id
|
| 24 |
-
# - Load attached task-file context when available
|
| 25 |
-
# - Return ONLY the final answer string
|
| 26 |
-
# - Stay framework-agnostic for now so we can plug in any LLM later
|
| 27 |
-
# """
|
| 28 |
-
# def __init__(self, config: Optional[AgentConfig] = None, llm_client=None):
|
| 29 |
-
# self.config = config or AgentConfig()
|
| 30 |
-
# self.llm_client = llm_client or HFLLMClient()
|
| 31 |
-
# self.task_file_tool = TaskFileTool(api_base_url=self.config.api_base_url)
|
| 32 |
-
|
| 33 |
-
# def __call__(self, question: str, task_id: Optional[str] = None) -> str:
|
| 34 |
-
# """
|
| 35 |
-
# Main entry point used by app.py.
|
| 36 |
-
# """
|
| 37 |
-
# context = self._load_context(task_id=task_id)
|
| 38 |
-
# raw_output = self._solve(question=question, context=context)
|
| 39 |
-
# final_answer = extract_final_answer(raw_output)
|
| 40 |
-
# return normalize_final_answer(final_answer)
|
| 41 |
-
|
| 42 |
-
# def _load_context(self, task_id: Optional[str]) -> str:
|
| 43 |
-
# """
|
| 44 |
-
# Try to fetch and read any task-linked file.
|
| 45 |
-
# Safe fallback: empty context.
|
| 46 |
-
# """
|
| 47 |
-
# if not task_id:
|
| 48 |
-
# return ""
|
| 49 |
-
|
| 50 |
-
# try:
|
| 51 |
-
# file_text = self.task_file_tool.get_task_context(task_id=task_id)
|
| 52 |
-
# if not file_text:
|
| 53 |
-
# return ""
|
| 54 |
-
|
| 55 |
-
# return file_text[: self.config.max_context_chars]
|
| 56 |
-
# except Exception:
|
| 57 |
-
# return ""
|
| 58 |
-
|
| 59 |
-
# def _solve(self, question: str, context: str) -> str:
|
| 60 |
-
# """
|
| 61 |
-
# Solve the question with either:
|
| 62 |
-
# 1) a plugged-in LLM client, or
|
| 63 |
-
# 2) a safe fallback so the app does not crash during setup.
|
| 64 |
-
|
| 65 |
-
# The LLM client is expected to expose a .generate(prompt: str) -> str method.
|
| 66 |
-
# We will wire the real model later.
|
| 67 |
-
# """
|
| 68 |
-
# prompt = build_solver_prompt(question=question, context=context)
|
| 69 |
-
|
| 70 |
-
# try:
|
| 71 |
-
# return self.llm_client.generate(prompt)
|
| 72 |
-
# except Exception as e:
|
| 73 |
-
# print(f"LLM generation error: {e}")
|
| 74 |
-
# return ""
|
| 75 |
-
|
| 76 |
-
# #2
|
| 77 |
-
# from __future__ import annotations
|
| 78 |
-
|
| 79 |
-
# import re
|
| 80 |
-
# from dataclasses import dataclass
|
| 81 |
-
# from pathlib import Path
|
| 82 |
-
# from typing import Optional
|
| 83 |
-
|
| 84 |
-
# from llm_client import HFLLMClient
|
| 85 |
-
# from prompts import build_solver_prompt
|
| 86 |
-
# from tools import (
|
| 87 |
-
# AudioTool,
|
| 88 |
-
# LogicTool,
|
| 89 |
-
# PythonExecutionTool,
|
| 90 |
-
# SpreadsheetTool,
|
| 91 |
-
# TaskFileTool,
|
| 92 |
-
# WebPageTool,
|
| 93 |
-
# WikipediaTool,
|
| 94 |
-
# )
|
| 95 |
-
# from utils import (
|
| 96 |
-
# extract_final_answer,
|
| 97 |
-
# extract_urls,
|
| 98 |
-
# get_file_extension,
|
| 99 |
-
# normalize_final_answer,
|
| 100 |
-
# )
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
# @dataclass
|
| 104 |
-
# class AgentConfig:
|
| 105 |
-
# api_base_url: str = "https://agents-course-unit4-scoring.hf.space"
|
| 106 |
-
# max_context_chars: int = 12000
|
| 107 |
-
# enable_llm: bool = False
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
# class SubmissionAgent:
|
| 111 |
-
# def __init__(self, config: Optional[AgentConfig] = None, llm_client=None):
|
| 112 |
-
# self.config = config or AgentConfig()
|
| 113 |
-
# self.llm_client = llm_client or HFLLMClient()
|
| 114 |
-
# self.task_file_tool = TaskFileTool(api_base_url=self.config.api_base_url)
|
| 115 |
-
# self.wikipedia_tool = WikipediaTool()
|
| 116 |
-
# self.web_tool = WebPageTool()
|
| 117 |
-
# self.audio_tool = AudioTool()
|
| 118 |
-
# self.sheet_tool = SpreadsheetTool()
|
| 119 |
-
# self.python_tool = PythonExecutionTool()
|
| 120 |
-
# self.logic_tool = LogicTool()
|
| 121 |
-
|
| 122 |
-
# def __call__(self, question: str, task_id: Optional[str] = None) -> str:
|
| 123 |
-
# file_path = ""
|
| 124 |
-
# if task_id:
|
| 125 |
-
# path_obj = self.task_file_tool.get_task_file_path(task_id)
|
| 126 |
-
# file_path = str(path_obj) if path_obj else ""
|
| 127 |
-
|
| 128 |
-
# # 1. hard deterministic solvers first
|
| 129 |
-
# answer = self._solve_deterministic(question, file_path)
|
| 130 |
-
# if answer:
|
| 131 |
-
# return normalize_final_answer(answer)
|
| 132 |
-
|
| 133 |
-
# # 2. evidence-based fallback
|
| 134 |
-
# answer = self._solve_with_evidence(question, file_path)
|
| 135 |
-
# return normalize_final_answer(answer)
|
| 136 |
-
|
| 137 |
-
# def _solve_deterministic(self, question: str, file_path: str) -> str:
|
| 138 |
-
# q = question.lower()
|
| 139 |
-
|
| 140 |
-
# # Reversed-text puzzle
|
| 141 |
-
# if 'tfel' in question and '.rewsna eht sa' in question:
|
| 142 |
-
# return "right"
|
| 143 |
-
|
| 144 |
-
# # Non-commutative table
|
| 145 |
-
# if "provide the subset of s involved in any possible counter-examples" in q:
|
| 146 |
-
# return self.logic_tool.solve_noncommutative_subset(question)
|
| 147 |
-
|
| 148 |
-
# # Grocery / botanical vegetables
|
| 149 |
-
# if "professor of botany" in q and "vegetables from my list" in q:
|
| 150 |
-
# return self._solve_botany_grocery(question)
|
| 151 |
-
|
| 152 |
-
# # Python file execution
|
| 153 |
-
# if file_path and Path(file_path).suffix.lower() == ".py":
|
| 154 |
-
# output = self.python_tool.run_python_file(file_path)
|
| 155 |
-
# return self._extract_last_number(output)
|
| 156 |
-
|
| 157 |
-
# # Spreadsheet total sales
|
| 158 |
-
# if file_path and Path(file_path).suffix.lower() in {".xlsx", ".xls", ".csv"}:
|
| 159 |
-
# if "total sales" in q and "food" in q:
|
| 160 |
-
# return self.sheet_tool.total_food_sales(file_path)
|
| 161 |
-
|
| 162 |
-
# # Audio tasks
|
| 163 |
-
# if file_path and Path(file_path).suffix.lower() in {".mp3", ".wav", ".m4a", ".flac", ".ogg"}:
|
| 164 |
-
# transcript = self.audio_tool.transcribe(file_path)
|
| 165 |
-
# return self._solve_from_audio_transcript(question, transcript)
|
| 166 |
-
|
| 167 |
-
# # Mercedes Sosa counting from Wikipedia evidence
|
| 168 |
-
# if "mercedes sosa" in q and "studio albums" in q:
|
| 169 |
-
# return self._solve_mercedes_sosa()
|
| 170 |
-
|
| 171 |
-
# # Malko historical filtering
|
| 172 |
-
# if "malko competition" in q and "country that no longer exists" in q:
|
| 173 |
-
# return self._solve_malko()
|
| 174 |
-
|
| 175 |
-
# return ""
|
| 176 |
-
|
| 177 |
-
# def _solve_with_evidence(self, question: str, file_path: str) -> str:
|
| 178 |
-
# evidence_parts = []
|
| 179 |
-
|
| 180 |
-
# if file_path:
|
| 181 |
-
# ext = get_file_extension(file_path)
|
| 182 |
-
# if ext in {".txt", ".md", ".csv", ".json", ".html", ".xml"}:
|
| 183 |
-
# try:
|
| 184 |
-
# evidence_parts.append(self.task_file_tool.read_file_as_text(Path(file_path)))
|
| 185 |
-
# except Exception:
|
| 186 |
-
# pass
|
| 187 |
-
|
| 188 |
-
# urls = extract_urls(question)
|
| 189 |
-
# for url in urls[:2]:
|
| 190 |
-
# try:
|
| 191 |
-
# evidence_parts.append(self.web_tool.fetch_text(url))
|
| 192 |
-
# except Exception:
|
| 193 |
-
# pass
|
| 194 |
-
|
| 195 |
-
# q = question.lower()
|
| 196 |
-
# if "wikipedia" in q or "featured article" in q or "malko" in q or "olympics" in q:
|
| 197 |
-
# evidence_parts.append(self._gather_wikipedia_evidence(question))
|
| 198 |
-
|
| 199 |
-
# evidence = "\n\n".join(p for p in evidence_parts if p)[: self.config.max_context_chars]
|
| 200 |
-
# if not evidence:
|
| 201 |
-
# return ""
|
| 202 |
-
|
| 203 |
-
# # deterministic extraction first
|
| 204 |
-
# cheap = self._cheap_extract(question, evidence)
|
| 205 |
-
# if cheap:
|
| 206 |
-
# return cheap
|
| 207 |
-
|
| 208 |
-
# if not self.config.enable_llm:
|
| 209 |
-
# return ""
|
| 210 |
-
|
| 211 |
-
# try:
|
| 212 |
-
# prompt = build_solver_prompt(question=question, context=evidence)
|
| 213 |
-
# raw = self.llm_client.generate(prompt)
|
| 214 |
-
# return extract_final_answer(raw)
|
| 215 |
-
# except Exception as e:
|
| 216 |
-
# print(f"LLM generation error: {e}")
|
| 217 |
-
# return ""
|
| 218 |
-
|
| 219 |
-
# def _solve_botany_grocery(self, question: str) -> str:
|
| 220 |
-
# m = re.search(r"Here's the list I have so far:\s*(.*?)\s*I need to make headings", question, re.S)
|
| 221 |
-
# if not m:
|
| 222 |
-
# return ""
|
| 223 |
-
|
| 224 |
-
# items = [x.strip() for x in m.group(1).replace("\n", " ").split(",") if x.strip()]
|
| 225 |
-
|
| 226 |
-
# # Botanical vegetables only for this grocery context
|
| 227 |
-
# vegetables = {
|
| 228 |
-
# "broccoli",
|
| 229 |
-
# "celery",
|
| 230 |
-
# "fresh basil",
|
| 231 |
-
# "lettuce",
|
| 232 |
-
# "sweet potatoes",
|
| 233 |
-
# }
|
| 234 |
-
|
| 235 |
-
# selected = sorted([item for item in items if item.lower() in vegetables], key=str.lower)
|
| 236 |
-
# return ", ".join(selected)
|
| 237 |
-
|
| 238 |
-
# def _solve_from_audio_transcript(self, question: str, transcript: str) -> str:
|
| 239 |
-
# q = question.lower()
|
| 240 |
-
# t = transcript.strip()
|
| 241 |
-
# if not t:
|
| 242 |
-
# return ""
|
| 243 |
-
|
| 244 |
-
# if "page numbers" in q:
|
| 245 |
-
# nums = sorted({int(x) for x in re.findall(r"\b\d+\b", t)})
|
| 246 |
-
# return ", ".join(str(x) for x in nums)
|
| 247 |
-
|
| 248 |
-
# if "ingredients" in q and "filling" in q:
|
| 249 |
-
# # remove measurements, keep ingredient-like phrases
|
| 250 |
-
# parts = re.split(r"[.,;\n]", t)
|
| 251 |
-
# cleaned = []
|
| 252 |
-
# for p in parts:
|
| 253 |
-
# s = p.strip().lower()
|
| 254 |
-
# s = re.sub(r"\b(one|two|three|four|five|six|seven|eight|nine|ten)\b", "", s)
|
| 255 |
-
# s = re.sub(r"\b\d+(/\d+)?\b", "", s)
|
| 256 |
-
# s = re.sub(r"\b(cup|cups|tablespoon|tablespoons|teaspoon|teaspoons|pinch|ounces|ounce)\b", "", s)
|
| 257 |
-
# s = re.sub(r"\s+", " ", s).strip(" ,.")
|
| 258 |
-
# if s and len(s) < 40:
|
| 259 |
-
# cleaned.append(s)
|
| 260 |
-
|
| 261 |
-
# cleaned = sorted(set(cleaned))
|
| 262 |
-
# return ", ".join(cleaned)
|
| 263 |
-
|
| 264 |
-
# return ""
|
| 265 |
-
|
| 266 |
-
# def _solve_mercedes_sosa(self) -> str:
|
| 267 |
-
# text = self.wikipedia_tool.get_page_text("Mercedes Sosa")
|
| 268 |
-
# if not text:
|
| 269 |
-
# return ""
|
| 270 |
-
|
| 271 |
-
# # Count years 2000..2009 appearing in studio-album discography style text
|
| 272 |
-
# years = re.findall(r"\b(200[0-9])\b", text)
|
| 273 |
-
# count = sum(1 for y in years if 2000 <= int(y) <= 2009)
|
| 274 |
-
# # This page contains many non-album years, so use discography-like patterns too
|
| 275 |
-
# line_hits = re.findall(r"(200[0-9]).{0,80}", text)
|
| 276 |
-
# filtered = [y for y in line_hits if 2000 <= int(y) <= 2009]
|
| 277 |
-
# if filtered:
|
| 278 |
-
# return str(len(filtered))
|
| 279 |
-
# return str(count) if count else ""
|
| 280 |
-
|
| 281 |
-
# def _solve_malko(self) -> str:
|
| 282 |
-
# text = self.wikipedia_tool.get_page_text("Malko Competition")
|
| 283 |
-
# if not text:
|
| 284 |
-
# return ""
|
| 285 |
-
|
| 286 |
-
# # Historical record-based heuristic
|
| 287 |
-
# # Looking for Claus Peter Flor / East Germany
|
| 288 |
-
# if "Claus Peter Flor" in text and "East Germany" in text:
|
| 289 |
-
# return "Claus"
|
| 290 |
-
|
| 291 |
-
# return ""
|
| 292 |
-
|
| 293 |
-
# def _gather_wikipedia_evidence(self, question: str) -> str:
|
| 294 |
-
# guesses = []
|
| 295 |
-
# q = question.lower()
|
| 296 |
-
|
| 297 |
-
# if "mercedes sosa" in q:
|
| 298 |
-
# guesses.append("Mercedes Sosa")
|
| 299 |
-
# if "malko" in q:
|
| 300 |
-
# guesses.append("Malko Competition")
|
| 301 |
-
# if "1928 summer olympics" in q:
|
| 302 |
-
# guesses.append("1928 Summer Olympics")
|
| 303 |
-
# if "featured article" in q and "dinosaur" in q:
|
| 304 |
-
# guesses.append("Wikipedia:Featured articles")
|
| 305 |
-
|
| 306 |
-
# texts = [self.wikipedia_tool.get_page_text(title) for title in guesses]
|
| 307 |
-
# return "\n\n".join(t for t in texts if t)
|
| 308 |
-
|
| 309 |
-
# def _cheap_extract(self, question: str, evidence: str) -> str:
|
| 310 |
-
# q = question.lower()
|
| 311 |
-
# ev = evidence.strip()
|
| 312 |
-
|
| 313 |
-
# if not ev:
|
| 314 |
-
# return ""
|
| 315 |
-
|
| 316 |
-
# if "how many" in q or "highest number" in q or "at bats" in q:
|
| 317 |
-
# nums = [int(x) for x in re.findall(r"\b\d+\b", ev)]
|
| 318 |
-
# if nums:
|
| 319 |
-
# return str(max(nums))
|
| 320 |
-
|
| 321 |
-
# if "what is the first name" in q or "give only the first name" in q:
|
| 322 |
-
# m = re.search(r"\b([A-Z][a-z]+)\s+[A-Z][a-z]+\b", ev)
|
| 323 |
-
# if m:
|
| 324 |
-
# return m.group(1)
|
| 325 |
-
|
| 326 |
-
# if "ioc country code" in q:
|
| 327 |
-
# m = re.search(r"\b[A-Z]{3}\b", ev)
|
| 328 |
-
# if m:
|
| 329 |
-
# return m.group(0)
|
| 330 |
-
|
| 331 |
-
# if "award number" in q:
|
| 332 |
-
# m = re.search(r"\b[A-Z0-9-]{6,}\b", ev)
|
| 333 |
-
# if m:
|
| 334 |
-
# return m.group(0)
|
| 335 |
-
|
| 336 |
-
# return ""
|
| 337 |
-
|
| 338 |
-
# def _extract_last_number(self, text: str) -> str:
|
| 339 |
-
# nums = re.findall(r"-?\d+(?:\.\d+)?", text or "")
|
| 340 |
-
# return nums[-1] if nums else ""
|
| 341 |
-
|
| 342 |
-
# #2
|
| 343 |
-
|
| 344 |
-
# from __future__ import annotations
|
| 345 |
-
|
| 346 |
-
# from dataclasses import dataclass
|
| 347 |
-
# from pathlib import Path
|
| 348 |
-
# from typing import Optional
|
| 349 |
-
|
| 350 |
-
# from smolagents import CodeAgent, tool
|
| 351 |
-
|
| 352 |
-
# from llm_client import build_local_model
|
| 353 |
-
# from tools import (
|
| 354 |
-
# AudioTool,
|
| 355 |
-
# LogicTool,
|
| 356 |
-
# PythonExecutionTool,
|
| 357 |
-
# SpreadsheetTool,
|
| 358 |
-
# TaskFileTool,
|
| 359 |
-
# WebPageTool,
|
| 360 |
-
# WikipediaTool,
|
| 361 |
-
# )
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
# @dataclass
|
| 365 |
-
# class AgentConfig:
|
| 366 |
-
# api_base_url: str = "https://agents-course-unit4-scoring.hf.space"
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
# class SubmissionAgent:
|
| 370 |
-
# def __init__(self, config: Optional[AgentConfig] = None):
|
| 371 |
-
# self.config = config or AgentConfig()
|
| 372 |
-
# self.task_file_tool = TaskFileTool(api_base_url=self.config.api_base_url)
|
| 373 |
-
# self.wikipedia_tool = WikipediaTool()
|
| 374 |
-
# self.web_tool = WebPageTool()
|
| 375 |
-
# self.audio_tool = AudioTool()
|
| 376 |
-
# self.sheet_tool = SpreadsheetTool()
|
| 377 |
-
# self.python_tool = PythonExecutionTool()
|
| 378 |
-
# self.logic_tool = LogicTool()
|
| 379 |
-
|
| 380 |
-
# @tool
|
| 381 |
-
# def get_task_file_path(task_id: str) -> str:
|
| 382 |
-
# """Return the local path of the attached file for a GAIA task_id, or empty string."""
|
| 383 |
-
# path_obj = self.task_file_tool.get_task_file_path(task_id)
|
| 384 |
-
# return str(path_obj) if path_obj else ""
|
| 385 |
-
|
| 386 |
-
# @tool
|
| 387 |
-
# def read_text_file(file_path: str) -> str:
|
| 388 |
-
# """Read a local text/csv/json/html/xml file and return its text content."""
|
| 389 |
-
# if not file_path:
|
| 390 |
-
# return ""
|
| 391 |
-
# return self.task_file_tool.read_file_as_text(Path(file_path))
|
| 392 |
-
|
| 393 |
-
# @tool
|
| 394 |
-
# def transcribe_audio(file_path: str) -> str:
|
| 395 |
-
# """Transcribe a local audio file such as mp3/wav/m4a/flac/ogg."""
|
| 396 |
-
# if not file_path:
|
| 397 |
-
# return ""
|
| 398 |
-
# return self.audio_tool.transcribe(file_path)
|
| 399 |
-
|
| 400 |
-
# @tool
|
| 401 |
-
# def run_python_file(file_path: str) -> str:
|
| 402 |
-
# """Run a local Python file and return stdout."""
|
| 403 |
-
# if not file_path:
|
| 404 |
-
# return ""
|
| 405 |
-
# return self.python_tool.run_python_file(file_path)
|
| 406 |
-
|
| 407 |
-
# @tool
|
| 408 |
-
# def total_food_sales(file_path: str) -> str:
|
| 409 |
-
# """Compute total sales from food only, excluding drinks, from a spreadsheet file."""
|
| 410 |
-
# if not file_path:
|
| 411 |
-
# return ""
|
| 412 |
-
# return self.sheet_tool.total_food_sales(file_path)
|
| 413 |
-
|
| 414 |
-
# @tool
|
| 415 |
-
# def get_wikipedia_page(title: str) -> str:
|
| 416 |
-
# """Fetch the text of an English Wikipedia page by title."""
|
| 417 |
-
# return self.wikipedia_tool.get_page_text(title)
|
| 418 |
-
|
| 419 |
-
# @tool
|
| 420 |
-
# def fetch_web_text(url: str) -> str:
|
| 421 |
-
# """Fetch readable text from a web page URL."""
|
| 422 |
-
# return self.web_tool.fetch_text(url)
|
| 423 |
-
|
| 424 |
-
# @tool
|
| 425 |
-
# def solve_noncommutative_subset(question_text: str) -> str:
|
| 426 |
-
# """Solve the specific operation-table non-commutativity question."""
|
| 427 |
-
# return self.logic_tool.solve_noncommutative_subset(question_text)
|
| 428 |
-
|
| 429 |
-
# @tool
|
| 430 |
-
# def solve_reverse_left(question_text: str) -> str:
|
| 431 |
-
# """Solve the reversed sentence asking for the opposite of 'left'."""
|
| 432 |
-
# if 'tfel' in question_text and '.rewsna eht sa' in question_text:
|
| 433 |
-
# return "right"
|
| 434 |
-
# return ""
|
| 435 |
-
|
| 436 |
-
# @tool
|
| 437 |
-
# def solve_botany_grocery(question_text: str) -> str:
|
| 438 |
-
# """Return only botanical vegetables from the grocery-list question."""
|
| 439 |
-
# import re
|
| 440 |
-
|
| 441 |
-
# m = re.search(r"Here's the list I have so far:\s*(.*?)\s*I need to make headings", question_text, re.S)
|
| 442 |
-
# if not m:
|
| 443 |
-
# return ""
|
| 444 |
-
|
| 445 |
-
# items = [x.strip() for x in m.group(1).replace("\n", " ").split(",") if x.strip()]
|
| 446 |
-
# vegetables = {"broccoli", "celery", "fresh basil", "lettuce", "sweet potatoes"}
|
| 447 |
-
# selected = sorted([item for item in items if item.lower() in vegetables], key=str.lower)
|
| 448 |
-
# return ", ".join(selected)
|
| 449 |
-
|
| 450 |
-
# @tool
|
| 451 |
-
# def solve_mercedes_sosa() -> str:
|
| 452 |
-
# """Count Mercedes Sosa studio albums between 2000 and 2009 using Wikipedia text."""
|
| 453 |
-
# import re
|
| 454 |
-
|
| 455 |
-
# text = self.wikipedia_tool.get_page_text("Mercedes Sosa")
|
| 456 |
-
# if not text:
|
| 457 |
-
# return ""
|
| 458 |
-
# # conservative fallback
|
| 459 |
-
# line_hits = re.findall(r"(200[0-9]).{0,100}", text)
|
| 460 |
-
# filtered = [y for y in line_hits if 2000 <= int(y) <= 2009]
|
| 461 |
-
# return str(len(filtered)) if filtered else ""
|
| 462 |
-
|
| 463 |
-
# @tool
|
| 464 |
-
# def solve_malko() -> str:
|
| 465 |
-
# """Return the first name for the Malko winner with nationality on record from a non-existing country."""
|
| 466 |
-
# text = self.wikipedia_tool.get_page_text("Malko Competition")
|
| 467 |
-
# if "Claus Peter Flor" in text and "East Germany" in text:
|
| 468 |
-
# return "Claus"
|
| 469 |
-
# return ""
|
| 470 |
-
|
| 471 |
-
# self.agent = CodeAgent(
|
| 472 |
-
# model=build_local_model(),
|
| 473 |
-
# tools=[
|
| 474 |
-
# get_task_file_path,
|
| 475 |
-
# read_text_file,
|
| 476 |
-
# transcribe_audio,
|
| 477 |
-
# run_python_file,
|
| 478 |
-
# total_food_sales,
|
| 479 |
-
# get_wikipedia_page,
|
| 480 |
-
# fetch_web_text,
|
| 481 |
-
# solve_noncommutative_subset,
|
| 482 |
-
# solve_reverse_left,
|
| 483 |
-
# solve_botany_grocery,
|
| 484 |
-
# solve_mercedes_sosa,
|
| 485 |
-
# solve_malko,
|
| 486 |
-
# ],
|
| 487 |
-
# additional_authorized_imports=["re", "json", "math", "statistics", "pathlib"],
|
| 488 |
-
# max_steps=6,
|
| 489 |
-
# )
|
| 490 |
-
|
| 491 |
-
# def __call__(self, question: str, task_id: Optional[str] = None) -> str:
|
| 492 |
-
# task_id = task_id or ""
|
| 493 |
-
# prompt = f"""
|
| 494 |
-
# You are solving one GAIA benchmark task.
|
| 495 |
-
|
| 496 |
-
# Rules:
|
| 497 |
-
# - Return only the final answer.
|
| 498 |
-
# - No explanation.
|
| 499 |
-
# - No "FINAL ANSWER".
|
| 500 |
-
# - Use tools when needed.
|
| 501 |
-
# - Prefer deterministic tools over guessing.
|
| 502 |
-
# - If there is an attached file, first call get_task_file_path("{task_id}").
|
| 503 |
-
|
| 504 |
-
# Question:
|
| 505 |
-
# {question}
|
| 506 |
-
# """
|
| 507 |
-
# result = self.agent.run(prompt)
|
| 508 |
-
# return str(result).strip()
|
| 509 |
-
|
| 510 |
from __future__ import annotations
|
| 511 |
|
| 512 |
from dataclasses import dataclass
|
| 513 |
-
from pathlib import Path
|
| 514 |
from typing import Optional
|
| 515 |
|
| 516 |
-
from
|
| 517 |
-
|
| 518 |
-
from
|
| 519 |
-
from
|
| 520 |
-
AudioTool,
|
| 521 |
-
LogicTool,
|
| 522 |
-
PythonExecutionTool,
|
| 523 |
-
RetrieveCSVStorageTool,
|
| 524 |
-
SpreadsheetTool,
|
| 525 |
-
TaskFileTool,
|
| 526 |
-
WebPageTool,
|
| 527 |
-
WikiTool,
|
| 528 |
-
WikipediaTool,
|
| 529 |
-
fetch_text_content,
|
| 530 |
-
read_excel,
|
| 531 |
-
)
|
| 532 |
-
from utils import normalize_final_answer
|
| 533 |
|
| 534 |
|
| 535 |
@dataclass
|
| 536 |
class AgentConfig:
|
| 537 |
api_base_url: str = "https://agents-course-unit4-scoring.hf.space"
|
| 538 |
-
|
| 539 |
|
| 540 |
|
| 541 |
class SubmissionAgent:
|
| 542 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
self.config = config or AgentConfig()
|
| 544 |
-
|
| 545 |
self.task_file_tool = TaskFileTool(api_base_url=self.config.api_base_url)
|
| 546 |
-
self.wikipedia_tool = WikipediaTool()
|
| 547 |
-
self.web_tool = WebPageTool()
|
| 548 |
-
self.audio_tool = AudioTool()
|
| 549 |
-
self.sheet_tool = SpreadsheetTool()
|
| 550 |
-
self.python_tool = PythonExecutionTool()
|
| 551 |
-
self.logic_tool = LogicTool()
|
| 552 |
-
|
| 553 |
-
self.wiki_storage_tool = RetrieveCSVStorageTool(
|
| 554 |
-
table_name="wiki",
|
| 555 |
-
init_storage=True,
|
| 556 |
-
storage_path="./storage",
|
| 557 |
-
)
|
| 558 |
-
|
| 559 |
-
@tool
|
| 560 |
-
def get_task_file_path(task_id: str) -> str:
|
| 561 |
-
"""
|
| 562 |
-
Get the local cached path for a GAIA task attachment.
|
| 563 |
-
|
| 564 |
-
Args:
|
| 565 |
-
task_id: The GAIA task identifier.
|
| 566 |
-
|
| 567 |
-
Returns:
|
| 568 |
-
The local file path if a file exists, otherwise an empty string.
|
| 569 |
-
"""
|
| 570 |
-
path_obj = self.task_file_tool.get_task_file_path(task_id)
|
| 571 |
-
return str(path_obj) if path_obj else ""
|
| 572 |
-
|
| 573 |
-
@tool
|
| 574 |
-
def inspect_file(file_path: str) -> str:
|
| 575 |
-
"""
|
| 576 |
-
Inspect a local file and return basic metadata.
|
| 577 |
-
|
| 578 |
-
Args:
|
| 579 |
-
file_path: The local file path.
|
| 580 |
-
|
| 581 |
-
Returns:
|
| 582 |
-
A short description with file name, suffix, and existence status.
|
| 583 |
-
"""
|
| 584 |
-
if not file_path:
|
| 585 |
-
return "No file path provided."
|
| 586 |
-
p = Path(file_path)
|
| 587 |
-
return f"name={p.name}, suffix={p.suffix.lower()}, exists={p.exists()}"
|
| 588 |
-
|
| 589 |
-
@tool
|
| 590 |
-
def read_local_text_file(file_path: str) -> str:
|
| 591 |
-
"""
|
| 592 |
-
Read a local text-like file.
|
| 593 |
-
|
| 594 |
-
Args:
|
| 595 |
-
file_path: The local file path.
|
| 596 |
-
|
| 597 |
-
Returns:
|
| 598 |
-
The file contents as text, or an empty string if unavailable.
|
| 599 |
-
"""
|
| 600 |
-
if not file_path:
|
| 601 |
-
return ""
|
| 602 |
-
return self.task_file_tool.read_file_as_text(Path(file_path))
|
| 603 |
-
|
| 604 |
-
@tool
|
| 605 |
-
def transcribe_local_audio(file_path: str) -> str:
|
| 606 |
-
"""
|
| 607 |
-
Transcribe a local audio file into text.
|
| 608 |
-
|
| 609 |
-
Args:
|
| 610 |
-
file_path: The local audio file path.
|
| 611 |
-
|
| 612 |
-
Returns:
|
| 613 |
-
The transcription text, or an empty string if transcription fails.
|
| 614 |
-
"""
|
| 615 |
-
if not file_path:
|
| 616 |
-
return ""
|
| 617 |
-
return self.audio_tool.transcribe(file_path)
|
| 618 |
-
|
| 619 |
-
@tool
|
| 620 |
-
def run_local_python(file_path: str) -> str:
|
| 621 |
-
"""
|
| 622 |
-
Execute a local Python file and capture stdout.
|
| 623 |
-
|
| 624 |
-
Args:
|
| 625 |
-
file_path: The local Python file path.
|
| 626 |
-
|
| 627 |
-
Returns:
|
| 628 |
-
The stdout output of the script, or an empty string if execution fails.
|
| 629 |
-
"""
|
| 630 |
-
if not file_path:
|
| 631 |
-
return ""
|
| 632 |
-
return self.python_tool.run_python_file(file_path)
|
| 633 |
-
|
| 634 |
-
@tool
|
| 635 |
-
def read_local_spreadsheet(file_path: str) -> str:
|
| 636 |
-
"""
|
| 637 |
-
Read a local spreadsheet and return all sheets as CSV-like text.
|
| 638 |
-
|
| 639 |
-
Args:
|
| 640 |
-
file_path: The local spreadsheet file path.
|
| 641 |
-
|
| 642 |
-
Returns:
|
| 643 |
-
A combined text representation of spreadsheet contents.
|
| 644 |
-
"""
|
| 645 |
-
if not file_path:
|
| 646 |
-
return ""
|
| 647 |
-
|
| 648 |
-
sheets = self.sheet_tool.read(file_path)
|
| 649 |
-
if not sheets:
|
| 650 |
-
return ""
|
| 651 |
-
|
| 652 |
-
parts = []
|
| 653 |
-
for sheet_name, df in sheets.items():
|
| 654 |
-
parts.append(f"Sheet: {sheet_name}")
|
| 655 |
-
try:
|
| 656 |
-
parts.append(df.to_csv(index=False))
|
| 657 |
-
except Exception:
|
| 658 |
-
parts.append(str(df))
|
| 659 |
-
return "\n\n".join(parts)
|
| 660 |
-
|
| 661 |
-
@tool
|
| 662 |
-
def compute_total_food_sales(file_path: str) -> str:
|
| 663 |
-
"""
|
| 664 |
-
Compute total food sales excluding drinks from a spreadsheet.
|
| 665 |
-
|
| 666 |
-
Args:
|
| 667 |
-
file_path: The local spreadsheet file path.
|
| 668 |
-
|
| 669 |
-
Returns:
|
| 670 |
-
The total food sales with two decimal places.
|
| 671 |
-
"""
|
| 672 |
-
if not file_path:
|
| 673 |
-
return ""
|
| 674 |
-
return self.sheet_tool.total_food_sales(file_path)
|
| 675 |
-
|
| 676 |
-
@tool
|
| 677 |
-
def get_wikipedia_page(title: str) -> str:
|
| 678 |
-
"""
|
| 679 |
-
Fetch the text of an English Wikipedia page by title.
|
| 680 |
-
|
| 681 |
-
Args:
|
| 682 |
-
title: The Wikipedia page title.
|
| 683 |
-
|
| 684 |
-
Returns:
|
| 685 |
-
The page text, or an empty string if not found.
|
| 686 |
-
"""
|
| 687 |
-
return self.wikipedia_tool.get_page_text(title)
|
| 688 |
-
|
| 689 |
-
@tool
|
| 690 |
-
def fetch_web_text(url: str) -> str:
|
| 691 |
-
"""
|
| 692 |
-
Fetch readable text from a web page URL.
|
| 693 |
-
|
| 694 |
-
Args:
|
| 695 |
-
url: The web page URL.
|
| 696 |
-
|
| 697 |
-
Returns:
|
| 698 |
-
The readable text content, or an empty string if fetch fails.
|
| 699 |
-
"""
|
| 700 |
-
try:
|
| 701 |
-
return self.web_tool.fetch_text(url)
|
| 702 |
-
except Exception:
|
| 703 |
-
return ""
|
| 704 |
-
|
| 705 |
-
@tool
|
| 706 |
-
def fetch_wiki_content(query: str, language: str | None = None) -> str:
|
| 707 |
-
"""
|
| 708 |
-
Fetch Wikipedia page content and store any extracted tables.
|
| 709 |
-
|
| 710 |
-
Args:
|
| 711 |
-
query: The Wikipedia page title.
|
| 712 |
-
language: The Wikipedia language code, such as en.
|
| 713 |
-
|
| 714 |
-
Returns:
|
| 715 |
-
The Wikipedia page text plus any stored table keys.
|
| 716 |
-
"""
|
| 717 |
-
language = language or "en"
|
| 718 |
-
wiki_tool = WikiTool(storage=self.wiki_storage_tool.get_storage())
|
| 719 |
-
return wiki_tool.forward(query=query, language=language)
|
| 720 |
-
|
| 721 |
-
@tool
|
| 722 |
-
def retrieve_stored_table(key: str) -> str:
|
| 723 |
-
"""
|
| 724 |
-
Retrieve a stored Wikipedia table by key.
|
| 725 |
-
|
| 726 |
-
Args:
|
| 727 |
-
key: The stored table key, such as table_1.
|
| 728 |
-
|
| 729 |
-
Returns:
|
| 730 |
-
The table as CSV-like text, or an error message.
|
| 731 |
-
"""
|
| 732 |
-
return self.wiki_storage_tool.forward(key)
|
| 733 |
-
|
| 734 |
-
@tool
|
| 735 |
-
def solve_noncommutative_table(question_text: str) -> str:
|
| 736 |
-
"""
|
| 737 |
-
Solve an operation-table commutativity question.
|
| 738 |
-
|
| 739 |
-
Args:
|
| 740 |
-
question_text: The full question text including the table.
|
| 741 |
-
|
| 742 |
-
Returns:
|
| 743 |
-
A comma-separated list of the elements involved in counterexamples.
|
| 744 |
-
"""
|
| 745 |
-
return self.logic_tool.solve_noncommutative_subset(question_text)
|
| 746 |
-
|
| 747 |
-
@tool
|
| 748 |
-
def extract_last_number(text: str) -> str:
|
| 749 |
-
"""
|
| 750 |
-
Extract the last numeric value from text.
|
| 751 |
-
|
| 752 |
-
Args:
|
| 753 |
-
text: Input text.
|
| 754 |
-
|
| 755 |
-
Returns:
|
| 756 |
-
The last numeric value, or an empty string if none is found.
|
| 757 |
-
"""
|
| 758 |
-
import re
|
| 759 |
-
|
| 760 |
-
nums = re.findall(r"-?\d+(?:\.\d+)?", text or "")
|
| 761 |
-
return nums[-1] if nums else ""
|
| 762 |
-
|
| 763 |
-
@tool
|
| 764 |
-
def extract_page_numbers(text: str) -> str:
|
| 765 |
-
"""
|
| 766 |
-
Extract unique page numbers from text and return them in ascending order.
|
| 767 |
-
|
| 768 |
-
Args:
|
| 769 |
-
text: Input text.
|
| 770 |
-
|
| 771 |
-
Returns:
|
| 772 |
-
A comma-separated list of ascending page numbers.
|
| 773 |
-
"""
|
| 774 |
-
import re
|
| 775 |
-
|
| 776 |
-
nums = sorted({int(x) for x in re.findall(r"\b\d+\b", text or "")})
|
| 777 |
-
return ", ".join(str(x) for x in nums)
|
| 778 |
-
|
| 779 |
-
@tool
|
| 780 |
-
def extract_first_name(text: str) -> str:
|
| 781 |
-
"""
|
| 782 |
-
Extract a likely first name from a full name in text.
|
| 783 |
-
|
| 784 |
-
Args:
|
| 785 |
-
text: Input text.
|
| 786 |
-
|
| 787 |
-
Returns:
|
| 788 |
-
The first name, or an empty string if not found.
|
| 789 |
-
"""
|
| 790 |
-
import re
|
| 791 |
-
|
| 792 |
-
m = re.search(r"\b([A-Z][a-z]+)\s+[A-Z][A-Za-z'’-]+\b", text or "")
|
| 793 |
-
return m.group(1) if m else ""
|
| 794 |
-
|
| 795 |
-
@tool
|
| 796 |
-
def extract_code_like_token(text: str) -> str:
|
| 797 |
-
"""
|
| 798 |
-
Extract a likely alphanumeric code token from text.
|
| 799 |
-
|
| 800 |
-
Args:
|
| 801 |
-
text: Input text.
|
| 802 |
-
|
| 803 |
-
Returns:
|
| 804 |
-
A likely code-like token, or an empty string if not found.
|
| 805 |
-
"""
|
| 806 |
-
import re
|
| 807 |
-
|
| 808 |
-
m = re.search(r"\b[A-Z0-9-]{3,}\b", text or "")
|
| 809 |
-
return m.group(0) if m else ""
|
| 810 |
-
|
| 811 |
-
self.agent = ToolCallingAgent(
|
| 812 |
-
model=build_local_model(),
|
| 813 |
-
tools=[
|
| 814 |
-
get_task_file_path,
|
| 815 |
-
inspect_file,
|
| 816 |
-
read_local_text_file,
|
| 817 |
-
transcribe_local_audio,
|
| 818 |
-
run_local_python,
|
| 819 |
-
read_local_spreadsheet,
|
| 820 |
-
compute_total_food_sales,
|
| 821 |
-
fetch_text_content,
|
| 822 |
-
read_excel,
|
| 823 |
-
get_wikipedia_page,
|
| 824 |
-
fetch_web_text,
|
| 825 |
-
fetch_wiki_content,
|
| 826 |
-
retrieve_stored_table,
|
| 827 |
-
solve_noncommutative_table,
|
| 828 |
-
extract_last_number,
|
| 829 |
-
extract_page_numbers,
|
| 830 |
-
extract_first_name,
|
| 831 |
-
extract_code_like_token,
|
| 832 |
-
],
|
| 833 |
-
max_steps=self.config.max_steps,
|
| 834 |
-
verbosity_level=1,
|
| 835 |
-
)
|
| 836 |
|
| 837 |
def __call__(self, question: str, task_id: Optional[str] = None) -> str:
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 854 |
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
""
|
| 858 |
-
|
| 859 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
from dataclasses import dataclass
|
|
|
|
| 4 |
from typing import Optional
|
| 5 |
|
| 6 |
+
from prompts import build_solver_prompt
|
| 7 |
+
from tools import TaskFileTool
|
| 8 |
+
from utils import extract_final_answer, normalize_final_answer
|
| 9 |
+
from llm_client import HFLLMClient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
@dataclass
|
| 13 |
class AgentConfig:
|
| 14 |
api_base_url: str = "https://agents-course-unit4-scoring.hf.space"
|
| 15 |
+
max_context_chars: int = 12000
|
| 16 |
|
| 17 |
|
| 18 |
class SubmissionAgent:
|
| 19 |
+
"""
|
| 20 |
+
V1 agent for the Hugging Face Agents Course Unit 4 final project.
|
| 21 |
+
|
| 22 |
+
Goals:
|
| 23 |
+
- Accept a benchmark question and optional task_id
|
| 24 |
+
- Load attached task-file context when available
|
| 25 |
+
- Return ONLY the final answer string
|
| 26 |
+
- Stay framework-agnostic for now so we can plug in any LLM later
|
| 27 |
+
"""
|
| 28 |
+
def __init__(self, config: Optional[AgentConfig] = None, llm_client=None):
|
| 29 |
self.config = config or AgentConfig()
|
| 30 |
+
self.llm_client = llm_client or HFLLMClient()
|
| 31 |
self.task_file_tool = TaskFileTool(api_base_url=self.config.api_base_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
def __call__(self, question: str, task_id: Optional[str] = None) -> str:
|
| 34 |
+
"""
|
| 35 |
+
Main entry point used by app.py.
|
| 36 |
+
"""
|
| 37 |
+
context = self._load_context(task_id=task_id)
|
| 38 |
+
raw_output = self._solve(question=question, context=context)
|
| 39 |
+
final_answer = extract_final_answer(raw_output)
|
| 40 |
+
return normalize_final_answer(final_answer)
|
| 41 |
+
|
| 42 |
+
def _load_context(self, task_id: Optional[str]) -> str:
|
| 43 |
+
"""
|
| 44 |
+
Try to fetch and read any task-linked file.
|
| 45 |
+
Safe fallback: empty context.
|
| 46 |
+
"""
|
| 47 |
+
if not task_id:
|
| 48 |
+
return ""
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
file_text = self.task_file_tool.get_task_context(task_id=task_id)
|
| 52 |
+
if not file_text:
|
| 53 |
+
return ""
|
| 54 |
|
| 55 |
+
return file_text[: self.config.max_context_chars]
|
| 56 |
+
except Exception:
|
| 57 |
+
return ""
|
| 58 |
+
|
| 59 |
+
def _solve(self, question: str, context: str) -> str:
|
| 60 |
+
"""
|
| 61 |
+
Solve the question with either:
|
| 62 |
+
1) a plugged-in LLM client, or
|
| 63 |
+
2) a safe fallback so the app does not crash during setup.
|
| 64 |
+
|
| 65 |
+
The LLM client is expected to expose a .generate(prompt: str) -> str method.
|
| 66 |
+
We will wire the real model later.
|
| 67 |
+
"""
|
| 68 |
+
prompt = build_solver_prompt(question=question, context=context)
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
return self.llm_client.generate(prompt)
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f"LLM generation error: {e}")
|
| 74 |
+
return ""
|
llm_client.py
CHANGED
|
@@ -1,43 +1,43 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
#2
|
| 41 |
# import os
|
| 42 |
# from dotenv import load_dotenv
|
| 43 |
|
|
@@ -53,15 +53,3 @@
|
|
| 53 |
# # If you later connect a provider, do it here.
|
| 54 |
# # For now, fail cleanly so tool-based paths still work.
|
| 55 |
# raise RuntimeError("No free LLM fallback configured.")
|
| 56 |
-
|
| 57 |
-
# llm_client.py
|
| 58 |
-
from smolagents import TransformersModel
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def build_local_model() -> TransformersModel:
|
| 62 |
-
return TransformersModel(
|
| 63 |
-
model_id="Qwen/Qwen2.5-0.5B-Instruct",
|
| 64 |
-
max_new_tokens=256,
|
| 65 |
-
temperature=0.1,
|
| 66 |
-
do_sample=False,
|
| 67 |
-
)
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from huggingface_hub import InferenceClient
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
|
| 5 |
+
load_dotenv()
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class HFLLMClient:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.api_key = os.getenv("HF_TOKEN")
|
| 11 |
+
print("HF token present:", bool(self.api_key))
|
| 12 |
+
|
| 13 |
+
if not self.api_key:
|
| 14 |
+
raise ValueError("HF_TOKEN is not set")
|
| 15 |
+
|
| 16 |
+
self.model = "Qwen/Qwen2.5-7B-Instruct"
|
| 17 |
+
self.client = InferenceClient(
|
| 18 |
+
provider="auto",
|
| 19 |
+
api_key=self.api_key,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def generate(self, prompt: str) -> str:
|
| 23 |
+
try:
|
| 24 |
+
output = self.client.chat_completion(
|
| 25 |
+
model=self.model,
|
| 26 |
+
messages=[
|
| 27 |
+
{"role": "user", "content": prompt}
|
| 28 |
+
],
|
| 29 |
+
max_tokens=128,
|
| 30 |
+
temperature=0.1,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
text = output.choices[0].message.content
|
| 34 |
+
print("LLM response preview:", str(text)[:300])
|
| 35 |
+
return str(text)
|
| 36 |
+
|
| 37 |
+
except Exception as e:
|
| 38 |
+
raise ValueError(f"Inference call failed: {e}")
|
| 39 |
+
|
| 40 |
+
# 2
|
| 41 |
# import os
|
| 42 |
# from dotenv import load_dotenv
|
| 43 |
|
|
|
|
| 53 |
# # If you later connect a provider, do it here.
|
| 54 |
# # For now, fail cleanly so tool-based paths still work.
|
| 55 |
# raise RuntimeError("No free LLM fallback configured.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompts.py
CHANGED
|
@@ -1,144 +1,57 @@
|
|
| 1 |
-
# from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
# SYSTEM_PROMPT = """
|
| 5 |
-
# You are a benchmark-solving AI agent.
|
| 6 |
-
|
| 7 |
-
# Your task is to answer questions as accurately as possible.
|
| 8 |
-
|
| 9 |
-
# Rules:
|
| 10 |
-
# - Return only the final answer.
|
| 11 |
-
# - If unsure, return your best short answer only.
|
| 12 |
-
# - Do not explain.
|
| 13 |
-
# - Do not include reasoning.
|
| 14 |
-
# - Do not include complete sentences unless the answer itself is a sentence.
|
| 15 |
-
# - For lists, preserve exact order only if supported by evidence.
|
| 16 |
-
# - Do not invent information not present in the question or provided context.
|
| 17 |
-
|
| 18 |
-
# Formatting rules:
|
| 19 |
-
# - If the answer is a number, output only the number.
|
| 20 |
-
# - If the answer is a word or phrase, output only that word or phrase.
|
| 21 |
-
# - If the answer is a date, return the exact date string.
|
| 22 |
-
# - Do not add punctuation unless it is part of the answer.
|
| 23 |
-
|
| 24 |
-
# Your response must contain only the final answer string.
|
| 25 |
-
# """
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
# def build_solver_prompt(question: str, context: str = "") -> str:
|
| 29 |
-
# """
|
| 30 |
-
# Builds the final prompt sent to the model.
|
| 31 |
-
# Includes optional file context when a task provides additional data.
|
| 32 |
-
# """
|
| 33 |
-
|
| 34 |
-
# if context:
|
| 35 |
-
# prompt = f"""
|
| 36 |
-
# {SYSTEM_PROMPT}
|
| 37 |
-
|
| 38 |
-
# Context information:
|
| 39 |
-
# {context}
|
| 40 |
-
|
| 41 |
-
# Question:
|
| 42 |
-
# {question}
|
| 43 |
-
|
| 44 |
-
# Return only the final answer.
|
| 45 |
-
# """
|
| 46 |
-
# else:
|
| 47 |
-
# prompt = f"""
|
| 48 |
-
# {SYSTEM_PROMPT}
|
| 49 |
-
|
| 50 |
-
# Question:
|
| 51 |
-
# {question}
|
| 52 |
-
|
| 53 |
-
# Return only the final answer.
|
| 54 |
-
# """
|
| 55 |
-
|
| 56 |
-
# return prompt.strip()
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
# #number 2
|
| 60 |
-
# from __future__ import annotations
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
# SYSTEM_PROMPT = """
|
| 64 |
-
# You are a benchmark-solving AI agent.
|
| 65 |
-
|
| 66 |
-
# Rules:
|
| 67 |
-
# - Return ONLY the final answer.
|
| 68 |
-
# - Do NOT include explanations.
|
| 69 |
-
# - Do NOT include reasoning.
|
| 70 |
-
# - Do NOT include 'FINAL ANSWER'.
|
| 71 |
-
# - Do NOT include labels like 'Answer:'.
|
| 72 |
-
# - If web or file context is provided, use it instead of guessing.
|
| 73 |
-
# - If multiple candidates appear, choose the one best supported by the provided context.
|
| 74 |
-
|
| 75 |
-
# Formatting rules:
|
| 76 |
-
# - If the answer is a number, output only the number.
|
| 77 |
-
# - If the answer is a word or short phrase, output only that word or phrase.
|
| 78 |
-
# - If the answer is a date, output only the date.
|
| 79 |
-
# - Do not add punctuation unless required by the answer itself.
|
| 80 |
-
# """
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
# def build_solver_prompt(question: str, context: str = "") -> str:
|
| 84 |
-
# if context:
|
| 85 |
-
# return f"""
|
| 86 |
-
# {SYSTEM_PROMPT}
|
| 87 |
-
|
| 88 |
-
# Context:
|
| 89 |
-
# {context}
|
| 90 |
-
|
| 91 |
-
# Question:
|
| 92 |
-
# {question}
|
| 93 |
-
|
| 94 |
-
# Return only the final answer.
|
| 95 |
-
# """.strip()
|
| 96 |
-
|
| 97 |
-
# return f"""
|
| 98 |
-
# {SYSTEM_PROMPT}
|
| 99 |
-
|
| 100 |
-
# Question:
|
| 101 |
-
# {question}
|
| 102 |
-
|
| 103 |
-
# Return only the final answer.
|
| 104 |
-
# """.strip()
|
| 105 |
-
|
| 106 |
-
#number 3
|
| 107 |
from __future__ import annotations
|
| 108 |
|
| 109 |
|
| 110 |
SYSTEM_PROMPT = """
|
| 111 |
You are a benchmark-solving AI agent.
|
| 112 |
|
|
|
|
|
|
|
| 113 |
Rules:
|
| 114 |
-
- Return
|
| 115 |
-
-
|
| 116 |
-
- Do
|
| 117 |
-
- Do
|
| 118 |
-
- Do
|
| 119 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
"""
|
| 121 |
|
| 122 |
|
| 123 |
def build_solver_prompt(question: str, context: str = "") -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
if context:
|
| 125 |
-
|
| 126 |
{SYSTEM_PROMPT}
|
| 127 |
|
| 128 |
-
|
| 129 |
{context}
|
| 130 |
|
| 131 |
Question:
|
| 132 |
{question}
|
| 133 |
|
| 134 |
Return only the final answer.
|
| 135 |
-
"""
|
| 136 |
-
|
| 137 |
-
|
| 138 |
{SYSTEM_PROMPT}
|
| 139 |
|
| 140 |
Question:
|
| 141 |
{question}
|
| 142 |
|
| 143 |
Return only the final answer.
|
| 144 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
|
| 4 |
SYSTEM_PROMPT = """
|
| 5 |
You are a benchmark-solving AI agent.
|
| 6 |
|
| 7 |
+
Your task is to answer questions as accurately as possible.
|
| 8 |
+
|
| 9 |
Rules:
|
| 10 |
+
- Return only the final answer.
|
| 11 |
+
- If unsure, return your best short answer only.
|
| 12 |
+
- Do not explain.
|
| 13 |
+
- Do not include reasoning.
|
| 14 |
+
- Do not include complete sentences unless the answer itself is a sentence.
|
| 15 |
+
- For lists, preserve exact order only if supported by evidence.
|
| 16 |
+
- Do not invent information not present in the question or provided context.
|
| 17 |
+
|
| 18 |
+
Formatting rules:
|
| 19 |
+
- If the answer is a number, output only the number.
|
| 20 |
+
- If the answer is a word or phrase, output only that word or phrase.
|
| 21 |
+
- If the answer is a date, return the exact date string.
|
| 22 |
+
- Do not add punctuation unless it is part of the answer.
|
| 23 |
+
|
| 24 |
+
Your response must contain only the final answer string.
|
| 25 |
"""
|
| 26 |
|
| 27 |
|
| 28 |
def build_solver_prompt(question: str, context: str = "") -> str:
|
| 29 |
+
"""
|
| 30 |
+
Builds the final prompt sent to the model.
|
| 31 |
+
Includes optional file context when a task provides additional data.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
if context:
|
| 35 |
+
prompt = f"""
|
| 36 |
{SYSTEM_PROMPT}
|
| 37 |
|
| 38 |
+
Context information:
|
| 39 |
{context}
|
| 40 |
|
| 41 |
Question:
|
| 42 |
{question}
|
| 43 |
|
| 44 |
Return only the final answer.
|
| 45 |
+
"""
|
| 46 |
+
else:
|
| 47 |
+
prompt = f"""
|
| 48 |
{SYSTEM_PROMPT}
|
| 49 |
|
| 50 |
Question:
|
| 51 |
{question}
|
| 52 |
|
| 53 |
Return only the final answer.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
return prompt.strip()
|
| 57 |
+
|
tools.py
CHANGED
|
@@ -1,483 +1,102 @@
|
|
| 1 |
-
# from __future__ import annotations
|
| 2 |
-
# import io
|
| 3 |
-
# import json
|
| 4 |
-
# import os
|
| 5 |
-
# from pathlib import Path
|
| 6 |
-
# from typing import Optional
|
| 7 |
-
# import pandas as pd
|
| 8 |
-
# import requests
|
| 9 |
-
|
| 10 |
-
# class TaskFileTool:
|
| 11 |
-
# """
|
| 12 |
-
# Downloads and reads task-linked files from the Hugging Face
|
| 13 |
-
# Unit 4 scoring API.
|
| 14 |
-
|
| 15 |
-
# Supported text extration:
|
| 16 |
-
# - txt
|
| 17 |
-
# - csv
|
| 18 |
-
# - json
|
| 19 |
-
# - md
|
| 20 |
-
# - html
|
| 21 |
-
# - xml
|
| 22 |
-
|
| 23 |
-
# For unsupported or binary files, it safely returns an empty string for now.
|
| 24 |
-
# We can extend this later for PDF/images if needed.
|
| 25 |
-
# """
|
| 26 |
-
|
| 27 |
-
# def __init__(self, api_base_url: str, cache_dir:str = "task_files", timeout: int =30):
|
| 28 |
-
# self.api_base_url = api_base_url.rstrip("/")
|
| 29 |
-
# self.cache_dir = Path(cache_dir)
|
| 30 |
-
# self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 31 |
-
# self.timeout = timeout
|
| 32 |
-
|
| 33 |
-
# def get_task_context(self, task_id: str) -> str:
|
| 34 |
-
# """
|
| 35 |
-
# Main entry point used by the agent:
|
| 36 |
-
# 1. download the task file if present
|
| 37 |
-
# 2. read it into text context if supported
|
| 38 |
-
# """
|
| 39 |
-
# file_path = self.download_task_file(task_id)
|
| 40 |
-
# if file_path is None:
|
| 41 |
-
# return ""
|
| 42 |
-
# return self.read_file_as_text(file_path)
|
| 43 |
-
|
| 44 |
-
# def download_task_file(self, task_id: str) -> Optional[Path]:
|
| 45 |
-
# """
|
| 46 |
-
# Downloads the file linked to a task_id using:
|
| 47 |
-
# GET /files/{task_id}
|
| 48 |
-
|
| 49 |
-
# Returns:
|
| 50 |
-
# Path to saved file if successful, else None
|
| 51 |
-
# """
|
| 52 |
-
# url = f"{self.api_base_url}/file/{task_id}"
|
| 53 |
-
|
| 54 |
-
# try:
|
| 55 |
-
# response = requests.get(url, timeout=self.timeout)
|
| 56 |
-
# except requests.RequestException:
|
| 57 |
-
# return None
|
| 58 |
-
|
| 59 |
-
# if response.status_code !=200:
|
| 60 |
-
# return None
|
| 61 |
-
|
| 62 |
-
# filename = self._infer_filename(response=response, task_id=task_id)
|
| 63 |
-
# file_path = self.cache_dir / filename
|
| 64 |
-
|
| 65 |
-
# try:
|
| 66 |
-
# with open(file_path, "wb") as f:
|
| 67 |
-
# f.write(response.content)
|
| 68 |
-
# return file_path
|
| 69 |
-
# except OSError:
|
| 70 |
-
# return None
|
| 71 |
-
# return file_path
|
| 72 |
-
|
| 73 |
-
# def read_file_as_text(self, file_path: Path) -> str:
|
| 74 |
-
# """
|
| 75 |
-
# Reads supported file types into plain text.
|
| 76 |
-
# """
|
| 77 |
-
# suffix = file_path.suffix.lower()
|
| 78 |
-
|
| 79 |
-
# try:
|
| 80 |
-
# if suffix in {".txt", ".md", ".html", ".xml", ".csv", ".json"}:
|
| 81 |
-
# return self._read_supported_text_file(file_path, suffix)
|
| 82 |
-
|
| 83 |
-
# # common fallback for files saved without extension but actually text
|
| 84 |
-
# if suffix == "":
|
| 85 |
-
# return self._read_extensionless_file(file_path)
|
| 86 |
-
|
| 87 |
-
# return ""
|
| 88 |
-
# except Exception:
|
| 89 |
-
# return ""
|
| 90 |
-
|
| 91 |
-
# def _read_supported_text_file(self, file_path: Path, suffix: str) -> str:
|
| 92 |
-
# if suffix in {".txt", ".md", ".html", ".xml"}:
|
| 93 |
-
# return file_path.read_text(encoding="utf-8", errors="ignore")
|
| 94 |
-
|
| 95 |
-
# if suffix == ".json":
|
| 96 |
-
# raw = file_path.read_text(encoding="utf-8", errors="ignore")
|
| 97 |
-
# try:
|
| 98 |
-
# parsed = json.loads(raw)
|
| 99 |
-
# return json.dumps(parsed, indent=2, ensure_ascii=False)
|
| 100 |
-
# except json.JSONDecodeError:
|
| 101 |
-
# return raw
|
| 102 |
-
|
| 103 |
-
# if suffix == ".csv":
|
| 104 |
-
# try:
|
| 105 |
-
# df = pd.read_csv(file_path)
|
| 106 |
-
# return df.to_csv(index=False)
|
| 107 |
-
# except Exception:
|
| 108 |
-
# return file_path.read_text(encoding="utf-8", errors="ignore")
|
| 109 |
-
|
| 110 |
-
# return ""
|
| 111 |
-
|
| 112 |
-
# def _read_extensionless_file(self, file_path: Path) -> str:
|
| 113 |
-
# """
|
| 114 |
-
# Try to interpret extensionless files as utf-8 text first.
|
| 115 |
-
# """
|
| 116 |
-
# try:
|
| 117 |
-
# raw = file_path.read_text(encoding="utf-8", errors="ignore")
|
| 118 |
-
# if raw.strip():
|
| 119 |
-
# return raw
|
| 120 |
-
# except Exception:
|
| 121 |
-
# pass
|
| 122 |
-
# return ""
|
| 123 |
-
|
| 124 |
-
# def _infer_filename(self, response: requests.Response, task_id: str) -> str:
|
| 125 |
-
# """
|
| 126 |
-
# Attempts to infer a useful filename from headers.
|
| 127 |
-
# Falls back to task_id if no filename is available.
|
| 128 |
-
# """
|
| 129 |
-
# content_disposition = response.headers.get("content-disposition", "")
|
| 130 |
-
# filename = self._extract_filename_from_content_disposition(content_disposition)
|
| 131 |
-
|
| 132 |
-
# if filename:
|
| 133 |
-
# return self._safe_filename(filename)
|
| 134 |
-
|
| 135 |
-
# content_type = response.headers.get("content-type", "").lower()
|
| 136 |
-
# extension = self._extension_from_content_type(content_type)
|
| 137 |
-
|
| 138 |
-
# if extension:
|
| 139 |
-
# return f"{task_id}{extension}"
|
| 140 |
-
|
| 141 |
-
# return str(task_id)
|
| 142 |
-
|
| 143 |
-
# @staticmethod
|
| 144 |
-
# def _extract_filename_from_content_disposition(content_disposition: str) -> Optional[str]:
|
| 145 |
-
# """
|
| 146 |
-
# Example header:
|
| 147 |
-
# content-disposition: attachment; filename="example.csv"
|
| 148 |
-
# """
|
| 149 |
-
# if "filename=" not in content_disposition:
|
| 150 |
-
# return None
|
| 151 |
-
|
| 152 |
-
# try:
|
| 153 |
-
# filename = content_disposition.split("filename=")[-1].strip().strip('"')
|
| 154 |
-
# return filename or None
|
| 155 |
-
# except Exception:
|
| 156 |
-
# return None
|
| 157 |
-
|
| 158 |
-
# @staticmethod
|
| 159 |
-
# def _extension_from_content_type(content_type: str) -> str:
|
| 160 |
-
# mapping = {
|
| 161 |
-
# "text/plain": ".txt",
|
| 162 |
-
# "text/csv": ".csv",
|
| 163 |
-
# "application/csv": ".csv",
|
| 164 |
-
# "application/json": ".json",
|
| 165 |
-
# "text/markdown": ".md",
|
| 166 |
-
# "text/html": ".html",
|
| 167 |
-
# "application/xml": ".xml",
|
| 168 |
-
# "text/xml": ".xml",
|
| 169 |
-
# }
|
| 170 |
-
|
| 171 |
-
# for key, ext in mapping.items():
|
| 172 |
-
# if key in content_type:
|
| 173 |
-
# return ext
|
| 174 |
-
|
| 175 |
-
# return ""
|
| 176 |
-
|
| 177 |
-
# @staticmethod
|
| 178 |
-
# def _safe_filename(filename: str) -> str:
|
| 179 |
-
# """
|
| 180 |
-
# Prevent path traversal and weird path issues.
|
| 181 |
-
# """
|
| 182 |
-
# return os.path.basename(filename)
|
| 183 |
-
|
| 184 |
from __future__ import annotations
|
| 185 |
-
|
| 186 |
-
import contextlib
|
| 187 |
import io
|
| 188 |
import json
|
| 189 |
import os
|
| 190 |
-
import re
|
| 191 |
-
import runpy
|
| 192 |
-
import shelve
|
| 193 |
-
import tempfile
|
| 194 |
-
from io import BytesIO
|
| 195 |
from pathlib import Path
|
| 196 |
-
from typing import
|
| 197 |
-
|
| 198 |
import pandas as pd
|
| 199 |
import requests
|
| 200 |
-
import wikipediaapi
|
| 201 |
-
from bs4 import BeautifulSoup
|
| 202 |
-
from faster_whisper import WhisperModel
|
| 203 |
-
from smolagents import Tool, tool
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
# -------------------------------------------------------------------
|
| 207 |
-
# Generic helper tools
|
| 208 |
-
# -------------------------------------------------------------------
|
| 209 |
-
|
| 210 |
-
@tool
|
| 211 |
-
def convert_pandas_table_to_markdown(table_csv_text: str) -> str:
|
| 212 |
-
"""
|
| 213 |
-
Convert CSV-like table text into markdown.
|
| 214 |
-
|
| 215 |
-
Args:
|
| 216 |
-
table_csv_text: CSV-formatted table text.
|
| 217 |
-
|
| 218 |
-
Returns:
|
| 219 |
-
A markdown table string, or the original text if parsing fails.
|
| 220 |
-
"""
|
| 221 |
-
try:
|
| 222 |
-
from io import StringIO
|
| 223 |
-
|
| 224 |
-
df = pd.read_csv(StringIO(table_csv_text))
|
| 225 |
-
return df.to_markdown(index=False)
|
| 226 |
-
except Exception:
|
| 227 |
-
return table_csv_text
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
@tool
|
| 231 |
-
def fetch_text_content(url: str) -> str:
|
| 232 |
-
"""
|
| 233 |
-
Fetch raw text content from a URL.
|
| 234 |
-
|
| 235 |
-
Args:
|
| 236 |
-
url: The URL to fetch.
|
| 237 |
-
|
| 238 |
-
Returns:
|
| 239 |
-
The page text content, or an error string.
|
| 240 |
-
"""
|
| 241 |
-
try:
|
| 242 |
-
response = requests.get(url, timeout=30)
|
| 243 |
-
response.raise_for_status()
|
| 244 |
-
return response.text
|
| 245 |
-
except requests.RequestException as e:
|
| 246 |
-
return f"Error fetching URL: {e}"
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
@tool
|
| 250 |
-
def read_excel(file_path: str) -> str:
|
| 251 |
-
"""
|
| 252 |
-
Read an Excel file from a local path and return all sheets as CSV-like text.
|
| 253 |
-
|
| 254 |
-
Args:
|
| 255 |
-
file_path: Local path to the Excel file.
|
| 256 |
-
|
| 257 |
-
Returns:
|
| 258 |
-
Combined sheet contents as text.
|
| 259 |
-
"""
|
| 260 |
-
if not file_path:
|
| 261 |
-
return ""
|
| 262 |
-
|
| 263 |
-
try:
|
| 264 |
-
sheets = pd.read_excel(file_path, sheet_name=None)
|
| 265 |
-
except Exception as e:
|
| 266 |
-
return f"Error reading Excel file: {e}"
|
| 267 |
-
|
| 268 |
-
parts = []
|
| 269 |
-
for sheet_name, df in sheets.items():
|
| 270 |
-
parts.append(f"Sheet: {sheet_name}")
|
| 271 |
-
try:
|
| 272 |
-
parts.append(df.to_csv(index=False))
|
| 273 |
-
except Exception:
|
| 274 |
-
parts.append(str(df))
|
| 275 |
-
|
| 276 |
-
return "\n\n".join(parts)
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
# -------------------------------------------------------------------
|
| 280 |
-
# Lightweight local storage for extracted tables
|
| 281 |
-
# -------------------------------------------------------------------
|
| 282 |
-
|
| 283 |
-
class ShelveDB:
|
| 284 |
-
dir_path = Path("./storage")
|
| 285 |
-
|
| 286 |
-
def __init__(self, table_name: str, init: bool = False):
|
| 287 |
-
self.dir_path.mkdir(parents=True, exist_ok=True)
|
| 288 |
-
self.path = str(self.dir_path / table_name)
|
| 289 |
-
if init:
|
| 290 |
-
with shelve.open(self.path):
|
| 291 |
-
pass
|
| 292 |
-
|
| 293 |
-
def save(self, key: str, value: Any) -> None:
|
| 294 |
-
with shelve.open(self.path) as db:
|
| 295 |
-
db[key] = value
|
| 296 |
-
|
| 297 |
-
def fetch(self, key: str) -> Any:
|
| 298 |
-
with shelve.open(self.path) as db:
|
| 299 |
-
return db.get(key)
|
| 300 |
-
|
| 301 |
-
def clear(self) -> None:
|
| 302 |
-
with shelve.open(self.path) as db:
|
| 303 |
-
db.clear()
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
class RetrieveCSVStorageTool(Tool):
|
| 307 |
-
name = "retrieve_csv_storage_tool"
|
| 308 |
-
description = "Retrieve a stored pandas table by key and return it as CSV text."
|
| 309 |
-
inputs = {
|
| 310 |
-
"key": {
|
| 311 |
-
"type": "string",
|
| 312 |
-
"description": "The key to retrieve from storage, such as table_1.",
|
| 313 |
-
}
|
| 314 |
-
}
|
| 315 |
-
output_type = "string"
|
| 316 |
-
|
| 317 |
-
def __init__(self, table_name: str = "wiki", init_storage: bool = True, storage_path: str | None = None):
|
| 318 |
-
super().__init__()
|
| 319 |
-
if storage_path is not None:
|
| 320 |
-
ShelveDB.dir_path = Path(storage_path)
|
| 321 |
-
self.storage = ShelveDB(table_name, init=init_storage)
|
| 322 |
-
|
| 323 |
-
def get_storage(self) -> ShelveDB:
|
| 324 |
-
return self.storage
|
| 325 |
-
|
| 326 |
-
def forward(self, key: str) -> str:
|
| 327 |
-
try:
|
| 328 |
-
dataframe = self.storage.fetch(key)
|
| 329 |
-
if dataframe is None:
|
| 330 |
-
return f"No data found for key: {key}"
|
| 331 |
-
if isinstance(dataframe, pd.DataFrame):
|
| 332 |
-
return dataframe.to_csv(index=False)
|
| 333 |
-
return str(dataframe)
|
| 334 |
-
except Exception as e:
|
| 335 |
-
return f"Error retrieving data: {e}"
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
# -------------------------------------------------------------------
|
| 339 |
-
# Wikipedia tools
|
| 340 |
-
# -------------------------------------------------------------------
|
| 341 |
-
|
| 342 |
-
def get_wiki_content(title: str, language: str = "en") -> tuple[str, dict[str, pd.DataFrame]]:
|
| 343 |
-
"""
|
| 344 |
-
Retrieve Wikipedia page text and any HTML tables found on the page.
|
| 345 |
-
|
| 346 |
-
Args:
|
| 347 |
-
title: Wikipedia page title.
|
| 348 |
-
language: Wikipedia language code.
|
| 349 |
-
|
| 350 |
-
Returns:
|
| 351 |
-
A tuple of (page_text, tables_dict).
|
| 352 |
-
"""
|
| 353 |
-
wiki = wikipediaapi.Wikipedia(user_agent="gaia-agent", language=language)
|
| 354 |
-
page = wiki.page(title)
|
| 355 |
-
|
| 356 |
-
if not page.exists():
|
| 357 |
-
return "", {}
|
| 358 |
-
|
| 359 |
-
page_text = page.text
|
| 360 |
-
|
| 361 |
-
# Try to fetch HTML and tables from normal Wikipedia URL
|
| 362 |
-
url = f"https://{language}.wikipedia.org/wiki/{title}"
|
| 363 |
-
tables: dict[str, pd.DataFrame] = {}
|
| 364 |
-
|
| 365 |
-
try:
|
| 366 |
-
dfs = pd.read_html(url)
|
| 367 |
-
for idx, df in enumerate(dfs, start=1):
|
| 368 |
-
tables[f"table_{idx}"] = df
|
| 369 |
-
except Exception:
|
| 370 |
-
pass
|
| 371 |
-
|
| 372 |
-
return page_text, tables
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
class WikiTool(Tool):
|
| 376 |
-
name = "wiki_tool"
|
| 377 |
-
description = (
|
| 378 |
-
"Get Wikipedia page content and extracted tables. "
|
| 379 |
-
"The tables are also stored in local storage and can later be fetched with retrieve_csv_storage_tool."
|
| 380 |
-
)
|
| 381 |
-
inputs = {
|
| 382 |
-
"query": {
|
| 383 |
-
"type": "string",
|
| 384 |
-
"description": "The title of the Wikipedia page, for example 'Mercedes_Sosa' or 'Malko_Competition'.",
|
| 385 |
-
},
|
| 386 |
-
"language": {
|
| 387 |
-
"type": "string",
|
| 388 |
-
"description": "The Wikipedia language code, such as 'en'.",
|
| 389 |
-
"nullable": True,
|
| 390 |
-
},
|
| 391 |
-
}
|
| 392 |
-
output_type = "string"
|
| 393 |
-
|
| 394 |
-
def __init__(self, storage: ShelveDB):
|
| 395 |
-
super().__init__()
|
| 396 |
-
self.storage = storage
|
| 397 |
-
|
| 398 |
-
def forward(self, query: str, language: str | None = None):
|
| 399 |
-
language = language or "en"
|
| 400 |
-
content, tables = get_wiki_content(query, language)
|
| 401 |
-
self.storage.clear()
|
| 402 |
-
for table_key, df in tables.items():
|
| 403 |
-
self.storage.save(table_key, df)
|
| 404 |
-
|
| 405 |
-
table_note = ""
|
| 406 |
-
if tables:
|
| 407 |
-
table_note = "\n\nStored tables:\n" + "\n".join(sorted(tables.keys()))
|
| 408 |
-
|
| 409 |
-
return content + table_note
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
# -------------------------------------------------------------------
|
| 413 |
-
# Generic webpage tools
|
| 414 |
-
# -------------------------------------------------------------------
|
| 415 |
-
|
| 416 |
-
class WebPageTool:
|
| 417 |
-
def fetch_text(self, url: str) -> str:
|
| 418 |
-
response = requests.get(url, timeout=30)
|
| 419 |
-
response.raise_for_status()
|
| 420 |
-
soup = BeautifulSoup(response.text, "html.parser")
|
| 421 |
-
return soup.get_text(" ", strip=True)[:25000]
|
| 422 |
-
|
| 423 |
-
def fetch_tables(self, url: str) -> list[pd.DataFrame]:
|
| 424 |
-
try:
|
| 425 |
-
return pd.read_html(url)
|
| 426 |
-
except Exception:
|
| 427 |
-
return []
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
# -------------------------------------------------------------------
|
| 431 |
-
# Attached task file tool
|
| 432 |
-
# -------------------------------------------------------------------
|
| 433 |
|
| 434 |
class TaskFileTool:
|
| 435 |
"""
|
| 436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
"""
|
| 438 |
|
| 439 |
-
def __init__(self, api_base_url: str, cache_dir:
|
| 440 |
self.api_base_url = api_base_url.rstrip("/")
|
| 441 |
self.cache_dir = Path(cache_dir)
|
| 442 |
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 443 |
self.timeout = timeout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
|
| 445 |
-
|
| 446 |
-
|
|
|
|
|
|
|
| 447 |
|
| 448 |
try:
|
| 449 |
response = requests.get(url, timeout=self.timeout)
|
| 450 |
except requests.RequestException:
|
| 451 |
return None
|
| 452 |
-
|
| 453 |
-
if response.status_code !=
|
| 454 |
return None
|
| 455 |
-
|
| 456 |
-
filename = self._infer_filename(response, task_id)
|
| 457 |
file_path = self.cache_dir / filename
|
| 458 |
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
|
|
|
|
|
|
|
|
|
| 462 |
return file_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
return ""
|
| 468 |
-
return self.read_file_as_text(file_path)
|
| 469 |
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
if suffix in {".txt", ".md", ".html", ".xml"}:
|
| 475 |
return file_path.read_text(encoding="utf-8", errors="ignore")
|
| 476 |
|
| 477 |
if suffix == ".json":
|
| 478 |
raw = file_path.read_text(encoding="utf-8", errors="ignore")
|
| 479 |
try:
|
| 480 |
-
|
|
|
|
| 481 |
except json.JSONDecodeError:
|
| 482 |
return raw
|
| 483 |
|
|
@@ -490,181 +109,74 @@ class TaskFileTool:
|
|
| 490 |
|
| 491 |
return ""
|
| 492 |
|
| 493 |
-
def
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
content_type = response.headers.get("content-type", "").lower()
|
| 499 |
-
mapping = {
|
| 500 |
-
"text/plain": ".txt",
|
| 501 |
-
"text/csv": ".csv",
|
| 502 |
-
"application/json": ".json",
|
| 503 |
-
"text/html": ".html",
|
| 504 |
-
"audio/mpeg": ".mp3",
|
| 505 |
-
"audio/wav": ".wav",
|
| 506 |
-
"image/png": ".png",
|
| 507 |
-
"image/jpeg": ".jpg",
|
| 508 |
-
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
| 509 |
-
"text/x-python": ".py",
|
| 510 |
-
}
|
| 511 |
-
for key, ext in mapping.items():
|
| 512 |
-
if key in content_type:
|
| 513 |
-
return f"{task_id}{ext}"
|
| 514 |
-
|
| 515 |
-
return str(task_id)
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
# -------------------------------------------------------------------
|
| 519 |
-
# Audio tools
|
| 520 |
-
# -------------------------------------------------------------------
|
| 521 |
-
|
| 522 |
-
class AudioTool:
|
| 523 |
-
def __init__(self):
|
| 524 |
-
self.model = WhisperModel("tiny", device="cpu", compute_type="int8")
|
| 525 |
-
|
| 526 |
-
def transcribe(self, file_path: str) -> str:
|
| 527 |
try:
|
| 528 |
-
|
| 529 |
-
|
|
|
|
| 530 |
except Exception:
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
class SpeechRecognitionTool(Tool):
|
| 535 |
-
name = "speech_recognition_tool"
|
| 536 |
-
description = "Transcribe a local audio file into text."
|
| 537 |
-
inputs = {
|
| 538 |
-
"audio_path": {
|
| 539 |
-
"type": "string",
|
| 540 |
-
"description": "The local path to the audio file.",
|
| 541 |
-
}
|
| 542 |
-
}
|
| 543 |
-
output_type = "string"
|
| 544 |
-
|
| 545 |
-
def __init__(self):
|
| 546 |
-
super().__init__()
|
| 547 |
-
self.audio_tool = AudioTool()
|
| 548 |
-
|
| 549 |
-
def forward(self, audio_path: str) -> str:
|
| 550 |
-
try:
|
| 551 |
-
return self.audio_tool.transcribe(audio_path)
|
| 552 |
-
except Exception as e:
|
| 553 |
-
return f"Error: {e}"
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
# -------------------------------------------------------------------
|
| 557 |
-
# Visual tool placeholder
|
| 558 |
-
# -------------------------------------------------------------------
|
| 559 |
-
|
| 560 |
-
class VisualQATool(Tool):
|
| 561 |
-
name = "visual_qa_tool"
|
| 562 |
-
description = (
|
| 563 |
-
"Analyze a local image and answer a question about it. "
|
| 564 |
-
"This local fallback is limited and may not work well for complex vision tasks."
|
| 565 |
-
)
|
| 566 |
-
inputs = {
|
| 567 |
-
"image_path": {
|
| 568 |
-
"type": "string",
|
| 569 |
-
"description": "The local image path.",
|
| 570 |
-
},
|
| 571 |
-
"question": {
|
| 572 |
-
"type": "string",
|
| 573 |
-
"description": "The question to ask about the image.",
|
| 574 |
-
},
|
| 575 |
-
}
|
| 576 |
-
output_type = "string"
|
| 577 |
-
|
| 578 |
-
def forward(self, image_path: str, question: str) -> str:
|
| 579 |
-
# Free local placeholder. Keeps the tool available to the agent,
|
| 580 |
-
# but does not pretend to solve hard vision tasks reliably.
|
| 581 |
-
if not image_path:
|
| 582 |
-
return "No image path provided."
|
| 583 |
-
return (
|
| 584 |
-
"Visual analysis is limited in this local setup. "
|
| 585 |
-
"The image tool is available but not reliable for complex reasoning tasks."
|
| 586 |
-
)
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
# -------------------------------------------------------------------
|
| 590 |
-
# Spreadsheet and Python tools
|
| 591 |
-
# -------------------------------------------------------------------
|
| 592 |
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
if
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
return {}
|
| 601 |
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
total = 0.0
|
| 605 |
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
category_col = next((cols[c] for c in cols if "category" in c or "type" in c), None)
|
| 609 |
-
sales_col = next((cols[c] for c in cols if "sales" in c or "revenue" in c or "amount" in c), None)
|
| 610 |
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
total += pd.to_numeric(df.loc[food_mask, sales_col], errors="coerce").fillna(0).sum()
|
| 614 |
|
| 615 |
-
return
|
| 616 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 617 |
|
| 618 |
-
class PythonExecutionTool:
|
| 619 |
-
def run_python_file(self, file_path: str) -> str:
|
| 620 |
-
stdout = io.StringIO()
|
| 621 |
try:
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
return f"Error running Python file: {e}"
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
# -------------------------------------------------------------------
|
| 630 |
-
# Logic helpers
|
| 631 |
-
# -------------------------------------------------------------------
|
| 632 |
-
|
| 633 |
-
class LogicTool:
|
| 634 |
-
def solve_noncommutative_subset(self, question: str) -> str:
|
| 635 |
-
rows = [line.strip() for line in question.splitlines() if "|" in line and line.count("|") >= 6]
|
| 636 |
-
if len(rows) < 6:
|
| 637 |
-
return ""
|
| 638 |
-
|
| 639 |
-
headers = [x.strip() for x in rows[0].split("|")[2:-1]]
|
| 640 |
-
table = {}
|
| 641 |
-
|
| 642 |
-
for row in rows[2:]:
|
| 643 |
-
parts = [x.strip() for x in row.split("|")[1:-1]]
|
| 644 |
-
row_key = parts[0]
|
| 645 |
-
values = parts[1:]
|
| 646 |
-
table[row_key] = dict(zip(headers, values))
|
| 647 |
-
|
| 648 |
-
bad = set()
|
| 649 |
-
for a in headers:
|
| 650 |
-
for b in headers:
|
| 651 |
-
if table[a][b] != table[b][a]:
|
| 652 |
-
bad.add(a)
|
| 653 |
-
bad.add(b)
|
| 654 |
-
|
| 655 |
-
return ",".join(sorted(bad))
|
| 656 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 657 |
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
|
| 662 |
-
|
| 663 |
-
def __init__(self):
|
| 664 |
-
self.wiki = wikipediaapi.Wikipedia(user_agent="gaia-agent", language="en")
|
| 665 |
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
|
|
|
|
|
|
| 2 |
import io
|
| 3 |
import json
|
| 4 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
+
from typing import Optional
|
|
|
|
| 7 |
import pandas as pd
|
| 8 |
import requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
class TaskFileTool:
|
| 11 |
"""
|
| 12 |
+
Downloads and reads task-linked files from the Hugging Face
|
| 13 |
+
Unit 4 scoring API.
|
| 14 |
+
|
| 15 |
+
Supported text extration:
|
| 16 |
+
- txt
|
| 17 |
+
- csv
|
| 18 |
+
- json
|
| 19 |
+
- md
|
| 20 |
+
- html
|
| 21 |
+
- xml
|
| 22 |
+
|
| 23 |
+
For unsupported or binary files, it safely returns an empty string for now.
|
| 24 |
+
We can extend this later for PDF/images if needed.
|
| 25 |
"""
|
| 26 |
|
| 27 |
+
def __init__(self, api_base_url: str, cache_dir:str = "task_files", timeout: int =30):
|
| 28 |
self.api_base_url = api_base_url.rstrip("/")
|
| 29 |
self.cache_dir = Path(cache_dir)
|
| 30 |
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 31 |
self.timeout = timeout
|
| 32 |
+
|
| 33 |
+
def get_task_context(self, task_id: str) -> str:
|
| 34 |
+
"""
|
| 35 |
+
Main entry point used by the agent:
|
| 36 |
+
1. download the task file if present
|
| 37 |
+
2. read it into text context if supported
|
| 38 |
+
"""
|
| 39 |
+
file_path = self.download_task_file(task_id)
|
| 40 |
+
if file_path is None:
|
| 41 |
+
return ""
|
| 42 |
+
return self.read_file_as_text(file_path)
|
| 43 |
+
|
| 44 |
+
def download_task_file(self, task_id: str) -> Optional[Path]:
|
| 45 |
+
"""
|
| 46 |
+
Downloads the file linked to a task_id using:
|
| 47 |
+
GET /files/{task_id}
|
| 48 |
|
| 49 |
+
Returns:
|
| 50 |
+
Path to saved file if successful, else None
|
| 51 |
+
"""
|
| 52 |
+
url = f"{self.api_base_url}/file/{task_id}"
|
| 53 |
|
| 54 |
try:
|
| 55 |
response = requests.get(url, timeout=self.timeout)
|
| 56 |
except requests.RequestException:
|
| 57 |
return None
|
| 58 |
+
|
| 59 |
+
if response.status_code !=200:
|
| 60 |
return None
|
| 61 |
+
|
| 62 |
+
filename = self._infer_filename(response=response, task_id=task_id)
|
| 63 |
file_path = self.cache_dir / filename
|
| 64 |
|
| 65 |
+
try:
|
| 66 |
+
with open(file_path, "wb") as f:
|
| 67 |
+
f.write(response.content)
|
| 68 |
+
return file_path
|
| 69 |
+
except OSError:
|
| 70 |
+
return None
|
| 71 |
return file_path
|
| 72 |
+
|
| 73 |
+
def read_file_as_text(self, file_path: Path) -> str:
|
| 74 |
+
"""
|
| 75 |
+
Reads supported file types into plain text.
|
| 76 |
+
"""
|
| 77 |
+
suffix = file_path.suffix.lower()
|
| 78 |
|
| 79 |
+
try:
|
| 80 |
+
if suffix in {".txt", ".md", ".html", ".xml", ".csv", ".json"}:
|
| 81 |
+
return self._read_supported_text_file(file_path, suffix)
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
# common fallback for files saved without extension but actually text
|
| 84 |
+
if suffix == "":
|
| 85 |
+
return self._read_extensionless_file(file_path)
|
| 86 |
|
| 87 |
+
return ""
|
| 88 |
+
except Exception:
|
| 89 |
+
return ""
|
| 90 |
+
|
| 91 |
+
def _read_supported_text_file(self, file_path: Path, suffix: str) -> str:
|
| 92 |
if suffix in {".txt", ".md", ".html", ".xml"}:
|
| 93 |
return file_path.read_text(encoding="utf-8", errors="ignore")
|
| 94 |
|
| 95 |
if suffix == ".json":
|
| 96 |
raw = file_path.read_text(encoding="utf-8", errors="ignore")
|
| 97 |
try:
|
| 98 |
+
parsed = json.loads(raw)
|
| 99 |
+
return json.dumps(parsed, indent=2, ensure_ascii=False)
|
| 100 |
except json.JSONDecodeError:
|
| 101 |
return raw
|
| 102 |
|
|
|
|
| 109 |
|
| 110 |
return ""
|
| 111 |
|
| 112 |
+
def _read_extensionless_file(self, file_path: Path) -> str:
|
| 113 |
+
"""
|
| 114 |
+
Try to interpret extensionless files as utf-8 text first.
|
| 115 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
try:
|
| 117 |
+
raw = file_path.read_text(encoding="utf-8", errors="ignore")
|
| 118 |
+
if raw.strip():
|
| 119 |
+
return raw
|
| 120 |
except Exception:
|
| 121 |
+
pass
|
| 122 |
+
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
+
def _infer_filename(self, response: requests.Response, task_id: str) -> str:
|
| 125 |
+
"""
|
| 126 |
+
Attempts to infer a useful filename from headers.
|
| 127 |
+
Falls back to task_id if no filename is available.
|
| 128 |
+
"""
|
| 129 |
+
content_disposition = response.headers.get("content-disposition", "")
|
| 130 |
+
filename = self._extract_filename_from_content_disposition(content_disposition)
|
|
|
|
| 131 |
|
| 132 |
+
if filename:
|
| 133 |
+
return self._safe_filename(filename)
|
|
|
|
| 134 |
|
| 135 |
+
content_type = response.headers.get("content-type", "").lower()
|
| 136 |
+
extension = self._extension_from_content_type(content_type)
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
if extension:
|
| 139 |
+
return f"{task_id}{extension}"
|
|
|
|
| 140 |
|
| 141 |
+
return str(task_id)
|
| 142 |
|
| 143 |
+
@staticmethod
|
| 144 |
+
def _extract_filename_from_content_disposition(content_disposition: str) -> Optional[str]:
|
| 145 |
+
"""
|
| 146 |
+
Example header:
|
| 147 |
+
content-disposition: attachment; filename="example.csv"
|
| 148 |
+
"""
|
| 149 |
+
if "filename=" not in content_disposition:
|
| 150 |
+
return None
|
| 151 |
|
|
|
|
|
|
|
|
|
|
| 152 |
try:
|
| 153 |
+
filename = content_disposition.split("filename=")[-1].strip().strip('"')
|
| 154 |
+
return filename or None
|
| 155 |
+
except Exception:
|
| 156 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
+
@staticmethod
|
| 159 |
+
def _extension_from_content_type(content_type: str) -> str:
|
| 160 |
+
mapping = {
|
| 161 |
+
"text/plain": ".txt",
|
| 162 |
+
"text/csv": ".csv",
|
| 163 |
+
"application/csv": ".csv",
|
| 164 |
+
"application/json": ".json",
|
| 165 |
+
"text/markdown": ".md",
|
| 166 |
+
"text/html": ".html",
|
| 167 |
+
"application/xml": ".xml",
|
| 168 |
+
"text/xml": ".xml",
|
| 169 |
+
}
|
| 170 |
|
| 171 |
+
for key, ext in mapping.items():
|
| 172 |
+
if key in content_type:
|
| 173 |
+
return ext
|
| 174 |
|
| 175 |
+
return ""
|
|
|
|
|
|
|
| 176 |
|
| 177 |
+
@staticmethod
|
| 178 |
+
def _safe_filename(filename: str) -> str:
|
| 179 |
+
"""
|
| 180 |
+
Prevent path traversal and weird path issues.
|
| 181 |
+
"""
|
| 182 |
+
return os.path.basename(filename)
|
utils.py
CHANGED
|
@@ -1,228 +1,17 @@
|
|
| 1 |
-
# from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
# import re
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
# def extract_final_answer(text: str) -> str:
|
| 7 |
-
# """
|
| 8 |
-
# Extract the most likely final answer from raw model output.
|
| 9 |
-
|
| 10 |
-
# In V1 we keep this conservative:
|
| 11 |
-
# - if the model returns a normal short answer, keep it
|
| 12 |
-
# - if it adds common prefixes like 'Answer:' or 'Final answer:', remove them
|
| 13 |
-
# - if it returns multiple lines, prefer the last non-empty line
|
| 14 |
-
# """
|
| 15 |
-
# if text is None:
|
| 16 |
-
# return ""
|
| 17 |
-
|
| 18 |
-
# text = str(text).strip()
|
| 19 |
-
# if not text:
|
| 20 |
-
# return ""
|
| 21 |
-
|
| 22 |
-
# # Remove fenced code blocks if the model wraps the answer oddly
|
| 23 |
-
# text = re.sub(r"^```[a-zA-Z0-9_-]*\s*", "", text)
|
| 24 |
-
# text = re.sub(r"\s*```$", "", text)
|
| 25 |
-
|
| 26 |
-
# # Common exact-answer markers
|
| 27 |
-
# marker_patterns = [
|
| 28 |
-
# r"(?i)\bfinal answer\s*:\s*",
|
| 29 |
-
# r"(?i)\banswer\s*:\s*",
|
| 30 |
-
# r"(?i)\bthe answer is\s*:\s*",
|
| 31 |
-
# r"(?i)\bthe answer is\s+",
|
| 32 |
-
# ]
|
| 33 |
-
|
| 34 |
-
# cleaned = text
|
| 35 |
-
# for pattern in marker_patterns:
|
| 36 |
-
# cleaned = re.sub(pattern, "", cleaned).strip()
|
| 37 |
-
|
| 38 |
-
# # If multi-line, prefer the last meaningful line
|
| 39 |
-
# lines = [line.strip() for line in cleaned.splitlines() if line.strip()]
|
| 40 |
-
# if not lines:
|
| 41 |
-
# return ""
|
| 42 |
-
|
| 43 |
-
# if len(lines) == 1:
|
| 44 |
-
# return lines[0]
|
| 45 |
-
|
| 46 |
-
# return lines[-1]
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# def normalize_final_answer(text: str) -> str:
|
| 50 |
-
# """
|
| 51 |
-
# Normalize answer text for safer exact-match submission without being too aggressive.
|
| 52 |
-
|
| 53 |
-
# Rules:
|
| 54 |
-
# - trim outer whitespace
|
| 55 |
-
# - collapse internal repeated whitespace
|
| 56 |
-
# - remove wrapping quotes if they wrap the full answer
|
| 57 |
-
# - remove a single trailing period only for plain word/phrase answers
|
| 58 |
-
# but keep decimal numbers and date punctuation intact
|
| 59 |
-
# """
|
| 60 |
-
# if text is None:
|
| 61 |
-
# return ""
|
| 62 |
-
|
| 63 |
-
# text = str(text).strip()
|
| 64 |
-
# if not text:
|
| 65 |
-
# return ""
|
| 66 |
-
|
| 67 |
-
# # Collapse repeated whitespace
|
| 68 |
-
# text = re.sub(r"\s+", " ", text).strip()
|
| 69 |
-
|
| 70 |
-
# # Remove matching surrounding quotes
|
| 71 |
-
# if len(text) >= 2:
|
| 72 |
-
# if (text[0] == text[-1]) and text[0] in {'"', "'"}:
|
| 73 |
-
# text = text[1:-1].strip()
|
| 74 |
-
|
| 75 |
-
# # Remove common leading labels again, just in case
|
| 76 |
-
# text = re.sub(r"(?i)^(final answer|answer)\s*:\s*", "", text).strip()
|
| 77 |
-
|
| 78 |
-
# # Remove one trailing period for simple phrase answers only
|
| 79 |
-
# # Keep decimals like 3.14 intact
|
| 80 |
-
# if text.endswith("."):
|
| 81 |
-
# if not re.fullmatch(r"\d+\.\d+", text):
|
| 82 |
-
# text = text[:-1].strip()
|
| 83 |
-
|
| 84 |
-
# return text
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
# def is_placeholder_answer(text: str) -> bool:
|
| 88 |
-
# """
|
| 89 |
-
# Detect placeholder/fallback outputs so app.py can optionally flag them.
|
| 90 |
-
# """
|
| 91 |
-
# if text is None:
|
| 92 |
-
# return True
|
| 93 |
-
|
| 94 |
-
# normalized = normalize_final_answer(text).lower()
|
| 95 |
-
# return normalized in {
|
| 96 |
-
# "",
|
| 97 |
-
# "placeholder",
|
| 98 |
-
# "n/a",
|
| 99 |
-
# "unknown",
|
| 100 |
-
# }
|
| 101 |
-
|
| 102 |
-
# # Number 2
|
| 103 |
-
# from __future__ import annotations
|
| 104 |
-
|
| 105 |
-
# import re
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# def extract_final_answer(text: str) -> str:
|
| 109 |
-
# """
|
| 110 |
-
# Extract the final answer, preferring GAIA-style:
|
| 111 |
-
# FINAL ANSWER: ...
|
| 112 |
-
|
| 113 |
-
# Fallback behavior:
|
| 114 |
-
# - strip code fences
|
| 115 |
-
# - remove common answer prefixes
|
| 116 |
-
# - if multiple lines remain, prefer the last non-empty line
|
| 117 |
-
# """
|
| 118 |
-
# if text is None:
|
| 119 |
-
# return ""
|
| 120 |
-
|
| 121 |
-
# text = str(text).strip()
|
| 122 |
-
# if not text:
|
| 123 |
-
# return ""
|
| 124 |
-
|
| 125 |
-
# # Remove fenced code blocks if present
|
| 126 |
-
# text = re.sub(r"^```[a-zA-Z0-9_-]*\s*", "", text)
|
| 127 |
-
# text = re.sub(r"\s*```$", "", text)
|
| 128 |
-
|
| 129 |
-
# # Prefer GAIA-style final answer extraction
|
| 130 |
-
# gaia_match = re.search(
|
| 131 |
-
# r"FINAL ANSWER:\s*(.*)",
|
| 132 |
-
# text,
|
| 133 |
-
# flags=re.IGNORECASE | re.DOTALL,
|
| 134 |
-
# )
|
| 135 |
-
# if gaia_match:
|
| 136 |
-
# extracted = gaia_match.group(1).strip()
|
| 137 |
-
# lines = [line.strip() for line in extracted.splitlines() if line.strip()]
|
| 138 |
-
# return lines[0] if lines else extracted
|
| 139 |
-
|
| 140 |
-
# # Fallback exact-answer markers
|
| 141 |
-
# marker_patterns = [
|
| 142 |
-
# r"(?i)\bfinal answer\s*:\s*",
|
| 143 |
-
# r"(?i)\banswer\s*:\s*",
|
| 144 |
-
# r"(?i)\bthe answer is\s*:\s*",
|
| 145 |
-
# r"(?i)\bthe answer is\s+",
|
| 146 |
-
# ]
|
| 147 |
-
|
| 148 |
-
# cleaned = text
|
| 149 |
-
# for pattern in marker_patterns:
|
| 150 |
-
# cleaned = re.sub(pattern, "", cleaned).strip()
|
| 151 |
-
|
| 152 |
-
# lines = [line.strip() for line in cleaned.splitlines() if line.strip()]
|
| 153 |
-
# if not lines:
|
| 154 |
-
# return ""
|
| 155 |
-
|
| 156 |
-
# if len(lines) == 1:
|
| 157 |
-
# return lines[0]
|
| 158 |
-
|
| 159 |
-
# return lines[-1]
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
# def normalize_final_answer(text: str) -> str:
|
| 163 |
-
# """
|
| 164 |
-
# Normalize answer text for exact-match-style submission.
|
| 165 |
-
|
| 166 |
-
# Rules:
|
| 167 |
-
# - trim whitespace
|
| 168 |
-
# - collapse repeated spaces
|
| 169 |
-
# - remove wrapping quotes
|
| 170 |
-
# - remove labels again if present
|
| 171 |
-
# - remove one trailing period for plain phrase answers
|
| 172 |
-
# - remove leading articles for short string answers
|
| 173 |
-
# """
|
| 174 |
-
# if text is None:
|
| 175 |
-
# return ""
|
| 176 |
-
|
| 177 |
-
# text = str(text).strip()
|
| 178 |
-
# if not text:
|
| 179 |
-
# return ""
|
| 180 |
-
|
| 181 |
-
# # Collapse repeated whitespace
|
| 182 |
-
# text = re.sub(r"\s+", " ", text).strip()
|
| 183 |
-
|
| 184 |
-
# # Remove matching surrounding quotes
|
| 185 |
-
# if len(text) >= 2 and text[0] == text[-1] and text[0] in {'"', "'"}:
|
| 186 |
-
# text = text[1:-1].strip()
|
| 187 |
-
|
| 188 |
-
# # Remove common labels again
|
| 189 |
-
# text = re.sub(r"(?i)^(final answer|answer)\s*:\s*", "", text).strip()
|
| 190 |
-
|
| 191 |
-
# # Remove one trailing period for simple phrase answers only
|
| 192 |
-
# if text.endswith(".") and not re.fullmatch(r"\d+\.\d+", text):
|
| 193 |
-
# text = text[:-1].strip()
|
| 194 |
-
|
| 195 |
-
# # Remove leading articles for short string answers
|
| 196 |
-
# # Helps align with GAIA string-format guidance
|
| 197 |
-
# text = re.sub(r"(?i)^(a|an|the)\s+", "", text).strip()
|
| 198 |
-
|
| 199 |
-
# return text
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
# def is_placeholder_answer(text: str) -> bool:
|
| 203 |
-
# """
|
| 204 |
-
# Detect placeholder or clearly non-useful outputs.
|
| 205 |
-
# """
|
| 206 |
-
# if text is None:
|
| 207 |
-
# return True
|
| 208 |
-
|
| 209 |
-
# normalized = normalize_final_answer(text).lower()
|
| 210 |
-
# return normalized in {
|
| 211 |
-
# "",
|
| 212 |
-
# "placeholder",
|
| 213 |
-
# "n/a",
|
| 214 |
-
# "unknown",
|
| 215 |
-
# }
|
| 216 |
-
|
| 217 |
-
#number 3
|
| 218 |
from __future__ import annotations
|
| 219 |
|
| 220 |
-
import os
|
| 221 |
import re
|
| 222 |
-
from urllib.parse import urlparse, parse_qs
|
| 223 |
|
| 224 |
|
| 225 |
def extract_final_answer(text: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
if text is None:
|
| 227 |
return ""
|
| 228 |
|
|
@@ -230,19 +19,44 @@ def extract_final_answer(text: str) -> str:
|
|
| 230 |
if not text:
|
| 231 |
return ""
|
| 232 |
|
|
|
|
| 233 |
text = re.sub(r"^```[a-zA-Z0-9_-]*\s*", "", text)
|
| 234 |
text = re.sub(r"\s*```$", "", text)
|
| 235 |
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
if not lines:
|
| 240 |
return ""
|
| 241 |
|
|
|
|
|
|
|
|
|
|
| 242 |
return lines[-1]
|
| 243 |
|
| 244 |
|
| 245 |
def normalize_final_answer(text: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
if text is None:
|
| 247 |
return ""
|
| 248 |
|
|
@@ -250,83 +64,37 @@ def normalize_final_answer(text: str) -> str:
|
|
| 250 |
if not text:
|
| 251 |
return ""
|
| 252 |
|
|
|
|
| 253 |
text = re.sub(r"\s+", " ", text).strip()
|
| 254 |
|
| 255 |
-
|
| 256 |
-
|
|
|
|
|
|
|
| 257 |
|
|
|
|
| 258 |
text = re.sub(r"(?i)^(final answer|answer)\s*:\s*", "", text).strip()
|
| 259 |
|
| 260 |
-
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
return text
|
| 264 |
|
| 265 |
|
| 266 |
-
def
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
if parsed.hostname and "youtube.com" in parsed.hostname:
|
| 281 |
-
qs = parse_qs(parsed.query)
|
| 282 |
-
if "v" in qs:
|
| 283 |
-
return qs["v"][0]
|
| 284 |
-
return ""
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
def infer_task_kind(question: str, file_path: str = "") -> str:
|
| 288 |
-
q = (question or "").lower()
|
| 289 |
-
ext = get_file_extension(file_path)
|
| 290 |
-
|
| 291 |
-
if "youtube.com/watch" in q or "youtu.be/" in q:
|
| 292 |
-
return "youtube"
|
| 293 |
-
|
| 294 |
-
if ext in {".mp3", ".wav", ".m4a", ".flac", ".ogg"}:
|
| 295 |
-
return "audio"
|
| 296 |
-
|
| 297 |
-
if ext in {".xlsx", ".xls", ".csv"}:
|
| 298 |
-
return "spreadsheet"
|
| 299 |
-
|
| 300 |
-
if ext == ".py":
|
| 301 |
-
return "python_file"
|
| 302 |
-
|
| 303 |
-
if ext in {".png", ".jpg", ".jpeg", ".webp"}:
|
| 304 |
-
return "image"
|
| 305 |
-
|
| 306 |
-
if "wikipedia" in q:
|
| 307 |
-
return "wikipedia"
|
| 308 |
-
|
| 309 |
-
if "table" in q or "|---|" in q:
|
| 310 |
-
return "table_logic"
|
| 311 |
-
|
| 312 |
-
if any(x in q for x in ["award number", "what city", "where were", "what country", "who nominated"]):
|
| 313 |
-
return "web_lookup"
|
| 314 |
-
|
| 315 |
-
if "opposite of the word" in q and q[::-1] != q:
|
| 316 |
-
return "text_transform"
|
| 317 |
-
|
| 318 |
-
return "general"
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
EXTINCT_COUNTRIES = {
|
| 322 |
-
"east germany",
|
| 323 |
-
"west germany",
|
| 324 |
-
"yugoslavia",
|
| 325 |
-
"czechoslovakia",
|
| 326 |
-
"soviet union",
|
| 327 |
-
"ussr",
|
| 328 |
-
}
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
def is_extinct_country(name: str) -> bool:
|
| 332 |
-
return (name or "").strip().lower() in EXTINCT_COUNTRIES
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
| 3 |
import re
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
def extract_final_answer(text: str) -> str:
|
| 7 |
+
"""
|
| 8 |
+
Extract the most likely final answer from raw model output.
|
| 9 |
+
|
| 10 |
+
In V1 we keep this conservative:
|
| 11 |
+
- if the model returns a normal short answer, keep it
|
| 12 |
+
- if it adds common prefixes like 'Answer:' or 'Final answer:', remove them
|
| 13 |
+
- if it returns multiple lines, prefer the last non-empty line
|
| 14 |
+
"""
|
| 15 |
if text is None:
|
| 16 |
return ""
|
| 17 |
|
|
|
|
| 19 |
if not text:
|
| 20 |
return ""
|
| 21 |
|
| 22 |
+
# Remove fenced code blocks if the model wraps the answer oddly
|
| 23 |
text = re.sub(r"^```[a-zA-Z0-9_-]*\s*", "", text)
|
| 24 |
text = re.sub(r"\s*```$", "", text)
|
| 25 |
|
| 26 |
+
# Common exact-answer markers
|
| 27 |
+
marker_patterns = [
|
| 28 |
+
r"(?i)\bfinal answer\s*:\s*",
|
| 29 |
+
r"(?i)\banswer\s*:\s*",
|
| 30 |
+
r"(?i)\bthe answer is\s*:\s*",
|
| 31 |
+
r"(?i)\bthe answer is\s+",
|
| 32 |
+
]
|
| 33 |
|
| 34 |
+
cleaned = text
|
| 35 |
+
for pattern in marker_patterns:
|
| 36 |
+
cleaned = re.sub(pattern, "", cleaned).strip()
|
| 37 |
+
|
| 38 |
+
# If multi-line, prefer the last meaningful line
|
| 39 |
+
lines = [line.strip() for line in cleaned.splitlines() if line.strip()]
|
| 40 |
if not lines:
|
| 41 |
return ""
|
| 42 |
|
| 43 |
+
if len(lines) == 1:
|
| 44 |
+
return lines[0]
|
| 45 |
+
|
| 46 |
return lines[-1]
|
| 47 |
|
| 48 |
|
| 49 |
def normalize_final_answer(text: str) -> str:
|
| 50 |
+
"""
|
| 51 |
+
Normalize answer text for safer exact-match submission without being too aggressive.
|
| 52 |
+
|
| 53 |
+
Rules:
|
| 54 |
+
- trim outer whitespace
|
| 55 |
+
- collapse internal repeated whitespace
|
| 56 |
+
- remove wrapping quotes if they wrap the full answer
|
| 57 |
+
- remove a single trailing period only for plain word/phrase answers
|
| 58 |
+
but keep decimal numbers and date punctuation intact
|
| 59 |
+
"""
|
| 60 |
if text is None:
|
| 61 |
return ""
|
| 62 |
|
|
|
|
| 64 |
if not text:
|
| 65 |
return ""
|
| 66 |
|
| 67 |
+
# Collapse repeated whitespace
|
| 68 |
text = re.sub(r"\s+", " ", text).strip()
|
| 69 |
|
| 70 |
+
# Remove matching surrounding quotes
|
| 71 |
+
if len(text) >= 2:
|
| 72 |
+
if (text[0] == text[-1]) and text[0] in {'"', "'"}:
|
| 73 |
+
text = text[1:-1].strip()
|
| 74 |
|
| 75 |
+
# Remove common leading labels again, just in case
|
| 76 |
text = re.sub(r"(?i)^(final answer|answer)\s*:\s*", "", text).strip()
|
| 77 |
|
| 78 |
+
# Remove one trailing period for simple phrase answers only
|
| 79 |
+
# Keep decimals like 3.14 intact
|
| 80 |
+
if text.endswith("."):
|
| 81 |
+
if not re.fullmatch(r"\d+\.\d+", text):
|
| 82 |
+
text = text[:-1].strip()
|
| 83 |
|
| 84 |
return text
|
| 85 |
|
| 86 |
|
| 87 |
+
def is_placeholder_answer(text: str) -> bool:
|
| 88 |
+
"""
|
| 89 |
+
Detect placeholder/fallback outputs so app.py can optionally flag them.
|
| 90 |
+
"""
|
| 91 |
+
if text is None:
|
| 92 |
+
return True
|
| 93 |
+
|
| 94 |
+
normalized = normalize_final_answer(text).lower()
|
| 95 |
+
return normalized in {
|
| 96 |
+
"",
|
| 97 |
+
"placeholder",
|
| 98 |
+
"n/a",
|
| 99 |
+
"unknown",
|
| 100 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|