Raj
Upload 2 files
91f0524 verified
Raw
History Blame Contribute Delete
14.9 kB
"""Custom tools for the GAIA agent.
Each tool is a @tool-decorated function that smolagents can call from a CodeAgent.
Keep tool docstrings precise — the LLM reads them to decide when to call.
"""
from __future__ import annotations
import io
import os
import re
import tempfile
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
import requests
from smolagents import tool
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
USER_AGENT = (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
"(KHTML, like Gecko) Chrome/124.0 Safari/537.36"
)
# ---------------------------------------------------------------------------
# Web search
# ---------------------------------------------------------------------------
@tool
def web_search(query: str, num_results: int = 10) -> str:
"""Search the web with Serper (Google results) and return the top hits.
Args:
query: The search query.
num_results: How many results to return (1-10).
Returns:
A text block of results: title, link, snippet. Use this to find URLs
worth reading with `read_webpage`.
"""
api_key = os.getenv("SERPER_API_KEY")
num_results = max(1, min(int(num_results), 10))
if not api_key:
# Fallback to DuckDuckGo if no Serper key.
try:
from duckduckgo_search import DDGS
with DDGS() as ddgs:
hits = list(ddgs.text(query, max_results=num_results))
if not hits:
return "No results."
return "\n\n".join(
f"[{i + 1}] {h.get('title', '')}\n{h.get('href', '')}\n{h.get('body', '')}"
for i, h in enumerate(hits)
)
except Exception as e: # pragma: no cover
return f"Search failed (no SERPER_API_KEY, DDG fallback errored): {e}"
try:
resp = requests.post(
"https://google.serper.dev/search",
headers={"X-API-KEY": api_key, "Content-Type": "application/json"},
json={"q": query, "num": num_results},
timeout=20,
)
resp.raise_for_status()
data = resp.json()
except Exception as e:
return f"Serper search failed: {e}"
parts: list[str] = []
if "answerBox" in data:
ab = data["answerBox"]
parts.append(
"ANSWER BOX:\n"
+ (ab.get("answer") or ab.get("snippet") or ab.get("title") or "").strip()
)
if "knowledgeGraph" in data:
kg = data["knowledgeGraph"]
parts.append(
f"KNOWLEDGE GRAPH: {kg.get('title', '')}{kg.get('description', '')}"
)
for i, item in enumerate(data.get("organic", [])[:num_results], 1):
parts.append(
f"[{i}] {item.get('title', '')}\n{item.get('link', '')}\n"
f"{item.get('snippet', '')}"
)
return "\n\n".join(parts) if parts else "No results."
# ---------------------------------------------------------------------------
# Web page reader
# ---------------------------------------------------------------------------
@tool
def read_webpage(url: str, max_chars: int = 15000) -> str:
"""Fetch a URL and return its main text content as Markdown.
Args:
url: The full URL to fetch (http or https).
max_chars: Maximum characters to return (truncated tail dropped).
Returns:
Markdown text. Use after `web_search` to actually read a page.
"""
try:
from bs4 import BeautifulSoup
from markdownify import markdownify
except Exception as e: # pragma: no cover
return f"Missing deps: {e}"
if not url.startswith(("http://", "https://")):
return f"Invalid URL: {url}"
try:
resp = requests.get(url, headers={"User-Agent": USER_AGENT}, timeout=25)
resp.raise_for_status()
except Exception as e:
return f"Fetch failed for {url}: {e}"
ctype = resp.headers.get("Content-Type", "").lower()
if "pdf" in ctype or url.lower().endswith(".pdf"):
return _pdf_to_text(resp.content, max_chars)
soup = BeautifulSoup(resp.text, "html.parser")
for tag in soup(["script", "style", "noscript", "header", "footer", "nav"]):
tag.decompose()
md = markdownify(str(soup), heading_style="ATX")
md = re.sub(r"\n{3,}", "\n\n", md).strip()
if len(md) > max_chars:
md = md[:max_chars] + "\n\n[...truncated...]"
return md
def _pdf_to_text(data: bytes, max_chars: int) -> str:
try:
from pypdf import PdfReader
except Exception:
try:
from PyPDF2 import PdfReader # type: ignore
except Exception as e:
return f"PDF read failed (install pypdf): {e}"
try:
reader = PdfReader(io.BytesIO(data))
text = "\n\n".join((p.extract_text() or "") for p in reader.pages)
except Exception as e:
return f"PDF parse failed: {e}"
if len(text) > max_chars:
text = text[:max_chars] + "\n\n[...truncated...]"
return text
# ---------------------------------------------------------------------------
# Wikipedia
# ---------------------------------------------------------------------------
@tool
def wikipedia_search(query: str, sentences: int = 8) -> str:
"""Look up a topic on English Wikipedia.
Args:
query: The page title or topic.
sentences: Sentences of summary to return.
Returns:
A summary block with the page URL, or an error message.
"""
try:
import wikipediaapi
except Exception as e: # pragma: no cover
return f"Missing deps: {e}"
wiki = wikipediaapi.Wikipedia(user_agent=USER_AGENT, language="en")
page = wiki.page(query)
if not page.exists():
# Try a search-then-fetch with the search API.
try:
resp = requests.get(
"https://en.wikipedia.org/w/api.php",
params={
"action": "query",
"list": "search",
"srsearch": query,
"format": "json",
"srlimit": 1,
},
headers={"User-Agent": USER_AGENT},
timeout=15,
)
hits = resp.json().get("query", {}).get("search", [])
if not hits:
return f"No Wikipedia page found for: {query}"
page = wiki.page(hits[0]["title"])
except Exception as e:
return f"Wikipedia lookup failed: {e}"
if not page.exists():
return f"No Wikipedia page found for: {query}"
summary = page.summary
parts = re.split(r"(?<=[.!?])\s+", summary)
out = " ".join(parts[: max(1, int(sentences))])
return f"{page.title}\n{page.fullurl}\n\n{out}"
# ---------------------------------------------------------------------------
# YouTube transcript
# ---------------------------------------------------------------------------
@tool
def youtube_transcript(url_or_id: str) -> str:
"""Fetch the transcript of a YouTube video.
Args:
url_or_id: A full YouTube URL or just the 11-char video ID.
Returns:
Plain text transcript, or an error message.
"""
vid = _yt_id(url_or_id)
if not vid:
return f"Could not parse YouTube id from: {url_or_id}"
try:
from youtube_transcript_api import YouTubeTranscriptApi
except Exception as e: # pragma: no cover
return f"Missing deps: {e}"
try:
chunks = YouTubeTranscriptApi.get_transcript(vid)
except Exception as e:
return f"Transcript fetch failed: {e}"
return " ".join(c["text"] for c in chunks)
def _yt_id(s: str) -> Optional[str]:
s = s.strip()
if re.fullmatch(r"[A-Za-z0-9_-]{11}", s):
return s
try:
u = urlparse(s)
except Exception:
return None
if u.hostname in ("youtu.be",):
return u.path.lstrip("/")[:11] or None
if u.hostname and "youtube" in u.hostname:
from urllib.parse import parse_qs
qs = parse_qs(u.query)
v = qs.get("v", [None])[0]
if v:
return v[:11]
m = re.search(r"/(embed|shorts)/([A-Za-z0-9_-]{11})", u.path)
if m:
return m.group(2)
m = re.search(r"([A-Za-z0-9_-]{11})", s)
return m.group(1) if m else None
# ---------------------------------------------------------------------------
# GAIA file attachment
# ---------------------------------------------------------------------------
@tool
def download_task_file(task_id: str) -> str:
"""Download the file attachment for a GAIA task (if one exists).
Args:
task_id: The task id of the current question.
Returns:
Absolute local path of the downloaded file, or a message saying
no file is attached. Read the file with normal Python after.
"""
base = os.getenv("GAIA_API_URL", DEFAULT_API_URL).rstrip("/")
url = f"{base}/files/{task_id}"
try:
resp = requests.get(url, timeout=30)
except Exception as e:
return f"Download error: {e}"
if resp.status_code == 404:
return "NO_FILE: this task has no attachment."
if resp.status_code != 200:
return f"Download failed: HTTP {resp.status_code}"
name = _filename_from_response(resp, task_id)
out_dir = Path(tempfile.gettempdir()) / "gaia_files"
out_dir.mkdir(parents=True, exist_ok=True)
path = out_dir / name
path.write_bytes(resp.content)
return str(path.resolve())
def _filename_from_response(resp: requests.Response, task_id: str) -> str:
cd = resp.headers.get("Content-Disposition", "")
m = re.search(r'filename\*?=(?:UTF-\d\'\')?"?([^";]+)"?', cd)
if m:
return m.group(1).strip()
ctype = resp.headers.get("Content-Type", "").split(";")[0].strip()
ext = {
"text/plain": ".txt",
"text/csv": ".csv",
"application/pdf": ".pdf",
"application/json": ".json",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/vnd.ms-excel": ".xls",
"application/x-python": ".py",
"image/png": ".png",
"image/jpeg": ".jpg",
"audio/mpeg": ".mp3",
"audio/wav": ".wav",
"audio/x-wav": ".wav",
"audio/mp4": ".m4a",
"video/mp4": ".mp4",
}.get(ctype, "")
return f"{task_id}{ext}"
# ---------------------------------------------------------------------------
# Excel / CSV reader (deterministic helper so the LLM doesn't have to handcraft)
# ---------------------------------------------------------------------------
@tool
def read_table(file_path: str, sheet: Optional[str] = None, max_rows: int = 200) -> str:
"""Read an Excel/CSV file and return a textual preview.
Args:
file_path: Absolute path to .xlsx / .xls / .csv / .tsv.
sheet: Optional sheet name (Excel only). Default: first sheet.
max_rows: Max rows to include in the preview.
Returns:
Column dtypes + a CSV-style preview. For deeper analysis, load it with
pandas yourself in a code block.
"""
import pandas as pd
p = Path(file_path)
if not p.exists():
return f"File not found: {file_path}"
suffix = p.suffix.lower()
try:
if suffix in (".xlsx", ".xls"):
df = pd.read_excel(p, sheet_name=sheet or 0)
elif suffix == ".tsv":
df = pd.read_csv(p, sep="\t")
else:
df = pd.read_csv(p)
except Exception as e:
return f"Read failed: {e}"
head = df.head(max_rows)
info = [
f"shape: {df.shape}",
"dtypes:",
df.dtypes.astype(str).to_string(),
"",
"preview:",
head.to_csv(index=False),
]
return "\n".join(info)
# ---------------------------------------------------------------------------
# Audio transcription via HF Inference (Whisper)
# ---------------------------------------------------------------------------
@tool
def transcribe_audio(file_path: str) -> str:
"""Transcribe an audio file (mp3/wav/m4a) using Whisper via HF Inference.
Args:
file_path: Absolute path to the audio file.
Returns:
The transcript text, or an error message.
"""
from huggingface_hub import InferenceClient
token = os.getenv("HF_TOKEN")
if not token:
return "Missing HF_TOKEN for HF Inference."
p = Path(file_path)
if not p.exists():
return f"File not found: {file_path}"
model_id = os.getenv("ASR_MODEL_ID", "openai/whisper-large-v3")
try:
client = InferenceClient(token=token)
out = client.automatic_speech_recognition(p.read_bytes(), model=model_id)
except Exception as e:
return f"ASR failed: {e}"
if isinstance(out, dict):
return out.get("text", "")
return getattr(out, "text", str(out))
# ---------------------------------------------------------------------------
# Image VQA via HF Inference
# ---------------------------------------------------------------------------
@tool
def analyze_image(file_path: str, question: str = "Describe this image in detail.") -> str:
"""Ask a vision-language model about an image file.
Args:
file_path: Absolute path to a .png / .jpg / .jpeg / .webp file.
question: The question to ask about the image. Default: detailed description.
Returns:
The model's answer text.
"""
import base64
from huggingface_hub import InferenceClient
token = os.getenv("HF_TOKEN")
if not token:
return "Missing HF_TOKEN for HF Inference."
p = Path(file_path)
if not p.exists():
return f"File not found: {file_path}"
model_id = os.getenv("VLM_MODEL_ID", "Qwen/Qwen2.5-VL-7B-Instruct")
provider = os.getenv("VLM_PROVIDER", "auto")
suffix = p.suffix.lower().lstrip(".")
mime = {"jpg": "jpeg"}.get(suffix, suffix) or "png"
b64 = base64.b64encode(p.read_bytes()).decode("ascii")
data_url = f"data:image/{mime};base64,{b64}"
try:
client = InferenceClient(token=token, provider=provider)
resp = client.chat.completions.create(
model=model_id,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": question},
{"type": "image_url", "image_url": {"url": data_url}},
],
}
],
max_tokens=512,
)
return resp.choices[0].message.content or ""
except Exception as e:
return f"VLM call failed: {e}"
__all__ = [
"web_search",
"read_webpage",
"wikipedia_search",
"youtube_transcript",
"download_task_file",
"read_table",
"transcribe_audio",
"analyze_image",
]