Unit_3_Agentic_RAG / tools.py
dnguyen47's picture
Update tools.py
92e2e56 verified
from smolagents import DuckDuckGoSearchTool
from smolagents import Tool
import random
import requests
from huggingface_hub import list_models
# Initialize the DuckDuckGo search tool
#search_tool = DuckDuckGoSearchTool()
class WeatherInfoTool(Tool):
name = "weather_info"
description = "Fetches dummy weather information for a given location."
inputs = {
"location": {
"type": "string",
"description": "The location to get weather information for."
}
}
output_type = "string"
def forward(self, location: str):
# Dummy weather data
weather_conditions = [
{"condition": "Rainy", "temp_c": 15},
{"condition": "Clear", "temp_c": 25},
{"condition": "Windy", "temp_c": 20}
]
# Randomly select a weather condition
data = random.choice(weather_conditions)
return f"Weather in {location}: {data['condition']}, {data['temp_c']}°C"
class HubStatsTool(Tool):
name = "hub_stats"
description = "Fetches the most downloaded model from a specific author on the Hugging Face Hub."
inputs = {
"author": {
"type": "string",
"description": "The username of the model author/organization to find models from."
}
}
output_type = "string"
def forward(self, author: str):
try:
# List models from the specified author, sorted by downloads
models = list(list_models(author=author, sort="downloads", direction=-1, limit=1))
if models:
model = models[0]
return f"The most downloaded model by {author} is {model.id} with {model.downloads:,} downloads."
else:
return f"No models found for author {author}."
except Exception as e:
return f"Error fetching models for {author}: {str(e)}"
class CountryLanguageTool(Tool):
name = "country_language"
description = "Returns the main language(s) spoken in a given country (offline database)."
inputs = {
"country": {
"type": "string",
"description": "The name of the country (e.g., France, Japan, Brazil)."
}
}
output_type = "string"
def forward(self, country: str):
try:
country_languages = {
"united states": "English",
"usa": "English",
"canada": "English, French",
"france": "French",
"germany": "German",
"spain": "Spanish",
"mexico": "Spanish",
"brazil": "Portuguese",
"japan": "Japanese",
"china": "Mandarin Chinese",
"india": "Hindi, English",
"russia": "Russian",
"italy": "Italian",
"south korea": "Korean",
"korea": "Korean",
"united kingdom": "English",
"uk": "English",
"australia": "English",
"nigeria": "English",
"egypt": "Arabic",
"saudi arabia": "Arabic",
"argentina": "Spanish",
"switzerland": "German, French, Italian, Romansh",
"south africa": "Zulu, Xhosa, Afrikaans, English",
}
key = country.lower().strip()
if key in country_languages:
return (
f"The main language(s) spoken in {country} "
f"is: {country_languages[key]}."
)
return f"Sorry, I don't have language data for {country}."
except Exception as e:
return f"Error retrieving language info: {str(e)}"
# class CountryLanguageTool(Tool):
# name = "country_language"
# description = "Gives the language(s) spoken in a requested country."
# inputs = {
# "country":{
# "type": "string",
# "description": "The country to get the language from."
# }
# }
# output_type = "string"
# def forward(self, country: str):
# try:
# url = f"https://restcountries.com/v3.1/name/{country}"
# response = requests.get(url, timeout=10)
# response.raise_for_status()
# data = response.json()
# if not data:
# return f"No data found for {country}."
# country_data = data[0]
# languages = country_data.get("languages", {})
# if not languages:
# return f"No language information available for {country}."
# language_list = ", ".join(languages.values())
# return (
# f"The main language(s) spoken in {country} "
# f"is: {language_list}."
# )
# except requests.exceptions.HTTPError:
# return f"Country '{country}' not found."
# except requests.exceptions.RequestException as e:
# return f"Network error while fetching data: {str(e)}"
# except Exception as e:
# return f"Unexpected error: {str(e)}"