Unit_3_Agentic_RAG / tools.py
renwei2024's picture
Fix the AssertionError about count
e26ac7d verified
from smolagents import Tool
import random
from huggingface_hub import list_models
# For the soccer news tool:
import datetime
import email.utils
import xml.etree.ElementTree as ET
from typing import List, Tuple
import requests
class WeatherInfoTool(Tool):
name = "weather_info"
description = "Fetches dummy weather information for a given location."
inputs = {
"location": {
"type": "string",
"description": "The location to get weather information for."
}
}
output_type = "string"
def forward(self, location: str):
# Dummy weather data
weather_conditions = [
{"condition": "Rainy", "temp_c": 15},
{"condition": "Clear", "temp_c": 25},
{"condition": "Windy", "temp_c": 20}
]
# Randomly select a weather condition
data = random.choice(weather_conditions)
return f"Weather in {location}: {data['condition']}, {data['temp_c']}°C"
class HubStatsTool(Tool):
name = "hub_stats"
description = "Fetches the most downloaded model from a specific author on the Hugging Face Hub."
inputs = {
"author": {
"type": "string",
"description": "The username of the model author/organization to find models from."
}
}
output_type = "string"
def forward(self, author: str):
try:
# List models from the specified author, sorted by downloads
models = list(list_models(author=author, sort="downloads", direction=-1, limit=1))
if models:
model = models[0]
return f"The most downloaded model by {author} is {model.id} with {model.downloads:,} downloads."
else:
return f"No models found for author {author}."
except Exception as e:
return f"Error fetching models for {author}: {str(e)}"
class SoccerNewsTool(Tool):
# ------------------------------------------------------------------ #
# Metadata consumed by CodeAgent #
# ------------------------------------------------------------------ #
name = "soccer_news"
description = "Fetches the latest worldwide soccer news headlines."
inputs = {
"count": {
"type": "integer",
"description": "How many headlines to return (1-20, default 5).",
"nullable": True,
"default": 5,
}
}
output_type = "string"
# RSS feeds to aggregate (swap in league-specific feeds if desired)
_RSS_FEEDS: List[str] = [
"https://www.espn.com/espn/rss/soccer/news",
"https://feeds.bbci.co.uk/sport/football/rss.xml",
"https://www.goal.com/feeds/en/news",
]
# ------------------------------------------------------------------ #
# Core implementation #
# ------------------------------------------------------------------ #
def forward(self, count: int = 5) -> str:
"""Return the *count* most-recent soccer headlines."""
count = max(1, min(int(count), 20)) # enforce bounds
headlines: List[Tuple[datetime.datetime, str, str]] = []
for feed_url in self._RSS_FEEDS:
try:
resp = requests.get(feed_url, timeout=10)
resp.raise_for_status()
except requests.RequestException:
continue # silently skip unreachable feeds
try:
root = ET.fromstring(resp.content)
for item in root.findall("./channel/item"):
title = (item.findtext("title") or "").strip()
link = (item.findtext("link") or "").strip()
raw_date = (item.findtext("pubDate") or "").strip()
# Parse RFC-822 date → UTC; fall back to “now” on failure
try:
pub_dt = datetime.datetime.fromtimestamp(
email.utils.mktime_tz(email.utils.parsedate_tz(raw_date)),
tz=datetime.timezone.utc,
)
except Exception:
pub_dt = datetime.datetime.now(tz=datetime.timezone.utc)
if title and link:
headlines.append((pub_dt, title, link))
except ET.ParseError:
continue # skip malformed feeds
if not headlines:
return "⚠️ Unable to fetch soccer news right now."
# newest-first, take the requested slice
headlines.sort(key=lambda x: x[0], reverse=True)
selected = headlines[:count]
return "\n".join(
f"{dt.astimezone().strftime('%Y-%m-%d %H:%M')}{title} ({link})"
for dt, title, link in selected
)