kokluch's picture
Add VirusTotal tools
a6ba120
import datetime
import os
import pytz
import requests
import yaml
from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, tool
from Gradio_UI import GradioUI
from classes.get_url_report import GetURLReportResponse, Data, Attributes, Stats
from classes.ip_address_report import IPAddressReport, TotalVotes, AnalysisStats
from classes.scan_url import DataAnalysis, Links, ScanResponse
from tools.final_answer import FinalAnswerTool
@tool
def get_current_time_in_timezone(timezone: str) -> str:
"""A tool that fetches the current local time in a specified timezone.
Args:
timezone: A string representing a valid timezone (e.g., 'America/New_York').
"""
try:
# Create timezone object
tz = pytz.timezone(timezone)
# Get current time in that timezone
local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
return f"The current local time in {timezone} is: {local_time}"
except Exception as e:
return f"Error fetching time for timezone '{timezone}': {str(e)}"
@tool
def get_my_ip_address() -> str:
"""
Retrieves the public IP address of the machine running this code.
Returns:
str: The public IP address.
Raises:
Exception: If the request to the external service fails.
"""
url = "https://api.ipify.org?format=json"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
return data["ip"]
else:
raise Exception(f"Failed to retrieve IP address: {response.status_code} - {response.text}")
@tool
def get_ip_address_report(ip_address: str) -> IPAddressReport:
"""
Fetches the IP address report from the VirusTotal API and returns it as an IPAddressReport object.
Args:
ip_address: The IP address to fetch the report for.
Returns:
IPAddressReport: An object containing the IP address report.
Raises:
Exception: If the request to the VirusTotal API fails.
"""
url = f"https://www.virustotal.com/api/v3/ip_addresses/{ip_address}"
headers = {
"accept": "application/json",
"x-apikey": os.getenv('VT_API_KEY')
}
response = requests.get(url, headers=headers)
if response.status_code == 200:
data = response.json()
report = IPAddressReport(
id=data["data"]["id"],
type=data["data"]["type"],
reputation=data["data"]["attributes"]["reputation"],
continent=data["data"]["attributes"]["continent"],
as_owner=data["data"]["attributes"]["as_owner"],
country=data["data"]["attributes"]["country"],
tags=data["data"]["attributes"]["tags"],
total_votes=TotalVotes(
harmless=data["data"]["attributes"]["total_votes"]["harmless"],
malicious=data["data"]["attributes"]["total_votes"]["malicious"]
),
network=data["data"]["attributes"]["network"],
last_analysis_stats=AnalysisStats(
malicious=data["data"]["attributes"]["last_analysis_stats"]["malicious"],
suspicious=data["data"]["attributes"]["last_analysis_stats"]["suspicious"],
undetected=data["data"]["attributes"]["last_analysis_stats"]["undetected"],
harmless=data["data"]["attributes"]["last_analysis_stats"]["harmless"],
timeout=data["data"]["attributes"]["last_analysis_stats"]["timeout"]
)
)
return report
else:
raise Exception(f"Failed to retrieve data: {response.status_code} - {response.text}")
@tool
def scan_url(url: str) -> ScanResponse:
"""
Request a scan of a given URL using the VirusTotal API.
Args:
url: The URL to scan.
Returns:
ScanResponse: The response from the VirusTotal API.
Raises:
Exception: If the request to the external service fails.
"""
endpoint = "https://www.virustotal.com/api/v3/urls"
payload = { "url" : url }
headers = {
"accept": "application/json",
"x-apikey": os.getenv('VT_API_KEY') ,
"content-type": "application/x-www-form-urlencoded"
}
# Send a POST request to the VirusTotal API
response = requests.post(endpoint, headers=headers, data=payload)
print(response.text)
try:
# Raise an exception if the request was unsuccessful
response.raise_for_status()
response_json = response.json()
return ScanResponse(
data=DataAnalysis(
type=response_json["data"]["type"],
id=response_json["data"]["id"],
links=Links(self_url=response_json["data"]["links"]["self"])
)
)
except requests.exceptions.RequestException as e:
# Handle any errors that occur during the request
raise Exception(f"Failed to retrieve data: {response.status_code} - {response.text}")
@tool
def get_scan_report(scan: ScanResponse) -> GetURLReportResponse:
"""
Fetch a report of a scan of a given URL using the VirusTotal API.
Args:
scan: The ScanResponse object returned by calling scan_url tool.
Returns:
GetURLReportResponse: The response from the VirusTotal API.
Raises:
Exception: If the request to the external service fails.
"""
headers = {
"accept": "application/json",
"x-apikey": os.getenv('VT_API_KEY')
}
# Send a GET request to the VirusTotal API
response = requests.get(scan.data.links.self_url, headers=headers)
print(response.text)
try:
# Raise an exception if the request was unsuccessful
response.raise_for_status()
response_json = response.json()
# Creating an instance of the data class from the JSON response
response = GetURLReportResponse(
data=Data(
id=response_json["data"]["id"],
type=response_json["data"]["type"],
attributes=Attributes(
date=response_json["data"]["attributes"]["date"],
status=response_json["data"]["attributes"]["status"],
stats=Stats(
malicious=response_json["data"]["attributes"]["stats"]["malicious"],
suspicious=response_json["data"]["attributes"]["stats"]["suspicious"],
undetected=response_json["data"]["attributes"]["stats"]["undetected"],
harmless=response_json["data"]["attributes"]["stats"]["harmless"],
timeout=response_json["data"]["attributes"]["stats"]["timeout"]
)
)
)
)
return response
except requests.exceptions.RequestException as e:
# Handle any errors that occur during the request
raise Exception(f"Failed to retrieve data: {response.status_code} - {response.text}")
final_answer = FinalAnswerTool()
model = HfApiModel(
token=os.getenv('HF_TOKEN'),
max_tokens=2096,
temperature=0.5,
model_id=os.getenv('MODEL'),
custom_role_conversions=None,
)
# Import tool from Hub
# image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
web_search_tool = DuckDuckGoSearchTool()
with open("prompts.yaml", 'r') as stream:
prompt_templates = yaml.safe_load(stream)
agent = CodeAgent(
model=model,
tools=[final_answer, get_my_ip_address, get_ip_address_report, scan_url, get_scan_report], ## add your tools here (don't remove final answer)
max_steps=6,
verbosity_level=1,
grammar=None,
planning_interval=None,
name=None,
description=None,
prompt_templates=prompt_templates
)
GradioUI(agent).launch()