Raj
Upload 6 files
08fe068 verified
Raw
History Blame Contribute Delete
14.8 kB
"""Custom tools for the GAIA agent.
Each tool is a @tool-decorated function that smolagents can call from a CodeAgent.
Keep tool docstrings precise — the LLM reads them to decide when to call.
"""
from __future__ import annotations
import io
import os
import re
import tempfile
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
import requests
from smolagents import tool
import config
USER_AGENT = (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
"(KHTML, like Gecko) Chrome/124.0 Safari/537.36"
)
# ---------------------------------------------------------------------------
# Web search
# ---------------------------------------------------------------------------
@tool
def web_search(query: str, num_results: int = config.SEARCH_RESULTS) -> str:
"""Search the web with Serper (Google results) and return the top hits.
Args:
query: The search query.
num_results: How many results to return (1-10).
Returns:
A text block of results: title, link, snippet. Use this to find URLs
worth reading with `read_webpage`.
"""
api_key = os.getenv("SERPER_API_KEY")
num_results = max(1, min(int(num_results), 10))
if not api_key:
# Fallback to DuckDuckGo if no Serper key.
try:
from duckduckgo_search import DDGS
with DDGS() as ddgs:
hits = list(ddgs.text(query, max_results=num_results))
if not hits:
return "No results."
return "\n\n".join(
f"[{i + 1}] {h.get('title', '')}\n{h.get('href', '')}\n{h.get('body', '')}"
for i, h in enumerate(hits)
)
except Exception as e: # pragma: no cover
return f"Search failed (no SERPER_API_KEY, DDG fallback errored): {e}"
try:
resp = requests.post(
"https://google.serper.dev/search",
headers={"X-API-KEY": api_key, "Content-Type": "application/json"},
json={"q": query, "num": num_results},
timeout=20,
)
resp.raise_for_status()
data = resp.json()
except Exception as e:
return f"Serper search failed: {e}"
parts: list[str] = []
if "answerBox" in data:
ab = data["answerBox"]
parts.append(
"ANSWER BOX:\n"
+ (ab.get("answer") or ab.get("snippet") or ab.get("title") or "").strip()
)
if "knowledgeGraph" in data:
kg = data["knowledgeGraph"]
parts.append(
f"KNOWLEDGE GRAPH: {kg.get('title', '')}{kg.get('description', '')}"
)
for i, item in enumerate(data.get("organic", [])[:num_results], 1):
parts.append(
f"[{i}] {item.get('title', '')}\n{item.get('link', '')}\n"
f"{item.get('snippet', '')}"
)
return "\n\n".join(parts) if parts else "No results."
# ---------------------------------------------------------------------------
# Web page reader
# ---------------------------------------------------------------------------
@tool
def read_webpage(url: str, max_chars: int = config.PAGE_MAX_CHARS) -> str:
"""Fetch a URL and return its main text content as Markdown.
Args:
url: The full URL to fetch (http or https).
max_chars: Maximum characters to return (truncated tail dropped).
Returns:
Markdown text. Use after `web_search` to actually read a page.
"""
try:
from bs4 import BeautifulSoup
from markdownify import markdownify
except Exception as e: # pragma: no cover
return f"Missing deps: {e}"
if not url.startswith(("http://", "https://")):
return f"Invalid URL: {url}"
try:
resp = requests.get(url, headers={"User-Agent": USER_AGENT}, timeout=25)
resp.raise_for_status()
except Exception as e:
return f"Fetch failed for {url}: {e}"
ctype = resp.headers.get("Content-Type", "").lower()
if "pdf" in ctype or url.lower().endswith(".pdf"):
return _pdf_to_text(resp.content, max_chars)
soup = BeautifulSoup(resp.text, "html.parser")
for tag in soup(["script", "style", "noscript", "header", "footer", "nav"]):
tag.decompose()
md = markdownify(str(soup), heading_style="ATX")
md = re.sub(r"\n{3,}", "\n\n", md).strip()
if len(md) > max_chars:
md = md[:max_chars] + "\n\n[...truncated...]"
return md
def _pdf_to_text(data: bytes, max_chars: int) -> str:
try:
from pypdf import PdfReader
except Exception:
try:
from PyPDF2 import PdfReader # type: ignore
except Exception as e:
return f"PDF read failed (install pypdf): {e}"
try:
reader = PdfReader(io.BytesIO(data))
text = "\n\n".join((p.extract_text() or "") for p in reader.pages)
except Exception as e:
return f"PDF parse failed: {e}"
if len(text) > max_chars:
text = text[:max_chars] + "\n\n[...truncated...]"
return text
# ---------------------------------------------------------------------------
# Wikipedia
# ---------------------------------------------------------------------------
@tool
def wikipedia_search(query: str, sentences: int = config.WIKI_SENTENCES) -> str:
"""Look up a topic on English Wikipedia.
Args:
query: The page title or topic.
sentences: Sentences of summary to return.
Returns:
A summary block with the page URL, or an error message.
"""
try:
import wikipediaapi
except Exception as e: # pragma: no cover
return f"Missing deps: {e}"
wiki = wikipediaapi.Wikipedia(user_agent=USER_AGENT, language="en")
page = wiki.page(query)
if not page.exists():
# Try a search-then-fetch with the search API.
try:
resp = requests.get(
"https://en.wikipedia.org/w/api.php",
params={
"action": "query",
"list": "search",
"srsearch": query,
"format": "json",
"srlimit": 1,
},
headers={"User-Agent": USER_AGENT},
timeout=15,
)
hits = resp.json().get("query", {}).get("search", [])
if not hits:
return f"No Wikipedia page found for: {query}"
page = wiki.page(hits[0]["title"])
except Exception as e:
return f"Wikipedia lookup failed: {e}"
if not page.exists():
return f"No Wikipedia page found for: {query}"
summary = page.summary
parts = re.split(r"(?<=[.!?])\s+", summary)
out = " ".join(parts[: max(1, int(sentences))])
return f"{page.title}\n{page.fullurl}\n\n{out}"
# ---------------------------------------------------------------------------
# YouTube transcript
# ---------------------------------------------------------------------------
@tool
def youtube_transcript(url_or_id: str) -> str:
"""Fetch the transcript of a YouTube video.
Args:
url_or_id: A full YouTube URL or just the 11-char video ID.
Returns:
Plain text transcript, or an error message.
"""
vid = _yt_id(url_or_id)
if not vid:
return f"Could not parse YouTube id from: {url_or_id}"
try:
from youtube_transcript_api import YouTubeTranscriptApi
except Exception as e: # pragma: no cover
return f"Missing deps: {e}"
try:
chunks = YouTubeTranscriptApi.get_transcript(vid)
except Exception as e:
return f"Transcript fetch failed: {e}"
return " ".join(c["text"] for c in chunks)
def _yt_id(s: str) -> Optional[str]:
s = s.strip()
if re.fullmatch(r"[A-Za-z0-9_-]{11}", s):
return s
try:
u = urlparse(s)
except Exception:
return None
if u.hostname in ("youtu.be",):
return u.path.lstrip("/")[:11] or None
if u.hostname and "youtube" in u.hostname:
from urllib.parse import parse_qs
qs = parse_qs(u.query)
v = qs.get("v", [None])[0]
if v:
return v[:11]
m = re.search(r"/(embed|shorts)/([A-Za-z0-9_-]{11})", u.path)
if m:
return m.group(2)
m = re.search(r"([A-Za-z0-9_-]{11})", s)
return m.group(1) if m else None
# ---------------------------------------------------------------------------
# GAIA file attachment
# ---------------------------------------------------------------------------
@tool
def download_task_file(task_id: str) -> str:
"""Download the file attachment for a GAIA task (if one exists).
Args:
task_id: The task id of the current question.
Returns:
Absolute local path of the downloaded file, or a message saying
no file is attached. Read the file with normal Python after.
"""
base = config.GAIA_API_URL.rstrip("/")
url = f"{base}/files/{task_id}"
try:
resp = requests.get(url, timeout=30)
except Exception as e:
return f"Download error: {e}"
if resp.status_code == 404:
return "NO_FILE: this task has no attachment."
if resp.status_code != 200:
return f"Download failed: HTTP {resp.status_code}"
name = _filename_from_response(resp, task_id)
out_dir = Path(tempfile.gettempdir()) / "gaia_files"
out_dir.mkdir(parents=True, exist_ok=True)
path = out_dir / name
path.write_bytes(resp.content)
return str(path.resolve())
def _filename_from_response(resp: requests.Response, task_id: str) -> str:
cd = resp.headers.get("Content-Disposition", "")
m = re.search(r'filename\*?=(?:UTF-\d\'\')?"?([^";]+)"?', cd)
if m:
return m.group(1).strip()
ctype = resp.headers.get("Content-Type", "").split(";")[0].strip()
ext = {
"text/plain": ".txt",
"text/csv": ".csv",
"application/pdf": ".pdf",
"application/json": ".json",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/vnd.ms-excel": ".xls",
"application/x-python": ".py",
"image/png": ".png",
"image/jpeg": ".jpg",
"audio/mpeg": ".mp3",
"audio/wav": ".wav",
"audio/x-wav": ".wav",
"audio/mp4": ".m4a",
"video/mp4": ".mp4",
}.get(ctype, "")
return f"{task_id}{ext}"
# ---------------------------------------------------------------------------
# Excel / CSV reader (deterministic helper so the LLM doesn't have to handcraft)
# ---------------------------------------------------------------------------
@tool
def read_table(file_path: str, sheet: Optional[str] = None, max_rows: int = 200) -> str:
"""Read an Excel/CSV file and return a textual preview.
Args:
file_path: Absolute path to .xlsx / .xls / .csv / .tsv.
sheet: Optional sheet name (Excel only). Default: first sheet.
max_rows: Max rows to include in the preview.
Returns:
Column dtypes + a CSV-style preview. For deeper analysis, load it with
pandas yourself in a code block.
"""
import pandas as pd
p = Path(file_path)
if not p.exists():
return f"File not found: {file_path}"
suffix = p.suffix.lower()
try:
if suffix in (".xlsx", ".xls"):
df = pd.read_excel(p, sheet_name=sheet or 0)
elif suffix == ".tsv":
df = pd.read_csv(p, sep="\t")
else:
df = pd.read_csv(p)
except Exception as e:
return f"Read failed: {e}"
head = df.head(max_rows)
info = [
f"shape: {df.shape}",
"dtypes:",
df.dtypes.astype(str).to_string(),
"",
"preview:",
head.to_csv(index=False),
]
return "\n".join(info)
# ---------------------------------------------------------------------------
# Audio transcription via HF Inference (Whisper)
# ---------------------------------------------------------------------------
@tool
def transcribe_audio(file_path: str) -> str:
"""Transcribe an audio file (mp3/wav/m4a) using Whisper via HF Inference.
Args:
file_path: Absolute path to the audio file.
Returns:
The transcript text, or an error message.
"""
from huggingface_hub import InferenceClient
token = os.getenv("HF_TOKEN")
if not token:
return "Missing HF_TOKEN for HF Inference."
p = Path(file_path)
if not p.exists():
return f"File not found: {file_path}"
model_id = config.ASR_MODEL_ID
try:
client = InferenceClient(token=token)
out = client.automatic_speech_recognition(p.read_bytes(), model=model_id)
except Exception as e:
return f"ASR failed: {e}"
if isinstance(out, dict):
return out.get("text", "")
return getattr(out, "text", str(out))
# ---------------------------------------------------------------------------
# Image VQA via HF Inference
# ---------------------------------------------------------------------------
@tool
def analyze_image(file_path: str, question: str = "Describe this image in detail.") -> str:
"""Ask a vision-language model about an image file.
Args:
file_path: Absolute path to a .png / .jpg / .jpeg / .webp file.
question: The question to ask about the image. Default: detailed description.
Returns:
The model's answer text.
"""
import base64
from huggingface_hub import InferenceClient
token = os.getenv("HF_TOKEN")
if not token:
return "Missing HF_TOKEN for HF Inference."
p = Path(file_path)
if not p.exists():
return f"File not found: {file_path}"
model_id = config.VLM_MODEL_ID
provider = config.VLM_PROVIDER
suffix = p.suffix.lower().lstrip(".")
mime = {"jpg": "jpeg"}.get(suffix, suffix) or "png"
b64 = base64.b64encode(p.read_bytes()).decode("ascii")
data_url = f"data:image/{mime};base64,{b64}"
try:
client = InferenceClient(token=token, provider=provider)
resp = client.chat.completions.create(
model=model_id,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": question},
{"type": "image_url", "image_url": {"url": data_url}},
],
}
],
max_tokens=512,
)
return resp.choices[0].message.content or ""
except Exception as e:
return f"VLM call failed: {e}"
__all__ = [
"web_search",
"read_webpage",
"wikipedia_search",
"youtube_transcript",
"download_task_file",
"read_table",
"transcribe_audio",
"analyze_image",
]